###############################################################################
##  jmSurface — Complete Worked Example
##  Using the Bundled CKD/CVD/Diabetes Multi-State Cohort (N=2,000)
##
##  This script demonstrates the full analysis pipeline:
##    1. Load bundled data
##    2. Explore the cohort
##    3. Fit the joint model (Stage 1 + Stage 2)
##    4. EDF diagnostics and interpretation
##    5. Dynamic prediction for individual patients
##    6. Visualization: surfaces, heatmaps, marginal slices
##    7. Compare multiple patients
##    8. Population-level risk stratification
##
##  Prerequisites:
##    install.packages("jmSurface_0.1.0.tar.gz", repos = NULL, type = "source")
###############################################################################

library(jmSurface)

cat("\n")
cat("================================================================\n")
cat("  jmSurface — Worked Example with Bundled CKD/CVD/Diabetes Data\n")
cat("================================================================\n\n")


## ═══════════════════════════════════════════════════════════════════
##  1. LOAD THE BUNDLED DATA
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 1: Loading bundled data...\n")
dat <- load_example_data()

long_data <- dat$long_data
surv_data <- dat$surv_data

cat("  Longitudinal data:", nrow(long_data), "rows\n")
cat("  Survival data:    ", nrow(surv_data), "rows\n")
cat("  Patients:         ", length(unique(long_data$patient_id)), "\n")
cat("  Biomarkers:       ", paste(unique(long_data$biomarker), collapse = ", "), "\n\n")


## ═══════════════════════════════════════════════════════════════════
##  2. EXPLORE THE COHORT
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 2: Exploring the cohort...\n\n")

## --- 2a. Biomarker summary statistics ---
cat("--- Biomarker Summary ---\n")
for (bm in c("eGFR", "BNP", "HbA1c")) {
  vals <- long_data$value[long_data$biomarker == bm]
  cat(sprintf("  %-6s  N=%6d  Mean=%.1f  SD=%.1f  Range=[%.1f, %.1f]\n",
              bm, length(vals), mean(vals), sd(vals), min(vals), max(vals)))
}
cat("\n")

## --- 2b. Transition event counts ---
cat("--- Transition Event Counts ---\n")
events <- surv_data[surv_data$status == 1, ]
trans_tab <- sort(table(events$transition), decreasing = TRUE)
for (i in seq_along(trans_tab)) {
  cat(sprintf("  %-25s  %d events\n", names(trans_tab)[i], trans_tab[i]))
}
cat("\n")

## --- 2c. Demographics ---
cat("--- Demographic Summary ---\n")
# One row per patient
pat_info <- surv_data[!duplicated(surv_data$patient_id),
                      c("patient_id", "age_baseline", "sex", "smoking",
                        "bmi", "entry_disease")]
cat(sprintf("  Age:    Mean=%.1f, SD=%.1f, Range=[%.0f, %.0f]\n",
            mean(pat_info$age_baseline), sd(pat_info$age_baseline),
            min(pat_info$age_baseline), max(pat_info$age_baseline)))
cat(sprintf("  Sex:    Female=%d (%.0f%%), Male=%d (%.0f%%)\n",
            sum(pat_info$sex == 0), 100 * mean(pat_info$sex == 0),
            sum(pat_info$sex == 1), 100 * mean(pat_info$sex == 1)))
cat(sprintf("  BMI:    Mean=%.1f, SD=%.1f\n",
            mean(pat_info$bmi, na.rm = TRUE), sd(pat_info$bmi, na.rm = TRUE)))
cat("  Entry:  ", paste(names(table(pat_info$entry_disease)),
                        table(pat_info$entry_disease), sep = "=", collapse = ", "), "\n\n")


## ═══════════════════════════════════════════════════════════════════
##  3. FIT THE JOINT MODEL
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 3: Fitting joint model (this takes ~30-60 seconds)...\n\n")

t0 <- proc.time()["elapsed"]

fit <- jmSurf(
  long_data  = long_data,
  surv_data  = surv_data,
  covariates = c("age_baseline", "sex"),
  k_marginal = c(5, 5),
  k_additive = 6,
  method     = "REML",
  verbose    = TRUE
)

elapsed <- round(proc.time()["elapsed"] - t0, 1)
cat(sprintf("\n  Total fitting time: %.1f seconds\n\n", elapsed))


## ═══════════════════════════════════════════════════════════════════
##  4. EDF DIAGNOSTICS AND INTERPRETATION
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 4: EDF Diagnostics\n\n")

## --- 4a. Print summary ---
summary(fit)

## --- 4b. Detailed EDF table ---
edf_df <- edf_diagnostics(fit)
cat("--- EDF Diagnostics Table ---\n")
print(edf_df, row.names = FALSE)
cat("\n")

