#' Argument matching
#'
#' @description
#' This function matches function arguments and is a modified version of
#' \code{\link[base]{match.arg}}.
#'
#' @param arg
#' A \code{character} (vector), the function argument.
#' @param choices
#' A \code{character} (vector) of allowed values for \code{arg}.
#' @param several.ok
#' Either \code{TRUE} if \code{arg} is allowed to have more than one element,
#' or \code{FALSE} else.
#' @param none.ok
#' Either \code{TRUE} if \code{arg} is allowed to have zero elements,
#' or \code{FALSE} else.
#'
#' @return
#' The un-abbreviated version of the exact or unique partial match if there is
#' one. Otherwise, an error is signaled if \code{several.ok} is \code{FALSE}
#' or \code{none.ok} is \code{FALSE}.
#' When \code{several.ok} is \code{TRUE} and (at least) one element of
#' \code{arg} has a match, all un-abbreviated versions of matches are returned.
#' When \code{none.ok} is \code{TRUE} and \code{arg} has zero elements,
#' \code{character(0)} is returned.
#'
#' @export

match_arg <- function(arg, choices, several.ok = FALSE, none.ok = FALSE) {
  checkmate::assert_character(arg)
  checkmate::assert_character(choices)
  checkmate::assert_flag(several.ok)
  checkmate::assert_flag(none.ok)
  arg_name <- deparse(substitute(arg))
  if (!several.ok && length(arg) > 1L) {
    cli::cli_abort(
      "{.var {arg_name}} must be of length 1.",
      call = NULL
    )
  }
  if (length(arg) == 0L) {
    if (none.ok) {
      return(character(0))
    } else {
      cli::cli_abort(
        "{.var {arg_name}} must be of length greater or equal 1.",
        call = NULL
      )
    }
  }
  i <- pmatch(arg, choices, nomatch = 0, duplicates.ok = TRUE)
  if (all(i == 0L)) {
    cli::cli_abort(
      "{.var {arg_name}} {ifelse(none.ok, 'can', 'must')} be one
      {ifelse(several.ok, 'or more', '')} of {.val {choices}}.",
      call = NULL
    )
  }
  i <- i[i > 0L]
  choices[i]
}

#' Check if an argument is a covariance matrix
#'
#' @description
#' This function checks whether the input is a symmetric, real matrix that
#' fulfills the covariance matrix properties.
#'
#' @param x
#' Object to check.
#'
#' @param dim
#' An \code{integer}, the matrix dimension.
#'
#' @param tolerance
#' A non-negative \code{numeric} tolerance value.
#'
#' @return
#' Compare to \code{\link[checkmate]{check_matrix}}.
#'
#' @export

check_covariance_matrix <- function(
    x, dim = NULL, tolerance = sqrt(.Machine$double.eps)) {
  checkmate::assert_number(tolerance, lower = 0)
  res <- checkmate::check_matrix(x, mode = "numeric")
  if (!isTRUE(res)) {
    return(res)
  }
  if (nrow(x) != ncol(x)) {
    return("Must be square")
  }
  if (any(is.na(x))) {
    return("Must not have NA values")
  }
  if (any(!is.finite(x))) {
    return("Must not have infinite values")
  }
  if (any(abs(x - t(x)) > tolerance)) {
    return("Must be symmetric")
  }
  if (any(eigen(x)$value < -tolerance)) {
    return("Must have positive eigenvalues only")
  }
  if (!is.null(dim)) {
    checkmate::assert_count(dim, positive = TRUE)
    if (nrow(x) != dim) {
      return(paste("Must be of dimension", dim))
    }
  }
  return(TRUE)
}

#' @rdname check_covariance_matrix
#' @inheritParams checkmate::assert_matrix
#' @export

assert_covariance_matrix <- checkmate::makeAssertionFunction(
  check_covariance_matrix
)

#' @rdname check_covariance_matrix
#' @inheritParams checkmate::test_matrix
#' @export
test_covariance_matrix <- checkmate::makeTestFunction(
  check_covariance_matrix
)

