# weights.R
#
# Estimand recovery helpers
.dagassist_exposure_kind <- function(Tvar,
                                     max_categorical_unique = 10L,
                                     integer_tol = sqrt(.Machine$double.eps)) {
  ##is treatment binary, categorical, or unsupported?
  #if treatment is null, fail gracefully
  if (is.null(Tvar)) return("unsupported")
  #also, if null after dropping nas
  v <- stats::na.omit(Tvar)
  if (!length(v)) return("unsupported")
  # logical -> binary 
  if (is.logical(v)) return("binary")
  # character -> categorical
  if (is.character(v)) v <- factor(v)
  # factor -> binary vs categorical
  if (is.factor(v)) {
    k <- nlevels(v)
    if (k == 2L) return("binary") #if 2 level factor, treat as binary
    if (k >= 3L) return("categorical") #if more than 2, categorical
    return("unsupported") #else, fail
  }
  ##is treatment continuous or binary?
  # numeric/integer -> binary vs categorical vs continuous
  if (is.numeric(v) || inherits(v, "integer")) {
    u <- sort(unique(v))
    k <- length(u)
    #scan for 1s and 0s for binary classification
    if (k == 2L && all(u %in% c(0, 1))) return("binary")
    # "coded categories" heuristic: small unique & integer-like values
    integer_like <- all(abs(u - round(u)) < integer_tol)
    if (k >= 3L && k <= max_categorical_unique && integer_like) {
      return("categorical")
    }
    #after filtering for other types, code as continuous if more than 3 unique
    if (k >= 3L) return("continuous")
  }
  #if none of the above apply, fail
  "unsupported"
}

#filter out incorrect weights args
.dagassist_normalize_weights_args <- function(args) {
  #return empty list if null
  if (is.null(args)) return(list())
  #if non-empty and not a list, helpful fail
  if (!is.list(args)) stop("`weights_args` must be a list.", call. = FALSE)
  #return val
  args
}

#keep only args that match internal formal arguments; return list(keep=..., drop=...)
.dagassist_filter_args <- function(args, fun) {
  if (length(args) == 0L) return(list(keep = list(), drop = character(0)))
  
  nms <- names(args)
  if (is.null(nms) || any(nms == "")) {
    stop("`weights_args` must be a *named* list.", call. = FALSE)
  }
  
  fmls <- names(formals(fun))
  keep_nms <- intersect(nms, fmls)
  drop_nms <- setdiff(nms, keep_nms)
  
  list(keep = args[keep_nms], drop = drop_nms)
}

.dagassist_formula_for_model_name <- function(x, model_name) {
  #parse model name for estimand type
  is_weighted <- grepl("\\((SATE|SATT)\\)\\s*$", model_name, ignore.case = TRUE)
  is_acde <- grepl("\\((SACDE|SCDE)\\)\\s*$", model_name, ignore.case = TRUE)
  #strip away the estimand notation to get the baseline model name
  base_name <- sub("\\s*\\((SATE|SATT|SACDE|SCDE)\\)\\s*$", "", model_name, ignore.case = TRUE)
  
  # If ACDE model label, build sequential_g formula from the *base* model formula
  if (is_acde) {
    base_fml <- .dagassist_formula_for_model_name(x, base_name)
    if (is.null(base_fml)) return(NULL)
    acde <- .dagassist_normalize_acde_spec(x$settings$acde)
    return(.dagassist_build_acde_formula(base_fml, x, acde))
  }
  
  # weighted models use the same regression formula as their base model
  if (is_weighted) {
    return(.dagassist_formula_for_model_name(x, base_name))
  }
  
  # ---- existing mapping logic (unchanged), but applied to base_name ----
  # NOTE: this is your previous body, adjusted to use `base_name` instead of `model_name`
  
  of <- x$formulas$original
  bf <- x$formulas$bivariate
  mf <- x$formulas$minimal
  cf <- x$formulas$canonical
  cex <- x$formulas$canonical_excl
  
  if (identical(base_name, "Original")) return(of)
  if (identical(base_name, "Bivariate")) return(bf)
  
  # Minimal i variants
  if (grepl("^Minimal\\s+[0-9]+$", base_name)) {
    idx <- as.integer(sub("^Minimal\\s+", "", base_name))
    if (length(x$formulas$minimal_list) && idx <= length(x$formulas$minimal_list)) {
      return(x$formulas$minimal_list[[idx]])
    }
    return(mf)
  }
  
  if (identical(base_name, "Canonical")) return(cf)
  
  # Canonical exclusion variants
  if (!is.null(cex)) {
    # list of excluded variants
    if (is.list(cex)) {
      for (nm in names(cex)) {
        lbl <- switch(
          nm,
          nct = "Canon. (-NCT)",
          nco = "Canon. (-NCO)",
          paste0("Canonical (", nm, ")")
        )
        if (identical(lbl, base_name)) return(cex[[nm]])
      }
    } else {
      # old single-model behavior; map any label to the same formula
      if (base_name %in% c("Canon. (-NCT)", "Canon. (-NCO)") || grepl("^Canonical\\s*\\(", base_name)) {
        return(cex)
      }
    }
  }
  
  NULL
}

# Choose controls for the treatment model
.dagassist_treatment_controls <- function(x, exposure) {
  # find control set for the treatment model:
  #   - Prefer canonical controls
  #   - Otherwise use the first minimal set
  #   - Otherwise fall back to the original controls (excluding exposure/outcome)
  
  controls <- x$controls_canonical
  if (length(controls)) {
    controls <- unname(controls)
  } else {
    if (length(x$controls_minimal_all)) {
      controls <- x$controls_minimal_all[[1L]]
    } else if (length(x$controls_minimal)) {
      controls <- x$controls_minimal
    } else {
      # Last resort: RHS of original formula, minus exposure and outcome
      rhs <- .rhs_terms_safe(x$formulas$original)
      out <- get_by_role(x$roles, "outcome")
      controls <- setdiff(rhs, c(exposure, out))
    }
  }
  # keep only variables that are actually present in the data
  controls <- intersect(controls, names(x$.__data))
  unique(controls)
}

# ---- Estimand normalization (supports vectors) ----
.dagassist_normalize_estimand <- function(estimand) {
  if (is.null(estimand)) return("RAW")
  est <- toupper(as.character(estimand))
  est <- match.arg(est,
                   choices = c("RAW","NONE","SATE","SATT","SACDE","SCDE"),
                   several.ok = TRUE)
  est[est == "NONE"] <- "RAW"
  est[est == "SCDE"]  <- "SACDE"
  unique(est)
}

# ---- Normalize ACDE spec list ----
.dagassist_normalize_acde_spec <- function(acde) {
  if (is.null(acde)) acde <- list()
  if (!is.list(acde)) stop("`sacde` must be a list.", call. = FALSE)
  defaults <- list(
    m = NULL,                 # mediators (character)
    x = NULL,                 # baseline covariates override (character)
    z = NULL,                 # intermediate covariates override (character)
    fe = NULL,                # fixed-effects vars override (character)
    fe_as_factor = TRUE,      # wrap FE vars as factor()
    include_descendants = FALSE  # treat Dmediator as mediators
  )
  # base R merge
  out <- defaults
  for (nm in names(acde)) out[[nm]] <- acde[[nm]]
  out
}

# ---- Detect if author formula conditions on mediator(s) ----
.dagassist_formula_controls_mediator <- function(formula, roles, include_descendants = FALSE) {
  rhs <- .rhs_terms_safe(formula)
  meds <- roles$variable[roles$role == "mediator"]
  if (isTRUE(include_descendants)) {
    meds <- unique(c(meds, roles$variable[roles$role == "Dmediator"]))
  }
  length(intersect(rhs, meds)) > 0L
}

