319 lines
13 KiB
Python
319 lines
13 KiB
Python
import os
|
|
import argparse
|
|
import numpy as np
|
|
from pathlib import Path
|
|
from osgeo import gdal
|
|
import rasterio as rio
|
|
from rasterio.enums import Resampling
|
|
from rasterio.warp import reproject
|
|
from osgeo import osr
|
|
|
|
# Attempt to import OmniCloudMask and set a flag
|
|
try:
|
|
from omnicloudmask import predict_from_array, load_multiband
|
|
HAS_OCM = True
|
|
except ImportError:
|
|
HAS_OCM = False
|
|
|
|
def calculate_utm_zone_and_hemisphere(longitude, latitude):
|
|
"""
|
|
Calculate the UTM zone and hemisphere based on longitude and latitude.
|
|
"""
|
|
utm_zone = int((longitude + 180) / 6) + 1
|
|
is_southern = latitude < 0
|
|
return utm_zone, is_southern
|
|
|
|
def reproject_to_projected_crs(input_path, output_path):
|
|
"""
|
|
Reprojects a raster to a projected coordinate system (e.g., UTM).
|
|
"""
|
|
input_ds = gdal.Open(str(input_path))
|
|
if not input_ds:
|
|
raise ValueError(f"Failed to open input raster: {input_path}")
|
|
|
|
# Get the source spatial reference
|
|
source_srs = osr.SpatialReference()
|
|
source_srs.ImportFromWkt(input_ds.GetProjection())
|
|
|
|
# Get the geographic coordinates of the image's center
|
|
geo_transform = input_ds.GetGeoTransform()
|
|
width = input_ds.RasterXSize
|
|
height = input_ds.RasterYSize
|
|
center_x = geo_transform[0] + (width / 2) * geo_transform[1]
|
|
center_y = geo_transform[3] + (height / 2) * geo_transform[5]
|
|
|
|
# Calculate the UTM zone and hemisphere dynamically
|
|
utm_zone, is_southern = calculate_utm_zone_and_hemisphere(center_x, center_y)
|
|
|
|
# Define the target spatial reference
|
|
target_srs = osr.SpatialReference()
|
|
target_srs.SetWellKnownGeogCS("WGS84")
|
|
target_srs.SetUTM(utm_zone, is_southern)
|
|
|
|
# Create the warp options
|
|
warp_options = gdal.WarpOptions(
|
|
dstSRS=target_srs.ExportToWkt(),
|
|
format="GTiff"
|
|
)
|
|
|
|
# Perform the reprojection
|
|
gdal.Warp(str(output_path), input_ds, options=warp_options)
|
|
input_ds = None # Close the dataset
|
|
print(f"Reprojected raster saved to: {output_path}")
|
|
return output_path
|
|
|
|
def resample_image(input_path, output_path, resolution=(10, 10), resample_alg="bilinear"):
|
|
"""
|
|
Resamples a raster to a specified resolution using gdal.Translate.
|
|
"""
|
|
print(f"Resampling {input_path} to {resolution}m resolution -> {output_path}")
|
|
|
|
# Reproject the input image to a projected CRS
|
|
reprojected_path = str(Path(output_path).with_name(f"{Path(output_path).stem}_reprojected.tif"))
|
|
reproject_to_projected_crs(input_path, reprojected_path)
|
|
|
|
# Open the reprojected dataset
|
|
input_ds = gdal.Open(reprojected_path)
|
|
if not input_ds:
|
|
raise ValueError(f"Failed to open reprojected raster: {reprojected_path}")
|
|
|
|
# Perform the resampling
|
|
result = gdal.Translate(
|
|
str(output_path),
|
|
input_ds,
|
|
xRes=resolution[0],
|
|
yRes=resolution[1],
|
|
resampleAlg=resample_alg
|
|
)
|
|
input_ds = None # Explicitly dereference the GDAL dataset
|
|
if result is None:
|
|
raise ValueError(f"Failed to resample image to {output_path}")
|
|
print(f"Successfully resampled image saved to: {output_path}")
|
|
return output_path
|
|
|
|
def run_ocm_on_image(image_path_10m, ocm_output_dir, save_mask=True):
|
|
"""
|
|
Processes a 10m resolution image with OmniCloudMask.
|
|
Adapted from process_with_ocm in the notebook.
|
|
"""
|
|
if not HAS_OCM:
|
|
print("OmniCloudMask not available. Please install with: pip install omnicloudmask")
|
|
return None, None
|
|
|
|
image_path_10m = Path(image_path_10m)
|
|
ocm_output_dir = Path(ocm_output_dir)
|
|
ocm_output_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
mask_10m_path = ocm_output_dir / f"{image_path_10m.stem}_ocm_mask_10m.tif"
|
|
|
|
try:
|
|
# Open the image to check dimensions
|
|
with rio.open(image_path_10m) as src:
|
|
width, height = src.width, src.height
|
|
|
|
# Check if the image is too small for OmniCloudMask
|
|
if width < 50 or height < 50:
|
|
print(f"Warning: Image {image_path_10m} is too small for OmniCloudMask (width: {width}, height: {height}). Skipping.")
|
|
return None, None
|
|
|
|
# PlanetScope 4-band images are typically [B,G,R,NIR]
|
|
# OCM expects [R,G,NIR] for its default model.
|
|
# Band numbers for load_multiband are 1-based.
|
|
# If original is B(1),G(2),R(3),NIR(4), then R=3, G=2, NIR=4
|
|
band_order = [3, 2, 4]
|
|
|
|
print(f"Loading 10m image for OCM: {image_path_10m}")
|
|
# load_multiband resamples if resample_res is different from source,
|
|
# but here image_path_10m is already 10m.
|
|
# We pass resample_res=None to use the image's own resolution.
|
|
rgn_data, profile = load_multiband(
|
|
input_path=str(image_path_10m),
|
|
resample_res=10, # Explicitly set target resolution for OCM
|
|
band_order=band_order
|
|
)
|
|
|
|
print("Applying OmniCloudMask...")
|
|
prediction = predict_from_array(rgn_data)
|
|
|
|
if save_mask:
|
|
profile.update(count=1, dtype='uint8')
|
|
with rio.open(mask_10m_path, 'w', **profile) as dst:
|
|
dst.write(prediction.astype('uint8'), 1)
|
|
print(f"Saved 10m OCM mask to: {mask_10m_path}")
|
|
|
|
# Summary (optional, can be removed for cleaner script output)
|
|
n_total = prediction.size
|
|
n_clear = np.sum(prediction == 0)
|
|
n_thick = np.sum(prediction == 1)
|
|
n_thin = np.sum(prediction == 2)
|
|
n_shadow = np.sum(prediction == 3)
|
|
print(f" OCM: Clear: {100*n_clear/n_total:.1f}%, Thick: {100*n_thick/n_total:.1f}%, Thin: {100*n_thin/n_total:.1f}%, Shadow: {100*n_shadow/n_total:.1f}%")
|
|
|
|
return str(mask_10m_path), profile
|
|
except Exception as e:
|
|
print(f"Error processing 10m image with OmniCloudMask: {str(e)}")
|
|
return None, None
|
|
|
|
|
|
def upsample_mask_to_3m(mask_10m_path, target_3m_image_path, output_3m_mask_path):
|
|
"""
|
|
Upsamples a 10m OCM mask to match the 3m target image.
|
|
Adapted from upsample_mask_to_highres in the notebook.
|
|
"""
|
|
print(f"Upsampling 10m mask {mask_10m_path} to 3m, referencing {target_3m_image_path}")
|
|
with rio.open(mask_10m_path) as src_mask, rio.open(target_3m_image_path) as src_img_3m:
|
|
mask_data_10m = src_mask.read(1)
|
|
|
|
img_shape_3m = (src_img_3m.height, src_img_3m.width)
|
|
img_transform_3m = src_img_3m.transform
|
|
img_crs_3m = src_img_3m.crs
|
|
|
|
upsampled_mask_3m_data = np.zeros(img_shape_3m, dtype=mask_data_10m.dtype)
|
|
|
|
reproject(
|
|
source=mask_data_10m,
|
|
destination=upsampled_mask_3m_data,
|
|
src_transform=src_mask.transform,
|
|
src_crs=src_mask.crs,
|
|
dst_transform=img_transform_3m,
|
|
dst_crs=img_crs_3m,
|
|
resampling=Resampling.nearest
|
|
)
|
|
|
|
profile_3m_mask = src_img_3m.profile.copy()
|
|
profile_3m_mask.update({
|
|
'count': 1,
|
|
'dtype': upsampled_mask_3m_data.dtype
|
|
})
|
|
|
|
with rio.open(output_3m_mask_path, 'w', **profile_3m_mask) as dst:
|
|
dst.write(upsampled_mask_3m_data, 1)
|
|
print(f"Upsampled 3m OCM mask saved to: {output_3m_mask_path}")
|
|
return str(output_3m_mask_path)
|
|
|
|
|
|
def apply_3m_mask_to_3m_image(image_3m_path, mask_3m_path, final_masked_output_path):
|
|
"""
|
|
Applies an upsampled 3m OCM mask to the original 3m image.
|
|
Adapted from apply_upsampled_mask_to_highres in the notebook.
|
|
"""
|
|
print(f"Applying 3m mask {mask_3m_path} to 3m image {image_3m_path}")
|
|
image_3m_path = Path(image_3m_path)
|
|
mask_3m_path = Path(mask_3m_path)
|
|
final_masked_output_path = Path(final_masked_output_path)
|
|
final_masked_output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
try:
|
|
with rio.open(image_3m_path) as src_img_3m, rio.open(mask_3m_path) as src_mask_3m:
|
|
img_data_3m = src_img_3m.read()
|
|
img_profile_3m = src_img_3m.profile.copy()
|
|
mask_data_3m = src_mask_3m.read(1)
|
|
|
|
if img_data_3m.shape[1:] != mask_data_3m.shape:
|
|
print(f"Warning: 3m image shape {img_data_3m.shape[1:]} and 3m mask shape {mask_data_3m.shape} do not match.")
|
|
# This should ideally not happen if upsampling was correct.
|
|
|
|
# OCM: 0=clear, 1=thick cloud, 2=thin cloud, 3=shadow
|
|
# We want to mask out (set to nodata) pixels where OCM is > 0
|
|
binary_mask = np.ones_like(mask_data_3m, dtype=np.uint8)
|
|
binary_mask[mask_data_3m > 0] = 0 # 0 for cloud/shadow, 1 for clear
|
|
|
|
masked_img_data_3m = img_data_3m.copy()
|
|
nodata_val = img_profile_3m.get('nodata', 0) # Use existing nodata or 0
|
|
|
|
for i in range(img_profile_3m['count']):
|
|
masked_img_data_3m[i][binary_mask == 0] = nodata_val
|
|
|
|
# Ensure dtype of profile matches data to be written
|
|
# If original image was float, but nodata is int (0), rasterio might complain
|
|
# It's safer to use the original image's dtype for the output.
|
|
img_profile_3m.update(dtype=img_data_3m.dtype)
|
|
|
|
with rio.open(final_masked_output_path, 'w', **img_profile_3m) as dst:
|
|
dst.write(masked_img_data_3m)
|
|
|
|
print(f"Final masked 3m image saved to: {final_masked_output_path}")
|
|
return str(final_masked_output_path)
|
|
|
|
except Exception as e:
|
|
print(f"Error applying 3m mask to 3m image: {str(e)}")
|
|
return None
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Process PlanetScope 3m imagery with OmniCloudMask.")
|
|
parser.add_argument("input_3m_image", type=str, help="Path to the input merged 3m PlanetScope GeoTIFF image.")
|
|
parser.add_argument("output_dir", type=str, help="Directory to save processed files (10m image, masks, final 3m masked image).")
|
|
|
|
args = parser.parse_args()
|
|
|
|
try:
|
|
# Resolve paths to absolute paths immediately
|
|
input_3m_path = Path(args.input_3m_image).resolve(strict=True)
|
|
# output_base_dir is the directory where outputs will be saved.
|
|
# It should exist when the script is called (created by the notebook).
|
|
output_base_dir = Path(args.output_dir).resolve(strict=True)
|
|
except FileNotFoundError as e:
|
|
print(f"Error: Path resolution failed. Input image or output base directory may not exist or is not accessible: {e}")
|
|
return
|
|
except Exception as e:
|
|
print(f"Error resolving paths: {e}")
|
|
return
|
|
|
|
# The check for input_3m_path.exists() is now covered by resolve(strict=True)
|
|
|
|
# Define intermediate and final file paths using absolute base paths
|
|
intermediate_dir = output_base_dir / "intermediate_ocm_files"
|
|
intermediate_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
image_10m_path = intermediate_dir / f"{input_3m_path.stem}_10m.tif"
|
|
# OCM mask (10m) will be saved inside run_ocm_on_image, in a subdir of intermediate_dir
|
|
ocm_mask_output_dir = intermediate_dir / "ocm_10m_mask_output"
|
|
|
|
# Upsampled OCM mask (3m)
|
|
mask_3m_upsampled_path = intermediate_dir / f"{input_3m_path.stem}_ocm_mask_3m_upsampled.tif"
|
|
|
|
# Final masked image (3m)
|
|
final_masked_3m_path = output_base_dir / f"{input_3m_path.stem}_ocm_masked_3m.tif"
|
|
|
|
print(f"--- Starting OCM processing for {input_3m_path.name} ---")
|
|
print(f"Input 3m image (absolute): {input_3m_path}")
|
|
print(f"Output base directory (absolute): {output_base_dir}")
|
|
print(f"Intermediate 10m image path: {image_10m_path}")
|
|
|
|
# 1. Resample 3m input to 10m for OCM
|
|
try:
|
|
resample_image(input_3m_path, image_10m_path, resolution=(10, 10))
|
|
except Exception as e:
|
|
print(f"Failed to resample to 10m: {e}")
|
|
return
|
|
|
|
# 2. Run OCM on the 10m image
|
|
mask_10m_generated_path, _ = run_ocm_on_image(image_10m_path, ocm_mask_output_dir)
|
|
if not mask_10m_generated_path:
|
|
print("OCM processing failed. Exiting.")
|
|
return
|
|
|
|
# 3. Upsample the 10m OCM mask to 3m
|
|
try:
|
|
upsample_mask_to_3m(mask_10m_generated_path, input_3m_path, mask_3m_upsampled_path)
|
|
except Exception as e:
|
|
print(f"Failed to upsample 10m OCM mask to 3m: {e}")
|
|
return
|
|
|
|
# 4. Apply the 3m upsampled mask to the original 3m image
|
|
try:
|
|
apply_3m_mask_to_3m_image(input_3m_path, mask_3m_upsampled_path, final_masked_3m_path)
|
|
except Exception as e:
|
|
print(f"Failed to apply 3m mask to 3m image: {e}")
|
|
return
|
|
|
|
print(f"--- Successfully completed OCM processing for {input_3m_path.name} ---")
|
|
print(f"Final 3m masked output: {final_masked_3m_path}")
|
|
|
|
if __name__ == "__main__":
|
|
if not HAS_OCM:
|
|
print("OmniCloudMask library is not installed. Please install it to run this script.")
|
|
print("You can typically install it using: pip install omnicloudmask")
|
|
else:
|
|
main() |