#' Estimate Multivariate Regression Association Measure
#'
#' @param y_data A \eqn{n \times d} matrix of responses.
#' @param x_data A \eqn{n \times p} Matrix of predictors.
#' @param z_data A \eqn{n \times q} matrix of conditional predictors.
#' @param bootstrap Perform the \eqn{m}-out-of-\eqn{n} bootstrap if \code{TRUE}. The default value is \code{FALSE}.
#' @param B Number of bootstrap replications. The default value is \code{1000}.
#' @param g_vec A vector used to generate a collection of rules for the \eqn{m}-out-of-\eqn{n} bootstrap. The default value is \code{seq(0.4,0.9,by = 0.05)}.
#' @description Estimate the multivariate regression association measure proposed in Shih and Chen (2025). Standard error estimates are obtained by applying the \eqn{m}-out-of-\eqn{n} bootstrap proposed in Dette and Kroll (2024).
#' @details The value \code{T_est} returned by \code{mram} is between \eqn{-1} and \eqn{1}. However, it is between \eqn{0} and \eqn{1} asymptotically. A small value indicates that \code{x_data} has low predictability for \code{y_data} condition on \code{z_data} in the sense of the considered measure. Similarly, a large value indicates that \code{x_data} has high predictability for \code{y_data} condition on \code{z_data}. If \code{z_data = NULL}, the returned value indicates the unconditional predictability.
#'
#' @return \item{T_est}{The estimate of the multivariate regression association measure.}
#' \item{T_se_cluster}{The standard error estimate based on the cluster rule.}
#' \item{m_vec}{The vector of \eqn{m} generated by \code{g_vec}.}
#' \item{T_se_vec}{The vector of standard error estimates obtained from the \eqn{m}-out-of-\eqn{n} bootstrap, where \eqn{m} is equal to \code{m_vec}.}
#' \item{J_cluster}{The index of the best \code{m_vec} chosen by the cluster rule.}
#'
#' @references Dette and Kroll (2024) A Simple Bootstrap for Chatterjee’s Rank Correlation, Biometrika, asae045.
#' @references Shih and Chen (2025) Measuring multivariate regression association via spatial sign (in revision, Computational Statistics & Data Analysis)
#' @seealso \code{\link{vs_mram}}
#'
#' @importFrom stats ks.test sd
#' @importFrom utils combn
#' @importFrom RANN nn2
#' @export
#'
#' @examples
#' n = 100
#' lambda_para = 3
#' sigma_para = 0.4
#'
#' x_data = matrix(rnorm(n*2),n,2)
#' y_data = matrix(0,n,2)
#' y_data[,1] = x_data[,1]+x_data[,2]+lambda_para*sigma_para*rnorm(n)
#' y_data[,2] = x_data[,1]-x_data[,2]+lambda_para*sigma_para*rnorm(n)
#'
#' library(MRAM)
#' res = mram(y_data,x_data,bootstrap = FALSE)

