shredx.datasets.datasets.load_sst_data#

shredx.datasets.datasets.load_sst_data(input_length: int, forecast_length: int, device: str = 'cpu')#

Loads the SST dataset from the datasets/sst/ directory.

Parameters:
input_lengthint

The length of the input sequence.

forecast_lengthint

The length of the forecast sequence.

devicestr, optional

The device to load the data on. Default is “cpu”.

Returns:
tuple: (train_ds, valid_ds, test_ds, metadata) where each is a TimeSeriesDataset and metadata is a dictionary containing the scalers.

Notes

The shape of each dataset is (input_length + forecast_length, rows, cols). Each dataloader will return a tuple of (input_window, output_window) where each is a Float[torch.Tensor, “length rows cols 1”].