SmartCane/python_app/rgb_visualization.py

572 lines
22 KiB
Python

#!/usr/bin/env python
"""
RGB Visualization Tool for Harvest Date Validation
Creates 3x3 temporal grids showing satellite imagery around registered and predicted harvest dates.
Extracts RGB from 8-band Planet scope data and clips to field boundaries from GeoJSON.
Functions:
- load_field_boundaries(): Load field geometries from GeoJSON
- find_closest_tiff(): Find available TIFF file closest to target date
- load_and_clip_tiff_rgb(): Load TIFF, extract RGB, clip to field boundary
- create_temporal_grid(): Create 3x3 grid (4 pre-harvest, 1 near, 2-3 post-harvest)
- generate_rgb_grids(): Main orchestration function
Usage:
from rgb_visualization import generate_rgb_grids
generate_rgb_grids(field_data, field_id, registered_harvest_dates, predicted_harvest_dates, output_dir, tiff_dir, geojson_path)
"""
import json
import numpy as np
import pandas as pd
from pathlib import Path
from datetime import datetime, timedelta
import matplotlib
matplotlib.use('Agg') # Use non-interactive backend to avoid display hangs
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import Normalize
import warnings
warnings.filterwarnings('ignore')
try:
import rasterio
from rasterio.mask import mask
import shapely.geometry as shgeom
except ImportError:
print("Warning: rasterio not available. RGB visualization will be skipped.")
rasterio = None
def load_field_boundaries(geojson_path, field_id):
"""
Load field boundary from GeoJSON file.
Args:
geojson_path (Path): Path to pivot.geojson
field_id (str): Field identifier (e.g., "13973")
Returns:
dict: GeoJSON feature or None if not found
shapely.geometry.Polygon: Field boundary polygon or None
"""
try:
with open(geojson_path) as f:
geojson_data = json.load(f)
# Match field ID in properties
for feature in geojson_data.get('features', []):
props = feature.get('properties', {})
# Try matching on 'field' or 'sub_field'
if str(props.get('field', '')) == str(field_id) or \
str(props.get('sub_field', '')) == str(field_id):
geometry = feature.get('geometry')
if geometry:
geom_type = geometry.get('type', '')
coordinates = geometry.get('coordinates', [])
# Handle MultiPolygon: coordinates[i] = [[[ring coords]], [[inner ring coords]], ...]
if geom_type == 'MultiPolygon':
# Use the first polygon from the multipolygon
if coordinates and len(coordinates) > 0:
coords = coordinates[0][0] # First polygon's exterior ring
polygon = shgeom.Polygon(coords)
return feature, polygon
# Handle Polygon: coordinates = [[[ring coords]], [[inner ring coords]], ...]
elif geom_type == 'Polygon':
if coordinates and len(coordinates) > 0:
coords = coordinates[0] # Exterior ring
polygon = shgeom.Polygon(coords)
return feature, polygon
print(f" ⚠ Field {field_id} not found in GeoJSON")
return None, None
except Exception as e:
print(f" ✗ Error loading GeoJSON: {e}")
return None, None
def find_overlapping_tiles(target_date, tiff_dir, field_boundary, days_window=60):
"""
Find all tile files for target_date (or closest date) that overlap with field_boundary.
Tile files are organized in subdirectories by date: 5x5/YYYY-MM-DD_HH/*.tif
Args:
target_date (pd.Timestamp): Target date to find tiles near
tiff_dir (Path): Directory containing 5x5 date subdirectories
field_boundary (shapely.Polygon): Field boundary for overlap detection
days_window (int): Max days to search before/after target
Returns:
tuple: (list of tile paths, actual_date, days_diff)
list: tile paths that overlap field
pd.Timestamp: actual date of tiles found
int: days difference from target to actual date found
"""
target_date = pd.Timestamp(target_date)
tiff_dir = Path(tiff_dir)
if not tiff_dir.exists():
return [], None, None
# Find all date subdirectories
available_dates = {} # {date: ([tile file paths], actual_dir_name)}
min_size_mb = 12.0 # Empty files are ~11.56 MB
for date_dir in tiff_dir.iterdir():
if not date_dir.is_dir():
continue
try:
# Parse date from directory name (YYYY-MM-DD or YYYY-MM-DD_HH)
dir_name = date_dir.name
# Extract just the date part before underscore if it exists
date_str = dir_name.split('_')[0]
tile_date = pd.Timestamp(date_str)
days_diff = (tile_date - target_date).days
if abs(days_diff) > days_window:
continue
# Find all .tif files in this date directory
tile_files = []
for tile_file in date_dir.glob('*.tif'):
# Skip obviously empty files
file_size_mb = tile_file.stat().st_size / (1024 * 1024)
if file_size_mb >= min_size_mb:
tile_files.append(tile_file)
if tile_files:
available_dates[tile_date] = (tile_files, dir_name)
except:
pass
if not available_dates:
return [], None, None
# Find closest date
closest_date = min(available_dates.keys(), key=lambda d: abs((d - target_date).days))
days_diff = (closest_date - target_date).days
tiles, _ = available_dates[closest_date]
# Filter tiles to only those that overlap field boundary
if rasterio is None or field_boundary is None:
# If rasterio not available, use all tiles (conservative approach)
return tiles, closest_date, days_diff
overlapping_tiles = []
for tile_path in tiles:
try:
with rasterio.open(tile_path) as src:
# Get tile bounds
tile_bounds = src.bounds # (left, bottom, right, top)
tile_geom = shgeom.box(*tile_bounds)
# Check if tile overlaps field
if tile_geom.intersects(field_boundary):
overlapping_tiles.append(tile_path)
except:
pass
if not overlapping_tiles:
# No overlapping tiles found, return all tiles for the closest date
return tiles, closest_date, days_diff
return overlapping_tiles, closest_date, days_diff
def load_and_clip_tiff_rgb(tiff_path, field_boundary, rgb_bands=(1, 2, 3)):
"""
Load TIFF and extract RGB bands clipped to field boundary.
For merged_final_tif files (cloud-masked and filtered):
- Band 1: Red
- Band 2: Green
- Band 3: Blue
- Band 4: NIR
- Band 5: CI
Args:
tiff_path (Path): Path to TIFF file
field_boundary (shapely.Polygon): Field boundary for clipping
rgb_bands (tuple): Band indices for RGB (1-indexed, defaults to 1,2,3 for merged_final_tif)
Returns:
np.ndarray: RGB data (height, width, 3) with values 0-1
or None if error occurs
"""
if rasterio is None or field_boundary is None:
return None
try:
with rasterio.open(tiff_path) as src:
# Check band count
if src.count < 3:
return None
# For merged_final_tif: bands 1,2,3 are R,G,B
bands_to_read = (1, 2, 3)
# Mask and read bands
geom = shgeom.mapping(field_boundary)
try:
masked_data, _ = mask(src, [geom], crop=True, indexes=list(bands_to_read))
# Stack RGB
rgb = np.stack([masked_data[i] for i in range(3)], axis=-1)
# Convert to float32 if not already
rgb = rgb.astype(np.float32)
# Normalize to 0-1 range
# Data appears to be 8-bit (0-255 range) stored as float32
# Check actual max value to determine normalization
max_val = np.nanmax(rgb)
if max_val > 0:
# If max is around 255 or less, assume 8-bit
if max_val <= 255:
rgb = rgb / 255.0
# If max is around 65535, assume 16-bit
elif max_val <= 65535:
rgb = rgb / 65535.0
# Otherwise divide by max to normalize
else:
rgb = rgb / max_val
rgb = np.clip(rgb, 0, 1)
# Check if result is all NaN
if np.all(np.isnan(rgb)):
return None
# Replace any remaining NaN with 0 (cloud/invalid pixels)
rgb = np.nan_to_num(rgb, nan=0.0)
return rgb
except ValueError:
return None
except Exception as e:
return None
def load_and_composite_tiles_rgb(tile_paths, field_boundary):
"""
Load RGB from multiple overlapping tiles and composite them into a single image.
Args:
tile_paths (list[Path]): List of tile file paths
field_boundary (shapely.Polygon): Field boundary for clipping
Returns:
np.ndarray: Composited RGB data (height, width, 3) with values 0-1
or None if error occurs
"""
if rasterio is None or field_boundary is None or not tile_paths:
return None
try:
# Load and composite all tiles
rgb_arrays = []
for tile_path in tile_paths:
rgb = load_and_clip_tiff_rgb(tile_path, field_boundary)
if rgb is not None:
rgb_arrays.append(rgb)
if not rgb_arrays:
return None
# If single tile, return it
if len(rgb_arrays) == 1:
composited = rgb_arrays[0]
else:
# If multiple tiles, use max composite
stacked = np.stack(rgb_arrays, axis=0)
composited = np.max(stacked, axis=0)
composited = composited.astype(np.float32)
# Stretch contrast: normalize to 0-1 range based on actual min/max in the data
# This makes dim images visible
valid_data = composited[composited > 0]
if len(valid_data) > 0:
data_min = np.percentile(valid_data, 2) # 2nd percentile to handle outliers
data_max = np.percentile(valid_data, 98) # 98th percentile
if data_max > data_min:
composited = (composited - data_min) / (data_max - data_min)
composited = np.clip(composited, 0, 1)
return composited.astype(np.float32)
except Exception as e:
return None
def create_temporal_rgb_grid(harvest_date, field_data, field_id, tiff_dir, field_boundary,
title, output_dir, harvest_type='registered', model_name=None, harvest_index=None):
"""
Create 5x3 temporal grid around harvest date (15 images, 7-day intervals).
Layout:
Row 1: T-56d, T-42d, T-35d, T-28d, T-21d (pre-harvest)
Row 2: T-14d, T-7d, T~0d, T+7d, T+14d (near harvest)
Row 3: T+21d, T+28d, T+35d, T+42d, T+56d (post-harvest progression)
Args:
harvest_date (pd.Timestamp): Target harvest date
field_data (pd.DataFrame): Field data with Date column
field_id (str): Field identifier
tiff_dir (Path): Directory with TIFF files
field_boundary (shapely.Polygon): Field boundary
title (str): Plot title
output_dir (Path): Output directory
harvest_type (str): 'registered' or 'predicted'
model_name (str): Model name for predicted harvests (e.g., 'Original', 'Long-Season')
harvest_index (int): Index of harvest within same model (for multiple harvests)
Returns:
Path: Path to saved PNG or None if failed
"""
harvest_date = pd.Timestamp(harvest_date)
# Target dates: 15 images at 7-day intervals (8 pre, 1 near, 6 post)
target_dates = [
harvest_date - timedelta(days=56), # T-56d
harvest_date - timedelta(days=49), # T-49d
harvest_date - timedelta(days=42), # T-42d
harvest_date - timedelta(days=35), # T-35d
harvest_date - timedelta(days=28), # T-28d
harvest_date - timedelta(days=21), # T-21d
harvest_date - timedelta(days=14), # T-14d
harvest_date - timedelta(days=7), # T-7d
harvest_date, # T~0d (near harvest)
harvest_date + timedelta(days=7), # T+7d
harvest_date + timedelta(days=14), # T+14d
harvest_date + timedelta(days=21), # T+21d
harvest_date + timedelta(days=28), # T+28d
harvest_date + timedelta(days=35), # T+35d
harvest_date + timedelta(days=42), # T+42d
harvest_date + timedelta(days=56), # T+56d (Note: non-standard to fill 5th col in row 3)
]
# Find TIFFs for each date
rgb_images = []
days_offsets = []
actual_dates = [] # Store actual dates of TIFFs found
for target in target_dates:
tile_paths, actual_date, days_diff = find_overlapping_tiles(target, tiff_dir, field_boundary, days_window=60)
if not tile_paths or actual_date is None:
rgb_images.append(None)
days_offsets.append(None)
actual_dates.append(None)
print(f" ⚠ No tiles found within 60 days of {target.strftime('%Y-%m-%d')} with sufficient data")
continue
rgb = load_and_composite_tiles_rgb(tile_paths, field_boundary)
rgb_images.append(rgb)
days_offsets.append(days_diff)
actual_dates.append(actual_date)
if rgb is not None:
print(f" ✓ Loaded {len(tile_paths)} tile(s) for {actual_date.strftime('%Y-%m-%d')} ({days_diff:+d}d from target)")
else:
print(f" ⚠ Loaded {len(tile_paths)} tile(s) but RGB data is None")
# Create 5x3 grid plot (15 images)
fig, axes = plt.subplots(3, 5, figsize=(25, 15))
fig.suptitle(f'{title}\nField {field_id} - {harvest_type.upper()} Harvest: {harvest_date.strftime("%Y-%m-%d")}',
fontsize=16, fontweight='bold')
# Grid positions (5 columns, 3 rows = 15 images)
positions = [
('T-56d', 0, 0), ('T-49d', 0, 1), ('T-42d', 0, 2), ('T-35d', 0, 3), ('T-28d', 0, 4),
('T-21d', 1, 0), ('T-14d', 1, 1), ('T-7d', 1, 2), ('T~0d', 1, 3), ('T+7d', 1, 4),
('T+14d', 2, 0), ('T+21d', 2, 1), ('T+28d', 2, 2), ('T+35d', 2, 3), ('T+42d', 2, 4),
]
for idx, (label, row, col) in enumerate(positions): # All 15 images
ax = axes[row, col]
if idx < len(rgb_images) and rgb_images[idx] is not None:
rgb_data = rgb_images[idx]
# Debug: check data range
data_min, data_max = np.nanmin(rgb_data), np.nanmax(rgb_data)
print(f" DEBUG: {label} RGB range: {data_min:.4f} - {data_max:.4f}, shape: {rgb_data.shape}")
# Display with explicit vmin/vmax to handle normalized 0-1 data
ax.imshow(rgb_data, vmin=0, vmax=1)
# Build title: label + offset + actual date
offset_str = f"{days_offsets[idx]:+d}d" if days_offsets[idx] is not None else "?"
date_str = actual_dates[idx].strftime('%Y-%m-%d') if actual_dates[idx] is not None else "No Date"
ax.set_title(f'{label}\n{offset_str}\n{date_str}', fontsize=10, fontweight='bold')
# Add red box around harvest date (T~0d at row=1, col=3)
if label == 'T~0d':
for spine in ax.spines.values():
spine.set_edgecolor('red')
spine.set_linewidth(4)
else:
ax.text(0.5, 0.5, 'No Data', ha='center', va='center', fontsize=12, color='gray')
ax.set_title(label, fontsize=10)
# Add red box for T~0d even if no data
if label == 'T~0d':
for spine in ax.spines.values():
spine.set_edgecolor('red')
spine.set_linewidth(4)
ax.set_xticks([])
ax.set_yticks([])
plt.tight_layout()
# Save figure with detailed naming: field_ID_harvestdate_model_harvestyle.png
harvest_date_str = harvest_date.strftime('%Y%m%d')
if harvest_type == 'registered':
filename = f'field_{field_id}_{harvest_date_str}_registered_harvest_rgb.png'
else:
# For predicted: include model name and harvest index if multiple
if harvest_index is not None and harvest_index > 0:
filename = f'field_{field_id}_{harvest_date_str}_{model_name}_harvest{harvest_index}_rgb.png'
else:
filename = f'field_{field_id}_{harvest_date_str}_{model_name}_harvest_rgb.png'
output_path = Path(output_dir) / filename
try:
plt.savefig(output_path, dpi=100, format='png')
plt.close()
print(f" ✓ Saved: {filename}")
return output_path
except Exception as e:
plt.close()
print(f" ✗ Error saving PNG: {e}")
return None
def generate_rgb_grids(field_data, field_id, registered_harvest_dates, predicted_harvest_dates,
output_dir, tiff_dir, geojson_path):
"""
Main orchestration function for RGB visualization.
Creates 3x3 grids for:
1. Registered harvest dates (if available)
2. Predicted harvest dates (if available)
Args:
field_data (pd.DataFrame): Field data with Date, CI columns
field_id (str): Field identifier
registered_harvest_dates (list): List of registered harvest dates (pd.Timestamp)
predicted_harvest_dates (list): List of predicted harvest dates (dict or pd.Timestamp)
output_dir (Path): Output directory for plots
tiff_dir (Path): Directory containing TIFF files
geojson_path (Path): Path to pivot.geojson
Returns:
dict: Summary of generated plots with keys 'registered' and 'predicted'
"""
if rasterio is None:
print(" ⚠ Rasterio not available - skipping RGB visualization")
return {'registered': [], 'predicted': []}
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
tiff_dir = Path(tiff_dir)
geojson_path = Path(geojson_path)
if not tiff_dir.exists():
print(f" ✗ TIFF directory not found: {tiff_dir}")
return {'registered': [], 'predicted': []}
if not geojson_path.exists():
print(f" ✗ GeoJSON not found: {geojson_path}")
return {'registered': [], 'predicted': []}
# Load field boundary
print(f" Loading field boundary for {field_id}...")
feature, field_boundary = load_field_boundaries(geojson_path, field_id)
if field_boundary is None:
print(f" ✗ Could not load field boundary for {field_id}")
return {'registered': [], 'predicted': []}
results = {'registered': [], 'predicted': []}
# Process registered harvest dates
if registered_harvest_dates and len(registered_harvest_dates) > 0:
print(f" Processing {len(registered_harvest_dates)} registered harvest dates...")
for i, harvest_date in enumerate(registered_harvest_dates):
if pd.isna(harvest_date):
continue
print(f" [{i+1}/{len(registered_harvest_dates)}] {harvest_date.strftime('%Y-%m-%d')}")
output_path = create_temporal_rgb_grid(
harvest_date, field_data, field_id, tiff_dir, field_boundary,
title='Registered Harvest Validation',
output_dir=output_dir,
harvest_type='registered',
model_name=None,
harvest_index=i
)
if output_path:
results['registered'].append(output_path)
# Process predicted harvest dates - grouped by model
if predicted_harvest_dates and len(predicted_harvest_dates) > 0:
print(f" Processing {len(predicted_harvest_dates)} predicted harvest dates...")
# Group by model to track index per model
harvest_by_model = {}
for harvest_info in predicted_harvest_dates:
# Handle both dict and Timestamp formats
if isinstance(harvest_info, dict):
harvest_date = harvest_info.get('harvest_date')
model_name = harvest_info.get('model_name', 'predicted')
else:
harvest_date = harvest_info
model_name = 'predicted'
if model_name not in harvest_by_model:
harvest_by_model[model_name] = []
harvest_by_model[model_name].append(harvest_date)
# Process each model's harvests
overall_index = 1
for model_name, harvest_dates in harvest_by_model.items():
for model_harvest_idx, harvest_date in enumerate(harvest_dates):
if pd.isna(harvest_date):
continue
print(f" [{overall_index}/{len(predicted_harvest_dates)}] {harvest_date.strftime('%Y-%m-%d')} ({model_name})")
output_path = create_temporal_rgb_grid(
harvest_date, field_data, field_id, tiff_dir, field_boundary,
title=f'Predicted Harvest Validation ({model_name})',
output_dir=output_dir,
harvest_type='predicted',
model_name=model_name,
harvest_index=model_harvest_idx
)
if output_path:
results['predicted'].append(output_path)
overall_index += 1
return results
if __name__ == '__main__':
# Example usage
print("RGB Visualization Tool")
print("This module is intended to be imported and called from compare_307_models_production.py")
print("\nExample:")
print(" from rgb_visualization import generate_rgb_grids")
print(" generate_rgb_grids(field_data, field_id, registered_dates, predicted_dates, output_dir, tiff_dir, geojson_path)")