mram = function(y_data,x_data,z_data = NULL,bootstrap = FALSE,B = 1000,g_vec = seq(0.4,0.9,by = 0.05)) {

  get_index = function(x,b) {

    if (is.matrix(x)) {

      return(x[b,,drop = FALSE])

    } else {

      return(x[b])

    }

  }
  row_sums_squared = function(x) {

    if (is.null(dim(x))) {

      return(x^2)

    } else {

      return(rowSums(x^2))

    }

  }

  if (is.matrix(y_data)) {

    n = dim(y_data)[1]

  } else {

    n = length(y_data)

  }

  m_vec = floor(n^g_vec)

  n_combn = combn(n,2)
  n_choose = choose(n,2)

  s1 = n_combn[1,]
  s2 = n_combn[2,]

  m_combn_vec = sapply(m_vec,function(a) combn(a,2))
  m_choose_vec = sapply(m_vec,function(a) choose(a,2))

  xz_data = cbind(x_data,z_data)

  # nearest neighbor xz

  nn = RANN::nn2(xz_data,k = 2)
  index = nn$nn.idx[,2]
  y_prime = get_index(y_data,index)

  k_y_data = get_index(y_data,s1)-get_index(y_data,s2)
  y_spatial_sign = k_y_data/sqrt(row_sums_squared(k_y_data))

  k_y_prime = get_index(y_prime,s1)-get_index(y_prime,s2)
  y_prime_spatial_sign = k_y_prime/sqrt(row_sums_squared(k_y_prime))
  y_prime_spatial_sign[is.na(y_prime_spatial_sign)] = 0

  T_est_xz = sum(y_spatial_sign*y_prime_spatial_sign)/n_choose

  if (is.null(z_data)) {

    T_est = T_est_xz

  } else {

    # nearest neighbor z

    nn = RANN::nn2(z_data,k = 2)
    index = nn$nn.idx[,2]
    y_prime = get_index(y_data,index)

    k_y_data = get_index(y_data,s1)-get_index(y_data,s2)
    y_spatial_sign = k_y_data/sqrt(row_sums_squared(k_y_data))

    k_y_prime = get_index(y_prime,s1)-get_index(y_prime,s2)
    y_prime_spatial_sign = k_y_prime/sqrt(row_sums_squared(k_y_prime))
    y_prime_spatial_sign[is.na(y_prime_spatial_sign)] = 0

    T_est_z = sum(y_spatial_sign*y_prime_spatial_sign)/n_choose

    T_est = (T_est_xz-T_est_z)/(1-T_est_z)

  }

  ### m-out-of-n Bootstrap ###

  if (bootstrap == FALSE) {

    return(list(T_est = T_est))

  } else {

    g_L = length(g_vec)
    i_seq = c(1:g_L)

    T_est_matrix = T_est_xz_matrix = T_est_z_matrix = matrix(0,B,g_L)

    for (i in i_seq) {

      m = m_vec[i]
      m_combn = m_combn_vec[[i]]
      m_choose = m_choose_vec[i]

      s1_m = m_combn[1,]
      s2_m = m_combn[2,]

      for (b in 1:B) {

        boot = sample(1:n,m)

        x_boot = get_index(x_data,boot)
        y_boot = get_index(y_data,boot)
        z_boot = get_index(z_data,boot)

        xz_boot = cbind(x_boot,z_boot)

        # nearest neighbor xz

        nn = RANN::nn2(xz_boot,k = 2)
        index_boot = nn$nn.idx[,2]
        y_pboot = get_index(y_boot,index_boot)

        k_y_boot = get_index(y_boot,s1_m)-get_index(y_boot,s2_m)
        y_boot_spatial_sign = k_y_boot/sqrt(row_sums_squared(k_y_boot))

        k_y_pboot = get_index(y_pboot,s1_m)-get_index(y_pboot,s2_m)
        y_pboot_spatial_sign = k_y_pboot/sqrt(row_sums_squared(k_y_pboot))
        y_pboot_spatial_sign[is.na(y_pboot_spatial_sign)] = 0

        T_est_xz_matrix[b,i] = sum(y_boot_spatial_sign*y_pboot_spatial_sign)/m_choose

        if (is.null(z_data)) {

          T_est_matrix[b,i] = T_est_xz_matrix[b,i]

        } else {

          # nearest neighbor z

          nn = RANN::nn2(z_boot,k = 2)
          index_boot = nn$nn.idx[,2]
          y_pboot = get_index(y_boot,index_boot)

          k_y_boot = get_index(y_boot,s1_m)-get_index(y_boot,s2_m)
          y_boot_spatial_sign = k_y_boot/sqrt(row_sums_squared(k_y_boot))

          k_y_pboot = get_index(y_pboot,s1_m)-get_index(y_pboot,s2_m)
          y_pboot_spatial_sign = k_y_pboot/sqrt(row_sums_squared(k_y_pboot))
          y_pboot_spatial_sign[is.na(y_pboot_spatial_sign)] = 0

          T_est_z_matrix[b,i] = sum(y_boot_spatial_sign*y_pboot_spatial_sign)/m_choose
          T_est_matrix[b,i] = (T_est_xz_matrix[b,i]-T_est_z_matrix[b,i])/(1-T_est_z_matrix[b,i])

        }

      }

    }

    T_temp = numeric(g_L)
    for (j in 1:g_L) {

      T_temp[j] = suppressWarnings(sum(sapply(1:g_L,function(k) ks.test(T_est_matrix[,j],T_est_matrix[,k])$statistic)))

    }

    J_cluster = which.min(T_temp)

    T_se_vec = numeric(g_L)
    for (j in 1:g_L) {

      T_se_vec[j] = sd(sqrt(m_vec[j])*T_est_matrix[,j])/sqrt(n)

    }

    ### result

    return(list(T_est = T_est,
                T_se_cluster = T_se_vec[J_cluster],
                J_cluster = J_cluster,
                m_vec = m_vec,
                T_se_vec = T_se_vec))

  }

}


