harvest pred working better now

This commit is contained in:
Timon 2026-01-18 08:22:02 +01:00
parent 7975f8ad06
commit 41bbc370f2
2 changed files with 26 additions and 28 deletions

View file

@ -46,6 +46,7 @@ from pathlib import Path
from harvest_date_pred_utils import (
load_model_and_config,
extract_features,
run_phase1_growing_window,
run_two_step_refinement,
build_production_harvest_table
)

View file

@ -307,45 +307,32 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
For each detected harvest, reset DOY counter for the next season.
This allows the model to detect multiple consecutive harvests in multi-year data.
Algorithm:
1. Start with season_anchor_day = 0 (DOY 1 at day 0)
2. Expand window: [0:1], [0:2], [0:3], ... until threshold crossed
3. When harvest detected: record date, set new season_anchor = day after harvest
4. Continue from next season start
Args:
threshold (float): Probability threshold (default 0.45)
consecutive_days (int): Required consecutive days above threshold (default 2)
Returns list of (harvest_date, harvest_idx) tuples.
"""
harvest_dates = []
season_anchor_day = 0 # DOY 1 starts at day 0
season_anchor_day = 0
current_pos = 0
while current_pos < len(field_data):
consecutive_above_threshold = 0
min_window_size = 120 # Need at least 120 days (~4 months) for patterns to establish
min_window_size = 120
for window_end in range(current_pos + 1, len(field_data) + 1):
window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True)
# Skip if window is too small (model needs long sequences for pattern learning)
if len(window_data) < min_window_size:
continue
try:
# CRITICAL: Pass season_anchor_day so DOY resets after harvest
reset_doy = current_pos > season_anchor_day
features = extract_features(
window_data,
config['features'],
ci_column=ci_column,
season_anchor_day=season_anchor_day,
season_anchor_day=season_anchor_day if reset_doy else None,
lookback_start=current_pos
)
# Apply scalers
features_scaled = features.copy().astype(float)
for fi, scaler in enumerate(scalers):
try:
@ -353,12 +340,10 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
except Exception:
pass
# Run model on expanding window
with torch.no_grad():
x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device)
imminent_probs, detected_probs = model(x_tensor)
# Check LAST timestep only
last_prob = detected_probs[0, -1].item()
if last_prob > threshold:
@ -366,23 +351,18 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
else:
consecutive_above_threshold = 0
# Harvest detected: N consecutive days above threshold
if consecutive_above_threshold >= consecutive_days:
harvest_idx = current_pos + window_end - consecutive_days
harvest_date = field_data.iloc[harvest_idx]['Date']
harvest_dates.append((harvest_date, harvest_idx))
# CRITICAL: Reset season anchor for next season
# DOY 1 starts at day after harvest
harvest_dates.append((harvest_date, harvest_idx))
season_anchor_day = harvest_idx + 1
current_pos = harvest_idx + 1
break
except Exception as e:
# Skip window on error
continue
else:
# No more harvests found
break
return harvest_dates
@ -413,8 +393,21 @@ def run_phase2_refinement(field_data, phase1_harvests, model, config, scalers, c
window_start_date = season_start_date - pd.Timedelta(days=40)
window_end_date = phase1_harvest_date + pd.Timedelta(days=40)
window_start_idx = max(0, (field_data['Date'] >= window_start_date).idxmax() if (field_data['Date'] >= window_start_date).any() else 0)
window_end_idx = min(len(field_data), (field_data['Date'] <= window_end_date).idxmax() + 1 if (field_data['Date'] <= window_end_date).any() else len(field_data))
# FIXED: Use proper index selection
mask_start = field_data['Date'] >= window_start_date
mask_end = field_data['Date'] <= window_end_date
if mask_start.any():
window_start_idx = mask_start.idxmax() # First True index
else:
window_start_idx = 0
if mask_end.any():
# Last True index: find where condition becomes False from the right
true_indices = np.where(mask_end)[0]
window_end_idx = true_indices[-1] + 1 # +1 for slicing (exclusive end)
else:
window_end_idx = len(field_data)
if window_end_idx <= window_start_idx:
refined_harvests.append((phase1_harvest_date, phase1_idx))
@ -525,6 +518,10 @@ def run_two_step_refinement(df: pd.DataFrame, model, config, scalers, device=Non
print() # New line after progress bar
print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields")
if results:
print(f" Sample harvest dates: {results[0]['phase2_harvest_date']}")
if len(results) > 1:
print(f" {results[-1]['phase2_harvest_date']}")
return results