From 41bbc370f20a23795b8fc7f00897f4731ffab00e Mon Sep 17 00:00:00 2001 From: Timon Date: Sun, 18 Jan 2026 08:22:02 +0100 Subject: [PATCH] harvest pred working better now --- python_app/22_harvest_baseline_prediction.py | 1 + python_app/harvest_date_pred_utils.py | 53 +++++++++----------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/python_app/22_harvest_baseline_prediction.py b/python_app/22_harvest_baseline_prediction.py index f39dca6..4184608 100644 --- a/python_app/22_harvest_baseline_prediction.py +++ b/python_app/22_harvest_baseline_prediction.py @@ -46,6 +46,7 @@ from pathlib import Path from harvest_date_pred_utils import ( load_model_and_config, extract_features, + run_phase1_growing_window, run_two_step_refinement, build_production_harvest_table ) diff --git a/python_app/harvest_date_pred_utils.py b/python_app/harvest_date_pred_utils.py index aa4199c..012a9f2 100644 --- a/python_app/harvest_date_pred_utils.py +++ b/python_app/harvest_date_pred_utils.py @@ -307,45 +307,32 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev For each detected harvest, reset DOY counter for the next season. This allows the model to detect multiple consecutive harvests in multi-year data. - - Algorithm: - 1. Start with season_anchor_day = 0 (DOY 1 at day 0) - 2. Expand window: [0:1], [0:2], [0:3], ... until threshold crossed - 3. When harvest detected: record date, set new season_anchor = day after harvest - 4. Continue from next season start - - Args: - threshold (float): Probability threshold (default 0.45) - consecutive_days (int): Required consecutive days above threshold (default 2) - - Returns list of (harvest_date, harvest_idx) tuples. """ harvest_dates = [] - season_anchor_day = 0 # DOY 1 starts at day 0 + season_anchor_day = 0 current_pos = 0 while current_pos < len(field_data): consecutive_above_threshold = 0 - min_window_size = 120 # Need at least 120 days (~4 months) for patterns to establish + min_window_size = 120 for window_end in range(current_pos + 1, len(field_data) + 1): window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True) - # Skip if window is too small (model needs long sequences for pattern learning) if len(window_data) < min_window_size: continue try: - # CRITICAL: Pass season_anchor_day so DOY resets after harvest + reset_doy = current_pos > season_anchor_day + features = extract_features( window_data, config['features'], ci_column=ci_column, - season_anchor_day=season_anchor_day, + season_anchor_day=season_anchor_day if reset_doy else None, lookback_start=current_pos ) - # Apply scalers features_scaled = features.copy().astype(float) for fi, scaler in enumerate(scalers): try: @@ -353,12 +340,10 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev except Exception: pass - # Run model on expanding window with torch.no_grad(): x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device) imminent_probs, detected_probs = model(x_tensor) - # Check LAST timestep only last_prob = detected_probs[0, -1].item() if last_prob > threshold: @@ -366,23 +351,18 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev else: consecutive_above_threshold = 0 - # Harvest detected: N consecutive days above threshold if consecutive_above_threshold >= consecutive_days: harvest_idx = current_pos + window_end - consecutive_days harvest_date = field_data.iloc[harvest_idx]['Date'] - harvest_dates.append((harvest_date, harvest_idx)) - # CRITICAL: Reset season anchor for next season - # DOY 1 starts at day after harvest + harvest_dates.append((harvest_date, harvest_idx)) season_anchor_day = harvest_idx + 1 current_pos = harvest_idx + 1 break except Exception as e: - # Skip window on error continue else: - # No more harvests found break return harvest_dates @@ -413,8 +393,21 @@ def run_phase2_refinement(field_data, phase1_harvests, model, config, scalers, c window_start_date = season_start_date - pd.Timedelta(days=40) window_end_date = phase1_harvest_date + pd.Timedelta(days=40) - window_start_idx = max(0, (field_data['Date'] >= window_start_date).idxmax() if (field_data['Date'] >= window_start_date).any() else 0) - window_end_idx = min(len(field_data), (field_data['Date'] <= window_end_date).idxmax() + 1 if (field_data['Date'] <= window_end_date).any() else len(field_data)) + # FIXED: Use proper index selection + mask_start = field_data['Date'] >= window_start_date + mask_end = field_data['Date'] <= window_end_date + + if mask_start.any(): + window_start_idx = mask_start.idxmax() # First True index + else: + window_start_idx = 0 + + if mask_end.any(): + # Last True index: find where condition becomes False from the right + true_indices = np.where(mask_end)[0] + window_end_idx = true_indices[-1] + 1 # +1 for slicing (exclusive end) + else: + window_end_idx = len(field_data) if window_end_idx <= window_start_idx: refined_harvests.append((phase1_harvest_date, phase1_idx)) @@ -525,6 +518,10 @@ def run_two_step_refinement(df: pd.DataFrame, model, config, scalers, device=Non print() # New line after progress bar print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields") + if results: + print(f" Sample harvest dates: {results[0]['phase2_harvest_date']}") + if len(results) > 1: + print(f" {results[-1]['phase2_harvest_date']}") return results