harvest pred working better now

This commit is contained in:
Timon 2026-01-18 08:22:02 +01:00
parent 7975f8ad06
commit 41bbc370f2
2 changed files with 26 additions and 28 deletions

View file

@ -46,6 +46,7 @@ from pathlib import Path
from harvest_date_pred_utils import ( from harvest_date_pred_utils import (
load_model_and_config, load_model_and_config,
extract_features, extract_features,
run_phase1_growing_window,
run_two_step_refinement, run_two_step_refinement,
build_production_harvest_table build_production_harvest_table
) )

View file

@ -307,45 +307,32 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
For each detected harvest, reset DOY counter for the next season. For each detected harvest, reset DOY counter for the next season.
This allows the model to detect multiple consecutive harvests in multi-year data. This allows the model to detect multiple consecutive harvests in multi-year data.
Algorithm:
1. Start with season_anchor_day = 0 (DOY 1 at day 0)
2. Expand window: [0:1], [0:2], [0:3], ... until threshold crossed
3. When harvest detected: record date, set new season_anchor = day after harvest
4. Continue from next season start
Args:
threshold (float): Probability threshold (default 0.45)
consecutive_days (int): Required consecutive days above threshold (default 2)
Returns list of (harvest_date, harvest_idx) tuples.
""" """
harvest_dates = [] harvest_dates = []
season_anchor_day = 0 # DOY 1 starts at day 0 season_anchor_day = 0
current_pos = 0 current_pos = 0
while current_pos < len(field_data): while current_pos < len(field_data):
consecutive_above_threshold = 0 consecutive_above_threshold = 0
min_window_size = 120 # Need at least 120 days (~4 months) for patterns to establish min_window_size = 120
for window_end in range(current_pos + 1, len(field_data) + 1): for window_end in range(current_pos + 1, len(field_data) + 1):
window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True) window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True)
# Skip if window is too small (model needs long sequences for pattern learning)
if len(window_data) < min_window_size: if len(window_data) < min_window_size:
continue continue
try: try:
# CRITICAL: Pass season_anchor_day so DOY resets after harvest reset_doy = current_pos > season_anchor_day
features = extract_features( features = extract_features(
window_data, window_data,
config['features'], config['features'],
ci_column=ci_column, ci_column=ci_column,
season_anchor_day=season_anchor_day, season_anchor_day=season_anchor_day if reset_doy else None,
lookback_start=current_pos lookback_start=current_pos
) )
# Apply scalers
features_scaled = features.copy().astype(float) features_scaled = features.copy().astype(float)
for fi, scaler in enumerate(scalers): for fi, scaler in enumerate(scalers):
try: try:
@ -353,12 +340,10 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
except Exception: except Exception:
pass pass
# Run model on expanding window
with torch.no_grad(): with torch.no_grad():
x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device) x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device)
imminent_probs, detected_probs = model(x_tensor) imminent_probs, detected_probs = model(x_tensor)
# Check LAST timestep only
last_prob = detected_probs[0, -1].item() last_prob = detected_probs[0, -1].item()
if last_prob > threshold: if last_prob > threshold:
@ -366,23 +351,18 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
else: else:
consecutive_above_threshold = 0 consecutive_above_threshold = 0
# Harvest detected: N consecutive days above threshold
if consecutive_above_threshold >= consecutive_days: if consecutive_above_threshold >= consecutive_days:
harvest_idx = current_pos + window_end - consecutive_days harvest_idx = current_pos + window_end - consecutive_days
harvest_date = field_data.iloc[harvest_idx]['Date'] harvest_date = field_data.iloc[harvest_idx]['Date']
harvest_dates.append((harvest_date, harvest_idx))
# CRITICAL: Reset season anchor for next season harvest_dates.append((harvest_date, harvest_idx))
# DOY 1 starts at day after harvest
season_anchor_day = harvest_idx + 1 season_anchor_day = harvest_idx + 1
current_pos = harvest_idx + 1 current_pos = harvest_idx + 1
break break
except Exception as e: except Exception as e:
# Skip window on error
continue continue
else: else:
# No more harvests found
break break
return harvest_dates return harvest_dates
@ -413,8 +393,21 @@ def run_phase2_refinement(field_data, phase1_harvests, model, config, scalers, c
window_start_date = season_start_date - pd.Timedelta(days=40) window_start_date = season_start_date - pd.Timedelta(days=40)
window_end_date = phase1_harvest_date + pd.Timedelta(days=40) window_end_date = phase1_harvest_date + pd.Timedelta(days=40)
window_start_idx = max(0, (field_data['Date'] >= window_start_date).idxmax() if (field_data['Date'] >= window_start_date).any() else 0) # FIXED: Use proper index selection
window_end_idx = min(len(field_data), (field_data['Date'] <= window_end_date).idxmax() + 1 if (field_data['Date'] <= window_end_date).any() else len(field_data)) mask_start = field_data['Date'] >= window_start_date
mask_end = field_data['Date'] <= window_end_date
if mask_start.any():
window_start_idx = mask_start.idxmax() # First True index
else:
window_start_idx = 0
if mask_end.any():
# Last True index: find where condition becomes False from the right
true_indices = np.where(mask_end)[0]
window_end_idx = true_indices[-1] + 1 # +1 for slicing (exclusive end)
else:
window_end_idx = len(field_data)
if window_end_idx <= window_start_idx: if window_end_idx <= window_start_idx:
refined_harvests.append((phase1_harvest_date, phase1_idx)) refined_harvests.append((phase1_harvest_date, phase1_idx))
@ -525,6 +518,10 @@ def run_two_step_refinement(df: pd.DataFrame, model, config, scalers, device=Non
print() # New line after progress bar print() # New line after progress bar
print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields") print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields")
if results:
print(f" Sample harvest dates: {results[0]['phase2_harvest_date']}")
if len(results) > 1:
print(f" {results[-1]['phase2_harvest_date']}")
return results return results