harvest pred working better now
This commit is contained in:
parent
7975f8ad06
commit
41bbc370f2
|
|
@ -46,6 +46,7 @@ from pathlib import Path
|
||||||
from harvest_date_pred_utils import (
|
from harvest_date_pred_utils import (
|
||||||
load_model_and_config,
|
load_model_and_config,
|
||||||
extract_features,
|
extract_features,
|
||||||
|
run_phase1_growing_window,
|
||||||
run_two_step_refinement,
|
run_two_step_refinement,
|
||||||
build_production_harvest_table
|
build_production_harvest_table
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -307,45 +307,32 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
|
||||||
|
|
||||||
For each detected harvest, reset DOY counter for the next season.
|
For each detected harvest, reset DOY counter for the next season.
|
||||||
This allows the model to detect multiple consecutive harvests in multi-year data.
|
This allows the model to detect multiple consecutive harvests in multi-year data.
|
||||||
|
|
||||||
Algorithm:
|
|
||||||
1. Start with season_anchor_day = 0 (DOY 1 at day 0)
|
|
||||||
2. Expand window: [0:1], [0:2], [0:3], ... until threshold crossed
|
|
||||||
3. When harvest detected: record date, set new season_anchor = day after harvest
|
|
||||||
4. Continue from next season start
|
|
||||||
|
|
||||||
Args:
|
|
||||||
threshold (float): Probability threshold (default 0.45)
|
|
||||||
consecutive_days (int): Required consecutive days above threshold (default 2)
|
|
||||||
|
|
||||||
Returns list of (harvest_date, harvest_idx) tuples.
|
|
||||||
"""
|
"""
|
||||||
harvest_dates = []
|
harvest_dates = []
|
||||||
season_anchor_day = 0 # DOY 1 starts at day 0
|
season_anchor_day = 0
|
||||||
current_pos = 0
|
current_pos = 0
|
||||||
|
|
||||||
while current_pos < len(field_data):
|
while current_pos < len(field_data):
|
||||||
consecutive_above_threshold = 0
|
consecutive_above_threshold = 0
|
||||||
min_window_size = 120 # Need at least 120 days (~4 months) for patterns to establish
|
min_window_size = 120
|
||||||
|
|
||||||
for window_end in range(current_pos + 1, len(field_data) + 1):
|
for window_end in range(current_pos + 1, len(field_data) + 1):
|
||||||
window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True)
|
window_data = field_data.iloc[current_pos:window_end].copy().reset_index(drop=True)
|
||||||
|
|
||||||
# Skip if window is too small (model needs long sequences for pattern learning)
|
|
||||||
if len(window_data) < min_window_size:
|
if len(window_data) < min_window_size:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# CRITICAL: Pass season_anchor_day so DOY resets after harvest
|
reset_doy = current_pos > season_anchor_day
|
||||||
|
|
||||||
features = extract_features(
|
features = extract_features(
|
||||||
window_data,
|
window_data,
|
||||||
config['features'],
|
config['features'],
|
||||||
ci_column=ci_column,
|
ci_column=ci_column,
|
||||||
season_anchor_day=season_anchor_day,
|
season_anchor_day=season_anchor_day if reset_doy else None,
|
||||||
lookback_start=current_pos
|
lookback_start=current_pos
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply scalers
|
|
||||||
features_scaled = features.copy().astype(float)
|
features_scaled = features.copy().astype(float)
|
||||||
for fi, scaler in enumerate(scalers):
|
for fi, scaler in enumerate(scalers):
|
||||||
try:
|
try:
|
||||||
|
|
@ -353,12 +340,10 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Run model on expanding window
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device)
|
x_tensor = torch.tensor(features_scaled, dtype=torch.float32).unsqueeze(0).to(device)
|
||||||
imminent_probs, detected_probs = model(x_tensor)
|
imminent_probs, detected_probs = model(x_tensor)
|
||||||
|
|
||||||
# Check LAST timestep only
|
|
||||||
last_prob = detected_probs[0, -1].item()
|
last_prob = detected_probs[0, -1].item()
|
||||||
|
|
||||||
if last_prob > threshold:
|
if last_prob > threshold:
|
||||||
|
|
@ -366,23 +351,18 @@ def run_phase1_growing_window(field_data, model, config, scalers, ci_column, dev
|
||||||
else:
|
else:
|
||||||
consecutive_above_threshold = 0
|
consecutive_above_threshold = 0
|
||||||
|
|
||||||
# Harvest detected: N consecutive days above threshold
|
|
||||||
if consecutive_above_threshold >= consecutive_days:
|
if consecutive_above_threshold >= consecutive_days:
|
||||||
harvest_idx = current_pos + window_end - consecutive_days
|
harvest_idx = current_pos + window_end - consecutive_days
|
||||||
harvest_date = field_data.iloc[harvest_idx]['Date']
|
harvest_date = field_data.iloc[harvest_idx]['Date']
|
||||||
harvest_dates.append((harvest_date, harvest_idx))
|
|
||||||
|
|
||||||
# CRITICAL: Reset season anchor for next season
|
harvest_dates.append((harvest_date, harvest_idx))
|
||||||
# DOY 1 starts at day after harvest
|
|
||||||
season_anchor_day = harvest_idx + 1
|
season_anchor_day = harvest_idx + 1
|
||||||
current_pos = harvest_idx + 1
|
current_pos = harvest_idx + 1
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Skip window on error
|
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
# No more harvests found
|
|
||||||
break
|
break
|
||||||
|
|
||||||
return harvest_dates
|
return harvest_dates
|
||||||
|
|
@ -413,8 +393,21 @@ def run_phase2_refinement(field_data, phase1_harvests, model, config, scalers, c
|
||||||
window_start_date = season_start_date - pd.Timedelta(days=40)
|
window_start_date = season_start_date - pd.Timedelta(days=40)
|
||||||
window_end_date = phase1_harvest_date + pd.Timedelta(days=40)
|
window_end_date = phase1_harvest_date + pd.Timedelta(days=40)
|
||||||
|
|
||||||
window_start_idx = max(0, (field_data['Date'] >= window_start_date).idxmax() if (field_data['Date'] >= window_start_date).any() else 0)
|
# FIXED: Use proper index selection
|
||||||
window_end_idx = min(len(field_data), (field_data['Date'] <= window_end_date).idxmax() + 1 if (field_data['Date'] <= window_end_date).any() else len(field_data))
|
mask_start = field_data['Date'] >= window_start_date
|
||||||
|
mask_end = field_data['Date'] <= window_end_date
|
||||||
|
|
||||||
|
if mask_start.any():
|
||||||
|
window_start_idx = mask_start.idxmax() # First True index
|
||||||
|
else:
|
||||||
|
window_start_idx = 0
|
||||||
|
|
||||||
|
if mask_end.any():
|
||||||
|
# Last True index: find where condition becomes False from the right
|
||||||
|
true_indices = np.where(mask_end)[0]
|
||||||
|
window_end_idx = true_indices[-1] + 1 # +1 for slicing (exclusive end)
|
||||||
|
else:
|
||||||
|
window_end_idx = len(field_data)
|
||||||
|
|
||||||
if window_end_idx <= window_start_idx:
|
if window_end_idx <= window_start_idx:
|
||||||
refined_harvests.append((phase1_harvest_date, phase1_idx))
|
refined_harvests.append((phase1_harvest_date, phase1_idx))
|
||||||
|
|
@ -525,6 +518,10 @@ def run_two_step_refinement(df: pd.DataFrame, model, config, scalers, device=Non
|
||||||
|
|
||||||
print() # New line after progress bar
|
print() # New line after progress bar
|
||||||
print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields")
|
print(f" ✓ Complete: Found {harvests_found} harvest events across {total_fields} fields")
|
||||||
|
if results:
|
||||||
|
print(f" Sample harvest dates: {results[0]['phase2_harvest_date']}")
|
||||||
|
if len(results) > 1:
|
||||||
|
print(f" {results[-1]['phase2_harvest_date']}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue