Source code for aodn_cloud_optimised.lib.GenericZarrHandler

import importlib.resources
import os
import re
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import nullcontext
from functools import cached_property, partial

import numcodecs
import numpy as np
import s3fs
import xarray as xr
import zarr
from dask import array as da
from dask.distributed import Lock
from distributed.client import FutureCancelledError
from distributed.scheduler import KilledWorker
from distributed.shuffle._exceptions import P2PConsistencyError
from xarray.coding.times import CFDatetimeCoder
from xarray.structure.merge import MergeError

from aodn_cloud_optimised.lib.CommonHandler import CommonHandler
from aodn_cloud_optimised.lib.logging import get_logger
from aodn_cloud_optimised.lib.s3Tools import (
    create_fileset,
    delete_objects_in_prefix,
    prefix_exists,
    split_s3_path,
)
from aodn_cloud_optimised.lib.schema import merge_schema_dict


[docs] class GridSizeMismatchError(ValueError): """Raised when a batch has spatial dimensions incompatible with the existing Zarr store.""" pass
[docs] def check_variable_values_dask( file_path, reference_values, variable_name, dataset_config, uuid_log ): """Checks if variable values in a file match reference values. Designed to be run in parallel (e.g., with Dask), hence it's a top-level function to avoid serialization issues with class methods. Args: file_path (str): Path to the NetCDF file to check. reference_values (np.ndarray): The reference NumPy array to compare against. variable_name (str): The name of the variable within the NetCDF file to compare. dataset_config (dict): The dataset configuration dictionary, used to get the logger name. uuid_log (str): A unique identifier for logging purposes within a batch. Returns: tuple[str, bool]: A tuple containing the file path and a boolean indicating if the file is problematic (True) or consistent (False). Returns (file_path, True) if any exception occurs during processing. """ logger_name = dataset_config.get("logger_name", "generic") logger = get_logger(logger_name) try: ds = xr.open_dataset(file_path) variable_values = ds[variable_name].values ds.close() logger.debug( f"{uuid_log}: Checking if {variable_name} in {file_path} is consistent with reference values." ) logger.debug(f"{uuid_log}: Reference values:\n{reference_values}") res = not np.array_equal(variable_values, reference_values) if res is True: logger.error( f"{uuid_log}: {variable_name} in {file_path} is NOT consistent with reference values." ) logger.error(f"{uuid_log}: Variable values:\n{variable_values}") else: logger.debug(f"{uuid_log}: Variable values:\n{variable_values}") # Check if the values are identical return file_path, res except Exception as e: logger.error(f"{uuid_log}: Failed to open {file_path}: {e}") return file_path, True
[docs] def check_append_dim_range_dask(file_path, dim_name, dataset_config, uuid_log): """Returns the min and max values of a dimension in a single file. Designed to be run in parallel (e.g., with Dask) to identify files whose append-dimension range overlaps with other files in the same batch. Args: file_path (str): Path to the NetCDF file. dim_name (str): Name of the append dimension (e.g., "TIME"). dataset_config (dict): Dataset configuration dict (used for logger name). uuid_log (str): Unique identifier for log correlation. Returns: tuple[str, any, any]: ``(file_path, min_val, max_val)`` on success, or ``(file_path, None, None)`` if the file cannot be opened or the dimension is missing. """ logger_name = dataset_config.get("logger_name", "generic") logger = get_logger(logger_name) try: with xr.open_dataset(file_path) as ds: values = ds[dim_name].values return file_path, values.min(), values.max() except Exception as e: logger.error(f"{uuid_log}: Failed to read '{dim_name}' from {file_path}: {e}") return file_path, None, None
[docs] def get_var_template_shape(ds, var_template_shape): """Return the first variable name from var_template_shape that exists in ds.""" if isinstance(var_template_shape, str): return var_template_shape if var_template_shape in ds.variables else None # If it's a list or iterable for var in var_template_shape: if var in ds.variables: return var return None
[docs] def preprocess_xarray(ds, dataset_config): """Performs preprocessing steps on an xarray Dataset. This function applies preprocessing logic defined in the dataset configuration, such as dropping variables, adding missing variables with NaN values, ensuring correct data types, and adding a 'filename' variable. Note: This function is designed to be used as a preprocessing step, potentially within `xr.open_mfdataset` or applied manually after opening datasets. It handles potential `RuntimeWarning` during NaN array creation. Args: ds (xr.Dataset): The input xarray Dataset to preprocess. dataset_config (dict): Configuration dictionary containing schema, dimensions, and other preprocessing parameters. Expected keys include 'logger_name', 'dimensions', 'schema', 'var_template_shape'. Returns: xr.Dataset: The preprocessed xarray Dataset. Raises: ValueError: If the dataset has no data variables left after filtering. TypeError: If a `RuntimeWarning` occurs during NaN array creation, indicating potential issues. """ # TODO: this is part a rewritten function available in the GenericHandler class below. # running the class method with xarray as preprocess=self.preprocess_xarray lead to many issues # 1) serialization of the arguments with pickle. # 2) Running in a dask remote cluster, for some netcdf files, xarray would send the data back to the local machine # and using ALL of the local ram. Complete nonsense. s3fs bug # https://github.com/fsspec/filesystem_spec/issues/1747 # https://discourse.pangeo.io/t/remote-cluster-with-dask-distributed-uses-the-deployment-machines-memory-and-internet-bandwitch/4637 # https://github.com/dask/distributed/discussions/8913 logger_name = dataset_config.get("logger_name", "generic") dimensions = dataset_config["schema_transformation"].get("dimensions") schema = dataset_config.get("schema") logger = get_logger(logger_name) # Drop variables not in the list vars_to_drop = set(ds.data_vars) - set(schema) # Add variables with "drop_var": true in the schema to vars_to_drop for var_name, var_details in schema.items(): if var_details.get("drop_var", False): vars_to_drop.add(var_name) ds_filtered = ds.drop_vars(vars_to_drop, errors="ignore") ds = ds_filtered if not ds.data_vars: logger.error( "The dataset has no data variables left. Check the dataset configuration." ) raise ValueError( "The dataset has no data variables left. Check the dataset configuration." ) # Regrid spectral (or any non-time) dimensions to a standard target grid when # source files have inconsistent coordinate values (e.g. ACS WAVELENGTH_a/c). target_dimension_grids = dataset_config["schema_transformation"].get( "target_dimension_grids" ) if target_dimension_grids: for dim_name, grid_spec in target_dimension_grids.items(): if dim_name not in ds.coords: logger.warning( f"target_dimension_grids: dimension '{dim_name}' not found in dataset, skipping." ) continue if "values" in grid_spec: target_coords = np.array(grid_spec["values"], dtype=np.float64) else: target_coords = np.arange( grid_spec["min"], grid_spec["max"] + grid_spec["step"] * 0.5, grid_spec["step"], dtype=np.float64, ) current_coords = ds[dim_name].values if not ( len(current_coords) == len(target_coords) and np.allclose(current_coords, target_coords, atol=1e-6) ): logger.info( f"Regridding '{dim_name}' from n={len(current_coords)} " f"[{current_coords[0]:.2f}{current_coords[-1]:.2f}] " f"to n={len(target_coords)} [{target_coords[0]:.2f}{target_coords[-1]:.2f}]" ) ds = ds.interp( {dim_name: target_coords}, method="linear", kwargs={"bounds_error": False, "fill_value": None}, ) ########## var_required = schema.copy() for dim_info in dimensions.values(): var_required.pop(dim_info["name"], None) # Remove variables with "drop_var": true from the var_required list for var_name, var_details in schema.items(): if var_details.get("drop_var", False): var_required.pop(var_name, None) schema_transformation = dataset_config["schema_transformation"] if schema_transformation.get("add_variables") is not None: variables_to_add = schema_transformation.get("add_variables") for variable_to_add in variables_to_add.items(): variable_to_add_name = variable_to_add[0] variable_to_add_info = variable_to_add[1] var_type = variable_to_add_info["schema"].get("type") if variable_to_add_info["source"].startswith("@filename"): # TODO: prob not needed for zarr. the code creates the filename anyway further below. Would be cleaner pass elif variable_to_add_info["source"].startswith("@global_attribute:"): gattr = variable_to_add_info["source"].split(":")[1] dim_name = variable_to_add_info["schema"].get("dimensions") # Validate dtype try: dtype_obj = np.dtype(var_type) except TypeError: raise ValueError(f"Invalid dtype: {var_type}") if dtype_obj.kind == "U": gatt_var_value = getattr(ds, gattr, None) length = ds.sizes[dim_name] string_array = np.full(length, gatt_var_value, dtype=object) ds[variable_to_add_name] = (dim_name, string_array) ds[variable_to_add_name].encoding = { "_FillValue": "", "dtype": object, "compressor": None, "filters": None, "object_codec": numcodecs.VLenUTF8(), # crucial! } # add other variable attributes defined in the config, except for type and dimensions for key, value in variable_to_add_info["schema"].items(): if key not in {"type", "dimensions"}: ds[variable_to_add_name].attrs[key] = value # if not (ds[variable_to_add_name] != "").all(): # logger.debug( # f"{variable_to_add_name} variable created some empty values for dataset with time_coverage_start {getattr(ds, 'time_coverage_start', None)} and time_coverage_end {getattr(ds, 'time_coverage_end', None)}" # ) else: # TODO: issues implementing numbers for zarrs. heaps of issues. # if _FillValue is set, the type automatically becomes float64. # and when removing fillvalue some random values get written as 0 # when to_zarr is called. but only on the append. # Could maybe try doing re minimum repadocucable example? # for now, maybe do a dirty raise error os user known not to have numbers as gatt for now # raise ValueError( "gattrs to variables is currently not supported for non string objects. Modify the dataset configuration" ) # Get the attribute value or fallback raw_attr_value = getattr(ds, gattr, None) # Type-safe conversion based on dtype if np.issubdtype(dtype_obj, np.integer): fill_value = -99 try: attr_value = int(raw_attr_value) except (TypeError, ValueError): attr_value = fill_value elif np.issubdtype(dtype_obj, np.floating): fill_value = np.nan try: attr_value = float(raw_attr_value) except (TypeError, ValueError): attr_value = fill_value else: raise ValueError( f"Unsupported dtype for numeric fallback: {dtype_obj}" ) if attr_value == 0: raise ValueError("shoud not be None") # TODO: known bug https://discourse.pangeo.io/t/dtype-is-ignored-if-fillvalue-in-encoding-is-provided-for-xr-to-zarr/2653 # if fill_Value is added, then on append, some fill_values are created, and the type of the data becomes float64 encoding = { # "_FillValue": fill_value, "dtype": dtype_obj, # "compressor": None, # "filters": None, } ds[variable_to_add_name] = ( (dim_name,), np.full(ds.sizes[dim_name], attr_value, dtype=dtype_obj), ) ds[variable_to_add_name] = ds[variable_to_add_name].astype( dtype_obj ) ds[variable_to_add_name].encoding = encoding # TODO: make the variable below something more generic? a parameter? var_template_shape = dataset_config["schema_transformation"].get( "var_template_shape" ) var_template_shape = get_var_template_shape(ds, var_template_shape) # retrieve filename from ds filename_placeholder = "UNKNOWN_FILENAME.nc" filename = filename_placeholder try: var = next(var for var in ds) # get the first variable if hasattr(ds[var], "encoding") and "source" in ds[var].encoding: filename = os.path.basename(ds[var].encoding["source"]) elif hasattr(ds, "encoding") and "source" in ds.encoding: filename = os.path.basename(ds.encoding["source"]) logger.debug(f"Filename found in ds.encoding: {filename}") except Exception as e: logger.debug( f"Original filename not available in the xarray dataset: {e} - will default to use {filename_placeholder}" ) filename = filename_placeholder # # Normalise filename to remove trailing > for encoding source filename = filename.strip().strip(">") # TODO: get filename; Should be from https://github.com/pydata/xarray/issues/9142 def get_append_dim(dimensions): append_dims = [ varname for varname, props in dimensions.items() if props.get("append_dim") ] if len(append_dims) > 1: raise ValueError( f"Dataset configuration error: Multiple dimensions with 'append_dim: true': {append_dims}" ) elif not append_dims: logger.warning( "Dataset configuration does not have a default append_dim set. " 'Please modify. Will default to using dimensions["time"].' ) append_dim_varname = dimensions["time"]["name"] else: append_dim_varname = dimensions[append_dims[0]]["name"] return append_dim_varname dest_name = "filename" append_dim = get_append_dim(dimensions) length = ds.sizes[append_dim] string_array = np.full( length, filename, dtype=object ) # having da.full could break things ds[dest_name] = (append_dim, string_array) ds[dest_name].encoding = { "_FillValue": "", "dtype": object, "compressor": None, "filters": None, "object_codec": numcodecs.VLenUTF8(), # crucial! } # Extend var_required with add_variables if schema_transformation.get("add_variables") is not None: for variable_to_add_name, variable_to_add_info in schema_transformation[ "add_variables" ].items(): if variable_to_add_name not in var_required: var_required[variable_to_add_name] = variable_to_add_info["schema"] # Add manually created filename variable (if it was created) if "filename" in ds and "filename" not in var_required: if schema_transformation.get("add_variables"): if "filename" in schema_transformation["add_variables"]: var_required["filename"] = schema_transformation["add_variables"][ "filename" ].get("schema") else: # fallback schema in case it's not in schema var_required["filename"] = { "type": "object", "units": "1", "long_name": "Filename of the source file", "dimensions": append_dim, } # Add other known constructed variables if present if schema_transformation.get("add_variables"): for extra_var in schema_transformation["add_variables"]: if extra_var in ds and extra_var not in var_required: if extra_var in schema: var_required[extra_var] = schema_transformation[extra_var].get( "schema" ) # TODO: when filename is added, this can be used to find the equivalent data already stored as CO, and create a NAN # version to push back as CO in order to "delete". If UNKOWN_FILENAME.nc, either run an error, or have another approach, # Maybe the file to delete, would be a filename, OR the physical og NetCDF, which means then that we have all of the info in it, and simply making it NAN # Structural validation: every dimension must have a 1D coordinate index so # that xr.open_mfdataset's combine_by_coords can infer concatenation order. # Files that fail this check are structurally incomplete and must be excluded # from the batch (the caller catches this ValueError). for dim_name in list(ds.dims): if dim_name not in ds.indexes: raise ValueError( f"File '{filename}' has dimension '{dim_name}' with no corresponding " f"1D coordinate index — file is structurally incomplete." ) logger.info(f"Applying preprocessing on dataset from {filename}") for variable_name in var_required: datatype = var_required[variable_name].get("type") # if variable doesn't exist if variable_name not in ds: logger.warning( f"Adding missing variable '{variable_name}' to xarray dataset with NaN values." ) var_dims = schema[variable_name].get("dims", None) # check the type of the variable (numerical of string) if np.issubdtype(datatype, np.number): # Add the missing variable to the dataset if var_dims: dim_names = tuple(var_dims) missing_coords = [dim for dim in dim_names if dim not in ds.coords] # create missing dimension for missing variable # TODO: this is dangerous in case the first file to process of a dataset doesnt have ALL the dimensions. If this is the case, the dimension will be full of NaNs. # How to avoid this? create a dummy NetCDF with epoq date, NaN variables but ALL valid dimensions? if missing_coords: for dim in missing_coords: logger.warning( f"Adding missing coordinate '{dim}' to xarray dataset." ) # size = dimensions[dim][ # "size" # # ] # get size from dataset's .dims # nan_array = da.full(size, da.nan, dtype=da.float64) # ds.coords[dim] = (dim, nan_array) ds = add_missing_coordinate_from_template( ds, dim, dataset_config["dataset_name"] ) shape = tuple([ds[dim].shape[0] for dim in dim_names]) nan_array = da.full(shape, da.nan, dtype=da.float64) else: logger.warning( f"The 'dims' key is missing for '{variable_name}' in the schema. Defaulting to creating '{variable_name}' with all available dimensions: {dimensions}" ) dim_names = tuple( dim_info["name"] for dim_info in dimensions.values() ) # TODO: to clean and remove this abomination. nan_array should always be created based on the dims set for each variable! nan_array = da.full( ds[var_template_shape].shape, da.nan, dtype=da.float64 ) ds[variable_name] = ( dim_names, nan_array, ) else: # for strings variables, it's quite likely that the variables don't have any dimensions associated. # This can be an issue to know which datapoint is associated to which string value. # In this case we might have to repeat the string the times of ('TIME') # ds[variable_name] = ( # (dimensions["time"]["name"],), # empty_string_array, # ) logger.warning( f"Variable '{variable_name}' is not of a numeric type (np.number)." ) ds[variable_name] = ds[variable_name].astype(datatype) # if variable already exists else: if (datatype == "timestamp[ns]") or ( datatype == "datetime64[ns]" ): # Timestamps do not have an astype method. But numpy.datetime64 do. datatype = "datetime64[ns]" # TODO: commented as never triggered, and causing some issues # elif not np.issubdtype(datatype, np.number): # # we repeat the string variable to match the size of the TIME dimension # ds[variable_name] = ( # (dimensions["time"]["name"],), # da.full_like( # ds[dimensions["time"]["name"]], # ds[variable_name], # dtype="<S1", # ), # ) # # Expand dimensions for variables missing the append_dim. # Some source NetCDF files have spatial-only variables (e.g. GDOP # with dims (I, J)) that need to be broadcast to include the # append_dim (e.g. TIME) for Zarr compatibility. # NOTE: This is NOT done here in preprocess_xarray because # expanding dims during open_mfdataset preprocessing can corrupt # TIME encoding in the Zarr store. Instead, dimension mismatches # are resolved defensively in _validate_and_fix_dims() just before # writing to Zarr. ds[variable_name] = ds[variable_name].astype(datatype) ds = _update_ds_gattr(ds, dataset_config) logger.info(f"Successfully applied preprocessing to the dataset from {filename}") return ds
def _update_ds_gattr(ds, dataset_config): ## Update all global attributes. # Add Global attributes into metadata (no schema) dataset_metadata = dict() global_attributes_dict = dataset_config["schema_transformation"].get( "global_attributes" ) if global_attributes_dict: gattrs_to_delete = global_attributes_dict.get("delete") if gattrs_to_delete: for gattr_to_delete in gattrs_to_delete: if gattr_to_delete in ds.attrs: del ds.attrs[gattr_to_delete] gattrs_to_set = global_attributes_dict.get("set") if gattrs_to_set: for gattr in gattrs_to_set: dataset_metadata[gattr] = gattrs_to_set[gattr] if "metadata_uuid" in dataset_config.keys(): dataset_metadata["metadata_uuid"] = dataset_config["metadata_uuid"] dataset_metadata["dataset_name"] = dataset_config["dataset_name"] # Update global attributes of the Dataset ds.attrs.update(dataset_metadata) return ds def _update_store_gattr(store, dataset_config: dict) -> None: """ Update global attributes directly in a Zarr store based on dataset configuration. Parameters ---------- store : zarr.Group The Zarr group/store to update. dataset_config : dict Dataset configuration dict including global attribute transforms. """ dataset_metadata = {} global_attributes_dict = dataset_config.get("schema_transformation", {}).get( "global_attributes", {} ) gattrs_to_delete = global_attributes_dict.get("delete", []) for gattr in gattrs_to_delete: store.attrs.pop(gattr, None) gattrs_to_set = global_attributes_dict.get("set", {}) dataset_metadata.update(gattrs_to_set) if "metadata_uuid" in dataset_config: dataset_metadata["metadata_uuid"] = dataset_config["metadata_uuid"] dataset_metadata["dataset_name"] = dataset_config["dataset_name"] store.attrs.update(dataset_metadata)
[docs] def add_missing_coordinate_from_template(ds, dim, dataset_name): """ Add a missing coordinate from a dataset template file. Args: ds (xarray.Dataset): The target dataset to modify. dim (str): The name of the missing dimension. dataset_name (str): Name of the dataset to locate the template file. Returns: xarray.Dataset: The dataset with the added coordinate. Raises: ValueError: If the template file is missing or does not contain the requested dimension. """ import logging logger = logging.getLogger(__name__) nc_template_path = str( importlib.resources.files("aodn_cloud_optimised") .joinpath("config") .joinpath("dataset") .joinpath(f"{dataset_name}.nc") ) if not os.path.exists(nc_template_path): raise ValueError( f"Missing required NetCDF template: {nc_template_path}. Please create it with nullify_netcdf_variables." ) with xr.open_dataset(nc_template_path) as ds_template: if dim not in ds_template.coords: raise ValueError( f"Dimension '{dim}' not found in template NetCDF: {nc_template_path}" ) logger.info(f"Adding coordinate '{dim}' from template file: {nc_template_path}") dim_data = ds_template.coords[dim].values ds.coords[dim] = (dim, dim_data) return ds
[docs] class GenericHandler(CommonHandler): """Handles the creation of cloud-optimised datasets in Zarr format. Provides methods to process NetCDF files (potentially in batches), apply preprocessing, handle inconsistencies, and write the data to an S3 Zarr store, managing chunking, consolidation, and appending. Inherits: CommonHandler: Provides common functionality like cluster management, logging, and configuration loading. """ def __init__(self, **kwargs): """Initialises the GenericZarrHandler. Sets up configuration, validates schema, defines chunking strategy, initialises S3 filesystem and Zarr store mapper, and sets default Zarr writing options. Args: **kwargs: Keyword arguments passed to the parent `CommonHandler` and used for specific Zarr handling configuration. Expected kwargs include those needed by `CommonHandler` and potentially `rechunk_drop_vars`. Configuration is loaded from the dataset config file specified in kwargs or defaults. """ super().__init__(**kwargs) json_validation_path = str( importlib.resources.files("aodn_cloud_optimised") .joinpath("config") .joinpath("schema_validation_zarr.json") ) self.validate_json( json_validation_path ) # we cannot validate the json config until self.dataset_config and self.logger["schema_transformation"] are set self.dimensions = self.dataset_config["schema_transformation"].get("dimensions") if self.dimensions is None: raise ValueError( "Missing required 'dimensions' key in 'schema_transformation'" ) self.rechunk_drop_vars = kwargs.get("rechunk_drop_vars", None) self.vars_incompatible_with_region = self.dataset_config[ "schema_transformation" ].get("vars_incompatible_with_region", None) self.append_dim_varname = self.get_append_dim() # Generic chunks structure self.chunks = { dim_info["name"]: dim_info["chunk"] for dim_info in self.dimensions.values() if "chunk" in dim_info } self.compute = bool(True) # TODO: fix this ugly abomination. Unfortunately, patching the s3_fs value in the unittest is not enough for # zarr! why? it works fine for parquet, but if I remove this if condition, my unittest breaks! maybe # self.s3_fs is overwritten somewhere?? need to check if os.getenv("RUNNING_UNDER_UNITTEST") == "true": port = os.getenv("MOTO_PORT", "5555") self.s3_fs_output = s3fs.S3FileSystem( anon=False, client_kwargs={ "endpoint_url": f"http://127.0.0.1:{port}/", "region_name": "us-east-1", }, ) self.s3_fs_input = s3fs.S3FileSystem( anon=False, client_kwargs={ "endpoint_url": f"http://127.0.0.1:{port}/", "region_name": "us-east-1", }, ) self.store = s3fs.S3Map( root=f"{self.cloud_optimised_output_path}", s3=self.s3_fs_output, check=False, ) # to_zarr common options self.safe_chunks = False self.align_chunks = True self.write_empty_chunks = False self.consolidated = True self.full_schema = merge_schema_dict( self.dataset_config["schema"], self.dataset_config["schema_transformation"] ) self.time_coder = CFDatetimeCoder(use_cftime=True) @cached_property def _get_decode_times(self): """Build decode_times parameter for xarray open operations. Returns a Mapping for selective CF time decoding when skip_cftime_decode is configured, otherwise returns self.time_coder (CFDatetimeCoder). This ensures that variables in skip_cftime_decode remain as numeric float64 throughout zarr write/append cycles, avoiding CF encoding conflicts. Returns: dict or CFDatetimeCoder: Mapping for per-variable decode or global decoder """ skip_vars = self.dataset_config["schema_transformation"].get( "skip_cftime_decode", [] ) if skip_vars: # Create a mapping: True (decode) by default, False (skip) for listed vars decode_times_map = {} for var_name in self.schema: decode_times_map[var_name] = var_name not in skip_vars return decode_times_map else: # If no skip list, use the original time_coder return self.time_coder
[docs] def delete_cloud_optimised_data(self, filename: str): """ Deletes data in the cloud-optimised Zarr dataset corresponding to a specific filename by replacing all data variables with NaNs, except dimension coordinates and the filename variable. Args: filename (str): The filename identifying the subset of data to delete. Returns: None Notes: - The function expects that the filename variable varies along a single dimension. - Before writing, the function checks that exactly one slice along the filename dimension is selected to prevent accidental overwrites. """ if not filename: self.logger.error("Filename to delete is not defined") if prefix_exists( self.cloud_optimised_output_path, s3_client_opts=self.s3_client_opts_output ): with xr.open_zarr( self.store, consolidated=True, decode_cf=True, decode_times=self._get_decode_times, decode_coords=True, ) as ds_org: # Compute only the filename variable to memory filenames = ds_org["filename"].compute().values # Normalise: remove trailing '>' which came from ds.encoding and strip whitespace filenames_clean = np.array( [str(f).strip().strip(">") for f in filenames] ) # Find matching indices (filename varies along a single dimension) matching_indices = np.where(filenames_clean == filename)[0] if matching_indices.size == 0: self.logger.error( f'No data matching the filename "{filename}" was found in the existing zarr dataset. Nothing to delete' ) return # Identify the dimension that filename varies over filename_dim = ds_org["filename"].dims[0] # Subset ds_org to just the matching indices (along filename_dim) ds_mod = ds_org.isel({filename_dim: matching_indices}) # Variables to preserve keep_vars = list(ds_mod.dims) + ["filename"] # Replace variables with NaN / None (only inside ds_mod) for var in ds_mod.data_vars: if var not in keep_vars: if np.issubdtype(ds_mod[var].dtype, np.number): ds_mod[var] = ds_mod[var].copy(deep=True) ds_mod[var].values = np.full(ds_mod[var].shape, np.nan) else: ds_mod[var] = ds_mod[var].copy(deep=True) ds_mod[var].values = np.full(ds_mod[var].shape, None) if ds_mod.sizes.get(filename_dim, 0) != 1: self.logger.error( f"Unexpected number of slices for filename '{filename}' along dimension '{filename_dim}': " f"expected 1, got {ds_mod.sizes.get(filename_dim)}. Deletion of data aborted" ) return # Write back modified dataset self.logger.info( f'Data matching "{filename}" replaced with NaNs. Writing back to zarr store.' ) self._write_ds(ds_mod, idx=0) self.logger.info( f'Data matching "{filename}" successfully replaced with NaN values' ) else: self.logger.warning( f"The dataset {self.cloud_optimised_output_path} does not exist yet. Nothing to delete" )
def _update_metadata(self): """ Update dataset variable attributes and global attributes without having to process new file """ if prefix_exists( self.cloud_optimised_output_path, s3_client_opts=self.s3_client_opts_output ): self.logger.info( f"{self.uuid_log}: Existing Zarr store found at {self.cloud_optimised_output_path}. Updating Metadata" ) self.logger.info(f"Dataset {self.dataset_name}: Updating Global Attributes") self.zarr_org = zarr.open_group(self.store, mode="a") _update_store_gattr(self.zarr_org, self.dataset_config) self.logger.info( f"Dataset {self.dataset_name}: Updating Variable Attributes" ) self.update_store_varattrs_from_schema( self.zarr_org, self.full_schema ) # self.full_schema is the merge of schema and schema_transformation.get("add_variables") zarr.consolidate_metadata(self.zarr_org.store) with xr.open_zarr( self.store, consolidated=True, decode_cf=True, decode_times=self._get_decode_times, decode_coords=True, ) as ds_mod: # assert self.dataset_config["schema_transformation"]["global_attributes"]["set"].items() <= ds_mod.attrs.items() expected_attrs = self.dataset_config["schema_transformation"][ "global_attributes" ]["set"] missing_or_mismatched = { k: v for k, v in expected_attrs.items() if ds_mod.attrs.get(k) != v } if missing_or_mismatched: self.logger.error( f"{self.uuid_log}: Global attribute update mismatch for dataset '{self.dataset_name}'. " f"The following attributes are missing or incorrect in the dataset: {missing_or_mismatched}" ) raise AssertionError( f"Metadata update failed for keys: {list(missing_or_mismatched.keys())}" ) else: self.logger.info( f"{self.uuid_log}: All expected global attributes successfully updated for dataset '{self.dataset_name}'." ) else: self.logger.error( f"Dataset {self.dataset_name} does not exist yet - cannot update metadata" )
[docs] def update_store_varattrs_from_schema(self, store, schema: dict) -> None: """ Update variable attributes directly in a Zarr store based on a schema dictionary. For each variable: - Add or update attributes defined in the schema - Delete attributes present in the store but not in the schema - Log type mismatches and other observations Parameters ---------- store : zarr.Group The Zarr group/store to update. schema : dict A dictionary where keys are variable names and values are dictionaries of metadata. """ non_attributes = {"drop_var", "type"} preserve_attributes = { "_ARRAY_DIMENSIONS", "units", "calendar", "_FillValue", "coordinates", "grid_mapping", } for var, attrs in schema.items(): if var in store: var_array = store[var] current_attrs = dict(var_array.attrs) # type check (schema vs actual Zarr array dtype) declared_type = attrs.get("type") actual_dtype = str(var_array.dtype) if declared_type and declared_type != actual_dtype: self.logger.warning( f"{self.uuid_log}: ⚠️ Type mismatch for '{var}': schema says '{declared_type}', Zarr store has '{actual_dtype}'" ) schema_attrs = { k: v for k, v in attrs.items() if k not in non_attributes } # ---- Delete unexpected attributes ---- for attr_key in current_attrs: if ( attr_key not in schema_attrs and not attr_key.startswith("_") and attr_key not in preserve_attributes ): del var_array.attrs[attr_key] self.logger.info( f"{self.uuid_log}: Deleted unexpected attribute '{attr_key}' from variable '{var}'" ) # ---- Set or update expected attributes ---- for attr_key, attr_value in schema_attrs.items(): var_array.attrs[attr_key] = attr_value else: self.logger.warning( f"{self.uuid_log}: ⚠️ Variable '{var}' in schema not found in Zarr store. Skipping." )
def log_chunking_strategy(self, ds): chunk_str = {k: v for k, v in ds.chunks.items()} self.logger.debug( f"{self.uuid_log}: Chunking strategy for the dataset to write: {chunk_str}" ) def create_sync_lock(self): if self.cluster_mode: self.lock = Lock( "zarr-append-lock" ) # Use a consistent name for locking access across all workers self.logger.info( f"Creating lock {self.lock} to prevent chunk loss/overwrite" ) else: self.lock = nullcontext() self.logger.info("No lock needed as no cluster is used") return self.lock
[docs] def get_append_dim(self): """ Determine the dimension marked with `append_dim: true` in the dataset configuration. This method scans the `self.dimensions` dictionary for a dimension where the `append_dim` property is set to `true`. If exactly one such dimension is found, it sets `self.append_dim_varname` to that dimension's key. If no dimension has `append_dim: true`, a warning is logged and `self.append_dim_varname` is defaulted to `self.dimensions["time"]["name"]`. If more than one dimension has `append_dim: true`, an error is raised. Raises: ValueError: If more than one dimension is marked with `append_dim: true`. """ # TODO: add a pydantic check append_dims = [ varname for varname, props in self.dimensions.items() if props.get("append_dim") ] if len(append_dims) > 1: raise ValueError( f"Dataset configuration error: Multiple dimensions with 'append_dim: true': {append_dims}" ) elif not append_dims: self.logger.warning( 'Dataset configuration does not have a default append_dim set. Please modify. Will default to using dimensions["time"].' ) return self.dimensions["time"]["name"] else: return self.dimensions[append_dims[0]]["name"]
[docs] def publish_cloud_optimised_fileset_batch(self, s3_file_uri_list): """Processes and publishes batches of NetCDF files to the Zarr store. Iterates through the input list of S3 URIs, processing them in batches determined by cluster resources or a default size. For each batch, it attempts to open the files as a single dataset using various strategies (engines, handling inconsistencies), preprocesses the data, and writes it to the target Zarr store using `_write_ds`. Includes fallback logic for individual file processing if batch methods fail. Args: s3_file_uri_list (list[str]): A list of S3 URIs pointing to the NetCDF files to be processed. Raises: ValueError: If `s3_file_uri_list` is None. """ # Iterate over s3_file_handle_list in batches if s3_file_uri_list is None: raise ValueError("input_objects is not defined") drop_vars_list = [ var_name for var_name, attrs in self.schema.items() if attrs.get("drop_var", False) ] partial_preprocess = partial( preprocess_xarray, dataset_config=self.dataset_config ) if self.cluster_mode: batch_size = self.get_batch_size(client=self.client) else: batch_size = 1 for idx, batch_uri_list in enumerate( self.batch_process_fileset(s3_file_uri_list, batch_size=batch_size) ): self.uuid_log = str(uuid.uuid4()) # value per batch self.logger.info( f"{self.uuid_log}: Listing objects to process and creating S3 file handle list for batch {idx + 1}." ) batch_files = create_fileset(batch_uri_list, self.s3_fs_input) self.logger.info( f"{self.uuid_log}: Processing batch {idx + 1} with files: {batch_files}" ) if drop_vars_list: self.logger.warning( f"{self.uuid_log}: Dropping variables from the dataset: {drop_vars_list}" ) partial_preprocess_already_run = False batch_is_processed = False max_retries = 5 retry_count = 0 while not batch_is_processed: try: ds = self.try_open_dataset( batch_files, partial_preprocess, drop_vars_list, ) partial_preprocess_already_run = True if ds is None: raise RuntimeError( "try_open_dataset returned None — no dataset could be opened for this batch." ) if not partial_preprocess_already_run: # Likely redundant now, but retained for safety self.logger.debug( f"{self.uuid_log}: 'partial_preprocess_already_run' is False. Applying preprocess_xarray." ) ds = preprocess_xarray(ds, self.dataset_config) self._write_ds(ds, idx) self.logger.info( f"{self.uuid_log}: Batch {idx + 1} successfully published to Zarr store: {self.store}" ) self.run_summary.record_batch_outcome(idx + 1, "success") batch_is_processed = True # Exit loop except ( FutureCancelledError, P2PConsistencyError, KilledWorker, RuntimeError, AttributeError, ) as e: error_text = str(e) SHUFFLE_KEYWORDS = [ "P2P", "failed during transfer phase", "failed during barrier phase", "failed during shuffle phase", "shuffle failure", "No active shuffle", "Unexpected error encountered during P2P", ] CONNECTION_KEYWORDS = [ "Timed out trying to connect", "CancelledError", "Too many open files", ] WORKER_KILLED_KEYWORDS = [ "workers died", "all those workers died", "KilledWorker", "worker died", ] DESERIALISATION_KEYWORDS = [ "Error during deserialization", "different environments", ] BROKEN_CLUSTER_KEYWORDS = [ "'NoneType' object has no attribute 'asyncio_loop'", ] # Determine if the error should trigger a retry (cluster reset) retryable = any( keyword in error_text for keyword in ( SHUFFLE_KEYWORDS + CONNECTION_KEYWORDS + WORKER_KILLED_KEYWORDS + DESERIALISATION_KEYWORDS + BROKEN_CLUSTER_KEYWORDS ) ) or isinstance(e, KilledWorker) # RuntimeError or AttributeError that is NOT retryable → treat as normal exception if isinstance(e, (RuntimeError, AttributeError)) and not retryable: # Treat as a regular exception: fallback to individual processing self.logger.error( f"{self.uuid_log}: Unexpected RuntimeError during batch {idx + 1}: {e}.\n" f"{traceback.format_exc()}" ) self.fallback_to_individual_processing( batch_files, partial_preprocess, drop_vars_list, idx ) if "ds" in locals(): self.postprocess(ds) batch_is_processed = True # mark as done else: # Retryable: check retry count first retry_count += 1 if retry_count > max_retries: self.logger.error( f"{self.uuid_log}: Batch {idx + 1} has exceeded retry limit " f"({max_retries}). Falling back to individual processing." ) self.run_summary.record_batch_outcome( idx + 1, "individual_fallback" ) self.fallback_to_individual_processing( batch_files, partial_preprocess, drop_vars_list, idx ) if "ds" in locals(): self.postprocess(ds) batch_is_processed = True else: # Scheduler or shuffle failure: reset cluster and retry batch self.logger.error( f"{self.uuid_log}: Scheduler/worker/shuffle failure during batch {idx + 1}: {e}. " f"Recreating Dask cluster and retrying batch…\nRetry number for failing batch:{retry_count}/{max_retries}" f"{traceback.format_exc()}" ) self.run_summary.record_batch_retry(idx + 1) self._reset_cluster() batch_is_processed = False except MergeError as e: self.logger.error( f"{self.uuid_log}: Failed to merge datasets for batch {idx + 1}: {e}" ) if "ds" in locals(): self.postprocess(ds) # TODO: what do do in this case? could just enter an infinit loop batch_is_processed = False # Loop again except GridSizeMismatchError as e: # All files in this batch have a spatial grid incompatible with the # existing Zarr store. Individual processing won't help — every file # has the same wrong grid. Log the full file list and skip the batch. file_basenames = [ os.path.basename(getattr(f, "path", str(f))) for f in batch_files ] self.logger.error( f"{self.uuid_log}: Batch {idx + 1} rejected — {e}\n" f"The following {len(batch_files)} file(s) have an incompatible " f"spatial grid and will be skipped:\n" + "\n".join(f" - {n}" for n in file_basenames) ) self.run_summary.record_grid_mismatch(idx + 1, file_basenames) batch_is_processed = True except Exception as e: self.logger.error( f"{self.uuid_log}: An unexpected error occurred during batch {idx + 1} processing: {e}.\n {traceback.format_exc()}" ) self.fallback_to_individual_processing( batch_files, partial_preprocess, drop_vars_list, idx ) if "ds" in locals(): self.postprocess(ds) batch_is_processed = True # Loop again
[docs] def try_open_dataset( self, batch_files, partial_preprocess, drop_vars_list, primary_engine="h5netcdf", fallback_engine="scipy", ): """Attempts to open a batch of files using multiple strategies. Tries opening the batch with the primary engine (`h5netcdf`). If that fails due to specific known errors (e.g., inconsistent coordinates, wrong NetCDF format signature), it attempts specific handling or tries the fallback engine (`scipy`). If all standard `open_mfdataset` attempts fail, it resorts to opening files individually with mixed engines. Args: batch_files (list[str]): List of S3 file paths in the current batch. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. primary_engine (str): The preferred engine for `xr.open_mfdataset`. Defaults to "h5netcdf". fallback_engine (str): The engine to try if the primary one fails due to specific format errors. Defaults to "scipy". Returns: xr.Dataset: The opened and potentially preprocessed dataset for the batch. Raises: RuntimeError: If all attempts, including individual file opening, fail. ValueError: If specific unrecoverable errors occur (e.g., invalid NetCDF format after trying both engines). """ def handle_engine(engine): try: return self._open_mfds( partial_preprocess, drop_vars_list, batch_files, engine=engine ) except (ValueError, TypeError): # Capture and inspect the traceback; 4 known scenarios tb = traceback.format_exc() match_grid_not_consistent = re.search( r"Coordinate variable (\w+) is neither monotonically increasing nor monotonically decreasing on all datasets", tb, ) match_non_monotonic_combine = re.search( r"Resulting object does not have monotonic global indexes along dimension (\w+)", tb, ) match_not_netcdf4_signature = re.search( r"is not the signature of a valid netCDF4 file", tb ) match_not_netcdf3_signature = re.search( r"is not a valid NetCDF 3 file", tb ) match_missing_coord_index = re.search( r"coordinate '(\w+)' has no corresponding index", tb ) match_missing_coord_index_preprocess = re.search( r"has dimension '(\w+)' with no corresponding 1D coordinate index", tb, ) if match_grid_not_consistent: variable_name = match_grid_not_consistent.group(1) return self.handle_coordinate_variable_issue( batch_files, variable_name, partial_preprocess, drop_vars_list, engine, ) elif match_non_monotonic_combine: variable_name = match_non_monotonic_combine.group(1) if variable_name == self.append_dim_varname: # combine_by_coords failed on the append dimension — this means # at least one file has TIME values that overlap with another # file in the batch. This is a data-provider bug: files should # have non-overlapping, monotonically increasing append-dim values. self.logger.warning( f"{self.uuid_log}: combine_by_coords failed — append dimension " f"'{variable_name}' is not globally monotonic across the batch. " f"This indicates overlapping {variable_name} values in the source " f"files (data bug). Identifying and excluding the problematic files." ) return self.handle_append_dim_overlap( batch_files, variable_name, partial_preprocess, drop_vars_list, engine, ) else: # Non-monotonic on a fixed coordinate (e.g. LONGITUDE, WAVELENGTH): # some files have a different spatial grid — find and exclude them. self.logger.warning( f"{self.uuid_log}: combine_by_coords failed — dimension '{variable_name}' " f"is not globally monotonic. Checking for files with inconsistent grids." ) return self.handle_coordinate_variable_issue( batch_files, variable_name, partial_preprocess, drop_vars_list, engine, ) elif match_missing_coord_index or match_missing_coord_index_preprocess: # One or more files has a dimension without a 1D coordinate index. # preprocess_xarray embeds the filename in its error message; the # retry helper extracts it from the traceback and removes the file. return self._open_mfds_retrying( partial_preprocess, drop_vars_list, batch_files, engine, ) elif engine == primary_engine and match_not_netcdf4_signature: # One or more files may be NetCDF3 or corrupt. Try to identify them # locally, open good files as a batch, and bad files with the fallback # engine. Falls back to "Switch to fallback engine" if no bad files are # found locally (cluster-env issue). return self.handle_invalid_format_in_batch( batch_files, primary_engine, fallback_engine, partial_preprocess, drop_vars_list, ) elif engine == fallback_engine and match_not_netcdf3_signature: # Should be a NetCDF3, but didn't work with scipy! # Final failure point for fallback engine self.logger.error( f"{self.uuid_log}: open_mfdataset with '{engine}' failed — " f"not a valid NetCDF3 file either. Both engines exhausted.\n{tb}" ) raise ValueError(f"Invalid NetCDF file format in {batch_files}") else: # Unknown or unhandled error — log the full traceback self.logger.error( f"{self.uuid_log}: open_mfdataset with '{engine}' failed with " f"an unrecognised error:\n{tb}" ) raise try: # First attempt with the primary engine return handle_engine(primary_engine) except ValueError as e: if str(e) == "Switch to fallback engine": try: # Second attempt with the fallback engine return handle_engine(fallback_engine) except Exception: # If fallback also fails, handle multi-engine fallback self.logger.warning( f"{self.uuid_log}: Both '{primary_engine}' and '{fallback_engine}' " f"engines failed to open batch as mfdataset. " f"Falling back to opening files with mixed engines:\n" f"{traceback.format_exc()}" ) return self.handle_multi_engine_fallback( batch_files, partial_preprocess, drop_vars_list ) else: raise
# Magic-byte signatures used for fast format detection (no file open required). _HDF5_MAGIC = b"\x89HDF\r\n\x1a\n" # NetCDF4 / HDF5 _CDF_MAGIC = (b"CDF\x01", b"CDF\x02") # NetCDF3 classic and 64-bit offset def _is_netcdf4_magic(self, path: str) -> bool: """Return True if the file at *path* starts with the HDF5 magic bytes. Reads only the first 8 bytes via an S3 range-GET — very cheap even for thousands of files. """ try: with self.s3_fs_input.open(path, "rb") as f: magic = f.read(8) return magic == self._HDF5_MAGIC except Exception: # If we can't read the file at all, treat it as incompatible so it # gets the fallback engine (and a proper error if both fail). return False
[docs] def handle_invalid_format_in_batch( self, batch_files, primary_engine, fallback_engine, partial_preprocess, drop_vars_list, ): """Handles a batch where the primary engine fails due to incompatible files. When `open_mfdataset` with `primary_engine` fails because one or more files are not in the expected format (e.g. a NetCDF3 file in a NetCDF4 batch), this method: 1. Scans each file by reading its first 8 bytes (magic-byte check) to determine whether it is NetCDF4/HDF5 or NetCDF3 — very fast even for thousands of files, and parallelised with threads. 2. Opens the compatible files individually with the primary engine. 3. Opens the incompatible files individually with the fallback engine. 4. Concatenates and returns the combined dataset. If no incompatible files are found locally, the failure is likely a cluster-environment issue rather than a bad file. In that case, a ``ValueError("Switch to fallback engine")`` is raised to let the caller try the fallback engine on the full batch. Args: batch_files (list): List of S3 file-like objects in the batch. primary_engine (str): Engine that failed ('h5netcdf'). fallback_engine (str): Engine to try for incompatible files ('scipy'). partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. Returns: xr.Dataset: The dataset combining all successfully opened files. Raises: ValueError("Switch to fallback engine"): If no incompatible files are found locally (cluster-env issue; let the caller try scipy batch open). RuntimeError: If no files at all could be opened. """ self.logger.warning( f"{self.uuid_log}: Scanning {len(batch_files)} files via magic-byte check " f"to identify those incompatible with '{primary_engine}' engine." ) # Parallelise the magic-byte reads — each is a tiny S3 range-GET. # 50 threads keep S3 latency well-hidden without overwhelming the scheduler. good_files = [] bad_files = [] with ThreadPoolExecutor(max_workers=50) as executor: futures = { executor.submit(self._is_netcdf4_magic, file.path): file for file in batch_files } for future in as_completed(futures): file = futures[future] is_hdf5 = future.result() # h5netcdf needs HDF5 files; scipy handles NetCDF3 if primary_engine == "h5netcdf": (good_files if is_hdf5 else bad_files).append(file) else: (good_files if not is_hdf5 else bad_files).append(file) self.logger.info( f"{self.uuid_log}: Format scan complete — " f"{len(good_files)} file(s) compatible with '{primary_engine}', " f"{len(bad_files)} file(s) need '{fallback_engine}'." ) if not bad_files: # All files are HDF5 locally — failure must be a cluster-env issue. self.logger.warning( f"{self.uuid_log}: All files appear to be NetCDF4/HDF5. " f"Batch failure is likely a cluster-environment issue. Switching to '{fallback_engine}'." ) raise ValueError("Switch to fallback engine") self.logger.warning( f"{self.uuid_log}: Contact the data provider. " f"Found {len(bad_files)}/{len(batch_files)} file(s) with NetCDF3 format " f"incompatible with '{primary_engine}' — will open with '{fallback_engine}':\n{bad_files}" ) datasets = [] # Open each file individually with the known correct engine. # Using _open_ds (not _open_mfds) for all files avoids chunk-structure # mismatches on object-dtype variables (e.g. 'filename') that arise when # mixing a batch-opened (open_mfdataset) dataset with individually-opened # datasets before xr.concat. The performance benefit of batch-opening the # good files is outweighed by the correctness risk. # # Both groups are opened in parallel via ThreadPoolExecutor — the same # pattern already used above for the magic-byte scan. Each thread handles # its own seek(0) + open so no seek/read ordering issue can arise. def _open_one(file, engine): if hasattr(file, "seek"): file.seek(0) return self._open_ds( file, partial_preprocess, drop_vars_list, engine=engine ) file_engine_pairs = [(f, primary_engine) for f in good_files] + [ (f, fallback_engine) for f in bad_files ] with ThreadPoolExecutor(max_workers=50) as executor: future_to_file = { executor.submit(_open_one, file, engine): (file, engine) for file, engine in file_engine_pairs } for future in as_completed(future_to_file): file, engine = future_to_file[future] try: ds = future.result() datasets.append(ds) if engine == fallback_engine: self.logger.info( f"{self.uuid_log}: Opened incompatible file '{file}' with '{fallback_engine}'." ) except Exception as e: err_suffix = ( f" either: {e}. Skipping." if engine == fallback_engine else f": {e}" ) self.logger.error( f"{self.uuid_log}: Could not open '{file}' with '{engine}'{err_suffix}" ) if not datasets: raise RuntimeError( f"{self.uuid_log}: No files in the batch could be opened with any engine." ) if len(datasets) == 1: self.logger.info( f"{self.uuid_log}: Successfully concatenated files from batch." ) return datasets[0] self.logger.info( f"{self.uuid_log}: Concatenating {len(datasets)} dataset(s) from mixed-engine open." ) ds = xr.concat( datasets, compat="override", coords="minimal", data_vars="all", dim=self.append_dim_varname, ) ds = ds.chunk(chunks=self.chunks) ds = ds.unify_chunks() self.logger.info( f"{self.uuid_log}: Successfully concatenated files from batch." ) return ds
# Pattern used by _open_mfds_retrying to extract a bad filename from the # preprocess_xarray error message embedded in the traceback. _PREPROCESS_BAD_FILE_RE = re.compile( r"File '([^']+)' has dimension '\w+' with no corresponding 1D coordinate index" ) def _open_mfds_retrying( self, partial_preprocess, drop_vars_list, batch_files, engine ): """Calls ``_open_mfds``, retrying after excluding any file whose preprocessing raises a recognisable "structurally incomplete" error. ``preprocess_xarray`` embeds the source filename in its ``ValueError`` message when it detects a structural problem (e.g. a dimension with no 1D coordinate index). When ``open_mfdataset`` propagates that exception back to the main process, this method extracts the filename from the traceback, removes the offending file from the batch, logs it, and retries. The loop repeats until the batch opens cleanly or every file has been excluded. This avoids a separate per-file scan: the bad file identifies itself through the preprocess error, and Dask's normal parallel machinery is used to open the remaining files. Args: partial_preprocess (callable): Preprocessing function passed to ``open_mfdataset``. Must embed the source filename in any ``ValueError`` it raises (as ``preprocess_xarray`` does). drop_vars_list (list[str]): Variables to drop when opening. batch_files (list): S3 file-like objects to open. engine (str): xarray engine to use. Returns: xr.Dataset: The dataset opened from the surviving files. Raises: RuntimeError: If every file in the batch was excluded. """ remaining = list(batch_files) excluded = [] while remaining: try: return self._open_mfds( partial_preprocess, drop_vars_list, remaining, engine ) except (ValueError, TypeError) as exc: # Search str(exc) ONLY — not traceback.format_exc() — because on # the second (and later) retries the full traceback includes the # chained context from earlier iterations, which would cause the # regex to match an already-excluded filename and then fail to find # it in `remaining`. The ValueError raised by open_mfdataset when # preprocess raises is the preprocess exception itself (Dask/xarray # re-raises it directly), so its message always contains the filename. match = self._PREPROCESS_BAD_FILE_RE.search(str(exc)) if not match: raise # unrelated error — propagate to the caller bad_basename = match.group(1) bad_file = next( ( f for f in remaining if os.path.basename(getattr(f, "path", str(f))) == bad_basename ), None, ) if bad_file is None: # Filename in error doesn't match any file we know — give up. raise self.logger.error( f"{self.uuid_log}: Contact the data provider. " f"Excluding structurally incomplete file from batch: {bad_file}" ) self.run_summary.record_bad_file( "structural", str(bad_file), [os.path.basename(str(bad_file))] ) excluded.append(bad_file) remaining = [f for f in remaining if f is not bad_file] raise RuntimeError( f"{self.uuid_log}: All {len(batch_files)} file(s) in the batch were " f"excluded due to structural issues. Excluded files: {excluded}" )
[docs] def handle_coordinate_variable_issue( self, batch_files, variable_name, partial_preprocess, drop_vars_list, engine ): """Handles errors caused by inconsistent coordinate variables in a batch. When `xr.open_mfdataset` fails due to a non-monotonic coordinate, this method identifies the problematic files by comparing the coordinate variable values against a reference (the first file). It then attempts to re-open the dataset excluding the problematic files. Args: batch_files (list[str]): List of S3 file paths in the batch. variable_name (str): The name of the inconsistent coordinate variable. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. engine (str): The engine ('h5netcdf' or 'scipy') being used when the error occurred. Returns: xr.Dataset: The dataset opened from the cleaned batch of files. """ self.logger.error( f"{self.uuid_log}: Detected issue with variable '{variable_name}': Inconsistent grid." ) self.logger.warning( f"{self.uuid_log}: Running variable consistency check across files in the batch." ) problematic_files = self.check_variable_values_parallel( batch_files, variable_name ) clean_batch_files = [ file for file in batch_files if file not in problematic_files ] self.logger.warning( f"{self.uuid_log}: Processing batch without problematic files: {problematic_files}" ) self.logger.info( f"{self.uuid_log}: Processing the following clean files:\n{clean_batch_files}" ) return self._open_mfds( partial_preprocess, drop_vars_list, clean_batch_files, engine )
[docs] def handle_append_dim_overlap( self, batch_files, dim_name, partial_preprocess, drop_vars_list, engine ): """Handles batches where files have overlapping values on the append dimension. When ``combine_by_coords`` fails because the append dimension (e.g. TIME) is not globally monotonic, it means at least one file in the batch contains values that overlap with another file. This is a data quality issue — all files in a correctly formed dataset should have non-overlapping, monotonically increasing values along the append dimension. This method opens each file individually to determine the min/max of ``dim_name``, sorts by the minimum value, then identifies any file whose minimum value falls within the range already covered by a preceding file. Those files are logged as data-provider errors, excluded from the batch, and the remaining clean files are re-opened. Args: batch_files (list): List of S3 file paths in the current batch. dim_name (str): Name of the append dimension (e.g., ``"TIME"``). partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. engine (str): The xarray engine to use for opening files. Returns: xr.Dataset: The dataset opened from the non-overlapping files. Raises: RuntimeError: If no overlapping files can be identified (unexpected failure), or if every file in the batch was excluded. """ self.logger.warning( f"{self.uuid_log}: Running {dim_name} range check across files in the batch." ) if self.cluster_mode: futures = self.client.map( check_append_dim_range_dask, batch_files, dim_name=dim_name, dataset_config=self.dataset_config, uuid_log=self.uuid_log, ) ranges = self.client.gather(futures) else: ranges = [ check_append_dim_range_dask( f, dim_name, self.dataset_config, self.uuid_log ) for f in batch_files ] # Separate files that could not be read bad_open = [f for f, lo, hi in ranges if lo is None] valid = sorted( [(f, lo, hi) for f, lo, hi in ranges if lo is not None], key=lambda x: x[1], ) problematic_files = list(bad_open) current_max = None for f, lo, hi in valid: if current_max is not None and lo <= current_max: problematic_files.append(f) self.logger.error( f"{self.uuid_log}: Contact the data provider. File '{f}' has " f"{dim_name} values that overlap with a preceding file in the batch " f"(file range: [{lo}, {hi}], previous max: {current_max}). " f"Excluding it from this batch." ) self.run_summary.record_bad_file( "time_overlap", str(getattr(f, "path", str(f))), [os.path.basename(getattr(f, "path", str(f)))], ) else: current_max = hi if not problematic_files: raise RuntimeError( f"{self.uuid_log}: combine_by_coords failed on append dimension " f"'{dim_name}' but no overlapping files were identified. Cannot recover." ) clean_batch_files = [f for f in batch_files if f not in problematic_files] self.logger.warning( f"{self.uuid_log}: Reprocessing batch without overlapping files: {problematic_files}" ) self.logger.info( f"{self.uuid_log}: Processing the following clean files:\n{clean_batch_files}" ) return self._open_mfds( partial_preprocess, drop_vars_list, clean_batch_files, engine )
[docs] def handle_multi_engine_fallback( self, batch_files, partial_preprocess, drop_vars_list ): """Handles fallback scenario where files need different engines. If `xr.open_mfdataset` fails with both 'h5netcdf' and 'scipy' engines (likely due to a mix of NetCDF3 and NetCDF4 files in the batch), this method attempts to open each file individually, trying 'scipy' first and falling back to 'h5netcdf', then concatenates the resulting datasets. Args: batch_files (list[str]): List of S3 file paths in the batch. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. Returns: xr.Dataset: The concatenated dataset from individually opened files. Raises: RuntimeError: If concatenation fails after individual opening attempts. """ self.logger.warning( f"{self.uuid_log}: Neither 'h5netcdf' nor 'scipy' engine could concatenate the dataset. " f"Falling back to opening files individually with different engines." ) try: ds = self._concatenate_files_different_engines( batch_files, partial_preprocess, drop_vars_list ) self.logger.info( f"{self.uuid_log}: Successfully concatenated files using different engines." ) return ds except Exception as e: self.logger.warning( f"{self.uuid_log}: Error during multi-engine fallback: {e}.\n {traceback.format_exc()}" ) raise RuntimeError("Fallback to individual processing needed.")
[docs] def fallback_to_individual_processing( self, batch_files, partial_preprocess, drop_vars_list, idx ): """Handles the ultimate fallback: processing files one by one. If all batch processing attempts (`try_open_dataset`) fail for a batch, this method iterates through the files in the batch, opens each one individually (using `_open_file_with_fallback`), and writes it to the Zarr store immediately using `_write_ds`. Args: batch_files (list[str]): List of S3 file paths in the failed batch. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. idx (int): The index of the current batch (for logging). """ self.logger.info( f"{self.uuid_log}: Batch processing methods failed for batch {idx + 1}. Falling back to processing files individually." ) try: self._process_individual_file_fallback( batch_files, partial_preprocess, drop_vars_list, idx ) except Exception as e: self.logger.error( f"{self.uuid_log}: An unexpected error occurred during individual file fallback processing for batch {idx + 1}: {e}.\n {traceback.format_exc()}" )
[docs] def check_variable_values_parallel(self, file_paths, variable_name): """Checks variable consistency across files in parallel. Compares the values of a specified variable in multiple NetCDF files against the values in the first file of the list. Uses the configured Dask cluster (if available) for parallel execution via the top-level `check_variable_values_dask` function. Args: file_paths (list[str]): List of S3 file paths to check. The first file is used as the reference. variable_name (str): The name of the variable to compare across files. Returns: list[str]: A list of file paths identified as having values for the specified variable that are inconsistent with the first file. Returns all file paths if the reference file cannot be opened. """ # Open the first file and store its variable values as the reference try: reference_ds = xr.open_dataset(file_paths[0]) reference_values = reference_ds[variable_name].values reference_ds.close() except Exception as e: self.logger.error( f"{self.uuid_log}: Failed to open the first file: {file_paths[0]}: {e}" ) return file_paths # If the first file fails, consider all files problematic if self.cluster_mode: # Use Dask to process files in parallel futures = self.client.map( check_variable_values_dask, file_paths[1:], reference_values=reference_values, variable_name=variable_name, dataset_config=self.dataset_config, uuid_log=self.uuid_log, ) results = self.client.gather(futures) else: # TODO: not running on a cluster. But if that's the case, most likely this will run a file one by one, # so it will break anyway as no reference values. Hopefully this won't happen. alternative would be # to have a ref value directly taken from the original zarr? but that would work only if the zarr # already exist... to annoying/complicated to implement as it shouldn't happen. easier to debug if it # does results = [ check_variable_values_dask( file_path, reference_values, variable_name, self.dataset_config, self.uuid_log, ) for file_path in file_paths[1:] ] # Collect problematic files problematic_files = [ file_path for file_path, is_problematic in results if is_problematic ] if problematic_files: self.logger.error( f"{self.uuid_log}: Contact the data provider. The following files have an inconsistent grid with the rest of the dataset for variable '{variable_name}':\n{problematic_files}" ) for f in problematic_files: self.run_summary.record_bad_file( "grid_inconsistency", str(getattr(f, "path", str(f))), [os.path.basename(getattr(f, "path", str(f)))], ) return problematic_files
def _handle_duplicate_regions( self, ds: xr.Dataset, idx: int, region_n: int = 1 ) -> xr.Dataset | None: """Handles writing data when time values overlap with existing Zarr store. Processes one contiguous overlapping region at a time, then recursively calls itself to handle any remaining overlapping data. Identifies contiguous regions in the existing Zarr store that overlap in append_dim (or time) with the new dataset (`ds`). Writes the corresponding slices from `ds` into these regions using `mode='r+'` and `region=...`. Also handles potential size mismatches (padding) and ensures variables incompatible with region writing are dropped. Finally, identifies and writes any non-overlapping data from `ds` using `_write_ds`. Args: ds (xr.Dataset): The new dataset batch to write. idx (int): The index of the current batch (for logging). region_n (int): The number of the current region being processed. Returns: xr.Dataset | None: The dataset containing only the non-overlapping (unprocessed) time values, or None if all data overlapped. """ append_dim = self.append_dim_varname append_dim_values_new = ds[append_dim].values with xr.open_zarr( self.store, consolidated=True, decode_cf=True, decode_times=self._get_decode_times, decode_coords=True, ) as ds_zarr: ds_zarr = ds_zarr.unify_chunks() append_dim_values_org = ds_zarr[append_dim].values ds_stored_org_dims = ds_zarr.dims common_values = np.intersect1d(append_dim_values_org, append_dim_values_new) if len(common_values) == 0: self.logger.info( f"{self.uuid_log}: No overlapping {append_dim} values found in batch {idx + 1}. Writing full dataset." ) self._write_ds(ds, idx) return ds self.logger.info( f"{self.uuid_log}: Duplicate values of '{append_dim}' found in the existing dataset. Overwriting." ) # Get indices of common time values in the original dataset common_indices = np.nonzero(np.isin(append_dim_values_org, common_values))[0] # Identify first contiguous region # # regions must be CONTIGIOUS!! very important. so looking for different regions # Define regions as slices for the common time values start = common_indices[0] for i in range(1, len(common_indices)): if common_indices[i] != common_indices[i - 1] + 1: end = common_indices[i - 1] break else: end = common_indices[-1] region_slice = slice(start, end + 1) region = {append_dim: region_slice} region_values = append_dim_values_org[region_slice] matching_indexes = np.where(np.isin(append_dim_values_new, region_values))[0] matching_indexes_array_size = matching_indexes.size region_num_elements = end - start + 1 ########################################## # extremely rare!! only happened once, and why?? if matching_indexes_array_size != region_num_elements: amount_to_pad = region_num_elements - matching_indexes_array_size self.logger.warning( f"{self.uuid_log}: Mismatch in size for region {region}. Original region size: {region_num_elements}, new data size: {matching_indexes_array_size}. " f"Attempting to pad the new dataset with {amount_to_pad} NaN index(es)." ) else: amount_to_pad = 0 sub_ds_for_region = ( ds.isel({append_dim: matching_indexes}) .drop_vars(self.vars_incompatible_with_region or [], errors="ignore") .pad({append_dim: (0, amount_to_pad)}) ) # Safety net: auto-detect any remaining variables/coordinates that have no # dimensions in common with the region. These cannot be written with # to_zarr(region=...) and would cause a ValueError. region_dims = set(region.keys()) auto_incompatible = [ v for v in list(sub_ds_for_region.data_vars) + list(sub_ds_for_region.coords) if v not in region_dims and not any(d in region_dims for d in sub_ds_for_region[v].dims) ] if auto_incompatible: self.logger.error( f"{self.uuid_log}: Auto-dropping variables/coordinates with no dimensions " f"in common with region dims {region_dims}: {auto_incompatible}. " f"Consider adding these to 'vars_incompatible_with_region' in the dataset config." ) region_dims_str = ", ".join(f"'{d}'" for d in sorted(region_dims)) auto_incompatible_str = ", ".join(f"'{v}'" for v in auto_incompatible) self.run_summary.record_config_hint(region_dims_str, auto_incompatible_str) sub_ds_for_region = sub_ds_for_region.drop_vars( auto_incompatible, errors="ignore" ) ########################################## for var in sub_ds_for_region: if "chunks" in sub_ds_for_region[var].encoding: del sub_ds_for_region[var].encoding["chunks"] # TODO: # compute() was added as unittests failed on github, but not locally. related to # https://github.com/pydata/xarray/issues/5219 # missing = [ # var # for var in self.vars_incompatible_with_region # if var not in sub_ds_for_region.variables # and var not in sub_ds_for_region.dims # ] # # if missing: # raise ValueError( # f"The following variables or dimensions are not present/mispelled in the dataset: {missing}" # ) # self.lock = self.create_sync_lock() self.logger.info( f"{self.uuid_log}: Batch {idx + 1}, Region {region_n} - Overwriting Zarr dataset in region: {region}, " f"with matching indexes in the new dataset: {matching_indexes}" ) self.logger.debug( f"{self.uuid_log}: Attempt at overwriting the following dataset to an existing Zarr region:\n{sub_ds_for_region}" ) self.log_chunking_strategy(sub_ds_for_region) self.logger.debug(f"{self.uuid_log}: The region used for this is: {region}") with self.lock: sub_ds_for_region.to_zarr( self.store, region=region, mode="r+", compute=True, consolidated=self.consolidated, safe_chunks=self.safe_chunks, align_chunks=self.align_chunks, write_empty_chunks=self.write_empty_chunks, ) with xr.open_zarr( self.store, consolidated=True, decode_cf=True, decode_times=self._get_decode_times, decode_coords=True, ) as ds_stored_zarr: # breakpoint() # ds_stored_zarr = ds_stored_zarr.align_chunks() for dim, size in ds_stored_org_dims.items(): updated_size = ds_stored_zarr.dims.get(dim) if updated_size is None: self.logger.error( f"{self.uuid_log}: Dimension '{dim}' is present in the incoming dataset but missing in the updated Zarr store." ) elif updated_size != size: self.logger.error( f"{self.uuid_log}: Mismatch in dimension '{dim}': incoming dataset has size {size}, " f"but updated Zarr store has size {updated_size}." ) else: self.logger.debug( f"{self.uuid_log}: {dim} has the same size after region overwrite" ) self.logger.debug( f"{self.uuid_log}: Zarr dataset after region overwriting has the following definition: {ds_stored_zarr.dims}" ) self.logger.info( f"{self.uuid_log}: Batch {idx + 1}, Region {region_n} - Successfully published to the Zarr store: {self.store}" ) remaining_values = np.setdiff1d(append_dim_values_new, region_values) if len(remaining_values) > 0: self.logger.info( f"{self.uuid_log}: Continuing to recursively handle remaining overlapping regions." ) ds_remaining = ds.sel({append_dim: remaining_values}) return self._handle_duplicate_regions(ds_remaining, idx, region_n + 1) self.logger.info( f"{self.uuid_log}: All overlapping regions from Batch {idx + 1} were successfully published to the Zarr store: {self.store}" ) return None def _open_file_with_fallback(self, file, partial_preprocess, drop_vars_list): """Opens a single file, trying 'h5netcdf' then 'scipy' engine. Used as part of the fallback strategy when `open_mfdataset` fails. h5netcdf is tried first since most files are NetCDF4. scipy is the fallback for NetCDF3 (classic) files. Args: file (str): The S3 path of the file to open. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. Returns: xr.Dataset: The opened and preprocessed dataset. Raises: Exception: If the file cannot be opened with either engine. """ try: engine = "h5netcdf" ds = self._open_ds(file, partial_preprocess, drop_vars_list, engine=engine) self.logger.info( f"{self.uuid_log}: Successfully opened {file} with '{engine}' engine." ) return ds except (ValueError, TypeError, OSError) as e: self.logger.debug( f"{self.uuid_log}: Error opening {file} with 'h5netcdf' engine: {e}. Trying 'scipy'." ) engine = "scipy" with self.s3_fs_input.open(file, "rb") as f: ds = self._open_ds(f, partial_preprocess, drop_vars_list, engine=engine) self.logger.info( f"{self.uuid_log}: Successfully opened {file} with '{engine}' engine." ) return ds def _process_individual_file_fallback( self, batch_files, partial_preprocess, drop_vars_list, idx ): """Processes files individually as a fallback. Iterates through `batch_files`, opens each using `_open_file_with_fallback`, and writes it immediately using `_write_ds`. Files that fail are logged and skipped so that one bad file does not prevent the rest of the batch from being processed. Args: batch_files (list[str]): List of S3 file paths in the batch. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. idx (int): The index of the current batch (for logging). """ failed_files: list[tuple[str, str]] = [] for file in batch_files: try: ds = self._open_file_with_fallback( file, partial_preprocess, drop_vars_list ) self._write_ds(ds, idx) except Exception as e: self.logger.error( f"{self.uuid_log}: Failed to process file '{file}' during " f"individual fallback for batch {idx + 1}: {e}", exc_info=True, ) failed_files.append((file, str(e))) if failed_files: file_list = "\n ".join(f"{f}: {err}" for f, err in failed_files) self.logger.error( f"{self.uuid_log}: {len(failed_files)}/{len(batch_files)} file(s) " f"failed during individual fallback for batch {idx + 1}:\n {file_list}" ) for f, err in failed_files: self.run_summary.record_individual_failure( idx + 1, os.path.basename(str(f)), err ) def _concatenate_files_different_engines( self, batch_files, partial_preprocess, drop_vars_list ): """Opens files individually and concatenates them. Used when `open_mfdataset` fails even with fallback engines. Opens each file using `_open_file_with_fallback` and then concatenates the list of resulting datasets along the append_dim dimension. Files that cannot be opened with either engine are skipped and logged as corrupt so the remaining clean files can still be processed. Args: batch_files (list[str]): List of S3 file paths in the batch. partial_preprocess (callable): The pre-configured preprocessing function. drop_vars_list (list[str]): List of variable names to drop. Returns: xr.Dataset: The concatenated dataset from all successfully opened files. Raises: RuntimeError: If no files in the batch could be opened. """ datasets = [] failed_files: list[tuple[str, str]] = [] for file in batch_files: try: ds = self._open_file_with_fallback( file, partial_preprocess, drop_vars_list ) datasets.append(ds) except Exception as e: self.logger.warning( f"{self.uuid_log}: Skipping file '{file}' — could not be opened " f"with any engine: {e}" ) failed_files.append((file, str(e))) if failed_files: file_list = "\n ".join(f"{f}: {err}" for f, err in failed_files) self.logger.error( f"{self.uuid_log}: Contact the data provider. " f"{len(failed_files)}/{len(batch_files)} file(s) could not be opened " f"and were excluded from the batch:\n {file_list}" ) for f, _err in failed_files: self.run_summary.record_bad_file( "cant_open", str(f), [os.path.basename(str(f))] ) if not datasets: raise RuntimeError( f"No files in the batch could be opened. All {len(batch_files)} file(s) failed." ) # Concatenate the datasets self.logger.info( f"{self.uuid_log}: Successfully read {len(datasets)}/{len(batch_files)} file(s) in the batch. Concatenating them now." ) ds = xr.concat( datasets, compat="override", coords="minimal", data_vars="all", dim=self.append_dim_varname, ) ds = ds.chunk(chunks=self.chunks) ds = ds.unify_chunks() # ds = ds.persist() self.logger.info( f"{self.uuid_log}: Successfully concatenated files from batch." ) return ds def _open_mfds( self, partial_preprocess, drop_vars_list, batch_files, engine="h5netcdf" ): """Opens multiple files as a single dataset using xr.open_mfdataset. Configures and calls `xr.open_mfdataset` with appropriate parameters for parallel processing, preprocessing, chunking, and decoding. Args: partial_preprocess (callable): The pre-configured preprocessing function to be passed to `open_mfdataset`. drop_vars_list (list[str]): List of variable names to drop. batch_files (list[str]): List of S3 file paths to open. engine (str): The engine to use ('h5netcdf' or 'scipy'). Defaults to "h5netcdf". Returns: xr.Dataset: The multi-file dataset opened by xarray. """ # Note: # if using preprocess within open_mfdataset, data_vars should probably be set to "minimal" as the preprocess # function will create the missing variables anyway. However I noticed that in some cases, some variables gets # "emptied" such as 20190205111000-ABOM-L3S_GHRSST-SSTfnd-MultiSensor-1d_dn_Southern.nc when processed within a # batch. Need to be careful # if preprocess is called after the open_mfdataset, then data_vars should probably be set to "all" as some variables # might be changed to NaN for a specific batch, if some variables aren't common to all NetCDF # Use selective decode_times mapping to skip cftime decoding for problematic variables # Variables in skip_cftime_decode will remain as numeric (e.g., float64) instead of being # converted to datetime64[ns], which is important for time variables with NaN values. decode_times_map = self._get_decode_times skip_vars = ( [k for k, v in decode_times_map.items() if not v] if isinstance(decode_times_map, dict) else [] ) if skip_vars: self.logger.info( f"{self.uuid_log}: Skip CF time decoding for variables: {skip_vars}" ) open_mfdataset_params = { "engine": engine, "parallel": True, "preprocess": partial_preprocess, # this sometimes hangs the process. to monitor "data_vars": "all", # EXTREMELY IMPORTANT! Could lead to some variables being empty silently when writing to zarr "concat_characters": True, "mask_and_scale": True, "decode_cf": True, "decode_times": decode_times_map, "decode_coords": True, "compat": "override", "coords": "minimal", "drop_variables": drop_vars_list, } ds = xr.open_mfdataset(batch_files, **open_mfdataset_params) self.logger.info( f"{self.uuid_log}: Engine {engine} used to open the batch of files" ) # Clean up attributes for skip_cftime_decode variables # These variables stay as numeric, so they shouldn't have CF time encoding attributes # that would conflict when xarray tries to CF-encode them for zarr if skip_vars: for var_name in skip_vars: if var_name in ds: # Remove CF time attributes that conflict during zarr encoding for attr_key in ["units", "calendar"]: if attr_key in ds[var_name].attrs: del ds[var_name].attrs[attr_key] self.logger.debug( f"{self.uuid_log}: Removed '{attr_key}' from {var_name} " "(variable stored as numeric, not CF time)" ) dataset_sort_by = self.dataset_config["schema_transformation"].get( "dataset_sort_by", None ) if dataset_sort_by: self.logger.info( f"{self.uuid_log}: sorting the dataset by {self.dataset_config['schema_transformation']['dataset_sort_by']}" ) ds = ds.sortby( self.dataset_config["schema_transformation"]["dataset_sort_by"] ) else: ds = ds.sortby(self.append_dim_varname) ds = ds.chunk(chunks=self.chunks) ds = ds.unify_chunks() # ds = ds.map_blocks(partial_preprocess) ## EXTREMELY DANGEROUS TO USE. CORRUPTS SOME DATA CHUNKS SILENTLY while it's working fine with preprocess # ds = ds.persist() return ds # @delayed # Cannot use delayed with xr.concat def _open_ds(self, file, partial_preprocess, drop_vars_list, engine="h5netcdf"): """Opens and preprocesses a single file. Uses `xr.open_dataset` with specified engine and parameters. Args: file (str | object): The file path or file-like object to open. partial_preprocess (callable): The pre-configured preprocessing function (unused here, `preprocess_xarray` is called directly). drop_vars_list (list[str]): List of variable names to drop during opening. engine (str): The engine to use ('h5netcdf' or 'scipy'). Defaults to "h5netcdf". Returns: xr.Dataset: The opened, chunked, and preprocessed dataset. """ open_dataset_params = { "engine": engine, "mask_and_scale": True, "decode_cf": True, "decode_times": self._get_decode_times, "decode_coords": True, "drop_variables": drop_vars_list, } ds = xr.open_dataset(file, **open_dataset_params) # got to do ds.chunks here as we can get a # *** NotImplementedError: Can not use auto rechunking with object dtype. We are unable to estimate the size in bytes of object data # for GSLA files # https://github.com/pydata/xarray/issues/4055 ds = ds.chunk(chunks=self.chunks) ds = ds.unify_chunks() ds = preprocess_xarray(ds, self.dataset_config) return ds def _find_duplicated_values(self, ds_org): """Checks for duplicate self.append_dim_varname values in the existing Zarr store. Logs an error if duplicate values are found in the append_dim dimension of the provided dataset (assumed to be the existing Zarr store). Args: ds_org (xr.Dataset): The dataset representing the existing Zarr store. Returns: bool: Always returns True. The main purpose is logging the error. """ # Find duplicates append_dim_values_org = ds_org[self.append_dim_varname].values unique, counts = np.unique(append_dim_values_org, return_counts=True) duplicates = unique[counts > 1] # TODO: Raise error if duplicates are found? Not necessarily an issue. For example, some SOOP dataset, same TIME, 2 different NetCDF files, 2 different vessel and location. if len(duplicates) > 0: self.logger.warning( f"{self.uuid_log}: Duplicate values of '{self.append_dim_varname}' dimension " f"found in the original Zarr dataset. This could lead to a corrupted dataset. Duplicates: {duplicates}" ) return True def _validate_and_fix_dims(self, ds: xr.Dataset, ds_org: xr.Dataset) -> xr.Dataset: """Validate variable dimensions against the existing Zarr store and fix mismatches. Compares each variable's dimensions in the new dataset against the existing Zarr store and applies the minimal transformation needed: * **Same dims, same order** — no-op. * **Same dims, different order** — ``transpose`` to match the store order. * **Missing dims (subset)** — ``broadcast_like`` a template built from the expected dimension order so that the variable gains the missing axis with the correct size and coordinate values. * **Extra or completely different dims** — raises ``ValueError`` because the mismatch cannot be resolved automatically. Args: ds: The new dataset to validate. ds_org: The existing Zarr dataset to compare against. Returns: The dataset with any fixable dimension mismatches resolved. Raises: ValueError: If a variable has incompatible dimensions that cannot be fixed by reordering or broadcast expansion. """ for var_name in ds.data_vars: if var_name not in ds_org.data_vars: continue new_dims = ds[var_name].dims existing_dims = ds_org[var_name].dims if new_dims == existing_dims: continue missing_dims = set(existing_dims) - set(new_dims) extra_dims = set(new_dims) - set(existing_dims) if not missing_dims and not extra_dims: # Same set of dims but in a different order — transpose to match. self.logger.warning( f"{self.uuid_log}: Variable '{var_name}' has dims {new_dims} " f"but existing Zarr store expects {existing_dims}. " f"Transposing to match store order." ) ds[var_name] = ds[var_name].transpose(*existing_dims) elif missing_dims and not extra_dims: self.logger.warning( f"{self.uuid_log}: Variable '{var_name}' has dims {new_dims} " f"but existing Zarr store expects {existing_dims}. " f"Broadcasting to add missing dims: {missing_dims}" ) # Build a broadcast template with the expected dims from ds template_coords = {d: ds[d] for d in existing_dims if d in ds.coords} template = xr.DataArray( dims=existing_dims, coords=template_coords, ) ds[var_name] = ds[var_name].broadcast_like(template) else: raise ValueError( f"Variable '{var_name}' has incompatible dimensions: " f"new={new_dims}, existing={existing_dims}. " f"Cannot reconcile (extra={extra_dims}, missing={missing_dims})." ) return ds def _write_ds(self, ds, idx): """Writes a dataset batch to the Zarr store. Sorts the dataset by append_dim, ensures correct chunking, removes incompatible encoding information, and determines whether to write to a new store, append to an existing store, or handle overlapping regions based on append_dim values. Args: ds (xr.Dataset): The preprocessed dataset batch to write. idx (int): The index of the current batch (used for determining if it's the first write and for logging). """ if any(chunk == 0 for chunk in self.chunks.values()): self.logger.warning( f"{self.uuid_log}: One or more dimensions in the dataset configuration have a chunk size of 0: {self.chunks}. " f"Please modify the configuration file. Defaulting to 'chunks=auto'." ) ds = ds.chunk(chunks="auto") else: ds = ds.chunk(chunks=self.chunks) # Remove inherited chunk encoding to avoid conflicts when rechunking before write. # Only the "chunks" key is problematic (see https://github.com/pydata/xarray/issues/5219 # and https://github.com/pydata/xarray/issues/5286). Clearing the full encoding dict is # too aggressive: it also strips the object_codec, _FillValue etc. that are set for # string variables (e.g. 'filename'), which causes xarray to load the dask array into # memory to infer the dtype and emit a SerializationWarning. for var in ds.variables: ds[var].encoding.pop("chunks", None) # Write the dataset to Zarr if prefix_exists( self.cloud_optimised_output_path, s3_client_opts=self.s3_client_opts_output ): self.logger.info( f"{self.uuid_log}: Existing Zarr store found at {self.cloud_optimised_output_path}. Appending data." ) # NOTE: In the next section, we need to figure out if we're reprocessing existing data. # For this, the logic is to open the original zarr store and compare with the new ds from # this batch if they have time values in common. # If this is the case, we need then to find the CONTIGUOUS regions as we can't assume that # the data is well-ordered. The logic below is looking for the matching regions and indexes with xr.open_zarr( self.store, consolidated=True, decode_cf=True, decode_times=self._get_decode_times, decode_coords=True, ) as ds_org: ds_org = ds_org.unify_chunks() self._find_duplicated_values(ds_org) self.logger.debug( f"{self.uuid_log}: loading existing Zarr dataset:\n{ds_org}" ) self.logger.debug( f"{self.uuid_log}: Existing Zarr dataset has the following chunk definition: {ds_org.chunks.items()}" ) # Validate spatial (non-append) dimension sizes before attempting to # append. If the batch has a different spatial grid than the store, all # files in the batch are fundamentally incompatible — bail out early with # a specific error so the caller can log and skip them cleanly. spatial_mismatches = { dim: (ds.sizes[dim], ds_org.sizes[dim]) for dim in ds_org.dims if dim != self.append_dim_varname and dim in ds.dims and ds.sizes[dim] != ds_org.sizes[dim] } if spatial_mismatches: mismatch_str = ", ".join( f"{dim}: batch={actual} vs store={expected}" for dim, (actual, expected) in spatial_mismatches.items() ) raise GridSizeMismatchError( f"Batch has incompatible spatial grid — {mismatch_str}. " f"Files in this batch cannot be appended to the existing store." ) ds = self._validate_and_fix_dims(ds, ds_org) append_dim_values_org = ds_org[self.append_dim_varname].values append_dim_values_new = ds[self.append_dim_varname].values # Find common append_dim (or time) values common_append_dim_values = np.intersect1d( append_dim_values_org, append_dim_values_new ) # Handle the 2 scenarios, reprocessing of a batch, or append new data if len(common_append_dim_values) > 0: self._handle_duplicate_regions( ds, idx, ) # No reprocessing needed else: self._append_zarr_store(ds) self.logger.info( f"{self.uuid_log}: Batch {idx + 1} successfully appended to the Zarr store: {self.store}" ) # First time writing the dataset else: self._write_new_zarr_store(ds) def _write_new_zarr_store(self, ds): """Writes the dataset to a new Zarr store (mode='w'). Used for the very first batch when the Zarr store doesn't exist yet. Args: ds (xr.Dataset): The dataset to write. """ self.logger.info( f"{self.uuid_log}: Writing data to a new Zarr store at {self.store}." ) self.lock = self.create_sync_lock() self.log_chunking_strategy(ds) with self.lock: ds.to_zarr( self.store, mode="w", # Overwrite mode for the first batch write_empty_chunks=self.write_empty_chunks, compute=True, # Compute the result immediately consolidated=self.consolidated, safe_chunks=self.safe_chunks, align_chunks=self.align_chunks, ) def _append_zarr_store(self, ds): """Appends the dataset to an existing Zarr store (mode='a-'). Used when writing subsequent batches that do not overlap in time with the existing data. Appends along the time dimension. Args: ds (xr.Dataset): The dataset to append. """ self.logger.info( f"{self.uuid_log}: Appending data to the existing Zarr store at {self.store}." ) self.lock = self.create_sync_lock() self.log_chunking_strategy(ds) with self.lock: ds.to_zarr( self.store, mode="a-", write_empty_chunks=self.write_empty_chunks, compute=True, # Compute the result immediately consolidated=self.consolidated, append_dim=self.append_dim_varname, safe_chunks=self.safe_chunks, align_chunks=self.align_chunks, )
[docs] def to_cloud_optimised(self, s3_file_uri_list=None): """Main entry point to convert NetCDF files to a Zarr dataset. Handles optional clearing of existing data, sets up the Dask cluster if configured, ensures the input file list is unique, and calls `publish_cloud_optimised_fileset_batch` to process the files. Closes the cluster afterwards if one was created. Args: s3_file_uri_list (list[str], optional): List of S3 URIs of NetCDF files to process. If None or empty, the method will exit early. Defaults to None. """ if self.clear_existing_data: self.logger.warning( f"Option 'clear_existing_data' is True. DELETING all existing Zarr objects at {self.cloud_optimised_output_path} if they exist." ) # TODO: delete all objects if prefix_exists( self.cloud_optimised_output_path, s3_client_opts=self.s3_client_opts_output, ): bucket_name, prefix = split_s3_path(self.cloud_optimised_output_path) self.logger.info( f"Deleting existing Zarr objects from bucket '{bucket_name}' with prefix '{prefix}'." ) delete_objects_in_prefix( bucket_name, prefix, self.s3_client_opts_output ) # Multiple file processing with cluster if s3_file_uri_list is not None: s3_file_uri_list = list( dict.fromkeys(s3_file_uri_list) ) # ensure the list is unique! self.s3_file_uri_list = s3_file_uri_list if self.cluster_mode: # creating a cluster to process multiple files at once self.client, self.cluster = self.create_cluster() if self.cluster_mode == "coiled": self.cluster_id = self.cluster.cluster_id else: self.cluster_id = self.cluster.name else: self.cluster_id = "local_execution" self.publish_cloud_optimised_fileset_batch(s3_file_uri_list) if self.cluster_mode: self.cluster_manager.close_cluster(self.client, self.cluster)
[docs] @staticmethod def filter_rechunk_dimensions(dimensions): """Filters dimensions based on the 'rechunk' flag in the config. Static method used potentially by the (currently commented out) `rechunk` method to determine which dimensions and their target chunk sizes should be used for rechunking. Args: dimensions (dict): The 'dimensions' section of the dataset configuration. Returns: dict: A dictionary where keys are dimension names and values are target chunk sizes for dimensions marked with 'rechunk: true'. """ rechunk_dimensions = {} for key, value in dimensions.items(): if "rechunk" in value and value["rechunk"]: rechunk_dimensions[value["name"]] = value["chunk"] return rechunk_dimensions
# TODO: # The following function is not quite ready. # Also to be defined if this is needed. # def rechunk(self, max_mem="8.0GB"): # """ # Rechunk a Zarr dataset stored on S3. # # Parameters: # - max_mem (str, optional): Maximum memory to use during rechunking (default is '8.0GB'). # # Returns: # None # # Example: # >>> gridded_handler_instance.rechunk(max_mem='12.0GB') # # This method rechunks a Zarr dataset stored on S3 using Dask and the specified target chunks. # The rechunked data is stored in a new Zarr dataset on S3. # The intermediate and previous rechunked data are deleted before the rechunking process. # The 'max_mem' parameter controls the maximum memory to use during rechunking. # # Note: The target chunks for rechunking are determined based on the dimensions and chunking information. # """ # # with worker_client(): # with Client() as client: # # target_chunks = self.filter_rechunk_dimensions( # self.dimensions # ) # only return a dict with the dimensions to rechunk # # # s3 = s3fs.S3FileSystem(anon=False) # # org_url = ( # self.cloud_optimised_output_path # ) # f's3://{self.optimised_bucket_name}/zarr/{self.dataset_name}.zarr' # # org_store = s3fs.S3Map(root=f'{org_url}', s3=s3, check=False) # # target_url = org_url.replace( # f"{self.dataset_name}", f"{self.dataset_name}_rechunked" # ) # target_store = s3fs.S3Map(root=f"{target_url}", s3=self.s3_fs, check=False) # # zarr.consolidate_metadata(org_store) # # ds = xr.open_zarr(fsspec.get_mapper(org_url, anon=True), consolidated=True) # # temp_url = org_url.replace( # f"{self.dataset_name}", f"{self.dataset_name}_intermediate" # ) # # temp_store = s3fs.S3Map(root=f"{temp_url}", s3=self.s3_fs, check=False) # # # delete previous version of intermediate and rechunked data # s3_client = boto3.resource("s3") # bucket = s3_client.Bucket(f"{self.optimised_bucket_name}") # self.logger.info("Delete previous rechunked version in progress") # # bucket.objects.filter( # Prefix=temp_url.replace(f"s3://{self.optimised_bucket_name}/", "") # ).delete() # bucket.objects.filter( # Prefix=target_url.replace(f"s3://{self.optimised_bucket_name}/", "") # ).delete() # self.logger.info( # f"Rechunking in progress with target chunks: {target_chunks}" # ) # # options = dict(overwrite=True) # array_plan = rechunk( # ds, # target_chunks, # max_mem, # target_store, # temp_store=temp_store, # temp_options=options, # target_options=options, # ) # with ProgressBar(): # self.logger.info(f"{array_plan.execute()}") # # zarr.consolidate_metadata(target_store)