SmartCane/python_app/harvest_detection_experiments/tests/test_model_inference.py
Timon fabbf3214d Enhance harvest detection logic and testing framework
- Updated `detect_mosaic_mode` function to check for grid-size subdirectories in addition to tile-named files.
- Added comprehensive tests for DOY reset logic in `test_doy_logic.py`.
- Implemented feature extraction tests in `test_feature_extraction.py`.
- Created tests for growing window method in `test_growing_window_only.py`.
- Developed a complete model inference test in `test_model_inference.py`.
- Added a debug script for testing two-step refinement logic in `test_script22_debug.py`.
2026-01-15 14:30:54 +01:00

124 lines
4.2 KiB
Python

#!/usr/bin/env python3
"""
Complete test: Feature extraction + Model inference + Phase 1 detection
"""
import sys
import pandas as pd
import numpy as np
import torch
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from harvest_date_pred_utils import (
load_model_and_config,
extract_features,
run_phase1_growing_window
)
project_name = "angata"
base_storage = Path("../laravel_app/storage/app") / project_name / "Data"
CI_DATA_FILE = base_storage / "extracted_ci" / "ci_data_for_python" / "ci_data_for_python.csv"
print("="*80)
print("DEBUG: Model Inference + Phase 1 Detection")
print("="*80)
# Load model
print("\n[1] Loading model...")
model, config, scalers = load_model_and_config(Path("."))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f" Device: {device}")
print(f" Scalers type: {type(scalers)}")
print(f" Number of scalers: {len(scalers) if isinstance(scalers, list) else 'N/A (dict/object)'}")
# Load CI data
print("\n[2] Loading CI data...")
ci_data = pd.read_csv(CI_DATA_FILE, dtype={'field': str})
ci_data['Date'] = pd.to_datetime(ci_data['Date'])
# Test on a known field (field 1)
test_field = "1"
field_data = ci_data[ci_data['field'] == test_field].sort_values('Date').reset_index(drop=True)
print(f"\n[3] Testing on field {test_field}...")
print(f" Data points: {len(field_data)}")
# Test with first 100 days
subset_100 = field_data.iloc[:100].copy().reset_index(drop=True)
print(f"\n[4] Testing model inference on first 100 days...")
try:
features = extract_features(subset_100, config['features'], ci_column='value')
print(f" Features shape: {features.shape}")
print(f" Features dtype: {features.dtype}")
# Apply scalers
features_scaled = features.copy().astype(float)
print(f" Applying {len(scalers)} scalers...")
for fi, scaler in enumerate(scalers):
try:
col_data = features[:, fi].reshape(-1, 1)
scaled_col = scaler.transform(col_data)
features_scaled[:, fi] = scaled_col.flatten()
if fi < 3: # Show first 3 scalers
print(f" Scaler {fi}: transformed {features[0, fi]:.4f}{features_scaled[0, fi]:.4f}")
except Exception as e:
print(f" ERROR in scaler {fi}: {e}")
raise
# Run model
print(f"\n Running model inference...")
x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device)
print(f" Tensor shape: {x_tensor.shape}")
with torch.no_grad():
imminent_probs, detected_probs = model(x_tensor)
print(f" Imminent probs shape: {imminent_probs.shape}")
print(f" Detected probs shape: {detected_probs.shape}")
print(f" Detected probs dtype: {detected_probs.dtype}")
# Analyze detected probs
detected_np = detected_probs[0].cpu().numpy() # Get first (only) batch
print(f"\n Detected head analysis:")
print(f" Min: {detected_np.min():.4f}")
print(f" Max: {detected_np.max():.4f}")
print(f" Mean: {detected_np.mean():.4f}")
print(f" Median: {np.median(detected_np):.4f}")
print(f" > 0.1: {(detected_np > 0.1).sum()} days")
print(f" > 0.3: {(detected_np > 0.3).sum()} days")
print(f" > 0.5: {(detected_np > 0.5).sum()} days")
# Show top 5 peaks
top_indices = np.argsort(detected_np)[-5:][::-1]
print(f"\n Top 5 detected peaks:")
for idx in top_indices:
date = subset_100.iloc[idx]['Date'].date()
prob = detected_np[idx]
print(f" Day {idx} ({date}): {prob:.4f}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
# Test Phase 1 growing window
print(f"\n[5] Testing Phase 1 growing window (threshold=0.3, consecutive=2)...")
try:
phase1_results = run_phase1_growing_window(
subset_100, model, config, scalers, 'value', device,
threshold=0.3, consecutive_days=2
)
print(f" ✓ Phase 1 found {len(phase1_results)} harvest(s):")
for harvest_date, harvest_idx in phase1_results:
print(f" {harvest_date.date()}: index {harvest_idx}")
except Exception as e:
print(f" ERROR: {e}")
import traceback
traceback.print_exc()
print("\n✓ Model inference test complete")