## --- 4c. Interpretation ---
cat("--- Clinical Interpretation of EDF ---\n")
for (i in seq_len(nrow(edf_df))) {
  tr <- edf_df$transition[i]
  edf <- edf_df$edf[i]
  dev <- edf_df$deviance_explained[i]
  cmx <- edf_df$complexity[i]

  cat(sprintf("  %s:\n", tr))
  if (cmx == "Linear") {
    cat(sprintf("    EDF=%.1f -> Nearly linear. Standard parametric JM would suffice.\n", edf))
    cat("    Simple dose-response relationship between biomarkers and transition risk.\n")
  } else if (cmx == "Moderate") {
    cat(sprintf("    EDF=%.1f -> Moderate nonlinearity. Threshold or saturation effects present.\n", edf))
    cat("    Some nonlinear structure; flexible model adds value over linear JM.\n")
  } else {
    cat(sprintf("    EDF=%.1f -> Substantial nonlinearity/interaction.\n", edf))
    cat("    Biomarker interactions drive risk; linear model would miss these patterns.\n")
  }
  cat(sprintf("    Deviance explained: %.1f%%\n\n", dev * 100))
}


## ═══════════════════════════════════════════════════════════════════
##  5. DYNAMIC PREDICTION FOR INDIVIDUAL PATIENTS
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 5: Dynamic Prediction — Individual Patients\n\n")

## --- 5a. Find an interesting patient (CKD entry with events) ---
ckd_patients <- surv_data$patient_id[surv_data$entry_disease == "CKD" &
                                      surv_data$status == 1]
pid1 <- unique(ckd_patients)[1]

cat(sprintf("--- Patient %d (CKD entry) ---\n", pid1))

# Show patient history
p_surv <- surv_data[surv_data$patient_id == pid1, ]
cat("  Disease trajectory:\n")
for (j in seq_len(nrow(p_surv))) {
  cat(sprintf("    t=[%.1f, %.1f]  %s -> %s  (status=%d)\n",
              p_surv$start_time[j], p_surv$stop_time[j],
              p_surv$state_from[j],
              ifelse(is.na(p_surv$state_to[j]), "censored", p_surv$state_to[j]),
              p_surv$status[j]))
}

# Show last biomarker values
p_long <- long_data[long_data$patient_id == pid1, ]
cat("  Last biomarker values:\n")
for (bm in c("eGFR", "BNP", "HbA1c")) {
  bm_vals <- p_long[p_long$biomarker == bm, ]
  last_row <- bm_vals[which.max(bm_vals$visit_time_years), ]
  cat(sprintf("    %-6s = %.1f at t=%.1f years\n",
              bm, last_row$value, last_row$visit_time_years))
}

# Dynamic prediction
pred1 <- dynPred(fit, patient_id = pid1, landmark = 0, horizon = 5)
cat("\n  Predicted transition risks (5-year horizon from t=0):\n")
for (tr in unique(pred1$transition)) {
  final_risk <- tail(pred1$risk[pred1$transition == tr], 1)
  risk_2y <- pred1$risk[pred1$transition == tr &
                         pred1$time <= 2.1 & pred1$time >= 1.9]
  risk_2y <- if (length(risk_2y) > 0) risk_2y[1] else NA
  cat(sprintf("    %-25s  2-year: %.1f%%   5-year: %.1f%%\n",
              tr, risk_2y * 100, final_risk * 100))
}
cat("\n")

## --- 5b. CVD patient ---
cvd_patients <- surv_data$patient_id[surv_data$entry_disease == "CVD" &
                                      surv_data$status == 1]
pid2 <- unique(cvd_patients)[1]

cat(sprintf("--- Patient %d (CVD entry) ---\n", pid2))
pred2 <- dynPred(fit, patient_id = pid2, landmark = 0, horizon = 5)
cat("  Predicted transition risks (5-year horizon from t=0):\n")
for (tr in unique(pred2$transition)) {
  final_risk <- tail(pred2$risk[pred2$transition == tr], 1)
  cat(sprintf("    %-25s  5-year: %.1f%%\n", tr, final_risk * 100))
}
cat("\n")

## --- 5c. Diabetes patient ---
dm_patients <- surv_data$patient_id[surv_data$entry_disease == "Diabetes" &
                                     surv_data$status == 1]
pid3 <- unique(dm_patients)[1]

cat(sprintf("--- Patient %d (Diabetes entry) ---\n", pid3))
pred3 <- dynPred(fit, patient_id = pid3, landmark = 0, horizon = 5)
cat("  Predicted transition risks (5-year horizon from t=0):\n")
for (tr in unique(pred3$transition)) {
  final_risk <- tail(pred3$risk[pred3$transition == tr], 1)
  cat(sprintf("    %-25s  5-year: %.1f%%\n", tr, final_risk * 100))
}
cat("\n")