.dagassist_apply_auto_acde <- function(estimand,
                                       formula,
                                       roles,
                                       auto_acde = TRUE,
                                       include_descendants = FALSE) {
  ests <- unique(.dagassist_normalize_estimand(estimand))
  
  # ACDE/CDE requires at least one mediator in the DAG / formula
  wants_acde <- any(ests %in% c("SACDE", "SCDE"))
  if (isTRUE(wants_acde)) {
    has_med <- FALSE
    if (!is.null(roles)) {
      if ("role" %in% names(roles)) {
        has_med <- any(roles$role == "mediator")
      }
      if (!isTRUE(has_med) && "is_mediator" %in% names(roles)) {
        has_med <- any(isTRUE(roles$is_mediator))
      }
    }
    if (!isTRUE(has_med)) {
      stop(
        paste0(
          "You requested estimand = 'SACDE' (alias: 'SCDE'), but no mediator node(s) were detected in your DAG ",
          "for this exposure/outcome pair.\n",
          "SACDE/SCDE is only defined when at least one mediator exists.\n\n",
          "Fix options:\n",
          "  1) Use estimand = 'SATE'/'SATT' for total effects (when no mediators are present), OR\n",
          "  2) Use estimand = 'RAW' to report the naive regression output.\n"
        ),
        call. = FALSE
      )
    }
  }
  # allow ATE/ATT if formula includes mediators; will omit automatically
  if (!isTRUE(auto_acde)) return(estimand)
  
  wants_total <- any(ests %in% c("SATE", "SATT"))
  if (!isTRUE(wants_total)) return(estimand)
  
  controls_mediator <- .dagassist_formula_controls_mediator(
    formula,
    roles,
    include_descendants = include_descendants
  )
  if (!isTRUE(controls_mediator)) return(estimand)
  
  estimand
}

.dagassist_safe_descendants <- function(dag, node) {
  tryCatch(dagitty::descendants(dag, node), error = function(e) character(0))
}

.dagassist_safe_ancestors <- function(dag, node) {
  tryCatch(dagitty::ancestors(dag, node), error = function(e) character(0))
}

.dagassist_infer_acde_mediators <- function(x, acde) {
  # explicit user override wins
  if (!is.null(acde$m) && length(acde$m)) return(unique(as.character(acde$m)))
  
  # infer from roles (and optionally Dmediator)
  meds <- x$roles$variable[x$roles$role == "mediator"]
  if (isTRUE(acde$include_descendants)) {
    meds <- unique(c(meds, x$roles$variable[x$roles$role == "Dmediator"]))
  }
  
  # if author formula includes some mediators, prioritize those
  rhs <- .rhs_terms_safe(x$formulas$original)
  in_fml <- intersect(rhs, meds)
  if (length(in_fml)) return(in_fml)
  
  # fallback: all DAG mediators
  unique(meds)
}

# Wrap factor() ONLY for bare symbols that exist as columns in `data`.
# This prevents factor(TRUE) and other length-1 constants from ever being created.
.dagassist_factorize_plain_terms <- function(terms, data_names = NULL) {
  if (!length(terms)) return(character(0))
  
  is_bare_symbol <- grepl("^[.A-Za-z][.A-Za-z0-9._]*$", terms)
  
  if (!is.null(data_names)) {
    is_in_data <- terms %in% data_names
    to_wrap <- is_bare_symbol & is_in_data
  } else {
    to_wrap <- is_bare_symbol
  }
  
  terms[to_wrap] <- paste0("factor(", terms[to_wrap], ")")
  terms
}

# add whichever estimands are requested
.dagassist_add_estimand_models <- function(x, mods) {
  ests <- .dagassist_normalize_estimand(x$settings$estimand)
  if (!length(ests) || identical(ests, "RAW")) return(mods)
  
  out <- mods
  if ("SATE" %in% ests) out <- .dagassist_add_weighted_models(x, out, estimand = "SATE")
  if ("SATT" %in% ests) out <- .dagassist_add_weighted_models(x, out, estimand = "SATT")
  
  if ("SACDE" %in% ests) out <- .dagassist_add_sacde_models(x, out)
  
  out
}


