SmartCane/python_app/22_harvest_baseline_prediction.py

137 lines
6 KiB
Python

"""
Script: 01_harvest_baseline_prediction.py
Purpose: BASELINE PREDICTION - Run ONCE to establish harvest date baseline for all fields and seasons
This script processes COMPLETE historical CI data (all available dates) and uses Model 307
to predict ALL harvest dates across the entire dataset. This becomes your reference baseline
for monitoring and comparisons going forward.
RUN FREQUENCY: Once during initial setup
INPUT: ci_data_for_python.csv (complete historical CI data from 02b_convert_rds_to_csv.R)
Location: laravel_app/storage/app/{project}/Data/extracted_ci/ci_data_for_python/ci_data_for_python.csv
OUTPUT: harvest_production_export.xlsx (baseline harvest predictions for all fields/seasons)
Workflow:
1. Load ci_data_for_python.csv (daily interpolated, all historical dates)
2. Group data by field and season (Model 307 detects season boundaries internally)
3. Run two-step harvest detection (Phase 1: fast detection, Phase 2: ±40 day refinement)
4. Export harvest_production_export.xlsx with columns:
- field, sub_field, season, year, season_start_date, season_end_date, phase1_harvest_date
Two-Step Detection Algorithm:
Phase 1 (Growing Window): Expands daily, checks when detected_prob > 0.5 for 3 consecutive days
Phase 2 (Refinement): Extracts ±40 day window, finds peak harvest signal with argmax
This is your GROUND TRUTH - compare all future predictions against this baseline.
Usage:
python 01_harvest_baseline_prediction.py [project_name]
conda activate pytorch_gpu
cd python_app
python 22_harvest_baseline_prediction.py angata
Examples:
python 01_harvest_baseline_prediction.py angata
python 01_harvest_baseline_prediction.py esa
python 01_harvest_baseline_prediction.py chemba
If no project specified, defaults to 'angata'
"""
import pandas as pd
import numpy as np
import torch
import sys
from pathlib import Path
from harvest_date_pred_utils import (
load_model_and_config,
extract_features,
run_phase1_growing_window,
run_two_step_refinement,
build_production_harvest_table
)
def main():
# Get project name from command line or use default
project_name = sys.argv[1] if len(sys.argv) > 1 else "angata"
field_filter = sys.argv[2] if len(sys.argv) > 2 else None # Optional: test single field
# Construct paths
base_storage = Path("../laravel_app/storage/app") / project_name / "Data"
ci_data_dir = base_storage / "extracted_ci" / "ci_data_for_python"
CI_DATA_FILE = ci_data_dir / "ci_data_for_python.csv"
harvest_data_dir = base_storage / "HarvestData"
harvest_data_dir.mkdir(parents=True, exist_ok=True) # Create if doesn't exist
OUTPUT_XLSX = harvest_data_dir / "harvest_production_export.xlsx"
MODEL_DIR = Path(".") # Model files in python_app/
# Check if input exists
if not CI_DATA_FILE.exists():
print(f"ERROR: {CI_DATA_FILE} not found")
print(f" Expected at: {CI_DATA_FILE.resolve()}")
print(f"\n Run 02b_convert_rds_to_csv.R first to generate this file:")
print(f" Rscript r_app/02b_convert_ci_rds_to_csv.R {project_name}")
return
print("="*80)
print(f"HARVEST DATE PREDICTION - LSTM MODEL 307 ({project_name})")
if field_filter:
print(f"TEST MODE: Single field ({field_filter})")
print("="*80)
# [1/4] Load model
print("\n[1/4] Loading Model 307...")
model, config, scalers = load_model_and_config(MODEL_DIR)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
# [2/4] Load and prepare CI data
print("\n[2/4] Loading CI data...")
print(f" From: {CI_DATA_FILE}")
ci_data = pd.read_csv(CI_DATA_FILE, dtype={'field': str}) # Force field as string
ci_data['Date'] = pd.to_datetime(ci_data['Date'])
print(f" Loaded {len(ci_data)} daily rows across {ci_data['field'].nunique()} fields")
print(f" Date range: {ci_data['Date'].min().date()} to {ci_data['Date'].max().date()}")
# Optional: Filter to single field for testing
if field_filter:
field_filter = str(field_filter) # Ensure field_filter is string
ci_data_filtered = ci_data[ci_data['field'] == field_filter]
if len(ci_data_filtered) == 0:
print(f"\n✗ ERROR: No data found for field '{field_filter}'")
available_fields = sorted(ci_data['field'].unique())
print(f" Available fields ({len(available_fields)}): {', '.join(available_fields[:10])}")
if len(available_fields) > 10:
print(f" ... and {len(available_fields) - 10} more")
return
ci_data = ci_data_filtered
print(f" ✓ Filtered to single field: {field_filter}")
print(f" Data points: {len(ci_data)} days")
# [3/4] Run model predictions with two-step detection
print("\n[3/4] Running two-step harvest detection...")
print(" (Using threshold=0.3, consecutive_days=2 - tuned baseline with DAH reset)")
refined_results = run_two_step_refinement(ci_data, model, config, scalers, device=device,
phase1_threshold=0.3, phase1_consecutive=2)
# Build and export
print("\nBuilding production harvest table...")
prod_table = build_production_harvest_table(refined_results)
prod_table.to_excel(OUTPUT_XLSX, index=False)
print(f"\n✓ Exported {len(prod_table)} predictions to {OUTPUT_XLSX}")
print(f"\nOutput location: {OUTPUT_XLSX.resolve()}")
print(f"\nStorage structure:")
print(f" Input: laravel_app/storage/app/{project_name}/Data/extracted_ci/ci_data_for_python/")
print(f" Output: laravel_app/storage/app/{project_name}/Data/HarvestData/")
print(f"\nColumn structure:")
print(f" field, sub_field, season, season_start_date, season_end_date, phase2_harvest_date")
print(f"\nNext steps:")
print(f" 1. Review predictions in harvest_production_export.xlsx")
print(f" 2. Run weekly monitoring: python 31_harvest_imminent_weekly.py {project_name}")
if __name__ == "__main__":
main()