## --- 5d. Landmark prediction (predict from t=2 onward) ---
cat(sprintf("--- Patient %d: Landmark prediction from t=2 years ---\n", pid1))
pred_lm <- dynPred(fit, patient_id = pid1, landmark = 2, horizon = 3)
cat("  Predicted risks (3-year horizon from t=2):\n")
for (tr in unique(pred_lm$transition)) {
  final_risk <- tail(pred_lm$risk[pred_lm$transition == tr], 1)
  cat(sprintf("    %-25s  3-year: %.1f%%\n", tr, final_risk * 100))
}
cat("\n")


## ═══════════════════════════════════════════════════════════════════
##  6. VISUALIZATION: SURFACES, HEATMAPS, SLICES
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 6: Generating Visualizations\n\n")

## Save all plots to a PDF
pdf_file <- "jmSurface_example_plots.pdf"
pdf(pdf_file, width = 10, height = 8)

for (tr in fit$transitions) {
  cat(sprintf("  Plotting: %s (EDF=%.1f)\n", tr, fit$edf[tr]))

  ## 6a. 3D Surface
  plot_surface(fit, transition = tr,
               main = paste0("Association Surface: ", tr,
                             " (EDF=", round(fit$edf[tr], 1), ")"))

  ## 6b. Contour Heatmap
  contour_heatmap(fit, transition = tr,
                  main = paste0("Contour Heatmap: ", tr))

  ## 6c. Marginal Effect Slices
  marginal_slices(fit, transition = tr,
                  main = paste0("Marginal Slices: ", tr))
}

dev.off()
cat(sprintf("\n  All plots saved to: %s\n\n", pdf_file))


## ═══════════════════════════════════════════════════════════════════
##  7. COMPARE MULTIPLE PATIENTS SIDE BY SIDE
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 7: Multi-Patient Risk Comparison\n\n")

## Pick 6 diverse patients
all_pids <- unique(surv_data$patient_id[surv_data$status == 1])
sample_pids <- all_pids[seq(1, length(all_pids), length.out = 6)]

cat("--- Risk Comparison Table (5-year, landmark=0) ---\n")
cat(sprintf("%-10s %-8s %-6s %-5s %-6s %-5s",
            "PatientID", "Entry", "Age", "Sex", "BMI", "Events"))

# Get unique transitions for header
all_trans <- character()
for (pid in sample_pids) {
  pred_tmp <- tryCatch(dynPred(fit, patient_id = pid, landmark = 0, horizon = 5),
                       error = function(e) NULL)
  if (!is.null(pred_tmp)) all_trans <- union(all_trans, unique(pred_tmp$transition))
}

# Print header
for (tr in all_trans) {
  short_tr <- gsub(" -> ", ">", tr)
  cat(sprintf("  %-12s", short_tr))
}
cat("\n")
cat(paste(rep("-", 10 + 8 + 6 + 5 + 6 + 5 + length(all_trans) * 14), collapse = ""), "\n")

for (pid in sample_pids) {
  pi <- pat_info[pat_info$patient_id == pid, ]
  if (nrow(pi) == 0) next
  n_ev <- sum(surv_data$patient_id == pid & surv_data$status == 1)
  sex_lab <- ifelse(pi$sex == 1, "M", "F")

  cat(sprintf("%-10d %-8s %-6.0f %-5s %-6.1f %-5d",
              pid, pi$entry_disease, pi$age_baseline, sex_lab, pi$bmi, n_ev))

  pred <- tryCatch(dynPred(fit, patient_id = pid, landmark = 0, horizon = 5),
                   error = function(e) NULL)

  for (tr in all_trans) {
    if (!is.null(pred) && tr %in% pred$transition) {
      risk_5y <- tail(pred$risk[pred$transition == tr], 1) * 100
      cat(sprintf("  %10.1f%%", risk_5y))
    } else {
      cat(sprintf("  %10s", "N/A"))
    }
  }
  cat("\n")
}
cat("\n")


## ═══════════════════════════════════════════════════════════════════
##  8. POPULATION-LEVEL RISK STRATIFICATION
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 8: Population-Level Risk Stratification\n\n")

## Compute 3-year risk for a sample of patients
set.seed(42)
sample_ids <- sample(all_pids, min(50, length(all_pids)))

risk_table <- data.frame(
  patient_id = integer(),
  age = numeric(),
  sex = integer(),
  bmi = numeric(),
  entry = character(),
  transition = character(),
  risk_3y = numeric(),
  stringsAsFactors = FALSE
)

