import os import argparse import numpy as np from pathlib import Path from osgeo import gdal import rasterio as rio from rasterio.enums import Resampling from rasterio.warp import reproject from osgeo import osr # Attempt to import OmniCloudMask and set a flag try: from omnicloudmask import predict_from_array, load_multiband HAS_OCM = True except ImportError: HAS_OCM = False def calculate_utm_zone_and_hemisphere(longitude, latitude): """ Calculate the UTM zone and hemisphere based on longitude and latitude. """ utm_zone = int((longitude + 180) / 6) + 1 is_southern = latitude < 0 return utm_zone, is_southern def reproject_to_projected_crs(input_path, output_path): """ Reprojects a raster to a projected coordinate system (e.g., UTM). """ input_ds = gdal.Open(str(input_path)) if not input_ds: raise ValueError(f"Failed to open input raster: {input_path}") # Get the source spatial reference source_srs = osr.SpatialReference() source_srs.ImportFromWkt(input_ds.GetProjection()) # Get the geographic coordinates of the image's center geo_transform = input_ds.GetGeoTransform() width = input_ds.RasterXSize height = input_ds.RasterYSize center_x = geo_transform[0] + (width / 2) * geo_transform[1] center_y = geo_transform[3] + (height / 2) * geo_transform[5] # Calculate the UTM zone and hemisphere dynamically utm_zone, is_southern = calculate_utm_zone_and_hemisphere(center_x, center_y) # Define the target spatial reference target_srs = osr.SpatialReference() target_srs.SetWellKnownGeogCS("WGS84") target_srs.SetUTM(utm_zone, is_southern) # Create the warp options warp_options = gdal.WarpOptions( dstSRS=target_srs.ExportToWkt(), format="GTiff" ) # Perform the reprojection gdal.Warp(str(output_path), input_ds, options=warp_options) input_ds = None # Close the dataset print(f"Reprojected raster saved to: {output_path}") return output_path def resample_image(input_path, output_path, resolution=(10, 10), resample_alg="bilinear"): """ Resamples a raster to a specified resolution using gdal.Translate. """ print(f"Resampling {input_path} to {resolution}m resolution -> {output_path}") # Reproject the input image to a projected CRS reprojected_path = str(Path(output_path).with_name(f"{Path(output_path).stem}_reprojected.tif")) reproject_to_projected_crs(input_path, reprojected_path) # Open the reprojected dataset input_ds = gdal.Open(reprojected_path) if not input_ds: raise ValueError(f"Failed to open reprojected raster: {reprojected_path}") # Perform the resampling result = gdal.Translate( str(output_path), input_ds, xRes=resolution[0], yRes=resolution[1], resampleAlg=resample_alg ) input_ds = None # Explicitly dereference the GDAL dataset if result is None: raise ValueError(f"Failed to resample image to {output_path}") print(f"Successfully resampled image saved to: {output_path}") return output_path def run_ocm_on_image(image_path_10m, ocm_output_dir, save_mask=True): """ Processes a 10m resolution image with OmniCloudMask. Adapted from process_with_ocm in the notebook. """ if not HAS_OCM: print("OmniCloudMask not available. Please install with: pip install omnicloudmask") return None, None image_path_10m = Path(image_path_10m) ocm_output_dir = Path(ocm_output_dir) ocm_output_dir.mkdir(exist_ok=True, parents=True) mask_10m_path = ocm_output_dir / f"{image_path_10m.stem}_ocm_mask_10m.tif" try: # Open the image to check dimensions with rio.open(image_path_10m) as src: width, height = src.width, src.height # Check if the image is too small for OmniCloudMask if width < 50 or height < 50: print(f"Warning: Image {image_path_10m} is too small for OmniCloudMask (width: {width}, height: {height}). Skipping.") return None, None # PlanetScope 4-band images are typically [B,G,R,NIR] # OCM expects [R,G,NIR] for its default model. # Band numbers for load_multiband are 1-based. # If original is B(1),G(2),R(3),NIR(4), then R=3, G=2, NIR=4 band_order = [3, 2, 4] print(f"Loading 10m image for OCM: {image_path_10m}") # load_multiband resamples if resample_res is different from source, # but here image_path_10m is already 10m. # We pass resample_res=None to use the image's own resolution. rgn_data, profile = load_multiband( input_path=str(image_path_10m), resample_res=10, # Explicitly set target resolution for OCM band_order=band_order ) print("Applying OmniCloudMask...") prediction = predict_from_array(rgn_data) if save_mask: profile.update(count=1, dtype='uint8') with rio.open(mask_10m_path, 'w', **profile) as dst: dst.write(prediction.astype('uint8'), 1) print(f"Saved 10m OCM mask to: {mask_10m_path}") # Summary (optional, can be removed for cleaner script output) n_total = prediction.size n_clear = np.sum(prediction == 0) n_thick = np.sum(prediction == 1) n_thin = np.sum(prediction == 2) n_shadow = np.sum(prediction == 3) print(f" OCM: Clear: {100*n_clear/n_total:.1f}%, Thick: {100*n_thick/n_total:.1f}%, Thin: {100*n_thin/n_total:.1f}%, Shadow: {100*n_shadow/n_total:.1f}%") return str(mask_10m_path), profile except Exception as e: print(f"Error processing 10m image with OmniCloudMask: {str(e)}") return None, None def upsample_mask_to_3m(mask_10m_path, target_3m_image_path, output_3m_mask_path): """ Upsamples a 10m OCM mask to match the 3m target image. Adapted from upsample_mask_to_highres in the notebook. """ print(f"Upsampling 10m mask {mask_10m_path} to 3m, referencing {target_3m_image_path}") with rio.open(mask_10m_path) as src_mask, rio.open(target_3m_image_path) as src_img_3m: mask_data_10m = src_mask.read(1) img_shape_3m = (src_img_3m.height, src_img_3m.width) img_transform_3m = src_img_3m.transform img_crs_3m = src_img_3m.crs upsampled_mask_3m_data = np.zeros(img_shape_3m, dtype=mask_data_10m.dtype) reproject( source=mask_data_10m, destination=upsampled_mask_3m_data, src_transform=src_mask.transform, src_crs=src_mask.crs, dst_transform=img_transform_3m, dst_crs=img_crs_3m, resampling=Resampling.nearest ) profile_3m_mask = src_img_3m.profile.copy() profile_3m_mask.update({ 'count': 1, 'dtype': upsampled_mask_3m_data.dtype }) with rio.open(output_3m_mask_path, 'w', **profile_3m_mask) as dst: dst.write(upsampled_mask_3m_data, 1) print(f"Upsampled 3m OCM mask saved to: {output_3m_mask_path}") return str(output_3m_mask_path) def apply_3m_mask_to_3m_image(image_3m_path, mask_3m_path, final_masked_output_path): """ Applies an upsampled 3m OCM mask to the original 3m image. Adapted from apply_upsampled_mask_to_highres in the notebook. """ print(f"Applying 3m mask {mask_3m_path} to 3m image {image_3m_path}") image_3m_path = Path(image_3m_path) mask_3m_path = Path(mask_3m_path) final_masked_output_path = Path(final_masked_output_path) final_masked_output_path.parent.mkdir(parents=True, exist_ok=True) try: with rio.open(image_3m_path) as src_img_3m, rio.open(mask_3m_path) as src_mask_3m: img_data_3m = src_img_3m.read() img_profile_3m = src_img_3m.profile.copy() mask_data_3m = src_mask_3m.read(1) if img_data_3m.shape[1:] != mask_data_3m.shape: print(f"Warning: 3m image shape {img_data_3m.shape[1:]} and 3m mask shape {mask_data_3m.shape} do not match.") # This should ideally not happen if upsampling was correct. # OCM: 0=clear, 1=thick cloud, 2=thin cloud, 3=shadow # We want to mask out (set to nodata) pixels where OCM is > 0 binary_mask = np.ones_like(mask_data_3m, dtype=np.uint8) binary_mask[mask_data_3m > 0] = 0 # 0 for cloud/shadow, 1 for clear masked_img_data_3m = img_data_3m.copy() nodata_val = img_profile_3m.get('nodata', 0) # Use existing nodata or 0 for i in range(img_profile_3m['count']): masked_img_data_3m[i][binary_mask == 0] = nodata_val # Ensure dtype of profile matches data to be written # If original image was float, but nodata is int (0), rasterio might complain # It's safer to use the original image's dtype for the output. img_profile_3m.update(dtype=img_data_3m.dtype) with rio.open(final_masked_output_path, 'w', **img_profile_3m) as dst: dst.write(masked_img_data_3m) print(f"Final masked 3m image saved to: {final_masked_output_path}") return str(final_masked_output_path) except Exception as e: print(f"Error applying 3m mask to 3m image: {str(e)}") return None def main(): parser = argparse.ArgumentParser(description="Process PlanetScope 3m imagery with OmniCloudMask.") parser.add_argument("input_3m_image", type=str, help="Path to the input merged 3m PlanetScope GeoTIFF image.") parser.add_argument("output_dir", type=str, help="Directory to save processed files (10m image, masks, final 3m masked image).") args = parser.parse_args() try: # Resolve paths to absolute paths immediately input_3m_path = Path(args.input_3m_image).resolve(strict=True) # output_base_dir is the directory where outputs will be saved. # It should exist when the script is called (created by the notebook). output_base_dir = Path(args.output_dir).resolve(strict=True) except FileNotFoundError as e: print(f"Error: Path resolution failed. Input image or output base directory may not exist or is not accessible: {e}") return except Exception as e: print(f"Error resolving paths: {e}") return # The check for input_3m_path.exists() is now covered by resolve(strict=True) # Define intermediate and final file paths using absolute base paths intermediate_dir = output_base_dir / "intermediate_ocm_files" intermediate_dir.mkdir(parents=True, exist_ok=True) image_10m_path = intermediate_dir / f"{input_3m_path.stem}_10m.tif" # OCM mask (10m) will be saved inside run_ocm_on_image, in a subdir of intermediate_dir ocm_mask_output_dir = intermediate_dir / "ocm_10m_mask_output" # Upsampled OCM mask (3m) mask_3m_upsampled_path = intermediate_dir / f"{input_3m_path.stem}_ocm_mask_3m_upsampled.tif" # Final masked image (3m) final_masked_3m_path = output_base_dir / f"{input_3m_path.stem}_ocm_masked_3m.tif" print(f"--- Starting OCM processing for {input_3m_path.name} ---") print(f"Input 3m image (absolute): {input_3m_path}") print(f"Output base directory (absolute): {output_base_dir}") print(f"Intermediate 10m image path: {image_10m_path}") # 1. Resample 3m input to 10m for OCM try: resample_image(input_3m_path, image_10m_path, resolution=(10, 10)) except Exception as e: print(f"Failed to resample to 10m: {e}") return # 2. Run OCM on the 10m image mask_10m_generated_path, _ = run_ocm_on_image(image_10m_path, ocm_mask_output_dir) if not mask_10m_generated_path: print("OCM processing failed. Exiting.") return # 3. Upsample the 10m OCM mask to 3m try: upsample_mask_to_3m(mask_10m_generated_path, input_3m_path, mask_3m_upsampled_path) except Exception as e: print(f"Failed to upsample 10m OCM mask to 3m: {e}") return # 4. Apply the 3m upsampled mask to the original 3m image try: apply_3m_mask_to_3m_image(input_3m_path, mask_3m_upsampled_path, final_masked_3m_path) except Exception as e: print(f"Failed to apply 3m mask to 3m image: {e}") return print(f"--- Successfully completed OCM processing for {input_3m_path.name} ---") print(f"Final 3m masked output: {final_masked_3m_path}") if __name__ == "__main__": if not HAS_OCM: print("OmniCloudMask library is not installed. Please install it to run this script.") print("You can typically install it using: pip install omnicloudmask") else: main()