Source code for xbatcher.generators

"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
from collections import OrderedDict
from typing import Any, Dict, Hashable, Iterator

import xarray as xr


def _slices(dimsize, size, overlap=0):
    # return a list of slices to chop up a single dimension
    if overlap >= size:
        raise ValueError(
            "input overlap must be less than the input sample length, but "
            f"the input sample length is {size} and the overlap is {overlap}"
        )
    slices = []
    stride = size - overlap
    for start in range(0, dimsize, stride):
        end = start + size
        if end <= dimsize:
            slices.append(slice(start, end))
    return slices


def _iterate_through_dataset(ds, dims, overlap={}):
    dim_slices = []
    for dim in dims:
        dimsize = ds.sizes[dim]
        size = dims[dim]
        olap = overlap.get(dim, 0)
        if size > dimsize:
            raise ValueError(
                "input sample length must be less than or equal to the "
                f"dimension length, but the sample length of {size} "
                f"is greater than the dimension length of {dimsize} "
                f"for {dim}"
            )
        dim_slices.append(_slices(dimsize, size, olap))

    for slices in itertools.product(*dim_slices):
        selector = {key: slice for key, slice in zip(dims, slices)}
        yield selector


def _drop_input_dims(ds, input_dims, suffix="_input"):
    # remove input_dims coordinates from datasets, rename the dimensions
    # then put intput_dims back in as coordinates
    out = ds.copy()
    for dim in input_dims:
        newdim = dim + suffix
        out = out.rename({dim: newdim})
        # extra steps needed if there is a coordinate
        if newdim in out:
            out = out.drop_vars(newdim)
            out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs
    return out


def _maybe_stack_batch_dims(ds, input_dims, stacked_dim_name="sample"):
    batch_dims = [d for d in ds.sizes if d not in input_dims]
    if len(batch_dims) < 2:
        return ds
    ds_stack = ds.stack(**{stacked_dim_name: batch_dims})
    # ensure correct order
    dim_order = (stacked_dim_name,) + tuple(input_dims)
    return ds_stack.transpose(*dim_order)


[docs]class BatchGenerator: """Create generator for iterating through xarray datarrays / datasets in batches. Parameters ---------- ds : ``xarray.Dataset`` or ``xarray.DataArray`` The data to iterate over input_dims : dict A dictionary specifying the size of the inputs in each dimension, e.g. ``{'lat': 30, 'lon': 30}`` These are the dimensions the ML library will see. All other dimensions will be stacked into one dimension called ``sample``. input_overlap : dict, optional A dictionary specifying the overlap along each dimension e.g. ``{'lat': 3, 'lon': 3}`` batch_dims : dict, optional A dictionary specifying the size of the batch along each dimension e.g. ``{'time': 10}``. These will always be iterated over. concat_input_dims : bool, optional If ``True``, the dimension chunks specified in ``input_dims`` will be concatenated and stacked into the ``sample`` dimension. The batch index will be included as a new level ``input_batch`` in the ``sample`` coordinate. If ``False``, the dimension chunks specified in ``input_dims`` will be iterated over. preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. Yields ------ ds_slice : ``xarray.Dataset`` or ``xarray.DataArray`` Slices of the array matching the given batch size specification. """
[docs] def __init__( self, ds: xr.Dataset, input_dims: Dict[Hashable, int], input_overlap: Dict[Hashable, int] = {}, batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, ): self.ds = ds # should be a dict self.input_dims = OrderedDict(input_dims) self.input_overlap = input_overlap self.batch_dims = OrderedDict(batch_dims) self.concat_input_dims = concat_input_dims self.preload_batch = preload_batch self._batches: Dict[int, Any] = self._gen_batches() # dict cache for batches
def __iter__(self) -> Iterator[xr.Dataset]: for idx in self._batches: yield self[idx] def __len__(self) -> int: return len(self._batches) def __getitem__(self, idx: int) -> xr.Dataset: if not isinstance(idx, int): raise NotImplementedError( f"{type(self).__name__}.__getitem__ currently requires a single integer key" ) if idx < 0: idx = list(self._batches)[idx] if idx in self._batches: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets = [ _drop_input_dims( self.ds.isel(**ds_input_select), list(self.input_dims), suffix=new_dim_suffix, ) for ds_input_select in self._batches[idx] ] dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] return _maybe_stack_batch_dims(dsc, new_input_dims) else: return _maybe_stack_batch_dims( self.ds.isel(**self._batches[idx]), list(self.input_dims) ) else: raise IndexError("list index out of range") def _gen_batches(self) -> dict: # in the future, we will want to do the batch generation lazily # going the eager route for now is allowing me to fill out the loader api # but it is likely to perform poorly. batches = [] for ds_batch_selector in self._iterate_batch_dims(self.ds): ds_batch = self.ds.isel(**ds_batch_selector) if self.preload_batch: ds_batch.load() input_generator = self._iterate_input_dims(ds_batch) if self.concat_input_dims: batches.append(list(input_generator)) else: batches += list(input_generator) return dict(zip(range(len(batches)), batches)) def _iterate_batch_dims(self, ds): return _iterate_through_dataset(ds, self.batch_dims) def _iterate_input_dims(self, ds): return _iterate_through_dataset(ds, self.input_dims, self.input_overlap)