harvest pred working better now
This commit is contained in:
parent
7975f8ad06
commit
41bbc370f2
|
|
@ -46,6 +46,7 @@ from pathlib import Path
|
|||
from harvest_date_pred_utils import (
|
||||
load_model_and_config,
|
||||
extract_features,
|
||||
run_phase1_growing_window,
|
||||
run_two_step_refinement,
|
||||
build_production_harvest_table
|
||||
)
|
||||
|
|
|
|||
|
|
@ -307,45 +307,32 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
|
|||
|
||||
For each detected harvest, reset DOY counter for the next season.
|
||||
This allows the model to detect multiple consecutive harvests in multi-year data.
|
||||
|
||||
Algorithm:
|
||||
1. Start with season_anchor_day = 0 (DOY 1 at day 0)
|
||||
2. Expand window: [0:1], [0:2], [0:3], ... until threshold crossed
|
||||
3. When harvest detected: record date, set new season_anchor = day after harvest
|
||||
4. Continue from next season start
|
||||
|
||||
Args:
|
||||
threshold (float): Probability threshold (default 0.45)
|
||||
consecutive_days (int): Required consecutive days above threshold (default 2)
|
||||
|
||||
Returns list of (harvest_date, harvest_idx) tuples.
|
||||
"""
|
||||
harvest_dates = []
|
||||
season_anchor_day = 0 # DOY 1 starts at day 0
|
||||
season_anchor_day = 0
|
||||
current_pos = 0
|
||||
|
||||
while current_pos < len(field_data):
|
||||
consecutive_above_threshold = 0
|
||||
min_window_size = 120 # Need at least 120 days (~4 months) for patterns to establish
|
||||
min_window_size = 120
|
||||
|
||||
for window_end in range(current_pos + 1, len(field_data) + 1):
|
||||
window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True)
|
||||
|
||||
# Skip if window is too small (model needs long sequences for pattern learning)
|
||||
if len(window_data) < min_window_size:
|
||||
continue
|
||||
|
||||
try:
|
||||
# CRITICAL: Pass season_anchor_day so DOY resets after harvest
|
||||
reset_doy = current_pos > season_anchor_day
|
||||
|
||||
features = extract_features(
|
||||
window_data,
|
||||
config['features'],
|
||||
ci_column=ci_column,
|
||||
season_anchor_day=season_anchor_day,
|
||||
season_anchor_day=season_anchor_day if reset_doy else None,
|
||||
lookback_start=current_pos
|
||||
)
|
||||
|
||||
# Apply scalers
|
||||
features_scaled = features.copy().astype(float)
|
||||
for fi, scaler in enumerate(scalers):
|
||||
try:
|
||||
|
|
@ -353,12 +340,10 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
|
|||
except Exception:
|
||||
pass
|
||||
|
||||
# Run model on expanding window
|
||||
with torch.no_grad():
|
||||
x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device)
|
||||
imminent_probs, detected_probs = model(x_tensor)
|
||||
|
||||
# Check LAST timestep only
|
||||
last_prob = detected_probs[0, -1].item()
|
||||
|
||||
if last_prob > threshold:
|
||||
|
|
@ -366,23 +351,18 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
|
|||
else:
|
||||
consecutive_above_threshold = 0
|
||||
|
||||
# Harvest detected: N consecutive days above threshold
|
||||
if consecutive_above_threshold >= consecutive_days:
|
||||
harvest_idx = current_pos + window_end - consecutive_days
|
||||
harvest_date = field_data.iloc[harvest_idx]['Date']
|
||||
harvest_dates.append((harvest_date, harvest_idx))
|
||||
|
||||
# CRITICAL: Reset season anchor for next season
|
||||
# DOY 1 starts at day after harvest
|
||||
harvest_dates.append((harvest_date, harvest_idx))
|
||||
season_anchor_day = harvest_idx + 1
|
||||
current_pos = harvest_idx + 1
|
||||
break
|
||||
|
||||
except Exception as e:
|
||||
# Skip window on error
|
||||
continue
|
||||
else:
|
||||
# No more harvests found
|
||||
break
|
||||
|
||||
return harvest_dates
|
||||
|
|
@ -413,8 +393,21 @@ def run_phase2_refinement(field_data, phase1_harvests, model, config, scalers, c
|
|||
window_start_date = season_start_date - pd.Timedelta(days=40)
|
||||
window_end_date = phase1_harvest_date + pd.Timedelta(days=40)
|
||||
|
||||
window_start_idx = max(0, (field_data['Date'] >= window_start_date).idxmax() if (field_data['Date'] >= window_start_date).any() else 0)
|
||||
window_end_idx = min(len(field_data), (field_data['Date'] <= window_end_date).idxmax() + 1 if (field_data['Date'] <= window_end_date).any() else len(field_data))
|
||||
# FIXED: Use proper index selection
|
||||
mask_start = field_data['Date'] >= window_start_date
|
||||
mask_end = field_data['Date'] <= window_end_date
|
||||
|
||||
if mask_start.any():
|
||||
window_start_idx = mask_start.idxmax() # First True index
|
||||
else:
|
||||
window_start_idx = 0
|
||||
|
||||
if mask_end.any():
|
||||
# Last True index: find where condition becomes False from the right
|
||||
true_indices = np.where(mask_end)[0]
|
||||
window_end_idx = true_indices[-1] + 1 # +1 for slicing (exclusive end)
|
||||
else:
|
||||
window_end_idx = len(field_data)
|
||||
|
||||
if window_end_idx <= window_start_idx:
|
||||
refined_harvests.append((phase1_harvest_date, phase1_idx))
|
||||
|
|
@ -525,6 +518,10 @@ def run_two_step_refinement(df: pd.DataFrame, model, config, scalers, device=Non
|
|||
|
||||
print() # New line after progress bar
|
||||
print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields")
|
||||
if results:
|
||||
print(f" Sample harvest dates: {results[0]['phase2_harvest_date']}")
|
||||
if len(results) > 1:
|
||||
print(f" {results[-1]['phase2_harvest_date']}")
|
||||
|
||||
return results
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue