# 11_YIELD_PREDICTION_COMPARISON.R # ================================== # This script compares yield prediction models with different predictor variables: # 1. CI-only model (cumulative_CI, DOY, CI_per_day) # 2. CI + Ratoon model # 3. CI + Ratoon + Additional variables (irrigation, variety) # # Outputs include: # - Model performance metrics (RMSE, R², MAE) # - Predicted vs Actual yield scatter plots # - Feature importance analysis # - Cross-validation results # # Usage: Rscript 11_yield_prediction_comparison.R [project_dir] # - project_dir: Project directory name (e.g., "esa", "aura") # 1. Load required libraries # ------------------------- suppressPackageStartupMessages({ library(here) library(sf) library(terra) library(dplyr) library(tidyr) library(lubridate) library(readr) library(readxl) library(caret) library(CAST) library(randomForest) library(ggplot2) library(gridExtra) }) # 2. Helper Functions # ----------------- #' Safe logging function safe_log <- function(message, level = "INFO") { timestamp <- format(Sys.time(), "%Y-%m-%d %H:%M:%S") cat(sprintf("[%s] %s: %s\n", timestamp, level, message)) } #' Prepare predictions with consistent naming and formatting prepare_predictions <- function(predictions, newdata) { # Simple version - just add predictions to the data result <- newdata %>% dplyr::mutate(predicted_TCH = round(as.numeric(predictions), 1)) return(result) } #' Calculate model performance metrics calculate_metrics <- function(predicted, actual) { valid_idx <- !is.na(predicted) & !is.na(actual) predicted <- predicted[valid_idx] actual <- actual[valid_idx] if (length(predicted) == 0) { return(list( RMSE = NA, MAE = NA, R2 = NA, n = 0 )) } rmse <- sqrt(mean((predicted - actual)^2)) mae <- mean(abs(predicted - actual)) r2 <- cor(predicted, actual)^2 return(list( RMSE = round(rmse, 2), MAE = round(mae, 2), R2 = round(r2, 3), n = length(predicted) )) } #' Create predicted vs actual plot create_prediction_plot <- function(predicted, actual, model_name, metrics) { plot_data <- data.frame( Predicted = predicted, Actual = actual ) %>% filter(!is.na(Predicted) & !is.na(Actual)) # Calculate plot limits to make axes equal min_val <- min(c(plot_data$Predicted, plot_data$Actual)) max_val <- max(c(plot_data$Predicted, plot_data$Actual)) p <- ggplot(plot_data, aes(x = Actual, y = Predicted)) + geom_point(alpha = 0.6, size = 3, color = "#2E86AB") + geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "red", linewidth = 1) + geom_smooth(method = "lm", se = TRUE, color = "#A23B72", fill = "#A23B72", alpha = 0.2) + coord_fixed(xlim = c(min_val, max_val), ylim = c(min_val, max_val)) + labs( title = paste("Yield Prediction:", model_name), subtitle = sprintf("RMSE: %.2f t/ha | MAE: %.2f t/ha | R²: %.3f | n: %d", metrics$RMSE, metrics$MAE, metrics$R2, metrics$n), x = "Actual TCH (t/ha)", y = "Predicted TCH (t/ha)" ) + theme_minimal() + theme( plot.title = element_text(face = "bold", size = 10), plot.subtitle = element_text(size = 9, color = "gray40"), axis.title = element_text(size = 10), axis.text = element_text(size = 9), panel.grid.minor = element_blank(), panel.border = element_rect(color = "gray80", fill = NA, linewidth = 1) ) return(p) } #' Load and prepare yield data from Excel load_yield_data <- function(excel_path) { safe_log(paste("Loading yield data from:", excel_path)) if (!file.exists(excel_path)) { stop(paste("Yield data file not found:", excel_path)) } yield_data <- readxl::read_excel(excel_path) %>% dplyr::mutate( # Extract year from Harvest_Date year = lubridate::year(Harvest_Date), # Rename columns for consistency tonnage_ha = TCH, # Ensure Ratoon is numeric Ratoon = as.numeric(Ratoon), # Create ratoon category (0 = plant cane, 1-2 = young ratoon, 3+ = old ratoon) Ratoon_Category = dplyr::case_when( Ratoon == 0 ~ "Plant Cane", Ratoon <= 2 ~ "Young Ratoon (1-2)", Ratoon <= 5 ~ "Mid Ratoon (3-5)", TRUE ~ "Old Ratoon (6+)" ), # Create irrigation category Irrigation_Category = dplyr::case_when( grepl("Pivot|pivot", Irrig_Type) ~ "Center Pivot", grepl("Drip|drip", Irrig_Type) ~ "Drip", grepl("Sprinkler|Spl", Irrig_Type) ~ "Sprinkler", TRUE ~ "Other" ) ) %>% dplyr::select( GROWER, Field, `Area (Ha)`, Cane_Variety, Ratoon, Ratoon_Category, Irrig_Type, Irrigation_Category, Cut_age, Harvest_Date, year, tonnage_ha, TSH ) %>% # Rename Field to sub_field for consistency with CI data dplyr::rename(sub_field = Field) safe_log(paste("Loaded", nrow(yield_data), "yield records")) safe_log(paste("Years covered:", paste(unique(yield_data$year), collapse = ", "))) safe_log(paste("Ratoon range:", min(yield_data$Ratoon), "to", max(yield_data$Ratoon))) return(yield_data) } # 3. Main Function # -------------- main <- function() { # Process command line arguments args <- commandArgs(trailingOnly = TRUE) # Process project_dir argument if (length(args) >= 1 && !is.na(args[1])) { project_dir <- as.character(args[1]) } else { project_dir <- "esa" # Default project } # Make project_dir available globally assign("project_dir", project_dir, envir = .GlobalEnv) safe_log("=== YIELD PREDICTION MODEL COMPARISON ===") safe_log(paste("Project:", project_dir)) # 4. Load project configuration # --------------------------- tryCatch({ source(here("r_app", "parameters_project.R")) }, error = function(e) { stop("Error loading parameters_project.R: ", e$message) }) # 5. Load yield data from multi-tab Excel file # ------------------------------------------- yield_excel_path <- file.path( "laravel_app", "storage", "app", project_dir, "Data", paste0(project_dir, "_yield_data.xlsx") ) safe_log(paste("Loading yield data from:", yield_excel_path)) # Get all sheet names sheet_names <- readxl::excel_sheets(here(yield_excel_path)) safe_log(paste("Found", length(sheet_names), "sheets:", paste(sheet_names, collapse = ", "))) # Read all sheets and combine them yield_data_list <- lapply(sheet_names, function(sheet_name) { safe_log(paste("Reading sheet:", sheet_name)) # Extract year from sheet name # Format is typically "YYYY-YY" (e.g., "2023-24" means harvest year 2024) # Take the SECOND year (harvest year) year_matches <- regmatches(sheet_name, gregexpr("[0-9]{4}|[0-9]{2}", sheet_name))[[1]] if (length(year_matches) >= 2) { # Second number is the harvest year second_year <- year_matches[2] # Convert 2-digit to 4-digit year if (nchar(second_year) == 2) { year_value <- as.numeric(paste0("20", second_year)) } else { year_value <- as.numeric(second_year) } } else if (length(year_matches) == 1) { # Only one year found, use it year_value <- as.numeric(year_matches[1]) } else { year_value <- NA } safe_log(paste(" Sheet:", sheet_name, "-> Year:", year_value)) df <- readxl::read_excel(here(yield_excel_path), sheet = sheet_name) %>% dplyr::mutate( sheet_name = sheet_name, season = ifelse(is.na(year_value), ifelse("year" %in% names(.), year, NA), year_value) ) # Try to standardize column names if ("Field" %in% names(df) && !"sub_field" %in% names(df)) { df <- df %>% dplyr::rename(sub_field = Field) } return(df) }) # Combine all sheets yield_data_full <- dplyr::bind_rows(yield_data_list) %>% dplyr::filter(!is.na(season)) %>% dplyr::mutate( # Ensure Ratoon is numeric Ratoon = as.numeric(Ratoon), # Create ratoon category Ratoon_Category = dplyr::case_when( Ratoon == 0 ~ "Plant Cane", Ratoon <= 2 ~ "Young Ratoon (1-2)", Ratoon <= 5 ~ "Mid Ratoon (3-5)", TRUE ~ "Old Ratoon (6+)" ), # Create irrigation category if Irrig_Type exists Irrigation_Category = if("Irrig_Type" %in% names(.)) { dplyr::case_when( grepl("Pivot|pivot", Irrig_Type) ~ "Center Pivot", grepl("Drip|drip", Irrig_Type) ~ "Drip", grepl("Sprinkler|Spl", Irrig_Type) ~ "Sprinkler", TRUE ~ "Other" ) } else { NA_character_ }, # Rename tonnage column if needed tonnage_ha = if("TCH" %in% names(.)) TCH else if("tonnage_ha" %in% names(.)) tonnage_ha else NA_real_ ) safe_log(paste("Loaded", nrow(yield_data_full), "yield records from all sheets")) safe_log(paste("Years covered:", paste(sort(unique(yield_data_full$season)), collapse = ", "))) safe_log(paste("Ratoon range:", min(yield_data_full$Ratoon, na.rm = TRUE), "to", max(yield_data_full$Ratoon, na.rm = TRUE))) safe_log(paste("Fields with yield data:", length(unique(yield_data_full$sub_field[!is.na(yield_data_full$tonnage_ha)])))) # 6. Load CI data # ------------- safe_log("Loading cumulative CI data") CI_quadrant <- readRDS(here(cumulative_CI_vals_dir, "All_pivots_Cumulative_CI_quadrant_year_v2.rds")) %>% dplyr::group_by(model) %>% tidyr::fill(field, sub_field, .direction = "downup") %>% dplyr::ungroup() # 7. Merge CI and yield data # ------------------------ safe_log("Merging CI and yield data") # Get maximum DOY (end of season) for each field/season combination CI_summary <- CI_quadrant %>% dplyr::group_by(sub_field, season) %>% dplyr::slice(which.max(DOY)) %>% dplyr::select(field, sub_field, cumulative_CI, DOY, season) %>% dplyr::mutate(CI_per_day = cumulative_CI / DOY) %>% dplyr::ungroup() # 7a. Calculate advanced time series features from CI data # ------------------------------------------------------- safe_log("Calculating time series-derived features") CI_features <- CI_quadrant %>% dplyr::group_by(sub_field, season) %>% dplyr::arrange(DOY) %>% dplyr::mutate( # Calculate daily CI increments daily_CI_increment = cumulative_CI - dplyr::lag(cumulative_CI, default = 0) ) %>% dplyr::summarise( # 1. Growth rate (linear slope of CI over time) CI_growth_rate = ifelse(n() > 2, coef(lm(cumulative_CI ~ DOY))[2], NA_real_), # 2. Early season CI (first 150 days) early_season_CI = sum(cumulative_CI[DOY <= 150], na.rm = TRUE), # 3. Growth consistency (coefficient of variation of daily increments) growth_consistency_cv = sd(daily_CI_increment, na.rm = TRUE) / mean(daily_CI_increment[daily_CI_increment > 0], na.rm = TRUE), # 4. Peak growth rate peak_CI_per_day = max(daily_CI_increment, na.rm = TRUE), # 5. Number of stress events (CI drops) stress_events = sum(daily_CI_increment < 0, na.rm = TRUE), # 6. Late season CI (last 60 days) late_season_CI = sum(cumulative_CI[DOY >= max(DOY) - 60], na.rm = TRUE), .groups = 'drop' ) %>% # Handle infinite values dplyr::mutate( growth_consistency_cv = ifelse(is.infinite(growth_consistency_cv) | is.nan(growth_consistency_cv), NA_real_, growth_consistency_cv) ) # Merge features back into CI_summary CI_summary <- CI_summary %>% dplyr::left_join(CI_features, by = c("sub_field", "season")) safe_log(sprintf("Added %d time series features", ncol(CI_features) - 2)) # 7b. Merge CI and yield data # ------------------------- safe_log("Merging CI and yield data") # Join with yield data to get yield, ratoon, and other information combined_data_all <- CI_summary %>% dplyr::left_join( yield_data_full, by = c("sub_field", "season") ) # Training data: completed seasons with yield data (mature fields only) training_data <- combined_data_all %>% dplyr::filter( !is.na(tonnage_ha), !is.na(cumulative_CI), DOY >= 240 # Only mature fields (>= 8 months) ) # Prediction data: future/current seasons without yield data (mature fields only) current_year <- lubridate::year(Sys.Date()) prediction_data <- combined_data_all %>% dplyr::filter( is.na(tonnage_ha), !is.na(cumulative_CI), !is.na(DOY), !is.na(CI_per_day), !is.na(Ratoon), # Ensure Ratoon is not NA for Model 2 DOY >= 240, # Only mature fields season >= current_year # Current and future seasons ) safe_log(paste("Training dataset:", nrow(training_data), "records")) safe_log(paste("Training fields:", length(unique(training_data$sub_field)))) safe_log(paste("Training seasons:", paste(sort(unique(training_data$season)), collapse = ", "))) safe_log(paste("\nPrediction dataset:", nrow(prediction_data), "records")) safe_log(paste("Prediction fields:", length(unique(prediction_data$sub_field)))) safe_log(paste("Prediction seasons:", paste(sort(unique(prediction_data$season)), collapse = ", "))) # Check if we have enough data if (nrow(training_data) < 10) { stop("Insufficient training data (need at least 10 records)") } # 8. Prepare datasets for modeling # ------------------------------ safe_log("Preparing datasets for modeling") # Define predictors for each model ci_predictors <- c("cumulative_CI", "DOY", "CI_per_day") ci_ratoon_predictors <- c("cumulative_CI", "DOY", "CI_per_day", "Ratoon") ci_ratoon_full_predictors <- c("cumulative_CI", "DOY", "CI_per_day", "Ratoon", "Irrigation_Category", "Cane_Variety") ci_timeseries_predictors <- c("cumulative_CI", "DOY", "CI_per_day", "Ratoon", "CI_growth_rate", "early_season_CI", "growth_consistency_cv", "peak_CI_per_day", "stress_events", "late_season_CI") response <- "tonnage_ha" # Configure cross-validation (5-fold CV) set.seed(206) # For reproducible splits ctrl <- caret::trainControl( method = "cv", number = 5, savePredictions = TRUE, verboseIter = TRUE ) # 9. Train Model 1: CI-only # ----------------------- safe_log("\n=== MODEL 1: CI PREDICTORS ONLY ===") set.seed(206) model_ci <- CAST::ffs( training_data[, ci_predictors], training_data[[response]], # Use [[ to extract as vector method = "rf", trControl = ctrl, importance = TRUE, withinSE = TRUE, tuneLength = 3, na.action = na.omit ) # Get predictions on training data (for validation metrics) pred_ci_train <- prepare_predictions( stats::predict(model_ci, newdata = training_data), training_data ) # Calculate metrics on training data metrics_ci <- calculate_metrics( pred_ci_train$predicted_TCH, training_data$tonnage_ha ) safe_log(sprintf("Model 1 - RMSE: %.2f | MAE: %.2f | R²: %.3f", metrics_ci$RMSE, metrics_ci$MAE, metrics_ci$R2)) # Report fold-level CV results cv_summary_ci <- model_ci$resample safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)", mean(cv_summary_ci$RMSE), sd(cv_summary_ci$RMSE), min(cv_summary_ci$RMSE), max(cv_summary_ci$RMSE))) safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)", mean(cv_summary_ci$Rsquared), sd(cv_summary_ci$Rsquared), min(cv_summary_ci$Rsquared), max(cv_summary_ci$Rsquared))) # Predict on future seasons if (nrow(prediction_data) > 0) { pred_ci_future <- prepare_predictions( stats::predict(model_ci, newdata = prediction_data), prediction_data ) safe_log(sprintf("Model 1 - Future predictions: %d fields", nrow(pred_ci_future))) } else { pred_ci_future <- NULL safe_log("Model 1 - No future data to predict on", "WARNING") } # 10. Train Model 2: CI + Ratoon # ---------------------------- safe_log("\n=== MODEL 2: CI + RATOON ===") set.seed(206) model_ci_ratoon <- CAST::ffs( training_data[, ci_ratoon_predictors], training_data[[response]], # Use [[ to extract as vector method = "rf", trControl = ctrl, importance = TRUE, withinSE = TRUE, tuneLength = 3, na.action = na.omit ) # Get predictions on training data (for validation metrics) pred_ci_ratoon_train <- prepare_predictions( stats::predict(model_ci_ratoon, newdata = training_data), training_data ) # Calculate metrics on training data metrics_ci_ratoon <- calculate_metrics( pred_ci_ratoon_train$predicted_TCH, training_data$tonnage_ha ) safe_log(sprintf("Model 2 - RMSE: %.2f | MAE: %.2f | R²: %.3f", metrics_ci_ratoon$RMSE, metrics_ci_ratoon$MAE, metrics_ci_ratoon$R2)) # Report fold-level CV results cv_summary_ci_ratoon <- model_ci_ratoon$resample safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)", mean(cv_summary_ci_ratoon$RMSE), sd(cv_summary_ci_ratoon$RMSE), min(cv_summary_ci_ratoon$RMSE), max(cv_summary_ci_ratoon$RMSE))) safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)", mean(cv_summary_ci_ratoon$Rsquared), sd(cv_summary_ci_ratoon$Rsquared), min(cv_summary_ci_ratoon$Rsquared), max(cv_summary_ci_ratoon$Rsquared))) # Predict on future seasons if (nrow(prediction_data) > 0) { pred_ci_ratoon_future <- prepare_predictions( stats::predict(model_ci_ratoon, newdata = prediction_data), prediction_data ) safe_log(sprintf("Model 2 - Future predictions: %d fields", nrow(pred_ci_ratoon_future))) } else { pred_ci_ratoon_future <- NULL safe_log("Model 2 - No future data to predict on", "WARNING") } # 11. Train Model 3: CI + Ratoon + Full variables # --------------------------------------------- safe_log("\n=== MODEL 3: CI + RATOON + IRRIGATION + VARIETY ===") set.seed(206) # Filter out records with missing categorical variables training_data_full <- training_data %>% dplyr::filter( !is.na(Irrigation_Category), !is.na(Cane_Variety) ) prediction_data_full <- prediction_data %>% dplyr::filter( !is.na(Irrigation_Category), !is.na(Cane_Variety) ) if (nrow(training_data_full) >= 10) { model_ci_ratoon_full <- CAST::ffs( training_data_full[, ci_ratoon_full_predictors], training_data_full[[response]], # Use [[ to extract as vector method = "rf", trControl = ctrl, importance = TRUE, withinSE = TRUE, tuneLength = 3, na.action = na.omit ) # Get predictions on training data (for validation metrics) pred_ci_ratoon_full_train <- prepare_predictions( stats::predict(model_ci_ratoon_full, newdata = training_data_full), training_data_full ) # Calculate metrics on training data metrics_ci_ratoon_full <- calculate_metrics( pred_ci_ratoon_full_train$predicted_TCH, training_data_full$tonnage_ha ) safe_log(sprintf("Model 3 - RMSE: %.2f | MAE: %.2f | R²: %.3f", metrics_ci_ratoon_full$RMSE, metrics_ci_ratoon_full$MAE, metrics_ci_ratoon_full$R2)) # Report fold-level CV results cv_summary_full <- model_ci_ratoon_full$resample safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)", mean(cv_summary_full$RMSE), sd(cv_summary_full$RMSE), min(cv_summary_full$RMSE), max(cv_summary_full$RMSE))) safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)", mean(cv_summary_full$Rsquared), sd(cv_summary_full$Rsquared), min(cv_summary_full$Rsquared), max(cv_summary_full$Rsquared))) # Predict on future seasons if (nrow(prediction_data_full) > 0) { pred_ci_ratoon_full_future <- prepare_predictions( stats::predict(model_ci_ratoon_full, newdata = prediction_data_full), prediction_data_full ) safe_log(sprintf("Model 3 - Future predictions: %d fields", nrow(pred_ci_ratoon_full_future))) } else { pred_ci_ratoon_full_future <- NULL safe_log("Model 3 - No future data to predict on", "WARNING") } } else { safe_log("Insufficient data for full model, skipping", "WARNING") model_ci_ratoon_full <- NULL metrics_ci_ratoon_full <- NULL pred_ci_ratoon_full_future <- NULL } # 11d. Train Model 4: CI + Ratoon + Time Series Features # ------------------------------------------------------ safe_log("\n=== MODEL 4: CI + RATOON + TIME SERIES FEATURES ===") set.seed(206) # Filter training data to ensure all time series features are present training_data_ts <- training_data %>% dplyr::filter( !is.na(CI_growth_rate), !is.na(early_season_CI), !is.na(growth_consistency_cv), !is.na(peak_CI_per_day), !is.na(stress_events), !is.na(late_season_CI) ) # Filter prediction data similarly prediction_data_ts <- prediction_data %>% dplyr::filter( !is.na(CI_growth_rate), !is.na(early_season_CI), !is.na(growth_consistency_cv), !is.na(peak_CI_per_day), !is.na(stress_events), !is.na(late_season_CI) ) safe_log(sprintf("Model 4 training records: %d", nrow(training_data_ts))) if (nrow(training_data_ts) >= 10) { model_ci_timeseries <- CAST::ffs( training_data_ts[, ci_timeseries_predictors], training_data_ts[[response]], method = "rf", trControl = ctrl, importance = TRUE, withinSE = TRUE, tuneLength = 3, na.action = na.omit ) # Get predictions on training data pred_ci_timeseries_train <- prepare_predictions( stats::predict(model_ci_timeseries, newdata = training_data_ts), training_data_ts ) # Calculate metrics metrics_ci_timeseries <- calculate_metrics( pred_ci_timeseries_train$predicted_TCH, training_data_ts$tonnage_ha ) safe_log(sprintf("Model 4 - RMSE: %.2f | MAE: %.2f | R²: %.3f", metrics_ci_timeseries$RMSE, metrics_ci_timeseries$MAE, metrics_ci_timeseries$R2)) # Report fold-level CV results cv_summary_ts <- model_ci_timeseries$resample safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)", mean(cv_summary_ts$RMSE), sd(cv_summary_ts$RMSE), min(cv_summary_ts$RMSE), max(cv_summary_ts$RMSE))) safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)", mean(cv_summary_ts$Rsquared), sd(cv_summary_ts$Rsquared), min(cv_summary_ts$Rsquared), max(cv_summary_ts$Rsquared))) # Predict on future seasons if (nrow(prediction_data_ts) > 0) { pred_ci_timeseries_future <- prepare_predictions( stats::predict(model_ci_timeseries, newdata = prediction_data_ts), prediction_data_ts ) safe_log(sprintf("Model 4 - Future predictions: %d fields", nrow(pred_ci_timeseries_future))) } else { pred_ci_timeseries_future <- NULL safe_log("Model 4 - No future data to predict on", "WARNING") } } else { safe_log("Insufficient data for time series model, skipping", "WARNING") model_ci_timeseries <- NULL metrics_ci_timeseries <- NULL pred_ci_timeseries_future <- NULL } # 12. Create comparison plots # ------------------------- safe_log("\n=== CREATING VISUALIZATION ===") # Create output directory output_dir <- file.path(reports_dir, "yield_prediction") dir.create(output_dir, recursive = TRUE, showWarnings = FALSE) # Create plots for training/validation plot_ci <- create_prediction_plot( pred_ci_train$predicted_TCH, training_data$tonnage_ha, "CI Only (Training Data)", metrics_ci ) plot_ci_ratoon <- create_prediction_plot( pred_ci_ratoon_train$predicted_TCH, training_data$tonnage_ha, "CI + Ratoon (Training Data)", metrics_ci_ratoon ) if (!is.null(model_ci_ratoon_full)) { # Get actual selected variables for Model 3 model3_vars <- paste(model_ci_ratoon_full$selectedvars, collapse = ", ") plot_ci_ratoon_full <- create_prediction_plot( pred_ci_ratoon_full_train$predicted_TCH, training_data_full$tonnage_ha, paste0("Model 3: ", model3_vars), metrics_ci_ratoon_full ) } else { plot_ci_ratoon_full <- NULL } if (!is.null(model_ci_timeseries)) { # Get actual selected variables for Model 4 model4_vars <- paste(model_ci_timeseries$selectedvars, collapse = ", ") plot_ci_timeseries <- create_prediction_plot( pred_ci_timeseries_train$predicted_TCH, training_data_ts$tonnage_ha, paste0("Model 4: ", model4_vars), metrics_ci_timeseries ) } else { plot_ci_timeseries <- NULL } # Determine which prediction data to use for table (prioritize Model 4, then 2, then 3) if (!is.null(pred_ci_timeseries_future) && nrow(pred_ci_timeseries_future) > 0) { future_preds_for_table <- pred_ci_timeseries_future table_model_name <- "Model 4 (Time Series)" } else if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) { future_preds_for_table <- pred_ci_ratoon_future table_model_name <- "Model 2 (CI + Ratoon)" } else if (!is.null(pred_ci_ratoon_full_future) && nrow(pred_ci_ratoon_full_future) > 0) { future_preds_for_table <- pred_ci_ratoon_full_future table_model_name <- "Model 3 (Full)" } else { future_preds_for_table <- NULL table_model_name <- NULL } # Create prediction table if (!is.null(future_preds_for_table)) { pred_table_data <- future_preds_for_table %>% dplyr::select(sub_field, season, predicted_TCH, Ratoon, DOY) %>% dplyr::arrange(desc(predicted_TCH)) %>% head(10) pred_table <- gridExtra::tableGrob( pred_table_data, rows = NULL, theme = gridExtra::ttheme_default( core = list(fg_params = list(cex = 0.6)), colhead = list(fg_params = list(cex = 0.7, fontface = "bold")) ) ) pred_text <- grid::textGrob( paste0("Future Yield Predictions - ", table_model_name, "\n", "(", nrow(future_preds_for_table), " fields in seasons ", paste(unique(future_preds_for_table$season), collapse = ", "), ")\n", "Top 10 predicted yields shown"), x = 0.5, y = 0.98, just = c("center", "top"), gp = grid::gpar(fontsize = 8, fontface = "bold") ) pred_panel <- gridExtra::arrangeGrob(pred_text, pred_table, ncol = 1, heights = c(0.25, 0.75)) } else { pred_panel <- grid::textGrob( "No future predictions available\n(No mature fields in current/future seasons)", gp = grid::gpar(fontsize = 10, col = "gray50") ) } # Combine all plots (3x2 grid: 3 rows, 2 columns) # Row 1: Model 1 and Model 2 # Row 2: Model 3 and Model 4 # Row 3: Feature explanations and Prediction table # Create feature explanation panel feature_text <- paste0( "SELECTED FEATURES BY MODEL (via Forward Feature Selection)\n\n", "Model 1: ", paste(model_ci$selectedvars, collapse = ", "), "\n", "Model 2: ", paste(model_ci_ratoon$selectedvars, collapse = ", "), "\n", if (!is.null(model_ci_ratoon_full)) paste0("Model 3: ", paste(model_ci_ratoon_full$selectedvars, collapse = ", "), "\n") else "", if (!is.null(model_ci_timeseries)) paste0("Model 4: ", paste(model_ci_timeseries$selectedvars, collapse = ", "), "\n\n") else "\n", "FEATURE DEFINITIONS:\n", "• cumulative_CI: Total CI accumulated from planting to harvest\n", "• DOY: Day of year at harvest (crop age proxy)\n", "• CI_per_day: Daily average CI (cumulative_CI / DOY)\n", "• Ratoon: Crop cycle number (0=plant cane, 1+=ratoon)\n", if (!is.null(model_ci_timeseries)) paste0( "• CI_growth_rate: Linear slope of CI over time (vigor)\n", "• growth_consistency_cv: CV of daily CI increments (stability)\n", "• early_season_CI: CI accumulated in first 150 days\n", "• peak_CI_per_day: Maximum daily CI increment observed\n", "• stress_events: Count of negative CI changes\n", "• late_season_CI: CI in final 60 days before harvest\n" ) else "", if (!is.null(model_ci_ratoon_full)) paste0( "• Irrigation_Category: Irrigation system type\n", "• Cane_Variety: Sugarcane variety planted" ) else "" ) feature_panel <- grid::textGrob( feature_text, x = 0.05, y = 0.95, just = c("left", "top"), gp = grid::gpar(fontsize = 7, fontfamily = "mono") ) if (!is.null(plot_ci_ratoon_full) && !is.null(plot_ci_timeseries)) { combined_plot <- gridExtra::grid.arrange( plot_ci, plot_ci_ratoon, plot_ci_ratoon_full, plot_ci_timeseries, feature_panel, pred_panel, ncol = 2, nrow = 3, heights = c(1.2, 1.2, 0.9), # Make plots bigger, bottom row smaller layout_matrix = rbind(c(1, 2), c(3, 4), c(5, 6)) ) } else if (!is.null(plot_ci_ratoon_full)) { # Only 3 models available combined_plot <- gridExtra::grid.arrange( plot_ci, plot_ci_ratoon, plot_ci_ratoon_full, pred_panel, ncol = 2, nrow = 2, top = grid::textGrob( paste("Yield Prediction Model Comparison -", project_dir, "\nTraining on", paste(sort(unique(training_data$season)), collapse = ", ")), gp = grid::gpar(fontsize = 16, fontface = "bold") ) ) } else { # Create prediction table for bottom (2-plot layout) if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) { pred_table_data <- pred_ci_ratoon_future %>% dplyr::select(sub_field, season, predicted_TCH, Ratoon, DOY) %>% dplyr::arrange(desc(predicted_TCH)) pred_table <- gridExtra::tableGrob( pred_table_data, rows = NULL, theme = gridExtra::ttheme_default( core = list(fg_params = list(cex = 0.7)), colhead = list(fg_params = list(cex = 0.8, fontface = "bold")) ) ) pred_text <- grid::textGrob( paste0("Future Yield Predictions (", nrow(pred_ci_ratoon_future), " fields in seasons ", paste(unique(pred_ci_ratoon_future$season), collapse = ", "), ")"), x = 0.5, y = 0.95, gp = grid::gpar(fontsize = 10, fontface = "bold") ) pred_panel <- gridExtra::arrangeGrob(pred_text, pred_table, ncol = 1, heights = c(0.1, 0.9)) # Combine two plots + prediction table combined_plot <- gridExtra::grid.arrange( plot_ci, plot_ci_ratoon, pred_panel, layout_matrix = rbind(c(1, 2), c(3, 3)), top = grid::textGrob( paste("Yield Prediction Model Comparison -", project_dir, "\nTraining on", paste(sort(unique(training_data$season)), collapse = ", ")), gp = grid::gpar(fontsize = 16, fontface = "bold") ) ) } else { # Combine two plots only combined_plot <- gridExtra::grid.arrange( plot_ci, plot_ci_ratoon, ncol = 2, top = grid::textGrob( paste("Yield Prediction Model Comparison -", project_dir, "\nTraining on", paste(sort(unique(training_data$season)), collapse = ", ")), gp = grid::gpar(fontsize = 16, fontface = "bold") ) ) } } # Save plot plot_file <- file.path(output_dir, paste0(project_dir, "_yield_prediction_comparison.png")) ggsave(plot_file, combined_plot, width = 16, height = 12, dpi = 300) safe_log(paste("Comparison plot saved to:", plot_file)) # 12b. Save future predictions to CSV # --------------------------------- if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) { future_pred_file <- file.path(output_dir, paste0(project_dir, "_future_predictions.csv")) future_pred_export <- pred_ci_ratoon_future %>% dplyr::select(field, sub_field, season, DOY, predicted_TCH, cumulative_CI, CI_per_day, Ratoon) %>% dplyr::arrange(desc(predicted_TCH)) readr::write_csv(future_pred_export, future_pred_file) safe_log(paste("Future predictions saved to:", future_pred_file)) } # 13. Create feature importance plot # -------------------------------- safe_log("Creating feature importance plot") # Extract variable importance from CI + Ratoon model var_imp <- caret::varImp(model_ci_ratoon)$importance %>% tibble::rownames_to_column("Variable") %>% dplyr::arrange(desc(Overall)) %>% dplyr::mutate(Variable = factor(Variable, levels = Variable)) imp_plot <- ggplot(var_imp, aes(x = Overall, y = Variable)) + geom_col(fill = "#2E86AB") + labs( title = "Feature Importance: CI + Ratoon Model", x = "Importance", y = "Variable" ) + theme_minimal() + theme( plot.title = element_text(face = "bold", size = 14), axis.title = element_text(size = 12) ) imp_file <- file.path(output_dir, paste0(project_dir, "_feature_importance.png")) ggsave(imp_file, imp_plot, width = 8, height = 6, dpi = 300) safe_log(paste("Feature importance plot saved to:", imp_file)) # 14. Create comparison table # ------------------------- comparison_table <- data.frame( Model = c("CI Only", "CI + Ratoon", "CI + Ratoon + Full", "CI + Ratoon + Time Series"), Predictors = c( paste(ci_predictors, collapse = ", "), paste(ci_ratoon_predictors, collapse = ", "), paste(ci_ratoon_full_predictors, collapse = ", "), paste(ci_timeseries_predictors, collapse = ", ") ), RMSE = c( metrics_ci$RMSE, metrics_ci_ratoon$RMSE, ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$RMSE), ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$RMSE) ), MAE = c( metrics_ci$MAE, metrics_ci_ratoon$MAE, ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$MAE), ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$MAE) ), R2 = c( metrics_ci$R2, metrics_ci_ratoon$R2, ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$R2), ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$R2) ), N = c( metrics_ci$n, metrics_ci_ratoon$n, ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$n), ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$n) ) ) # Save comparison table comparison_file <- file.path(output_dir, paste0(project_dir, "_model_comparison.csv")) readr::write_csv(comparison_table, comparison_file) safe_log(paste("Model comparison table saved to:", comparison_file)) # 15. Save model objects # -------------------- models_file <- file.path(output_dir, paste0(project_dir, "_yield_models.rds")) saveRDS(list( model1 = model_ci, model2 = model_ci_ratoon, model3 = model_ci_ratoon_full, model4 = model_ci_timeseries, metrics_ci = metrics_ci, metrics_ci_ratoon = metrics_ci_ratoon, metrics_ci_ratoon_full = metrics_ci_ratoon_full, metrics_ci_timeseries = metrics_ci_timeseries, training_predictions_ci = pred_ci_train, training_predictions_ci_ratoon = pred_ci_ratoon_train, training_predictions_ci_ratoon_full = if(!is.null(model_ci_ratoon_full)) pred_ci_ratoon_full_train else NULL, training_predictions_ci_timeseries = if(!is.null(model_ci_timeseries)) pred_ci_timeseries_train else NULL, future_predictions_ci = pred_ci_future, future_predictions_ci_ratoon = pred_ci_ratoon_future, future_predictions_ci_ratoon_full = pred_ci_ratoon_full_future, future_predictions_ci_timeseries = pred_ci_timeseries_future, training_data = training_data, prediction_data = prediction_data ), models_file) safe_log(paste("Model objects saved to:", models_file)) # 16. Print summary # --------------- cat("\n=== YIELD PREDICTION MODEL COMPARISON SUMMARY ===\n") print(comparison_table) cat("\n=== IMPROVEMENT ANALYSIS ===\n") rmse_improvement <- ((metrics_ci$RMSE - metrics_ci_ratoon$RMSE) / metrics_ci$RMSE) * 100 r2_improvement <- ((metrics_ci_ratoon$R2 - metrics_ci$R2) / metrics_ci$R2) * 100 cat(sprintf("Adding Ratoon to CI model:\n")) cat(sprintf(" - RMSE improvement: %.1f%% (%.2f → %.2f t/ha)\n", rmse_improvement, metrics_ci$RMSE, metrics_ci_ratoon$RMSE)) cat(sprintf(" - R² improvement: %.1f%% (%.3f → %.3f)\n", r2_improvement, metrics_ci$R2, metrics_ci_ratoon$R2)) if (!is.null(metrics_ci_ratoon_full)) { rmse_improvement_full <- ((metrics_ci_ratoon$RMSE - metrics_ci_ratoon_full$RMSE) / metrics_ci_ratoon$RMSE) * 100 r2_improvement_full <- ((metrics_ci_ratoon_full$R2 - metrics_ci_ratoon$R2) / metrics_ci_ratoon$R2) * 100 cat(sprintf("\nAdding Irrigation + Variety to CI + Ratoon model:\n")) cat(sprintf(" - RMSE improvement: %.1f%% (%.2f → %.2f t/ha)\n", rmse_improvement_full, metrics_ci_ratoon$RMSE, metrics_ci_ratoon_full$RMSE)) cat(sprintf(" - R² improvement: %.1f%% (%.3f → %.3f)\n", r2_improvement_full, metrics_ci_ratoon$R2, metrics_ci_ratoon_full$R2)) } cat("\n=== TRAINING/PREDICTION SUMMARY ===\n") cat(sprintf("Training seasons: %s\n", paste(sort(unique(training_data$season)), collapse = ", "))) cat(sprintf("Training records: %d fields\n", nrow(training_data))) if (!is.null(pred_ci_ratoon_future)) { cat(sprintf("\nPrediction seasons: %s\n", paste(sort(unique(prediction_data$season)), collapse = ", "))) cat(sprintf("Prediction records: %d fields\n", nrow(pred_ci_ratoon_future))) cat("\nTop 5 predicted yields:\n") print(pred_ci_ratoon_future %>% dplyr::select(sub_field, season, predicted_TCH, Ratoon) %>% dplyr::arrange(desc(predicted_TCH)) %>% head(5)) } cat("\n=== OUTPUT FILES ===\n") cat(paste("Comparison plot:", plot_file, "\n")) cat(paste("Feature importance:", imp_file, "\n")) cat(paste("Comparison table:", comparison_file, "\n")) cat(paste("Model objects:", models_file, "\n")) if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) { future_pred_file <- file.path(output_dir, paste0(project_dir, "_future_predictions.csv")) cat(paste("Future predictions:", future_pred_file, "\n")) } cat("\n=== SEED SENSITIVITY ANALYSIS ===\n") cat("Current seed: 206\n") cat("Using same seed ensures:\n") cat(" - Identical fold assignments across runs\n") cat(" - Identical bootstrap samples in random forest\n") cat(" - Reproducible results\n\n") cat("Expected variation with different seeds:\n") cat(" - RMSE: ±1-3 t/ha (typical range)\n") cat(" - R²: ±0.02-0.05 (typical range)\n") cat(" - Feature selection may change slightly\n") cat(" - Predictions will vary but trends remain consistent\n\n") cat("To test seed sensitivity, modify set.seed(206) to different values\n") cat("and re-run the script to compare results.\n") cat("\n=== YIELD PREDICTION COMPARISON COMPLETED ===\n") } # 6. Script execution # ----------------- if (sys.nframe() == 0) { main() }