cat("  Computing 3-year risks for", length(sample_ids), "patients...\n")
for (pid in sample_ids) {
  pred <- tryCatch(
    dynPred(fit, patient_id = pid, landmark = 0, horizon = 3),
    error = function(e) NULL
  )
  if (is.null(pred)) next

  pi <- pat_info[pat_info$patient_id == pid, ]
  if (nrow(pi) == 0) next

  for (tr in unique(pred$transition)) {
    risk_val <- tail(pred$risk[pred$transition == tr], 1)
    risk_table <- rbind(risk_table, data.frame(
      patient_id = pid,
      age = pi$age_baseline,
      sex = pi$sex,
      bmi = pi$bmi,
      entry = pi$entry_disease,
      transition = tr,
      risk_3y = risk_val,
      stringsAsFactors = FALSE
    ))
  }
}

cat("  Done. Risk table has", nrow(risk_table), "rows.\n\n")

## Summary by transition
cat("--- 3-Year Risk Distribution by Transition ---\n")
cat(sprintf("%-25s %8s %8s %8s %8s %8s\n",
            "Transition", "N", "Mean%", "SD%", "Min%", "Max%"))
cat(paste(rep("-", 75), collapse = ""), "\n")
for (tr in unique(risk_table$transition)) {
  r <- risk_table$risk_3y[risk_table$transition == tr] * 100
  cat(sprintf("%-25s %8d %8.1f %8.1f %8.1f %8.1f\n",
              tr, length(r), mean(r), sd(r), min(r), max(r)))
}
cat("\n")

## High-risk patients (>50% 3-year risk for any transition)
high_risk <- risk_table[risk_table$risk_3y > 0.50, ]
if (nrow(high_risk) > 0) {
  cat("--- High-Risk Patients (>50% 3-year risk) ---\n")
  high_risk <- high_risk[order(-high_risk$risk_3y), ]
  for (i in seq_len(min(10, nrow(high_risk)))) {
    h <- high_risk[i, ]
    cat(sprintf("  Patient %d (Age=%.0f, %s): %.1f%% for %s\n",
                h$patient_id, h$age,
                ifelse(h$sex == 1, "Male", "Female"),
                h$risk_3y * 100, h$transition))
  }
} else {
  cat("  No patients exceed 50% 3-year risk threshold.\n")
}
cat("\n")

## Risk by age group
cat("--- Mean 3-Year Risk by Age Group ---\n")
risk_table$age_group <- cut(risk_table$age, breaks = c(0, 50, 60, 70, 80, Inf),
                             labels = c("<50", "50-59", "60-69", "70-79", "80+"))
age_risk <- aggregate(risk_3y ~ age_group + transition, data = risk_table,
                       FUN = function(x) round(mean(x) * 100, 1))
names(age_risk)[3] <- "mean_risk_pct"
print(age_risk[order(age_risk$transition, age_risk$age_group), ], row.names = FALSE)
cat("\n")


## ═══════════════════════════════════════════════════════════════════
##  9. SAVE RESULTS
## ═══════════════════════════════════════════════════════════════════

cat(">>> Step 9: Saving Results\n\n")

## Save fitted model
saveRDS(fit, "jmSurface_fitted_model.rds")
cat("  Saved: jmSurface_fitted_model.rds\n")

## Save EDF diagnostics
write.csv(edf_df, "jmSurface_edf_diagnostics.csv", row.names = FALSE)
cat("  Saved: jmSurface_edf_diagnostics.csv\n")

## Save risk table
write.csv(risk_table, "jmSurface_risk_table.csv", row.names = FALSE)
cat("  Saved: jmSurface_risk_table.csv\n")

## Save example predictions
all_preds <- rbind(pred1, pred2, pred3)
write.csv(all_preds, "jmSurface_example_predictions.csv", row.names = FALSE)
cat("  Saved: jmSurface_example_predictions.csv\n")

cat("  Saved: jmSurface_example_plots.pdf (generated in Step 6)\n")


## ═══════════════════════════════════════════════════════════════════
##  DONE
## ═══════════════════════════════════════════════════════════════════

cat("\n")
cat("================================================================\n")
cat("  EXAMPLE COMPLETE\n")
cat("================================================================\n")
cat("  Model fitted:          ", length(fit$transitions), "transitions\n")
cat("  EDF range:             ", round(min(fit$edf), 1), "-",
    round(max(fit$edf), 1), "\n")
cat("  Patients predicted:    ", length(sample_ids), "\n")
cat("  Output files:\n")
cat("    - jmSurface_fitted_model.rds\n")
cat("    - jmSurface_edf_diagnostics.csv\n")
cat("    - jmSurface_risk_table.csv\n")
cat("    - jmSurface_example_predictions.csv\n")
cat("    - jmSurface_example_plots.pdf\n")
cat("================================================================\n\n")