# Add weighted versions of each model column (ATE/ATT) using WeightIt
# Key fix: compute weights *per model spec* (Minimal k vs Canonical, etc.)
# and do NOT create "Original (ATE)" or "Original (ATT)" columns.
.dagassist_add_weighted_models <- function(x, mods, estimand = NULL) {
  
  ests <- .dagassist_normalize_estimand(
    if (!is.null(estimand)) estimand else x$settings$estimand
  )
  
  # Weighting only applies to total-effect estimands
  ests <- intersect(ests, c("SATE", "SATT"))
  if (!length(ests)) return(mods)
  est <- ests[1L]
  
  data0 <- x$.__data
  if (is.null(data0)) {
    stop(
      "Original data not found on the report object.\n",
      "Estimand recovery requires calling DAGassist() with the `data` argument.",
      call. = FALSE
    )
  }
  
  exp_nm <- get_by_role(x$roles, "exposure")
  out_nm <- get_by_role(x$roles, "outcome")
  
  if (length(exp_nm) != 1L || is.na(exp_nm) || !nzchar(exp_nm)) {
    stop(
      "Estimand recovery currently supports exactly one exposure node.\n",
      "DAGassist found ", length(exp_nm), " exposure nodes in the DAG.\n\n",
      "Please either:\n",
      "  * simplify the DAG to a single exposure for this call, or\n",
      "  * set `estimand = \"none\"`.\n",
      call. = FALSE
    )
  }
  
  if (!exp_nm %in% names(data0)) {
    stop(
      "Exposure variable '", exp_nm, "' was identified in the DAG but not found in `data`.\n",
      "Please check that the DAG node names and data column names match.",
      call. = FALSE
    )
  }
  
  # Guardrail: estimand recovery is not supported when the exposure is specified as an interaction term
  # (e.g., exposure = "A:B" or "A*B"). Precompute a single treatment variable in `data` instead.
  if (.dagassist_is_interaction_exposure(exp_nm)) {
    stop(
      "Estimand recovery is not supported when the exposure is an interaction term (e.g., X1:X2 or X1*X2).\n\n",
      "To proceed, precompute a single treatment variable in your data (e.g., treat_post) and use that as the exposure node, ",
      "or set estimand = 'raw'/'none'.",
      call. = FALSE
    )
  }
  
  # Weight args: user-configurable, but filtered to WeightIt::weightit() formals
  wargs <- .dagassist_normalize_weights_args(x$settings$weights_args)
  trim_at <- wargs[["trim_at"]]          # DAGassist-specific (optional)
  wargs[["trim_at"]] <- NULL
  
  #stabilize by default
  if (is.null(wargs[["stabilize"]])) wargs[["stabilize"]] <- TRUE
  
  ##dependency checks for the ATE dependencies
  if (!requireNamespace("WeightIt", quietly = TRUE)) {
    stop(
      "Estimand recovery requires the 'WeightIt' package.\n",
      "Install it (install.packages('WeightIt')) or set estimand = 'raw'.",
      call. = FALSE
    )
  }
  
  if (!requireNamespace("marginaleffects", quietly = TRUE)) {
    stop(
      "Estimand recovery requires the 'marginaleffects' package.\n",
      "Install it (install.packages('marginaleffects')) or set estimand = 'raw'.",
      call. = FALSE
    )
  }
  
  engine <- x$settings$engine
  engine_args <- x$settings$engine_args
  if (is.null(engine)) {
    stop(
      "Modeling engine not found on the report object.\n",
      "Please ensure DAGassist() stores `engine` in report$settings.",
      call. = FALSE
    )
  }
  if (!is.list(engine_args)) engine_args <- list()
  
  # --- helper: compute weights and fit one weighted spec on complete-case data ---
  .fit_weighted_one <- function(model_name) {
    
    # Do NOT produce "Original (ATE)/(ATT)" columns
    if (identical(model_name, "Original")) return(NULL)
    
    # ---- DAG-based controls for weighting (match manual workflow) ----
    # NOTE: eval_all "extras" (e.g., eu/japan/...) are intentionally NOT used for
    # the treatment model. Those are nuisance regressors, not DAG-approved adjusters.
    
    # Identify the DAG control set for this model column
    controls <- character(0)
    
    if (identical(model_name, "Bivariate")) {
      controls <- character(0)
      
    } else if (identical(model_name, "Canonical")) {
      controls <- x$controls_canonical
      
    } else if (grepl("^Minimal\\s+[0-9]+$", model_name)) {
      idx <- as.integer(sub("^Minimal\\s+", "", model_name))
      if (length(x$controls_minimal_all) && idx <= length(x$controls_minimal_all)) {
        controls <- x$controls_minimal_all[[idx]]
      } else {
        controls <- x$controls_minimal
      }
      
    } else if (identical(model_name, "Canon. (-NCT)")) {
      controls <- x$controls_canonical_excl[["nct"]]
      
    } else if (identical(model_name, "Canon. (-NCO)")) {
      controls <- x$controls_canonical_excl[["nco"]]
      
    } else if (grepl("^Canonical\\s*\\((.+)\\)$", model_name)) {
      # fallback for any other canonical-excl labels
      nm <- sub("^Canonical\\s*\\((.+)\\)$", "\\1", model_name)
      controls <- x$controls_canonical_excl[[nm]]
      
    } else {
      # default: try to use the model formula mapping as last resort
      # but still avoid pulling eval_all RHS into weighting
      controls <- character(0)
    }
    
    controls <- unique(intersect(controls, names(data0)))
  
    # Build the regression formula for the weighted refit.
    # Use DAG controls for the treatment model, but allow eval_all extras in the outcome model.
    rhs_extras <- character(0)
    if (isTRUE(x$settings$eval_all)) {
      rhs_extras <- setdiff(.rhs_terms_safe(x$formulas$original), x$roles$variable)
      
      # keep transformed terms (e.g., factor(x), poly(x,2)) as long as their symbols exist in data
      rhs_extras <- rhs_extras[vapply(
        rhs_extras,
        function(tt) {
          vv <- tryCatch(all.vars(stats::as.formula(paste0("~", tt))), error = function(e) character(0))
          length(vv) > 0 && all(vv %in% names(data0))
        },
        logical(1)
      )]
    }
    
    controls_out <- unique(c(controls, rhs_extras))
    
    fml <- .build_formula_with_controls(x$formulas$original, exp_nm, out_nm, controls_out)
    
    #treatment model formula: X ~ controls (or ~1 if none)
    #if eval_all=TRUE, include non-DAG terms in the weights model
    controls_treat <- controls
    if (isTRUE(x$settings$eval_all)) {
      controls_treat <- controls_out
    }
    
    #wts_omit = keep terms in outcome model but omit them from weighting formula
    wts_omit <- x$settings$wts_omit
    if (is.character(wts_omit) && length(wts_omit)) {
      # drop any control term that contains an omitted token as a word component
      # (e.g., drops "year", "factor(year)", "i(year)", "year:treated", etc.)
      pats <- paste0(
        "\\b(",
        paste(vapply(wts_omit, .escape_regex, character(1)), collapse = "|"),
        ")\\b"
      )
      controls_treat <- controls_treat[!grepl(pats, controls_treat)]
    }
    
    if (length(controls_treat)) {
      f_treat <- stats::as.formula(paste(exp_nm, "~", paste(controls_treat, collapse = " + ")))
    } else {
      f_treat <- stats::as.formula(paste(exp_nm, "~ 1"))
    }
    
    # Build complete-case analytic data for THIS spec.
    # Use variables needed for treatment + outcome model evaluation.
    # Build complete-case analytic data for THIS spec.
    # Keep vars needed for treatment + outcome + clustering.
    # fixed to include fixest tail vars so refits with FE don't fail.
    sp <- .strip_fixest_parts(fml)
    base_fml <- sp$base
    
    # vars from base part (y, x's, etc.)
    vars_need <- unique(c(all.vars(base_fml), exp_nm, out_nm, controls_treat))
    
    # vars from fixest tail (e.g., FE like `year + province`, IV parts, etc.)
    # We parse each top-level '|' segment as "~ <segment>" and take all.vars.
    s_full <- paste(deparse(fml, width.cutoff = 500L), collapse = " ")
    parts  <- .split_top_level(s_full, sep = "|")
    tail_vars <- character(0)
    if (length(parts) >= 2L) {
      tails <- trimws(parts[-1])
      for (tt in tails) {
        # Protect against empty pieces
        if (!nzchar(tt)) next
        f_tt <- stats::as.formula(paste("~", tt), env = environment(fml))
        tail_vars <- c(tail_vars, all.vars(f_tt))
      }
    }
    vars_need <- unique(c(vars_need, tail_vars))
    
    # keep cluster vars too
    cluster_vars <- character(0)
    if ("cluster" %in% names(engine_args)) {
      cl <- engine_args$cluster
      if (inherits(cl, "formula")) cluster_vars <- all.vars(cl)
      if (is.character(cl) && length(cl) == 1L) cluster_vars <- cl
    }
    if ("clusters" %in% names(engine_args)) {
      cl <- engine_args$clusters
      if (inherits(cl, "formula")) cluster_vars <- c(cluster_vars, all.vars(cl))
      if (is.character(cl) && length(cl) == 1L) cluster_vars <- c(cluster_vars, cl)
    }
    vars_need <- unique(c(vars_need, cluster_vars))
    
    vars_need <- intersect(vars_need, names(data0))
    data_cc <- stats::na.omit(data0[, vars_need, drop = FALSE])
    
    if (!nrow(data_cc)) return(NULL)
    
    # Determine exposure kind on this model's analytic sample
    kind <- .dagassist_exposure_kind(data_cc[[exp_nm]])
    if (!kind %in% c("binary", "categorical", "continuous")) {
      u <- try(sort(unique(stats::na.omit(data_cc[[exp_nm]]))), silent = TRUE)
      stop(
        "Estimand recovery supports:\n",
        "  * Binary exposures (0/1 numeric, logical, or 2-level factor)\n",
        "  * Categorical exposures (factor/character or small-unique integer-like codes)\n",
        "  * Continuous numeric exposures\n\n",
        "Exposure '", exp_nm, "' is class: ", paste(class(data_cc[[exp_nm]]), collapse = "/"),
        if (!inherits(u, "try-error")) paste0("\nObserved values (unique): ", paste(u, collapse = ", ")) else "",
        "\n\nPlease recode it or set estimand = 'raw'.",
        call. = FALSE
      )
    }
    
    # WeightIt sometimes behaves better with ordered factors for categorical-coded numerics
    data_wt <- data_cc
    if (identical(kind, "categorical") && !is.factor(data_wt[[exp_nm]])) {
      data_wt[[exp_nm]] <- factor(data_wt[[exp_nm]], ordered = TRUE)
    }
    
    # Filter args -> only those accepted by WeightIt::weightit()
    fa <- .dagassist_filter_args(wargs, WeightIt::weightit)
    if (length(fa$drop)) {
      warning(
        "Ignoring these weights_args for WeightIt::weightit(): ",
        paste(fa$drop, collapse = ", "),
        call. = FALSE
      )
    }
    
    #changed internal and display terminology from ATE->SATE; cannot pass
    #directly to weigtit, which does not recognize an SATE estimand parameter
    #will probably need to change this bandaid later when I add EV and ATE is a valid 
    #parameter. 
    est_wt <- switch(
      toupper(est),
      SATE = "ATE",
      SATT = "ATT",
      toupper(est)
    )
    
    # run weightit, but keep its warnings out of the dagassist console to 
    # reduce clutter and confusion
    wtobj <- suppressWarnings(
      do.call(
        WeightIt::weightit,
        c(
          list(
            formula = f_treat,
            data = data_wt,
            method = "glm",
            estimand = est_wt
          ),
          fa$keep
        )
      )
    )
    
    w <- wtobj$weights
    
    # Optional trimming/capping
    if (!is.null(trim_at)) {
      if (!is.numeric(trim_at) || length(trim_at) != 1L || trim_at <= 0 || trim_at >= 1) {
        stop("weights_args$trim_at must be a single number in (0, 1).", call. = FALSE)
      }
      if ("trim" %in% getNamespaceExports("WeightIt")) {
        w <- WeightIt::trim(w, at = trim_at)
      } else {
        cap <- as.numeric(stats::quantile(w, probs = trim_at, na.rm = TRUE, names = FALSE))
        w <- pmin(w, cap)
      }
    }
    
    if (length(w) != nrow(data_cc)) {
      stop(
        "WeightIt returned ", length(w), " weights for data with ",
        nrow(data_cc), " rows.\n",
        "This indicates a mismatch between the analytic sample and WeightIt's internal sample.\n",
        "Inspect the treatment model and missingness in controls.",
        call. = FALSE
      )
    }
    
    #if clusters were captured as a full-length vector, subset to complete case rows
    #this is specifically for lmrobust which throws errors if there is any
    #missingness because it stores cluster as its own full length vector
    .subset_cluster_vec <- function(cl, data_cc) {
      if (is.null(cl)) return(cl)
      if (!is.atomic(cl) || length(cl) <= 1L) return(cl)
      
      rn <- rownames(data_cc)
      idx <- suppressWarnings(as.integer(rn))
      
      # only subset when it looks like the vector corresponds to the original data0 rows
      if (!anyNA(idx) && length(cl) >= max(idx)) {
        return(cl[idx])
      }
      cl
    }
    
    if ("clusters" %in% names(engine_args)) {
      engine_args$clusters <- .subset_cluster_vec(engine_args$clusters, data_cc)
    }
    if ("cluster" %in% names(engine_args)) {
      engine_args$cluster <- .subset_cluster_vec(engine_args$cluster, data_cc)
    }
    
    # Fit weighted version of THIS model on THIS model’s CC data
    engine_args_w <- utils::modifyList(engine_args, list(weights = w))
    
    # Fit weighted version of THIS model on THIS model’s CC data
    engine_args_w <- utils::modifyList(engine_args, list(weights = w))
    #specifically suppress binomial warning, which won't be caught in the prior
    #warning suppression. it may be classed as a message or something.
    fit_w <- withCallingHandlers(
      .safe_fit(engine, fml, data_cc, engine_args_w),
      warning = function(wrn) {
        msg <- conditionMessage(wrn)
        if (grepl("non-integer #successes in a binomial glm", msg, fixed = TRUE)) {
          invokeRestart("muffleWarning")
        }
      }
    )
    #restrict output to the exposure only because high-dimensional FE terms 
    #make marginaleffects hang
    me <- tryCatch(
      {
        if (identical(kind, "continuous")) {
          marginaleffects::avg_slopes(
            fit_w,
            variables = exp_nm,
            type = "response"
          )
        } else {
          marginaleffects::avg_comparisons(
            fit_w,
            variables = exp_nm,
            type = "response"
          )
        }
      },
      error = function(e) fit_w
    )
    
    if (!inherits(me, "DAGassist_fit_error")) {
      #when exposure is multi-level avg_comparisons returns multiple contrasts
      #modified to keep all exposure contrasts and rename them to the model's 
      #coefficient names so they align with the table modelsummary rows 
      #that will be automatically set up via the raw models (e.g., X1, X2 for a 3-level factor).
      if ("term" %in% names(me) && any(me$term == exp_nm)) {
        rows_exp <- which(me$term == exp_nm)
        
        if (length(rows_exp) > 1L) {
          exp_coef_names <- character(0)
          
          # lmer/glmer: fixed effects live in fixef()
          if (inherits(fit_w, "merMod")) {
            exp_coef_names <- names(lme4::fixef(fit_w))
          } else {
            # fallback: try coef() names for other engines
            exp_coef_names <- tryCatch(names(stats::coef(fit_w)), error = function(e) character(0))
          }
          
          exp_coef_names <- exp_coef_names[grepl(paste0("^", exp_nm), exp_coef_names)]
          
          if (length(exp_coef_names) == length(rows_exp)) {
            me$term[rows_exp] <- exp_coef_names
          } else {
            # last-resort: make terms unique so nothing is silently dropped
            me$term[rows_exp] <- paste0(exp_nm, "__", seq_along(rows_exp))
            warning(
              sprintf(
                "Exposure '%s' has >2 levels and DAGassist could not map all contrasts to coefficient names; using generic labels in the (%s) column.",
                exp_nm, est
              ),
              call. = FALSE
            )
          }
        }
      }
      
      # Preserve metadata for diagnostics/debugging
      attr(me, "dagassist_estimand") <- est
      attr(me, "dagassist_weightit") <- wtobj
      attr(me, "dagassist_treat_formula") <- f_treat
      attr(me, "dagassist_trim_at") <- trim_at
      attr(me, "dagassist_weights") <- w
      attr(me, "dagassist_weighted_fit") <- fit_w
    }
    
    me
  }
  
  # Compute weighted fit per spec
  weighted_mods <- list()
  for (nm in names(mods)) {
    fit_w <- .fit_weighted_one(nm)
    if (!is.null(fit_w)) weighted_mods[[nm]] <- fit_w
  }
  
  # Splice weighted columns in directly after their base column
  est_label <- paste0(" (", est, ")")
  mods_out <- list()
  for (nm in names(mods)) {
    mods_out[[nm]] <- mods[[nm]]
    if (!is.null(weighted_mods[[nm]])) {
      mods_out[[paste0(nm, est_label)]] <- weighted_mods[[nm]]
    }
  }
  
  mods_out
}

