shredx.datasets.datasets.load_sst_data#
- shredx.datasets.datasets.load_sst_data(input_length: int, forecast_length: int, device: str = 'cpu')#
Loads the SST dataset from the datasets/sst/ directory.
- Parameters:
- input_lengthint
The length of the input sequence.
- forecast_lengthint
The length of the forecast sequence.
- devicestr, optional
The device to load the data on. Default is “cpu”.
- Returns:
- tuple: (train_ds, valid_ds, test_ds, metadata) where each is a TimeSeriesDataset and metadata is a dictionary containing the scalers.
Notes
The shape of each dataset is (input_length + forecast_length, rows, cols). Each dataloader will return a tuple of (input_window, output_window) where each is a Float[torch.Tensor, “length rows cols 1”].