#!/usr/bin/env python """ RGB Visualization Tool for Harvest Date Validation Creates 3x3 temporal grids showing satellite imagery around registered and predicted harvest dates. Extracts RGB from 8-band Planet scope data and clips to field boundaries from GeoJSON. Functions: - load_field_boundaries(): Load field geometries from GeoJSON - find_closest_tiff(): Find available TIFF file closest to target date - load_and_clip_tiff_rgb(): Load TIFF, extract RGB, clip to field boundary - create_temporal_grid(): Create 3x3 grid (4 pre-harvest, 1 near, 2-3 post-harvest) - generate_rgb_grids(): Main orchestration function Usage: from rgb_visualization import generate_rgb_grids generate_rgb_grids(field_data, field_id, registered_harvest_dates, predicted_harvest_dates, output_dir, tiff_dir, geojson_path) """ import json import numpy as np import pandas as pd from pathlib import Path from datetime import datetime, timedelta import matplotlib matplotlib.use('Agg') # Use non-interactive backend to avoid display hangs import matplotlib.pyplot as plt import matplotlib.patches as patches from matplotlib.colors import Normalize import warnings warnings.filterwarnings('ignore') try: import rasterio from rasterio.mask import mask import shapely.geometry as shgeom except ImportError: print("Warning: rasterio not available. RGB visualization will be skipped.") rasterio = None def load_field_boundaries(geojson_path, field_id): """ Load field boundary from GeoJSON file. Args: geojson_path (Path): Path to pivot.geojson field_id (str): Field identifier (e.g., "13973") Returns: dict: GeoJSON feature or None if not found shapely.geometry.Polygon: Field boundary polygon or None """ try: with open(geojson_path) as f: geojson_data = json.load(f) # Match field ID in properties for feature in geojson_data.get('features', []): props = feature.get('properties', {}) # Try matching on 'field' or 'sub_field' if str(props.get('field', '')) == str(field_id) or \ str(props.get('sub_field', '')) == str(field_id): geometry = feature.get('geometry') if geometry: geom_type = geometry.get('type', '') coordinates = geometry.get('coordinates', []) # Handle MultiPolygon: coordinates[i] = [[[ring coords]], [[inner ring coords]], ...] if geom_type == 'MultiPolygon': # Use the first polygon from the multipolygon if coordinates and len(coordinates) > 0: coords = coordinates[0][0] # First polygon's exterior ring polygon = shgeom.Polygon(coords) return feature, polygon # Handle Polygon: coordinates = [[[ring coords]], [[inner ring coords]], ...] elif geom_type == 'Polygon': if coordinates and len(coordinates) > 0: coords = coordinates[0] # Exterior ring polygon = shgeom.Polygon(coords) return feature, polygon print(f" ⚠ Field {field_id} not found in GeoJSON") return None, None except Exception as e: print(f" ✗ Error loading GeoJSON: {e}") return None, None def find_overlapping_tiles(target_date, tiff_dir, field_boundary, days_window=60): """ Find all tile files for target_date (or closest date) that overlap with field_boundary. Tile files are organized in subdirectories by date: 5x5/YYYY-MM-DD_HH/*.tif Args: target_date (pd.Timestamp): Target date to find tiles near tiff_dir (Path): Directory containing 5x5 date subdirectories field_boundary (shapely.Polygon): Field boundary for overlap detection days_window (int): Max days to search before/after target Returns: tuple: (list of tile paths, actual_date, days_diff) list: tile paths that overlap field pd.Timestamp: actual date of tiles found int: days difference from target to actual date found """ target_date = pd.Timestamp(target_date) tiff_dir = Path(tiff_dir) if not tiff_dir.exists(): return [], None, None # Find all date subdirectories available_dates = {} # {date: ([tile file paths], actual_dir_name)} min_size_mb = 12.0 # Empty files are ~11.56 MB for date_dir in tiff_dir.iterdir(): if not date_dir.is_dir(): continue try: # Parse date from directory name (YYYY-MM-DD or YYYY-MM-DD_HH) dir_name = date_dir.name # Extract just the date part before underscore if it exists date_str = dir_name.split('_')[0] tile_date = pd.Timestamp(date_str) days_diff = (tile_date - target_date).days if abs(days_diff) > days_window: continue # Find all .tif files in this date directory tile_files = [] for tile_file in date_dir.glob('*.tif'): # Skip obviously empty files file_size_mb = tile_file.stat().st_size / (1024 * 1024) if file_size_mb >= min_size_mb: tile_files.append(tile_file) if tile_files: available_dates[tile_date] = (tile_files, dir_name) except: pass if not available_dates: return [], None, None # Find closest date closest_date = min(available_dates.keys(), key=lambda d: abs((d - target_date).days)) days_diff = (closest_date - target_date).days tiles, _ = available_dates[closest_date] # Filter tiles to only those that overlap field boundary if rasterio is None or field_boundary is None: # If rasterio not available, use all tiles (conservative approach) return tiles, closest_date, days_diff overlapping_tiles = [] for tile_path in tiles: try: with rasterio.open(tile_path) as src: # Get tile bounds tile_bounds = src.bounds # (left, bottom, right, top) tile_geom = shgeom.box(*tile_bounds) # Check if tile overlaps field if tile_geom.intersects(field_boundary): overlapping_tiles.append(tile_path) except: pass if not overlapping_tiles: # No overlapping tiles found, return all tiles for the closest date return tiles, closest_date, days_diff return overlapping_tiles, closest_date, days_diff def load_and_clip_tiff_rgb(tiff_path, field_boundary, rgb_bands=(1, 2, 3)): """ Load TIFF and extract RGB bands clipped to field boundary. For merged_final_tif files (cloud-masked and filtered): - Band 1: Red - Band 2: Green - Band 3: Blue - Band 4: NIR - Band 5: CI Args: tiff_path (Path): Path to TIFF file field_boundary (shapely.Polygon): Field boundary for clipping rgb_bands (tuple): Band indices for RGB (1-indexed, defaults to 1,2,3 for merged_final_tif) Returns: np.ndarray: RGB data (height, width, 3) with values 0-1 or None if error occurs """ if rasterio is None or field_boundary is None: return None try: with rasterio.open(tiff_path) as src: # Check band count if src.count < 3: return None # For merged_final_tif: bands 1,2,3 are R,G,B bands_to_read = (1, 2, 3) # Mask and read bands geom = shgeom.mapping(field_boundary) try: masked_data, _ = mask(src, [geom], crop=True, indexes=list(bands_to_read)) # Stack RGB rgb = np.stack([masked_data[i] for i in range(3)], axis=-1) # Convert to float32 if not already rgb = rgb.astype(np.float32) # Normalize to 0-1 range # Data appears to be 8-bit (0-255 range) stored as float32 # Check actual max value to determine normalization max_val = np.nanmax(rgb) if max_val > 0: # If max is around 255 or less, assume 8-bit if max_val <= 255: rgb = rgb / 255.0 # If max is around 65535, assume 16-bit elif max_val <= 65535: rgb = rgb / 65535.0 # Otherwise divide by max to normalize else: rgb = rgb / max_val rgb = np.clip(rgb, 0, 1) # Check if result is all NaN if np.all(np.isnan(rgb)): return None # Replace any remaining NaN with 0 (cloud/invalid pixels) rgb = np.nan_to_num(rgb, nan=0.0) return rgb except ValueError: return None except Exception as e: return None def load_and_composite_tiles_rgb(tile_paths, field_boundary): """ Load RGB from multiple overlapping tiles and composite them into a single image. Args: tile_paths (list[Path]): List of tile file paths field_boundary (shapely.Polygon): Field boundary for clipping Returns: np.ndarray: Composited RGB data (height, width, 3) with values 0-1 or None if error occurs """ if rasterio is None or field_boundary is None or not tile_paths: return None try: # Load and composite all tiles rgb_arrays = [] for tile_path in tile_paths: rgb = load_and_clip_tiff_rgb(tile_path, field_boundary) if rgb is not None: rgb_arrays.append(rgb) if not rgb_arrays: return None # If single tile, return it if len(rgb_arrays) == 1: composited = rgb_arrays[0] else: # If multiple tiles, use max composite stacked = np.stack(rgb_arrays, axis=0) composited = np.max(stacked, axis=0) composited = composited.astype(np.float32) # Stretch contrast: normalize to 0-1 range based on actual min/max in the data # This makes dim images visible valid_data = composited[composited > 0] if len(valid_data) > 0: data_min = np.percentile(valid_data, 2) # 2nd percentile to handle outliers data_max = np.percentile(valid_data, 98) # 98th percentile if data_max > data_min: composited = (composited - data_min) / (data_max - data_min) composited = np.clip(composited, 0, 1) return composited.astype(np.float32) except Exception as e: return None def create_temporal_rgb_grid(harvest_date, field_data, field_id, tiff_dir, field_boundary, title, output_dir, harvest_type='registered', model_name=None, harvest_index=None): """ Create 5x3 temporal grid around harvest date (15 images, 7-day intervals). Layout: Row 1: T-56d, T-42d, T-35d, T-28d, T-21d (pre-harvest) Row 2: T-14d, T-7d, T~0d, T+7d, T+14d (near harvest) Row 3: T+21d, T+28d, T+35d, T+42d, T+56d (post-harvest progression) Args: harvest_date (pd.Timestamp): Target harvest date field_data (pd.DataFrame): Field data with Date column field_id (str): Field identifier tiff_dir (Path): Directory with TIFF files field_boundary (shapely.Polygon): Field boundary title (str): Plot title output_dir (Path): Output directory harvest_type (str): 'registered' or 'predicted' model_name (str): Model name for predicted harvests (e.g., 'Original', 'Long-Season') harvest_index (int): Index of harvest within same model (for multiple harvests) Returns: Path: Path to saved PNG or None if failed """ harvest_date = pd.Timestamp(harvest_date) # Target dates: 15 images at 7-day intervals (8 pre, 1 near, 6 post) target_dates = [ harvest_date - timedelta(days=56), # T-56d harvest_date - timedelta(days=49), # T-49d harvest_date - timedelta(days=42), # T-42d harvest_date - timedelta(days=35), # T-35d harvest_date - timedelta(days=28), # T-28d harvest_date - timedelta(days=21), # T-21d harvest_date - timedelta(days=14), # T-14d harvest_date - timedelta(days=7), # T-7d harvest_date, # T~0d (near harvest) harvest_date + timedelta(days=7), # T+7d harvest_date + timedelta(days=14), # T+14d harvest_date + timedelta(days=21), # T+21d harvest_date + timedelta(days=28), # T+28d harvest_date + timedelta(days=35), # T+35d harvest_date + timedelta(days=42), # T+42d harvest_date + timedelta(days=56), # T+56d (Note: non-standard to fill 5th col in row 3) ] # Find TIFFs for each date rgb_images = [] days_offsets = [] actual_dates = [] # Store actual dates of TIFFs found for target in target_dates: tile_paths, actual_date, days_diff = find_overlapping_tiles(target, tiff_dir, field_boundary, days_window=60) if not tile_paths or actual_date is None: rgb_images.append(None) days_offsets.append(None) actual_dates.append(None) print(f" ⚠ No tiles found within 60 days of {target.strftime('%Y-%m-%d')} with sufficient data") continue rgb = load_and_composite_tiles_rgb(tile_paths, field_boundary) rgb_images.append(rgb) days_offsets.append(days_diff) actual_dates.append(actual_date) if rgb is not None: print(f" ✓ Loaded {len(tile_paths)} tile(s) for {actual_date.strftime('%Y-%m-%d')} ({days_diff:+d}d from target)") else: print(f" ⚠ Loaded {len(tile_paths)} tile(s) but RGB data is None") # Create 5x3 grid plot (15 images) fig, axes = plt.subplots(3, 5, figsize=(25, 15)) fig.suptitle(f'{title}\nField {field_id} - {harvest_type.upper()} Harvest: {harvest_date.strftime("%Y-%m-%d")}', fontsize=16, fontweight='bold') # Grid positions (5 columns, 3 rows = 15 images) positions = [ ('T-56d', 0, 0), ('T-49d', 0, 1), ('T-42d', 0, 2), ('T-35d', 0, 3), ('T-28d', 0, 4), ('T-21d', 1, 0), ('T-14d', 1, 1), ('T-7d', 1, 2), ('T~0d', 1, 3), ('T+7d', 1, 4), ('T+14d', 2, 0), ('T+21d', 2, 1), ('T+28d', 2, 2), ('T+35d', 2, 3), ('T+42d', 2, 4), ] for idx, (label, row, col) in enumerate(positions): # All 15 images ax = axes[row, col] if idx < len(rgb_images) and rgb_images[idx] is not None: rgb_data = rgb_images[idx] # Debug: check data range data_min, data_max = np.nanmin(rgb_data), np.nanmax(rgb_data) print(f" DEBUG: {label} RGB range: {data_min:.4f} - {data_max:.4f}, shape: {rgb_data.shape}") # Display with explicit vmin/vmax to handle normalized 0-1 data ax.imshow(rgb_data, vmin=0, vmax=1) # Build title: label + offset + actual date offset_str = f"{days_offsets[idx]:+d}d" if days_offsets[idx] is not None else "?" date_str = actual_dates[idx].strftime('%Y-%m-%d') if actual_dates[idx] is not None else "No Date" ax.set_title(f'{label}\n{offset_str}\n{date_str}', fontsize=10, fontweight='bold') # Add red box around harvest date (T~0d at row=1, col=3) if label == 'T~0d': for spine in ax.spines.values(): spine.set_edgecolor('red') spine.set_linewidth(4) else: ax.text(0.5, 0.5, 'No Data', ha='center', va='center', fontsize=12, color='gray') ax.set_title(label, fontsize=10) # Add red box for T~0d even if no data if label == 'T~0d': for spine in ax.spines.values(): spine.set_edgecolor('red') spine.set_linewidth(4) ax.set_xticks([]) ax.set_yticks([]) plt.tight_layout() # Save figure with detailed naming: field_ID_harvestdate_model_harvestyle.png harvest_date_str = harvest_date.strftime('%Y%m%d') if harvest_type == 'registered': filename = f'field_{field_id}_{harvest_date_str}_registered_harvest_rgb.png' else: # For predicted: include model name and harvest index if multiple if harvest_index is not None and harvest_index > 0: filename = f'field_{field_id}_{harvest_date_str}_{model_name}_harvest{harvest_index}_rgb.png' else: filename = f'field_{field_id}_{harvest_date_str}_{model_name}_harvest_rgb.png' output_path = Path(output_dir) / filename try: plt.savefig(output_path, dpi=100, format='png') plt.close() print(f" ✓ Saved: {filename}") return output_path except Exception as e: plt.close() print(f" ✗ Error saving PNG: {e}") return None def generate_rgb_grids(field_data, field_id, registered_harvest_dates, predicted_harvest_dates, output_dir, tiff_dir, geojson_path): """ Main orchestration function for RGB visualization. Creates 3x3 grids for: 1. Registered harvest dates (if available) 2. Predicted harvest dates (if available) Args: field_data (pd.DataFrame): Field data with Date, CI columns field_id (str): Field identifier registered_harvest_dates (list): List of registered harvest dates (pd.Timestamp) predicted_harvest_dates (list): List of predicted harvest dates (dict or pd.Timestamp) output_dir (Path): Output directory for plots tiff_dir (Path): Directory containing TIFF files geojson_path (Path): Path to pivot.geojson Returns: dict: Summary of generated plots with keys 'registered' and 'predicted' """ if rasterio is None: print(" ⚠ Rasterio not available - skipping RGB visualization") return {'registered': [], 'predicted': []} output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) tiff_dir = Path(tiff_dir) geojson_path = Path(geojson_path) if not tiff_dir.exists(): print(f" ✗ TIFF directory not found: {tiff_dir}") return {'registered': [], 'predicted': []} if not geojson_path.exists(): print(f" ✗ GeoJSON not found: {geojson_path}") return {'registered': [], 'predicted': []} # Load field boundary print(f" Loading field boundary for {field_id}...") feature, field_boundary = load_field_boundaries(geojson_path, field_id) if field_boundary is None: print(f" ✗ Could not load field boundary for {field_id}") return {'registered': [], 'predicted': []} results = {'registered': [], 'predicted': []} # Process registered harvest dates if registered_harvest_dates and len(registered_harvest_dates) > 0: print(f" Processing {len(registered_harvest_dates)} registered harvest dates...") for i, harvest_date in enumerate(registered_harvest_dates): if pd.isna(harvest_date): continue print(f" [{i+1}/{len(registered_harvest_dates)}] {harvest_date.strftime('%Y-%m-%d')}") output_path = create_temporal_rgb_grid( harvest_date, field_data, field_id, tiff_dir, field_boundary, title='Registered Harvest Validation', output_dir=output_dir, harvest_type='registered', model_name=None, harvest_index=i ) if output_path: results['registered'].append(output_path) # Process predicted harvest dates - grouped by model if predicted_harvest_dates and len(predicted_harvest_dates) > 0: print(f" Processing {len(predicted_harvest_dates)} predicted harvest dates...") # Group by model to track index per model harvest_by_model = {} for harvest_info in predicted_harvest_dates: # Handle both dict and Timestamp formats if isinstance(harvest_info, dict): harvest_date = harvest_info.get('harvest_date') model_name = harvest_info.get('model_name', 'predicted') else: harvest_date = harvest_info model_name = 'predicted' if model_name not in harvest_by_model: harvest_by_model[model_name] = [] harvest_by_model[model_name].append(harvest_date) # Process each model's harvests overall_index = 1 for model_name, harvest_dates in harvest_by_model.items(): for model_harvest_idx, harvest_date in enumerate(harvest_dates): if pd.isna(harvest_date): continue print(f" [{overall_index}/{len(predicted_harvest_dates)}] {harvest_date.strftime('%Y-%m-%d')} ({model_name})") output_path = create_temporal_rgb_grid( harvest_date, field_data, field_id, tiff_dir, field_boundary, title=f'Predicted Harvest Validation ({model_name})', output_dir=output_dir, harvest_type='predicted', model_name=model_name, harvest_index=model_harvest_idx ) if output_path: results['predicted'].append(output_path) overall_index += 1 return results if __name__ == '__main__': # Example usage print("RGB Visualization Tool") print("This module is intended to be imported and called from compare_307_models_production.py") print("\nExample:") print(" from rgb_visualization import generate_rgb_grids") print(" generate_rgb_grids(field_data, field_id, registered_dates, predicted_dates, output_dir, tiff_dir, geojson_path)")