SmartCane/r_app/11_yield_prediction_comparison.R
2026-01-06 14:17:37 +01:00

1068 lines
39 KiB
R

# 11_YIELD_PREDICTION_COMPARISON.R
# ==================================
# This script compares yield prediction models with different predictor variables:
# 1. CI-only model (cumulative_CI, DOY, CI_per_day)
# 2. CI + Ratoon model
# 3. CI + Ratoon + Additional variables (irrigation, variety)
#
# Outputs include:
# - Model performance metrics (RMSE, R², MAE)
# - Predicted vs Actual yield scatter plots
# - Feature importance analysis
# - Cross-validation results
#
# Usage: Rscript 11_yield_prediction_comparison.R [project_dir]
# - project_dir: Project directory name (e.g., "esa", "aura")
# 1. Load required libraries
# -------------------------
suppressPackageStartupMessages({
library(here)
library(sf)
library(terra)
library(dplyr)
library(tidyr)
library(lubridate)
library(readr)
library(readxl)
library(caret)
library(CAST)
library(randomForest)
library(ggplot2)
library(gridExtra)
})
# 2. Helper Functions
# -----------------
#' Safe logging function
safe_log <- function(message, level = "INFO") {
timestamp <- format(Sys.time(), "%Y-%m-%d %H:%M:%S")
cat(sprintf("[%s] %s: %s\n", timestamp, level, message))
}
#' Prepare predictions with consistent naming and formatting
prepare_predictions <- function(predictions, newdata) {
# Simple version - just add predictions to the data
result <- newdata %>%
dplyr::mutate(predicted_TCH = round(as.numeric(predictions), 1))
return(result)
}
#' Calculate model performance metrics
calculate_metrics <- function(predicted, actual) {
valid_idx <- !is.na(predicted) & !is.na(actual)
predicted <- predicted[valid_idx]
actual <- actual[valid_idx]
if (length(predicted) == 0) {
return(list(
RMSE = NA,
MAE = NA,
R2 = NA,
n = 0
))
}
rmse <- sqrt(mean((predicted - actual)^2))
mae <- mean(abs(predicted - actual))
r2 <- cor(predicted, actual)^2
return(list(
RMSE = round(rmse, 2),
MAE = round(mae, 2),
R2 = round(r2, 3),
n = length(predicted)
))
}
#' Create predicted vs actual plot
create_prediction_plot <- function(predicted, actual, model_name, metrics) {
plot_data <- data.frame(
Predicted = predicted,
Actual = actual
) %>% filter(!is.na(Predicted) & !is.na(Actual))
# Calculate plot limits to make axes equal
min_val <- min(c(plot_data$Predicted, plot_data$Actual))
max_val <- max(c(plot_data$Predicted, plot_data$Actual))
p <- ggplot(plot_data, aes(x = Actual, y = Predicted)) +
geom_point(alpha = 0.6, size = 3, color = "#2E86AB") +
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "red", linewidth = 1) +
geom_smooth(method = "lm", se = TRUE, color = "#A23B72", fill = "#A23B72", alpha = 0.2) +
coord_fixed(xlim = c(min_val, max_val), ylim = c(min_val, max_val)) +
labs(
title = paste("Yield Prediction:", model_name),
subtitle = sprintf("RMSE: %.2f t/ha | MAE: %.2f t/ha | R²: %.3f | n: %d",
metrics$RMSE, metrics$MAE, metrics$R2, metrics$n),
x = "Actual TCH (t/ha)",
y = "Predicted TCH (t/ha)"
) +
theme_minimal() +
theme(
plot.title = element_text(face = "bold", size = 10),
plot.subtitle = element_text(size = 9, color = "gray40"),
axis.title = element_text(size = 10),
axis.text = element_text(size = 9),
panel.grid.minor = element_blank(),
panel.border = element_rect(color = "gray80", fill = NA, linewidth = 1)
)
return(p)
}
#' Load and prepare yield data from Excel
load_yield_data <- function(excel_path) {
safe_log(paste("Loading yield data from:", excel_path))
if (!file.exists(excel_path)) {
stop(paste("Yield data file not found:", excel_path))
}
yield_data <- readxl::read_excel(excel_path) %>%
dplyr::mutate(
# Extract year from Harvest_Date
year = lubridate::year(Harvest_Date),
# Rename columns for consistency
tonnage_ha = TCH,
# Ensure Ratoon is numeric
Ratoon = as.numeric(Ratoon),
# Create ratoon category (0 = plant cane, 1-2 = young ratoon, 3+ = old ratoon)
Ratoon_Category = dplyr::case_when(
Ratoon == 0 ~ "Plant Cane",
Ratoon <= 2 ~ "Young Ratoon (1-2)",
Ratoon <= 5 ~ "Mid Ratoon (3-5)",
TRUE ~ "Old Ratoon (6+)"
),
# Create irrigation category
Irrigation_Category = dplyr::case_when(
grepl("Pivot|pivot", Irrig_Type) ~ "Center Pivot",
grepl("Drip|drip", Irrig_Type) ~ "Drip",
grepl("Sprinkler|Spl", Irrig_Type) ~ "Sprinkler",
TRUE ~ "Other"
)
) %>%
dplyr::select(
GROWER, Field, `Area (Ha)`, Cane_Variety, Ratoon, Ratoon_Category,
Irrig_Type, Irrigation_Category, Cut_age, Harvest_Date, year,
tonnage_ha, TSH
) %>%
# Rename Field to sub_field for consistency with CI data
dplyr::rename(sub_field = Field)
safe_log(paste("Loaded", nrow(yield_data), "yield records"))
safe_log(paste("Years covered:", paste(unique(yield_data$year), collapse = ", ")))
safe_log(paste("Ratoon range:", min(yield_data$Ratoon), "to", max(yield_data$Ratoon)))
return(yield_data)
}
# 3. Main Function
# --------------
main <- function() {
# Process command line arguments
args <- commandArgs(trailingOnly = TRUE)
# Process project_dir argument
if (length(args) >= 1 && !is.na(args[1])) {
project_dir <- as.character(args[1])
} else {
project_dir <- "esa" # Default project
}
# Make project_dir available globally
assign("project_dir", project_dir, envir = .GlobalEnv)
safe_log("=== YIELD PREDICTION MODEL COMPARISON ===")
safe_log(paste("Project:", project_dir))
# 4. Load project configuration
# ---------------------------
tryCatch({
source(here("r_app", "parameters_project.R"))
}, error = function(e) {
stop("Error loading parameters_project.R: ", e$message)
})
# 5. Load yield data from multi-tab Excel file
# -------------------------------------------
yield_excel_path <- file.path(
"laravel_app", "storage", "app", project_dir, "Data",
paste0(project_dir, "_yield_data.xlsx")
)
safe_log(paste("Loading yield data from:", yield_excel_path))
# Get all sheet names
sheet_names <- readxl::excel_sheets(here(yield_excel_path))
safe_log(paste("Found", length(sheet_names), "sheets:", paste(sheet_names, collapse = ", ")))
# Read all sheets and combine them
yield_data_list <- lapply(sheet_names, function(sheet_name) {
safe_log(paste("Reading sheet:", sheet_name))
# Extract year from sheet name
# Format is typically "YYYY-YY" (e.g., "2023-24" means harvest year 2024)
# Take the SECOND year (harvest year)
year_matches <- regmatches(sheet_name, gregexpr("[0-9]{4}|[0-9]{2}", sheet_name))[[1]]
if (length(year_matches) >= 2) {
# Second number is the harvest year
second_year <- year_matches[2]
# Convert 2-digit to 4-digit year
if (nchar(second_year) == 2) {
year_value <- as.numeric(paste0("20", second_year))
} else {
year_value <- as.numeric(second_year)
}
} else if (length(year_matches) == 1) {
# Only one year found, use it
year_value <- as.numeric(year_matches[1])
} else {
year_value <- NA
}
safe_log(paste(" Sheet:", sheet_name, "-> Year:", year_value))
df <- readxl::read_excel(here(yield_excel_path), sheet = sheet_name) %>%
dplyr::mutate(
sheet_name = sheet_name,
season = ifelse(is.na(year_value),
ifelse("year" %in% names(.), year, NA),
year_value)
)
# Try to standardize column names
if ("Field" %in% names(df) && !"sub_field" %in% names(df)) {
df <- df %>% dplyr::rename(sub_field = Field)
}
return(df)
})
# Combine all sheets
yield_data_full <- dplyr::bind_rows(yield_data_list) %>%
dplyr::filter(!is.na(season)) %>%
dplyr::mutate(
# Ensure Ratoon is numeric
Ratoon = as.numeric(Ratoon),
# Create ratoon category
Ratoon_Category = dplyr::case_when(
Ratoon == 0 ~ "Plant Cane",
Ratoon <= 2 ~ "Young Ratoon (1-2)",
Ratoon <= 5 ~ "Mid Ratoon (3-5)",
TRUE ~ "Old Ratoon (6+)"
),
# Create irrigation category if Irrig_Type exists
Irrigation_Category = if("Irrig_Type" %in% names(.)) {
dplyr::case_when(
grepl("Pivot|pivot", Irrig_Type) ~ "Center Pivot",
grepl("Drip|drip", Irrig_Type) ~ "Drip",
grepl("Sprinkler|Spl", Irrig_Type) ~ "Sprinkler",
TRUE ~ "Other"
)
} else {
NA_character_
},
# Rename tonnage column if needed
tonnage_ha = if("TCH" %in% names(.)) TCH else if("tonnage_ha" %in% names(.)) tonnage_ha else NA_real_
)
safe_log(paste("Loaded", nrow(yield_data_full), "yield records from all sheets"))
safe_log(paste("Years covered:", paste(sort(unique(yield_data_full$season)), collapse = ", ")))
safe_log(paste("Ratoon range:", min(yield_data_full$Ratoon, na.rm = TRUE), "to",
max(yield_data_full$Ratoon, na.rm = TRUE)))
safe_log(paste("Fields with yield data:", length(unique(yield_data_full$sub_field[!is.na(yield_data_full$tonnage_ha)]))))
# 6. Load CI data
# -------------
safe_log("Loading cumulative CI data")
CI_quadrant <- readRDS(here(cumulative_CI_vals_dir, "All_pivots_Cumulative_CI_quadrant_year_v2.rds")) %>%
dplyr::group_by(model) %>%
tidyr::fill(field, sub_field, .direction = "downup") %>%
dplyr::ungroup()
# 7. Merge CI and yield data
# ------------------------
safe_log("Merging CI and yield data")
# Get maximum DOY (end of season) for each field/season combination
CI_summary <- CI_quadrant %>%
dplyr::group_by(sub_field, season) %>%
dplyr::slice(which.max(DOY)) %>%
dplyr::select(field, sub_field, cumulative_CI, DOY, season) %>%
dplyr::mutate(CI_per_day = cumulative_CI / DOY) %>%
dplyr::ungroup()
# 7a. Calculate advanced time series features from CI data
# -------------------------------------------------------
safe_log("Calculating time series-derived features")
CI_features <- CI_quadrant %>%
dplyr::group_by(sub_field, season) %>%
dplyr::arrange(DOY) %>%
dplyr::mutate(
# Calculate daily CI increments
daily_CI_increment = cumulative_CI - dplyr::lag(cumulative_CI, default = 0)
) %>%
dplyr::summarise(
# 1. Growth rate (linear slope of CI over time)
CI_growth_rate = ifelse(n() > 2,
coef(lm(cumulative_CI ~ DOY))[2],
NA_real_),
# 2. Early season CI (first 150 days)
early_season_CI = sum(cumulative_CI[DOY <= 150], na.rm = TRUE),
# 3. Growth consistency (coefficient of variation of daily increments)
growth_consistency_cv = sd(daily_CI_increment, na.rm = TRUE) /
mean(daily_CI_increment[daily_CI_increment > 0], na.rm = TRUE),
# 4. Peak growth rate
peak_CI_per_day = max(daily_CI_increment, na.rm = TRUE),
# 5. Number of stress events (CI drops)
stress_events = sum(daily_CI_increment < 0, na.rm = TRUE),
# 6. Late season CI (last 60 days)
late_season_CI = sum(cumulative_CI[DOY >= max(DOY) - 60], na.rm = TRUE),
.groups = 'drop'
) %>%
# Handle infinite values
dplyr::mutate(
growth_consistency_cv = ifelse(is.infinite(growth_consistency_cv) | is.nan(growth_consistency_cv),
NA_real_,
growth_consistency_cv)
)
# Merge features back into CI_summary
CI_summary <- CI_summary %>%
dplyr::left_join(CI_features, by = c("sub_field", "season"))
safe_log(sprintf("Added %d time series features", ncol(CI_features) - 2))
# 7b. Merge CI and yield data
# -------------------------
safe_log("Merging CI and yield data")
# Join with yield data to get yield, ratoon, and other information
combined_data_all <- CI_summary %>%
dplyr::left_join(
yield_data_full,
by = c("sub_field", "season")
)
# Training data: completed seasons with yield data (mature fields only)
training_data <- combined_data_all %>%
dplyr::filter(
!is.na(tonnage_ha),
!is.na(cumulative_CI),
DOY >= 240 # Only mature fields (>= 8 months)
)
# Prediction data: future/current seasons without yield data (mature fields only)
current_year <- lubridate::year(Sys.Date())
prediction_data <- combined_data_all %>%
dplyr::filter(
is.na(tonnage_ha),
!is.na(cumulative_CI),
!is.na(DOY),
!is.na(CI_per_day),
!is.na(Ratoon), # Ensure Ratoon is not NA for Model 2
DOY >= 240, # Only mature fields
season >= current_year # Current and future seasons
)
safe_log(paste("Training dataset:", nrow(training_data), "records"))
safe_log(paste("Training fields:", length(unique(training_data$sub_field))))
safe_log(paste("Training seasons:", paste(sort(unique(training_data$season)), collapse = ", ")))
safe_log(paste("\nPrediction dataset:", nrow(prediction_data), "records"))
safe_log(paste("Prediction fields:", length(unique(prediction_data$sub_field))))
safe_log(paste("Prediction seasons:", paste(sort(unique(prediction_data$season)), collapse = ", ")))
# Check if we have enough data
if (nrow(training_data) < 10) {
stop("Insufficient training data (need at least 10 records)")
}
# 8. Prepare datasets for modeling
# ------------------------------
safe_log("Preparing datasets for modeling")
# Define predictors for each model
ci_predictors <- c("cumulative_CI", "DOY", "CI_per_day")
ci_ratoon_predictors <- c("cumulative_CI", "DOY", "CI_per_day", "Ratoon")
ci_ratoon_full_predictors <- c("cumulative_CI", "DOY", "CI_per_day", "Ratoon",
"Irrigation_Category", "Cane_Variety")
ci_timeseries_predictors <- c("cumulative_CI", "DOY", "CI_per_day", "Ratoon",
"CI_growth_rate", "early_season_CI", "growth_consistency_cv",
"peak_CI_per_day", "stress_events", "late_season_CI")
response <- "tonnage_ha"
# Configure cross-validation (5-fold CV)
set.seed(206) # For reproducible splits
ctrl <- caret::trainControl(
method = "cv",
number = 5,
savePredictions = TRUE,
verboseIter = TRUE
)
# 9. Train Model 1: CI-only
# -----------------------
safe_log("\n=== MODEL 1: CI PREDICTORS ONLY ===")
set.seed(206)
model_ci <- CAST::ffs(
training_data[, ci_predictors],
training_data[[response]], # Use [[ to extract as vector
method = "rf",
trControl = ctrl,
importance = TRUE,
withinSE = TRUE,
tuneLength = 3,
na.action = na.omit
)
# Get predictions on training data (for validation metrics)
pred_ci_train <- prepare_predictions(
stats::predict(model_ci, newdata = training_data),
training_data
)
# Calculate metrics on training data
metrics_ci <- calculate_metrics(
pred_ci_train$predicted_TCH,
training_data$tonnage_ha
)
safe_log(sprintf("Model 1 - RMSE: %.2f | MAE: %.2f | R²: %.3f",
metrics_ci$RMSE, metrics_ci$MAE, metrics_ci$R2))
# Report fold-level CV results
cv_summary_ci <- model_ci$resample
safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)",
mean(cv_summary_ci$RMSE), sd(cv_summary_ci$RMSE),
min(cv_summary_ci$RMSE), max(cv_summary_ci$RMSE)))
safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)",
mean(cv_summary_ci$Rsquared), sd(cv_summary_ci$Rsquared),
min(cv_summary_ci$Rsquared), max(cv_summary_ci$Rsquared)))
# Predict on future seasons
if (nrow(prediction_data) > 0) {
pred_ci_future <- prepare_predictions(
stats::predict(model_ci, newdata = prediction_data),
prediction_data
)
safe_log(sprintf("Model 1 - Future predictions: %d fields", nrow(pred_ci_future)))
} else {
pred_ci_future <- NULL
safe_log("Model 1 - No future data to predict on", "WARNING")
}
# 10. Train Model 2: CI + Ratoon
# ----------------------------
safe_log("\n=== MODEL 2: CI + RATOON ===")
set.seed(206)
model_ci_ratoon <- CAST::ffs(
training_data[, ci_ratoon_predictors],
training_data[[response]], # Use [[ to extract as vector
method = "rf",
trControl = ctrl,
importance = TRUE,
withinSE = TRUE,
tuneLength = 3,
na.action = na.omit
)
# Get predictions on training data (for validation metrics)
pred_ci_ratoon_train <- prepare_predictions(
stats::predict(model_ci_ratoon, newdata = training_data),
training_data
)
# Calculate metrics on training data
metrics_ci_ratoon <- calculate_metrics(
pred_ci_ratoon_train$predicted_TCH,
training_data$tonnage_ha
)
safe_log(sprintf("Model 2 - RMSE: %.2f | MAE: %.2f | R²: %.3f",
metrics_ci_ratoon$RMSE, metrics_ci_ratoon$MAE, metrics_ci_ratoon$R2))
# Report fold-level CV results
cv_summary_ci_ratoon <- model_ci_ratoon$resample
safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)",
mean(cv_summary_ci_ratoon$RMSE), sd(cv_summary_ci_ratoon$RMSE),
min(cv_summary_ci_ratoon$RMSE), max(cv_summary_ci_ratoon$RMSE)))
safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)",
mean(cv_summary_ci_ratoon$Rsquared), sd(cv_summary_ci_ratoon$Rsquared),
min(cv_summary_ci_ratoon$Rsquared), max(cv_summary_ci_ratoon$Rsquared)))
# Predict on future seasons
if (nrow(prediction_data) > 0) {
pred_ci_ratoon_future <- prepare_predictions(
stats::predict(model_ci_ratoon, newdata = prediction_data),
prediction_data
)
safe_log(sprintf("Model 2 - Future predictions: %d fields", nrow(pred_ci_ratoon_future)))
} else {
pred_ci_ratoon_future <- NULL
safe_log("Model 2 - No future data to predict on", "WARNING")
}
# 11. Train Model 3: CI + Ratoon + Full variables
# ---------------------------------------------
safe_log("\n=== MODEL 3: CI + RATOON + IRRIGATION + VARIETY ===")
set.seed(206)
# Filter out records with missing categorical variables
training_data_full <- training_data %>%
dplyr::filter(
!is.na(Irrigation_Category),
!is.na(Cane_Variety)
)
prediction_data_full <- prediction_data %>%
dplyr::filter(
!is.na(Irrigation_Category),
!is.na(Cane_Variety)
)
if (nrow(training_data_full) >= 10) {
model_ci_ratoon_full <- CAST::ffs(
training_data_full[, ci_ratoon_full_predictors],
training_data_full[[response]], # Use [[ to extract as vector
method = "rf",
trControl = ctrl,
importance = TRUE,
withinSE = TRUE,
tuneLength = 3,
na.action = na.omit
)
# Get predictions on training data (for validation metrics)
pred_ci_ratoon_full_train <- prepare_predictions(
stats::predict(model_ci_ratoon_full, newdata = training_data_full),
training_data_full
)
# Calculate metrics on training data
metrics_ci_ratoon_full <- calculate_metrics(
pred_ci_ratoon_full_train$predicted_TCH,
training_data_full$tonnage_ha
)
safe_log(sprintf("Model 3 - RMSE: %.2f | MAE: %.2f | R²: %.3f",
metrics_ci_ratoon_full$RMSE, metrics_ci_ratoon_full$MAE,
metrics_ci_ratoon_full$R2))
# Report fold-level CV results
cv_summary_full <- model_ci_ratoon_full$resample
safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)",
mean(cv_summary_full$RMSE), sd(cv_summary_full$RMSE),
min(cv_summary_full$RMSE), max(cv_summary_full$RMSE)))
safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)",
mean(cv_summary_full$Rsquared), sd(cv_summary_full$Rsquared),
min(cv_summary_full$Rsquared), max(cv_summary_full$Rsquared)))
# Predict on future seasons
if (nrow(prediction_data_full) > 0) {
pred_ci_ratoon_full_future <- prepare_predictions(
stats::predict(model_ci_ratoon_full, newdata = prediction_data_full),
prediction_data_full
)
safe_log(sprintf("Model 3 - Future predictions: %d fields", nrow(pred_ci_ratoon_full_future)))
} else {
pred_ci_ratoon_full_future <- NULL
safe_log("Model 3 - No future data to predict on", "WARNING")
}
} else {
safe_log("Insufficient data for full model, skipping", "WARNING")
model_ci_ratoon_full <- NULL
metrics_ci_ratoon_full <- NULL
pred_ci_ratoon_full_future <- NULL
}
# 11d. Train Model 4: CI + Ratoon + Time Series Features
# ------------------------------------------------------
safe_log("\n=== MODEL 4: CI + RATOON + TIME SERIES FEATURES ===")
set.seed(206)
# Filter training data to ensure all time series features are present
training_data_ts <- training_data %>%
dplyr::filter(
!is.na(CI_growth_rate),
!is.na(early_season_CI),
!is.na(growth_consistency_cv),
!is.na(peak_CI_per_day),
!is.na(stress_events),
!is.na(late_season_CI)
)
# Filter prediction data similarly
prediction_data_ts <- prediction_data %>%
dplyr::filter(
!is.na(CI_growth_rate),
!is.na(early_season_CI),
!is.na(growth_consistency_cv),
!is.na(peak_CI_per_day),
!is.na(stress_events),
!is.na(late_season_CI)
)
safe_log(sprintf("Model 4 training records: %d", nrow(training_data_ts)))
if (nrow(training_data_ts) >= 10) {
model_ci_timeseries <- CAST::ffs(
training_data_ts[, ci_timeseries_predictors],
training_data_ts[[response]],
method = "rf",
trControl = ctrl,
importance = TRUE,
withinSE = TRUE,
tuneLength = 3,
na.action = na.omit
)
# Get predictions on training data
pred_ci_timeseries_train <- prepare_predictions(
stats::predict(model_ci_timeseries, newdata = training_data_ts),
training_data_ts
)
# Calculate metrics
metrics_ci_timeseries <- calculate_metrics(
pred_ci_timeseries_train$predicted_TCH,
training_data_ts$tonnage_ha
)
safe_log(sprintf("Model 4 - RMSE: %.2f | MAE: %.2f | R²: %.3f",
metrics_ci_timeseries$RMSE, metrics_ci_timeseries$MAE,
metrics_ci_timeseries$R2))
# Report fold-level CV results
cv_summary_ts <- model_ci_timeseries$resample
safe_log(sprintf(" CV Folds - RMSE: %.2f ± %.2f (range: %.2f-%.2f)",
mean(cv_summary_ts$RMSE), sd(cv_summary_ts$RMSE),
min(cv_summary_ts$RMSE), max(cv_summary_ts$RMSE)))
safe_log(sprintf(" CV Folds - R²: %.3f ± %.3f (range: %.3f-%.3f)",
mean(cv_summary_ts$Rsquared), sd(cv_summary_ts$Rsquared),
min(cv_summary_ts$Rsquared), max(cv_summary_ts$Rsquared)))
# Predict on future seasons
if (nrow(prediction_data_ts) > 0) {
pred_ci_timeseries_future <- prepare_predictions(
stats::predict(model_ci_timeseries, newdata = prediction_data_ts),
prediction_data_ts
)
safe_log(sprintf("Model 4 - Future predictions: %d fields", nrow(pred_ci_timeseries_future)))
} else {
pred_ci_timeseries_future <- NULL
safe_log("Model 4 - No future data to predict on", "WARNING")
}
} else {
safe_log("Insufficient data for time series model, skipping", "WARNING")
model_ci_timeseries <- NULL
metrics_ci_timeseries <- NULL
pred_ci_timeseries_future <- NULL
}
# 12. Create comparison plots
# -------------------------
safe_log("\n=== CREATING VISUALIZATION ===")
# Create output directory
output_dir <- file.path(reports_dir, "yield_prediction")
dir.create(output_dir, recursive = TRUE, showWarnings = FALSE)
# Create plots for training/validation
plot_ci <- create_prediction_plot(
pred_ci_train$predicted_TCH,
training_data$tonnage_ha,
"CI Only (Training Data)",
metrics_ci
)
plot_ci_ratoon <- create_prediction_plot(
pred_ci_ratoon_train$predicted_TCH,
training_data$tonnage_ha,
"CI + Ratoon (Training Data)",
metrics_ci_ratoon
)
if (!is.null(model_ci_ratoon_full)) {
# Get actual selected variables for Model 3
model3_vars <- paste(model_ci_ratoon_full$selectedvars, collapse = ", ")
plot_ci_ratoon_full <- create_prediction_plot(
pred_ci_ratoon_full_train$predicted_TCH,
training_data_full$tonnage_ha,
paste0("Model 3: ", model3_vars),
metrics_ci_ratoon_full
)
} else {
plot_ci_ratoon_full <- NULL
}
if (!is.null(model_ci_timeseries)) {
# Get actual selected variables for Model 4
model4_vars <- paste(model_ci_timeseries$selectedvars, collapse = ", ")
plot_ci_timeseries <- create_prediction_plot(
pred_ci_timeseries_train$predicted_TCH,
training_data_ts$tonnage_ha,
paste0("Model 4: ", model4_vars),
metrics_ci_timeseries
)
} else {
plot_ci_timeseries <- NULL
}
# Determine which prediction data to use for table (prioritize Model 4, then 2, then 3)
if (!is.null(pred_ci_timeseries_future) && nrow(pred_ci_timeseries_future) > 0) {
future_preds_for_table <- pred_ci_timeseries_future
table_model_name <- "Model 4 (Time Series)"
} else if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) {
future_preds_for_table <- pred_ci_ratoon_future
table_model_name <- "Model 2 (CI + Ratoon)"
} else if (!is.null(pred_ci_ratoon_full_future) && nrow(pred_ci_ratoon_full_future) > 0) {
future_preds_for_table <- pred_ci_ratoon_full_future
table_model_name <- "Model 3 (Full)"
} else {
future_preds_for_table <- NULL
table_model_name <- NULL
}
# Create prediction table
if (!is.null(future_preds_for_table)) {
pred_table_data <- future_preds_for_table %>%
dplyr::select(sub_field, season, predicted_TCH, Ratoon, DOY) %>%
dplyr::arrange(desc(predicted_TCH)) %>%
head(10)
pred_table <- gridExtra::tableGrob(
pred_table_data,
rows = NULL,
theme = gridExtra::ttheme_default(
core = list(fg_params = list(cex = 0.6)),
colhead = list(fg_params = list(cex = 0.7, fontface = "bold"))
)
)
pred_text <- grid::textGrob(
paste0("Future Yield Predictions - ", table_model_name, "\n",
"(", nrow(future_preds_for_table), " fields in seasons ",
paste(unique(future_preds_for_table$season), collapse = ", "), ")\n",
"Top 10 predicted yields shown"),
x = 0.5, y = 0.98,
just = c("center", "top"),
gp = grid::gpar(fontsize = 8, fontface = "bold")
)
pred_panel <- gridExtra::arrangeGrob(pred_text, pred_table, ncol = 1, heights = c(0.25, 0.75))
} else {
pred_panel <- grid::textGrob(
"No future predictions available\n(No mature fields in current/future seasons)",
gp = grid::gpar(fontsize = 10, col = "gray50")
)
}
# Combine all plots (3x2 grid: 3 rows, 2 columns)
# Row 1: Model 1 and Model 2
# Row 2: Model 3 and Model 4
# Row 3: Feature explanations and Prediction table
# Create feature explanation panel
feature_text <- paste0(
"SELECTED FEATURES BY MODEL (via Forward Feature Selection)\n\n",
"Model 1: ", paste(model_ci$selectedvars, collapse = ", "), "\n",
"Model 2: ", paste(model_ci_ratoon$selectedvars, collapse = ", "), "\n",
if (!is.null(model_ci_ratoon_full)) paste0("Model 3: ", paste(model_ci_ratoon_full$selectedvars, collapse = ", "), "\n") else "",
if (!is.null(model_ci_timeseries)) paste0("Model 4: ", paste(model_ci_timeseries$selectedvars, collapse = ", "), "\n\n") else "\n",
"FEATURE DEFINITIONS:\n",
"• cumulative_CI: Total CI accumulated from planting to harvest\n",
"• DOY: Day of year at harvest (crop age proxy)\n",
"• CI_per_day: Daily average CI (cumulative_CI / DOY)\n",
"• Ratoon: Crop cycle number (0=plant cane, 1+=ratoon)\n",
if (!is.null(model_ci_timeseries)) paste0(
"• CI_growth_rate: Linear slope of CI over time (vigor)\n",
"• growth_consistency_cv: CV of daily CI increments (stability)\n",
"• early_season_CI: CI accumulated in first 150 days\n",
"• peak_CI_per_day: Maximum daily CI increment observed\n",
"• stress_events: Count of negative CI changes\n",
"• late_season_CI: CI in final 60 days before harvest\n"
) else "",
if (!is.null(model_ci_ratoon_full)) paste0(
"• Irrigation_Category: Irrigation system type\n",
"• Cane_Variety: Sugarcane variety planted"
) else ""
)
feature_panel <- grid::textGrob(
feature_text,
x = 0.05, y = 0.95,
just = c("left", "top"),
gp = grid::gpar(fontsize = 7, fontfamily = "mono")
)
if (!is.null(plot_ci_ratoon_full) && !is.null(plot_ci_timeseries)) {
combined_plot <- gridExtra::grid.arrange(
plot_ci, plot_ci_ratoon,
plot_ci_ratoon_full, plot_ci_timeseries,
feature_panel, pred_panel,
ncol = 2,
nrow = 3,
heights = c(1.2, 1.2, 0.9), # Make plots bigger, bottom row smaller
layout_matrix = rbind(c(1, 2), c(3, 4), c(5, 6))
)
} else if (!is.null(plot_ci_ratoon_full)) {
# Only 3 models available
combined_plot <- gridExtra::grid.arrange(
plot_ci, plot_ci_ratoon, plot_ci_ratoon_full, pred_panel,
ncol = 2,
nrow = 2,
top = grid::textGrob(
paste("Yield Prediction Model Comparison -", project_dir,
"\nTraining on", paste(sort(unique(training_data$season)), collapse = ", ")),
gp = grid::gpar(fontsize = 16, fontface = "bold")
)
)
} else {
# Create prediction table for bottom (2-plot layout)
if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) {
pred_table_data <- pred_ci_ratoon_future %>%
dplyr::select(sub_field, season, predicted_TCH, Ratoon, DOY) %>%
dplyr::arrange(desc(predicted_TCH))
pred_table <- gridExtra::tableGrob(
pred_table_data,
rows = NULL,
theme = gridExtra::ttheme_default(
core = list(fg_params = list(cex = 0.7)),
colhead = list(fg_params = list(cex = 0.8, fontface = "bold"))
)
)
pred_text <- grid::textGrob(
paste0("Future Yield Predictions (", nrow(pred_ci_ratoon_future),
" fields in seasons ", paste(unique(pred_ci_ratoon_future$season), collapse = ", "), ")"),
x = 0.5, y = 0.95,
gp = grid::gpar(fontsize = 10, fontface = "bold")
)
pred_panel <- gridExtra::arrangeGrob(pred_text, pred_table, ncol = 1, heights = c(0.1, 0.9))
# Combine two plots + prediction table
combined_plot <- gridExtra::grid.arrange(
plot_ci, plot_ci_ratoon, pred_panel,
layout_matrix = rbind(c(1, 2), c(3, 3)),
top = grid::textGrob(
paste("Yield Prediction Model Comparison -", project_dir,
"\nTraining on", paste(sort(unique(training_data$season)), collapse = ", ")),
gp = grid::gpar(fontsize = 16, fontface = "bold")
)
)
} else {
# Combine two plots only
combined_plot <- gridExtra::grid.arrange(
plot_ci, plot_ci_ratoon,
ncol = 2,
top = grid::textGrob(
paste("Yield Prediction Model Comparison -", project_dir,
"\nTraining on", paste(sort(unique(training_data$season)), collapse = ", ")),
gp = grid::gpar(fontsize = 16, fontface = "bold")
)
)
}
}
# Save plot
plot_file <- file.path(output_dir, paste0(project_dir, "_yield_prediction_comparison.png"))
ggsave(plot_file, combined_plot, width = 16, height = 12, dpi = 300)
safe_log(paste("Comparison plot saved to:", plot_file))
# 12b. Save future predictions to CSV
# ---------------------------------
if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) {
future_pred_file <- file.path(output_dir, paste0(project_dir, "_future_predictions.csv"))
future_pred_export <- pred_ci_ratoon_future %>%
dplyr::select(field, sub_field, season, DOY, predicted_TCH,
cumulative_CI, CI_per_day, Ratoon) %>%
dplyr::arrange(desc(predicted_TCH))
readr::write_csv(future_pred_export, future_pred_file)
safe_log(paste("Future predictions saved to:", future_pred_file))
}
# 13. Create feature importance plot
# --------------------------------
safe_log("Creating feature importance plot")
# Extract variable importance from CI + Ratoon model
var_imp <- caret::varImp(model_ci_ratoon)$importance %>%
tibble::rownames_to_column("Variable") %>%
dplyr::arrange(desc(Overall)) %>%
dplyr::mutate(Variable = factor(Variable, levels = Variable))
imp_plot <- ggplot(var_imp, aes(x = Overall, y = Variable)) +
geom_col(fill = "#2E86AB") +
labs(
title = "Feature Importance: CI + Ratoon Model",
x = "Importance",
y = "Variable"
) +
theme_minimal() +
theme(
plot.title = element_text(face = "bold", size = 14),
axis.title = element_text(size = 12)
)
imp_file <- file.path(output_dir, paste0(project_dir, "_feature_importance.png"))
ggsave(imp_file, imp_plot, width = 8, height = 6, dpi = 300)
safe_log(paste("Feature importance plot saved to:", imp_file))
# 14. Create comparison table
# -------------------------
comparison_table <- data.frame(
Model = c("CI Only", "CI + Ratoon", "CI + Ratoon + Full", "CI + Ratoon + Time Series"),
Predictors = c(
paste(ci_predictors, collapse = ", "),
paste(ci_ratoon_predictors, collapse = ", "),
paste(ci_ratoon_full_predictors, collapse = ", "),
paste(ci_timeseries_predictors, collapse = ", ")
),
RMSE = c(
metrics_ci$RMSE,
metrics_ci_ratoon$RMSE,
ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$RMSE),
ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$RMSE)
),
MAE = c(
metrics_ci$MAE,
metrics_ci_ratoon$MAE,
ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$MAE),
ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$MAE)
),
R2 = c(
metrics_ci$R2,
metrics_ci_ratoon$R2,
ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$R2),
ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$R2)
),
N = c(
metrics_ci$n,
metrics_ci_ratoon$n,
ifelse(is.null(metrics_ci_ratoon_full), NA, metrics_ci_ratoon_full$n),
ifelse(is.null(metrics_ci_timeseries), NA, metrics_ci_timeseries$n)
)
)
# Save comparison table
comparison_file <- file.path(output_dir, paste0(project_dir, "_model_comparison.csv"))
readr::write_csv(comparison_table, comparison_file)
safe_log(paste("Model comparison table saved to:", comparison_file))
# 15. Save model objects
# --------------------
models_file <- file.path(output_dir, paste0(project_dir, "_yield_models.rds"))
saveRDS(list(
model1 = model_ci,
model2 = model_ci_ratoon,
model3 = model_ci_ratoon_full,
model4 = model_ci_timeseries,
metrics_ci = metrics_ci,
metrics_ci_ratoon = metrics_ci_ratoon,
metrics_ci_ratoon_full = metrics_ci_ratoon_full,
metrics_ci_timeseries = metrics_ci_timeseries,
training_predictions_ci = pred_ci_train,
training_predictions_ci_ratoon = pred_ci_ratoon_train,
training_predictions_ci_ratoon_full = if(!is.null(model_ci_ratoon_full)) pred_ci_ratoon_full_train else NULL,
training_predictions_ci_timeseries = if(!is.null(model_ci_timeseries)) pred_ci_timeseries_train else NULL,
future_predictions_ci = pred_ci_future,
future_predictions_ci_ratoon = pred_ci_ratoon_future,
future_predictions_ci_ratoon_full = pred_ci_ratoon_full_future,
future_predictions_ci_timeseries = pred_ci_timeseries_future,
training_data = training_data,
prediction_data = prediction_data
), models_file)
safe_log(paste("Model objects saved to:", models_file))
# 16. Print summary
# ---------------
cat("\n=== YIELD PREDICTION MODEL COMPARISON SUMMARY ===\n")
print(comparison_table)
cat("\n=== IMPROVEMENT ANALYSIS ===\n")
rmse_improvement <- ((metrics_ci$RMSE - metrics_ci_ratoon$RMSE) / metrics_ci$RMSE) * 100
r2_improvement <- ((metrics_ci_ratoon$R2 - metrics_ci$R2) / metrics_ci$R2) * 100
cat(sprintf("Adding Ratoon to CI model:\n"))
cat(sprintf(" - RMSE improvement: %.1f%% (%.2f → %.2f t/ha)\n",
rmse_improvement, metrics_ci$RMSE, metrics_ci_ratoon$RMSE))
cat(sprintf(" - R² improvement: %.1f%% (%.3f → %.3f)\n",
r2_improvement, metrics_ci$R2, metrics_ci_ratoon$R2))
if (!is.null(metrics_ci_ratoon_full)) {
rmse_improvement_full <- ((metrics_ci_ratoon$RMSE - metrics_ci_ratoon_full$RMSE) /
metrics_ci_ratoon$RMSE) * 100
r2_improvement_full <- ((metrics_ci_ratoon_full$R2 - metrics_ci_ratoon$R2) /
metrics_ci_ratoon$R2) * 100
cat(sprintf("\nAdding Irrigation + Variety to CI + Ratoon model:\n"))
cat(sprintf(" - RMSE improvement: %.1f%% (%.2f → %.2f t/ha)\n",
rmse_improvement_full, metrics_ci_ratoon$RMSE, metrics_ci_ratoon_full$RMSE))
cat(sprintf(" - R² improvement: %.1f%% (%.3f → %.3f)\n",
r2_improvement_full, metrics_ci_ratoon$R2, metrics_ci_ratoon_full$R2))
}
cat("\n=== TRAINING/PREDICTION SUMMARY ===\n")
cat(sprintf("Training seasons: %s\n", paste(sort(unique(training_data$season)), collapse = ", ")))
cat(sprintf("Training records: %d fields\n", nrow(training_data)))
if (!is.null(pred_ci_ratoon_future)) {
cat(sprintf("\nPrediction seasons: %s\n", paste(sort(unique(prediction_data$season)), collapse = ", ")))
cat(sprintf("Prediction records: %d fields\n", nrow(pred_ci_ratoon_future)))
cat("\nTop 5 predicted yields:\n")
print(pred_ci_ratoon_future %>%
dplyr::select(sub_field, season, predicted_TCH, Ratoon) %>%
dplyr::arrange(desc(predicted_TCH)) %>%
head(5))
}
cat("\n=== OUTPUT FILES ===\n")
cat(paste("Comparison plot:", plot_file, "\n"))
cat(paste("Feature importance:", imp_file, "\n"))
cat(paste("Comparison table:", comparison_file, "\n"))
cat(paste("Model objects:", models_file, "\n"))
if (!is.null(pred_ci_ratoon_future) && nrow(pred_ci_ratoon_future) > 0) {
future_pred_file <- file.path(output_dir, paste0(project_dir, "_future_predictions.csv"))
cat(paste("Future predictions:", future_pred_file, "\n"))
}
cat("\n=== SEED SENSITIVITY ANALYSIS ===\n")
cat("Current seed: 206\n")
cat("Using same seed ensures:\n")
cat(" - Identical fold assignments across runs\n")
cat(" - Identical bootstrap samples in random forest\n")
cat(" - Reproducible results\n\n")
cat("Expected variation with different seeds:\n")
cat(" - RMSE: ±1-3 t/ha (typical range)\n")
cat(" - R²: ±0.02-0.05 (typical range)\n")
cat(" - Feature selection may change slightly\n")
cat(" - Predictions will vary but trends remain consistent\n\n")
cat("To test seed sensitivity, modify set.seed(206) to different values\n")
cat("and re-run the script to compare results.\n")
cat("\n=== YIELD PREDICTION COMPARISON COMPLETED ===\n")
}
# 6. Script execution
# -----------------
if (sys.nframe() == 0) {
main()
}