#' Check if an argument is a correlation matrix
#'
#' @description
#' This function checks whether the input is a symmetric, real matrix that
#' fulfills the correlation matrix properties.
#'
#' @param x
#' Object to check.
#'
#' @param dim
#' An \code{integer}, the matrix dimension.
#'
#' @param tolerance
#' A non-negative \code{numeric} tolerance value.
#'
#' @return
#' Compare to \code{\link[checkmate]{check_matrix}}.
#'
#' @export

check_correlation_matrix <- function(
    x, dim = NULL, tolerance = sqrt(.Machine$double.eps)) {
  checkmate::assert_number(tolerance, lower = 0)
  res <- checkmate::check_matrix(x, mode = "numeric")
  if (!isTRUE(res)) {
    return(res)
  }
  if (nrow(x) != ncol(x)) {
    return("Must be square")
  }
  if (any(is.na(x))) {
    return("Must not have NA values")
  }
  if (any(!is.finite(x))) {
    return("Must not have infinite values")
  }
  if (any(abs(x - t(x)) > tolerance)) {
    return("Must be symmetric")
  }
  if (any(abs(diag(x) - 1) > tolerance)) {
    return("Must have ones on the diagonal")
  }
  if (any(x < -1 | x > 1)) {
    return("Must have values between -1 and 1")
  }
  if (!is.null(dim)) {
    checkmate::assert_count(dim, positive = TRUE)
    if (nrow(x) != dim) {
      return(paste("Must be of dimension", dim))
    }
  }
  return(TRUE)
}

#' @rdname check_correlation_matrix
#' @inheritParams checkmate::assert_matrix
#' @export

assert_correlation_matrix <- checkmate::makeAssertionFunction(
  check_correlation_matrix
)

#' @rdname check_correlation_matrix
#' @inheritParams checkmate::test_matrix
#' @export
test_correlation_matrix <- checkmate::makeTestFunction(
  check_correlation_matrix
)

#' Check if an argument is a transition probability matrix
#'
#' @description
#' This function checks whether the input is a quadratic, real matrix with
#' elements between 0 and 1 and row sums equal to 1.
#'
#' @param x
#' Object to check.
#'
#' @param dim
#' An \code{integer}, the matrix dimension.
#'
#' @param tolerance
#' A non-negative \code{numeric} tolerance value.
#'
#' @return
#' Compare to \code{\link[checkmate]{check_matrix}}.
#'
#' @export

check_transition_probability_matrix <- function(
    x, dim = NULL, tolerance = sqrt(.Machine$double.eps)) {
  checkmate::assert_number(tolerance, lower = 0)
  res <- checkmate::check_matrix(x, mode = "numeric")
  if (!isTRUE(res)) {
    return(res)
  }
  if (nrow(x) != ncol(x)) {
    return("Must be square")
  }
  if (any(is.na(x))) {
    return("Must not have NA values")
  }
  if (any(!is.finite(x))) {
    return("Must not have infinite values")
  }
  if (any(x < 0 | x > 1)) {
    return("Must have values between 0 and 1")
  }
  if (any(abs(rowSums(x) - 1) > tolerance)) {
    return("Must have row sums equal to 1")
  }
  if (!is.null(dim)) {
    checkmate::assert_count(dim, positive = TRUE)
    if (nrow(x) != dim) {
      return(paste("Must be of dimension", dim))
    }
  }
  return(TRUE)
}

#' @rdname check_transition_probability_matrix
#' @inheritParams checkmate::assert_matrix
#' @export

assert_transition_probability_matrix <- checkmate::makeAssertionFunction(
  check_transition_probability_matrix
)

#' @rdname check_transition_probability_matrix
#' @inheritParams checkmate::test_matrix
#' @export
test_transition_probability_matrix <- checkmate::makeTestFunction(
  check_transition_probability_matrix
)

#' Check if an argument is a probability vector
#'
#' @description
#' This function checks whether the input is a real vector with non-negative
#' entries that add up to one.
#'
#' @param x
#' Object to check.
#'
#' @param tolerance
#' A non-negative \code{numeric} tolerance value.
#'
#' @inheritParams checkmate::check_numeric
#'
#' @return
#' Compare to \code{\link[checkmate]{check_numeric}}.
#'
#' @export

