Source code for xbatcher.accessors

from typing import Union

import xarray as xr

from .generators import BatchGenerator


def _as_xarray_dataarray(xr_obj: Union[xr.Dataset, xr.DataArray]) -> xr.DataArray:
    """
    Convert xarray.Dataset to xarray.DataArray if needed, so that it can
    be converted into a Tensor object.
    """
    if isinstance(xr_obj, xr.Dataset):
        xr_obj = xr_obj.to_array().squeeze(dim="variable")

    return xr_obj


[docs]@xr.register_dataarray_accessor("batch") @xr.register_dataset_accessor("batch") class BatchAccessor:
[docs] def __init__(self, xarray_obj): """ Batch accessor returning a BatchGenerator object via the `generator method` """ self._obj = xarray_obj
[docs] def generator(self, *args, **kwargs): """ Return a BatchGenerator via the batch accessor Parameters ---------- *args : iterable Positional arguments to pass to the `BatchGenerator` constructor. **kwargs : dict Keyword arguments to pass to the `BatchGenerator` constructor. """ return BatchGenerator(self._obj, *args, **kwargs)
[docs]@xr.register_dataarray_accessor("tf") @xr.register_dataset_accessor("tf") class TFAccessor:
[docs] def __init__(self, xarray_obj): self._obj = xarray_obj
[docs] def to_tensor(self): """Convert this DataArray to a tensorflow.Tensor""" import tensorflow as tf dataarray = _as_xarray_dataarray(xr_obj=self._obj) return tf.convert_to_tensor(dataarray.data)
[docs]@xr.register_dataarray_accessor("torch") @xr.register_dataset_accessor("torch") class TorchAccessor:
[docs] def __init__(self, xarray_obj): self._obj = xarray_obj
[docs] def to_tensor(self): """Convert this DataArray to a torch.Tensor""" import torch dataarray = _as_xarray_dataarray(xr_obj=self._obj) return torch.tensor(data=dataarray.data)
[docs] def to_named_tensor(self): """ Convert this DataArray to a torch.Tensor with named dimensions. See https://pytorch.org/docs/stable/named_tensor.html """ import torch dataarray = _as_xarray_dataarray(xr_obj=self._obj) return torch.tensor(data=dataarray.data, names=tuple(dataarray.sizes))