#' Function to calculate N_total by the algorithm proposed by Sevin et al. (2004)
#'
#' @param topolo Crossing scheme described by topology of tree
#' @param gene_df1_sel Parental set of crossing
#' @param gene_df2_sel Parental set of crossing
#' @param recom_mat Matrix of recombination rate among genes.
#' @param prob_total Probability of success.
#' @param last_cross Whether or not to conduct the last cross to a cultivar without target alleles.
#' @param last_selfing Whether or not to conduct the last selfing.
#'
#' @importFrom dplyr pull filter
#' @returns
#' `calCost` function returns the list of  `n_plant_df`, `gene_df1_sel` and `gene_df2_sel`.
#' `n_plant_df` contains the information about the number of progenies should be produced at each crossing (node in the tree).
#'
#' @export


calcCost <- function(topolo, gene_df1_sel, gene_df2_sel, recom_mat, prob_total, last_cross = FALSE, last_selfing = FALSE){

  if(last_cross + last_selfing == 1){
    stop("The current version only supports scenarios where last_cross and last_selfing are both TRUE or both FALSE.")
  }

  if(last_cross){
    n_node <- topolo$Nnode + 1
  }else{
    n_node <- topolo$Nnode
  }

  nodematrix <- matrix(NA, nrow = nrow(gene_df1_sel), ncol = n_node)
  colnames(nodematrix) <- paste(ncol(gene_df1_sel) + 1:n_node)

  gene_df1_sel <- cbind(gene_df1_sel, nodematrix)
  gene_df2_sel <- cbind(gene_df2_sel, nodematrix)

  # edgemat
  edgemat <- topolo$edge
  colnames(edgemat) <- c("V1", "V2")
  edgemat <- as.data.frame(edgemat)
  node_vec <- edgemat$V1 |> unique()

  # calculate n_plant for each internal node
  if(last_cross + last_selfing == 2){

    n_plant_df <- data.frame(nodeid = numeric(length(node_vec) + 2),
                             n_plant = numeric(length(node_vec) + 2),
                             p_1 = numeric(length(node_vec) + 2),
                             p_2 = numeric(length(node_vec) + 2))

  }else if(last_cross + last_selfing == 1){
    n_plant_df <- data.frame(nodeid = numeric(length(node_vec) + 1),
                             n_plant = numeric(length(node_vec) + 1),
                             p_1 = numeric(length(node_vec) + 1),
                             p_2 = numeric(length(node_vec) + 1))

  }else if(last_cross + last_selfing == 0){
    n_plant_df <- data.frame(nodeid = numeric(length(node_vec)),
                             n_plant = numeric(length(node_vec)),
                             p_1 = numeric(length(node_vec)),
                             p_2 = numeric(length(node_vec)))
  }


  i <- 1
  for(node in node_vec){
    p_id <- edgemat |> filter(V1 == node) |> pull("V2")  # parents id
    if(identical(gene_df1_sel[, p_id[1]], gene_df2_sel[, p_id[1]])){  # Homo
      gene_df1_sel[, node] <- gene_df1_sel[, p_id[1]]
      p_1 <- 1
    }else{
      gene_df1_sel[, node] <- 0
      gene_df1_sel[(gene_df1_sel[, p_id[1]] | gene_df2_sel[, p_id[1]]), node] <- 1
      p_1 <- calcProb(gene_df1_sel[, p_id[1]], gene_df2_sel[, p_id[1]], gene_df1_sel[, node], recom_mat)
    }

    if(identical(gene_df1_sel[, p_id[2]], gene_df2_sel[, p_id[2]])){  # Homo
      gene_df2_sel[, node] <- gene_df1_sel[, p_id[2]]
      p_2 <- 1
    }else{
      gene_df2_sel[, node] <- 0
      gene_df2_sel[(gene_df1_sel[, p_id[2]] | gene_df2_sel[, p_id[2]]), node] <- 1
      p_2 <- calcProb(gene_df1_sel[, p_id[2]], gene_df2_sel[, p_id[2]], gene_df2_sel[, node], recom_mat)
    }

    n_plant_df[i, "nodeid"] <- node
    n_plant_df[i, "p_1"] <- p_1
    n_plant_df[i, "p_2"] <- p_2
    i <- i + 1
  }

  if(last_cross){
    p_id <- c(node, ncol(gene_df1_sel))  # parents id
    if(identical(gene_df1_sel[, p_id[1]], gene_df2_sel[, p_id[1]])){  # Homo
      p_1 <- 1
    }else{
      p_1 <- calcProb(gene_df1_sel[, node], gene_df2_sel[, node], rep(1, nrow(gene_df1_sel)), recom_mat)
    }

    p_2 <- 1

    gene_df1_sel[, ncol(gene_df1_sel)] <- 1
    gene_df2_sel[, ncol(gene_df2_sel)] <- 0


    n_plant_df[i, "nodeid"] <- ncol(gene_df1_sel)
    n_plant_df[i, "p_1"] <- p_1
    n_plant_df[i, "p_2"] <- p_2
  }

  if(last_selfing){
    i <- i + 1
    p_last_self <- calcProb(rep(0, nrow(gene_df1_sel)), rep(1, nrow(gene_df1_sel)), rep(1, nrow(gene_df1_sel)), recom_mat)
    p_1 <- p_last_self
    p_2 <- p_last_self

    n_plant_df[i, "nodeid"] <- ncol(gene_df1_sel) + 1
    n_plant_df[i, "p_1"] <- p_1
    n_plant_df[i, "p_2"] <- p_2

  }

  n_prob <- sum(apply(n_plant_df[, c("p_1", "p_2")], 1, FUN = function(x) {x[1] < 0.99999 | x[2] < 0.99999} ))
  prob_suc <- prob_total^(1/n_prob)

  for(i in 1:nrow(n_plant_df)){

    p_1 <- n_plant_df[i, "p_1"]
    p_2 <- n_plant_df[i, "p_2"]

    if(p_1 < 0.99999 | p_2 < 0.99999){
      n_plant <- log(1 - prob_suc)/log(1 - p_1 * p_2)
    }else{
      n_plant <- 1
    }

    n_plant_df[i, "n_plant"] <- n_plant

  }

  output <- list(n_plant_df = n_plant_df,
                 gene_df1_sel = gene_df1_sel,
                 gene_df2_sel = gene_df2_sel)

  class(output) <- "gpyramid_one"
  return(output)
}

calcProb <- function(p1, p2, f1, recom_mat){
  n <- nrow(recom_mat)
  id_f1_1 <- which(p1 != p2)
  prob <- 1/2
  if(length(id_f1_1) > 1){
    for(i in 1:(length(id_f1_1) - 1)){
      id1 <- id_f1_1[i]
      id2 <- id_f1_1[i + 1]
      if(sum(p1[c(id1, id2)]) == 2 & sum(p2[c(id1, id2)]) == 0){
        prob <- prob * (1 - recom_mat[id1, id2])
      }else if(sum(p1[c(id1, id2)]) == 0 & sum(p2[c(id1, id2)]) == 2){
        prob <- prob * (1 - recom_mat[id1, id2])
      }else if(sum(p1[c(id1, id2)]) == 0 & sum(p2[c(id1, id2)]) == 0){
        next
      }else if(sum(p1[c(id1, id2)]) == 2 & sum(p2[c(id1, id2)]) == 2){
        next
      }else{
        prob <- prob * recom_mat[id1, id2]
      }
    }
  }
  return(prob)
}