check_probability_vector <- function(
    x, len = NULL, tolerance = sqrt(.Machine$double.eps)
) {
  checkmate::assert_number(tolerance, lower = 0)
  res <- check_numeric_vector(
    x, any.missing = FALSE, len = len, lower = 0, upper = 1
  )
  if (!isTRUE(res)) {
    return(res)
  }
  if (abs(sum(x) - 1) > tolerance) {
    return("Must add up to 1")
  }
  return(TRUE)
}

#' @rdname check_probability_vector
#' @inheritParams checkmate::assert_atomic_vector
#' @inheritParams checkmate::assert_numeric
#' @export

assert_probability_vector <- checkmate::makeAssertionFunction(
  check_probability_vector
)

#' @rdname check_probability_vector
#' @inheritParams checkmate::assert_atomic_vector
#' @inheritParams checkmate::assert_numeric
#' @export

test_probability_vector <- checkmate::makeTestFunction(
  check_probability_vector
)

#' Check if an argument is a list of lists
#'
#' @description
#' This function checks whether the input is a list that contains list elements.
#'
#' @param x
#' Object to check.
#'
#' @inheritParams checkmate::check_list
#'
#' @return
#' Compare to \code{\link[checkmate]{check_list}}.
#'
#' @export

check_list_of_lists <- function(
    x, len = NULL
  ) {
  res <- checkmate::check_list(x, len = len)
  if (!isTRUE(res)) {
    return(res)
  }
  for (i in seq_along(x)) {
    res <- checkmate::check_list(x[[i]], len = len)
    if (!isTRUE(res)) {
      return(paste("Check for element", i, "failed:", res))
    }
  }
  return(TRUE)
}

#' @rdname check_list_of_lists
#' @inheritParams checkmate::assert_list
#' @export

assert_list_of_lists <- checkmate::makeAssertionFunction(
  check_list_of_lists
)

#' @rdname check_list_of_lists
#' @inheritParams checkmate::assert_list
#' @export

test_list_of_lists <- checkmate::makeTestFunction(
  check_list_of_lists
)

#' Check if an argument is a numeric vector
#'
#' @description
#' This function checks whether the input is a numeric vector.
#'
#' @param x
#' Object to check.
#'
#' @inheritParams checkmate::check_numeric
#' @inheritParams checkmate::check_atomic_vector
#'
#' @return
#' Compare to \code{\link[checkmate]{check_numeric}}.
#'
#' @export

check_numeric_vector <- function(
    x, lower = -Inf, upper = Inf, finite = FALSE, any.missing = TRUE,
    all.missing = TRUE, len = NULL, min.len = NULL, max.len = NULL,
    unique = FALSE, sorted = FALSE, names = NULL, typed.missing = FALSE,
    null.ok = FALSE
) {
  res1 <- checkmate::check_atomic_vector(
    x, any.missing = any.missing, all.missing = all.missing, len = len,
    min.len = min.len, max.len = max.len, unique = unique, names = names
  )
  if (!isTRUE(res1)) {
    return(res1)
  }
  res2 <- checkmate::check_numeric(
    x, lower = lower, upper = upper, finite = finite, any.missing = any.missing,
    all.missing = all.missing, len = len, min.len = min.len, max.len = max.len,
    unique = unique, sorted = sorted, names = names,
    typed.missing = typed.missing, null.ok = null.ok
  )
  if (!isTRUE(res2)) {
    return(res2)
  }
  return(TRUE)
}

#' @rdname check_numeric_vector
#' @inheritParams checkmate::assert_numeric
#' @inheritParams checkmate::assert_atomic_vector
#' @export

assert_numeric_vector <- checkmate::makeAssertionFunction(
  check_numeric_vector
)

#' @rdname check_numeric_vector
#' @inheritParams checkmate::assert_numeric
#' @inheritParams checkmate::assert_atomic_vector
#' @export

test_numeric_vector <- checkmate::makeTestFunction(
  check_numeric_vector
)
