Segmented regression

# 24-hour urine volume and 28-day intensive care unit mortality in sepsis 24

# https://www.frontiersin.org/journals/medicine/articles/10.3389/fmed.2024.1486232/full

library(mgcv)

library(segmented)

library(ggplot2)

library(dplyr)

library(caret)

library(lattice)

# Data input and preparation

data <- data.frame(

  urine_volume_group = c("<50", "50-399", "400-999", "1000-1999", ">=2000"),

  n_patients = c(379, 863, 1649, 2136, 2119),

  mortality_percent = c(20.32, 19.81, 11.22, 5.90, 3.82)

)

data$mortality_count <- round(data$n_patients * data$mortality_percent / 100, 0)

# Function to generate simulated data

generate_simulated_data <- function(data) {

  set.seed(2)

  simulated_data <- data.frame()

  for (i in 1:nrow(data)) {

    group <- data[i, ]

    urine_volume <- switch(

      i,

      runif(group$n_patients, 0, 50),

      runif(group$n_patients, 50, 399),

      runif(group$n_patients, 400, 999),

      runif(group$n_patients, 1000, 1999),

      runif(group$n_patients, 2000, 3000)

    )

    mortality <- c(rep(1, group$mortality_count), rep(0, group$n_patients - group$mortality_count))

    simulated_group <- data.frame(

      urine_volume = urine_volume,

      mortality = sample(mortality),

      group = group$urine_volume_group

    )

    simulated_data <- rbind(simulated_data, simulated_group)

  }

  return(simulated_data)

}

# Generate initial dataset

simulated_data <- generate_simulated_data(data)

# Function to find inflection point using segmented regression with multiple starting points

find_inflection_point <- function(train_data) {

  # Define multiple starting points across the range of urine volume

  start_points <- quantile(train_data$urine_volume, probs = c(0.3, 0.4, 0.5, 0.6, 0.7))

  best_model <- NULL

  best_aic <- Inf  

  # Try each starting point

  for (start_point in start_points) {

    tryCatch({

      # Fit initial linear model

      lm_model <- lm(mortality ~ urine_volume, data = train_data)      

      # Try to fit segmented model

      seg_model <- segmented(lm_model, seg.Z = ~ urine_volume, 

                           psi = list(urine_volume = start_point),

                           control = seg.control(display = FALSE))

      # Calculate AIC

      current_aic <- AIC(seg_model)    

      # Update best model if this one has lower AIC

      if (current_aic < best_aic) {

        best_model <- seg_model

        best_aic <- current_aic

      }

    }, error = function(e) NULL, warning = function(w) NULL)

  }

  # Return the breakpoint from the best model

  if (!is.null(best_model) && !is.null(best_model$psi)) {

    return(best_model$psi[1])

  } else {

    return(NA)

  }

}

# Perform 10-fold cross-validation

set.seed(456)

n_folds <- 10

inflection_points <- c()

# Create fold indices

folds <- createFolds(simulated_data$mortality, k = n_folds, list = TRUE)

# Run cross-validation

for (i in seq_along(folds)) {

  cat("Processing fold", i, "of", n_folds, "\n")  

  train_indices <- unlist(folds[-i])  # Use all folds except current one for training

  train_data <- simulated_data[train_indices, ]

    # Find inflection point

  inflection_point <- find_inflection_point(train_data)

  # Only add non-NA points

  if (!is.na(inflection_point)) {

    inflection_points <- c(inflection_points, inflection_point)

  }

}

# Check if we have any valid inflection points

if (length(inflection_points) > 0) {

  # Calculate summary statistics

  inflection_stats <- list(

    mean = mean(inflection_points),

    median = median(inflection_points),

    min = min(inflection_points),

    max = max(inflection_points),

    iqr = IQR(inflection_points),

    sd = sd(inflection_points),

    n_valid = length(inflection_points)

  )

    # Create a box plot of inflection points

  inflection_plot <- ggplot(data.frame(inflection_points = inflection_points), 

                           aes(y = inflection_points, x = "")) +

    geom_boxplot(fill = "lightblue") +

    geom_jitter(width = 0.2, alpha = 0.5, color = "blue") +

    labs(title = paste("Distribution of Inflection Points\n(n =", length(inflection_points), ")"),

         y = "Urine Volume (mL)",

         x = "") +

    theme_minimal() +

    theme(plot.title = element_text(hjust = 0.5))

  # Print summary statistics

  cat("\nSummary Statistics for Inflection Points:\n")

  cat("Number of valid inflection points:", inflection_stats$n_valid, "\n")

  cat("Mean:", round(inflection_stats$mean, 2), "mL\n")

  cat("Median:", round(inflection_stats$median, 2), "mL\n")

  cat("Min:", round(inflection_stats$min, 2), "mL\n")

  cat("Max:", round(inflection_stats$max, 2), "mL\n")

  cat("IQR:", round(inflection_stats$iqr, 2), "mL\n")

  cat("SD:", round(inflection_stats$sd, 2), "mL\n")

# Display the box plot

  print(inflection_plot)  

  # Final model with all data

  final_lm_model <- lm(mortality ~ urine_volume, data = simulated_data)

  final_seg_model <- tryCatch({

    segmented(final_lm_model, seg.Z = ~ urine_volume, 

             psi = list(urine_volume = inflection_stats$median),

             control = seg.control(display = FALSE))

  }, error = function(e) NULL)

    final_gam_model <- gam(mortality ~ s(urine_volume), data = simulated_data, family = binomial)  

  # Create final comparison plot if segmented model converged

  if (!is.null(final_seg_model)) {

    simulated_data$gam_pred <- predict(final_gam_model, type = "response")

    simulated_data$seg_pred <- predict(final_seg_model)

    final_plot <- ggplot(simulated_data, aes(x = urine_volume)) +

      geom_point(aes(y = mortality), color = "blue", alpha = 0.2) +

      geom_line(aes(y = gam_pred, color = "GAM"), size = 1) +

      geom_line(aes(y = seg_pred, color = "Segmented"), size = 1) +

      geom_vline(xintercept = inflection_stats$median, 

                 color = "red", linetype = "dashed") +

      scale_color_manual(values = c("GAM" = "green", "Segmented" = "red")) +

      labs(title = "Model Comparison with Median Inflection Point",

           subtitle = paste("Median inflection point:", round(inflection_stats$median, 2), "mL"),

           y = "Mortality",

           x = "Urine Volume (mL)",

           color = "Model Type") +

      theme_minimal()

    

    print(final_plot)

  }

} else {

  cat("No valid inflection points were found during cross-validation.\n")

  cat("Consider adjusting the model parameters or checking the data distribution.\n")

}

留言

這個網誌中的熱門文章

可轉移性、普遍性、代表性和外部有效性

頻率學派 vs 貝氏學派

貝氏分析計算器