#labels for weight columns
.dagassist_model_name_labels <- function(estimand) {
  est <- toupper(as.character(estimand))
  switch(
    est,
    SATE      = "(SATE)",
    SATT      = "(SATT)",
    SACDE     = "(SACDE)",
    SEQG_RAW  = "(seqg raw)",
    RAW       = "",
    NONE      = "",
    ""
  )
}

# ---- Guardrail: ACDE mediator types ----
# DirectEffects::sequential_g() is fragile when mediators are not numeric columns.
# This guard identifies non-numeric mediators and returns an actionable message.
.dagassist_acde_guard_mediators <- function(data, m_terms) {
  if (is.null(m_terms) || !length(m_terms)) return(NULL)
  
  m_terms <- unique(as.character(m_terms))
  m_terms <- intersect(m_terms, names(data))
  if (!length(m_terms)) return(NULL)
  
  classes <- vapply(m_terms, function(nm) paste(class(data[[nm]]), collapse = "/"), character(1))
  is_bad  <- vapply(m_terms, function(nm) {
    v <- data[[nm]]
    is.factor(v) || is.character(v) || is.logical(v)
  }, logical(1))
  
  bad <- m_terms[is_bad]
  if (!length(bad)) return(NULL)
  
  # levels (only meaningful for factor/character)
  lvl_txt <- vapply(bad, function(nm) {
    v <- data[[nm]]
    if (is.character(v)) v <- factor(v)
    if (is.factor(v)) {
      lv <- levels(v)
      paste0("levels=", length(lv), if (length(lv) && length(lv) <= 8) paste0(" (", paste(lv, collapse = ", "), ")") else "")
    } else if (is.logical(v)) {
      "logical"
    } else {
      ""
    }
  }, character(1))
  
  bullets <- paste0(
    "  - ", bad, "  [class: ", classes[match(bad, m_terms)], 
    ifelse(nzchar(lvl_txt), paste0("; ", lvl_txt), ""),
    "]"
  )
  
  paste0(
    "SACDE fit aborted before calling DirectEffects::sequential_g().\n\n",
    "Reason:\n",
    "  At least one mediator is non-numeric (factor/character/logical).\n",
    "  DirectEffects::sequential_g() can throw `subscript out of bounds` in this case, ",
    "  because categorical mediators expand to multiple model-matrix columns which do not ",
    "  match the mediator term labels.\n\n",
    "Problematic mediator(s):\n",
    paste(bullets, collapse = "\n"), "\n\n",
    "How to fix:\n",
    "  1) Recode mediator(s) to numeric before calling DAGassist (e.g., binary 0/1).\n",
    "  2) One-hot encode multi-category mediators into numeric dummy columns, then pass\n",
    "     those dummy names explicitly via `acde = list(m = c(\"M1\",\"M2\", ...))`. \n",
    "     Either ensure your DAG nodes match those column names, or use imply = FALSE to prevent mismatch issues. \n",
    "  3) Exclude the categorical mediator(s) from SACDE by explicitly setting `sacde$m`.\n"
  )
}

.dagassist_parse_seqg_formula <- function(f_seqg) {
  # f_seqg is like: Y ~ (block1) | (blockZ) | (blockM)
  txt <- paste(deparse(f_seqg), collapse = "")
  txt <- gsub("\\s+", " ", txt)
  parts <- strsplit(txt, "\\|", fixed = FALSE)[[1]]
  parts <- trimws(parts)
  
  # parts[1] is "Y ~ block1"
  y_and_x <- strsplit(parts[1], "~", fixed = TRUE)[[1]]
  y <- trimws(y_and_x[1])
  block1 <- trimws(y_and_x[2])
  
  blockZ <- if (length(parts) >= 2) trimws(parts[2]) else "0"
  blockM <- if (length(parts) >= 3) trimws(parts[3]) else stop("seqg formula missing mediator block")
  
  list(y = y, block1 = block1, blockZ = blockZ, blockM = blockM)
}