Source code for xbatcher.tests.test_keras_loaders

import numpy as np
import pytest
import xarray as xr

tf = pytest.importorskip("tensorflow")

from xbatcher import BatchGenerator
from xbatcher.loaders.keras import CustomTFDataset


[docs]@pytest.fixture(scope="module") def ds_xy(): n_samples = 100 n_features = 5 ds = xr.Dataset( { "x": ( ["sample", "feature"], np.random.random((n_samples, n_features)), ), "y": (["sample"], np.random.random(n_samples)), }, ) return ds
[docs]def test_custom_dataarray(ds_xy): x = ds_xy["x"] y = ds_xy["y"] x_gen = BatchGenerator(x, {"sample": 10}) y_gen = BatchGenerator(y, {"sample": 10}) dataset = CustomTFDataset(x_gen, y_gen) # test __getitem__ x_batch, y_batch = dataset[0] assert x_batch.shape == (10, 5) assert y_batch.shape == (10,) assert tf.is_tensor(x_batch) assert tf.is_tensor(y_batch) # test __len__ assert len(dataset) == len(x_gen)
[docs]def test_custom_dataarray_with_transform(ds_xy): x = ds_xy["x"] y = ds_xy["y"] x_gen = BatchGenerator(x, {"sample": 10}) y_gen = BatchGenerator(y, {"sample": 10}) def x_transform(batch): return batch * 0 + 1 def y_transform(batch): return batch * 0 - 1 dataset = CustomTFDataset( x_gen, y_gen, transform=x_transform, target_transform=y_transform ) x_batch, y_batch = dataset[0] assert x_batch.shape == (10, 5) assert y_batch.shape == (10,) assert tf.is_tensor(x_batch) assert tf.is_tensor(y_batch) assert tf.experimental.numpy.all(x_batch == 1) assert tf.experimental.numpy.all(y_batch == -1)