shredx.modules.transformer.SINDyAttentionSINDyLossTransformerEncoder#

class shredx.modules.transformer.SINDyAttentionSINDyLossTransformerEncoder(d_model: int, n_heads: int, forecast_length: int, num_layers: int, dim_feedforward: int, dropout: float, activation: Module, layer_norm_eps: float, norm_first: bool, bias: bool, strict_symmetry: bool, input_length: int, hidden_size: int, sindy_loss_threshold: float, dt: float, device: str = 'cpu')#

Bases: SINDyLossMixin, SINDyAttentionTransformerEncoder

Transformer encoder with SINDy attention and SINDy loss regularization.

Combines SINDy-based attention in the final layer with SINDy loss regularization for ODE-based latent rollouts and sparse dynamics.

Parameters:
d_modelint

Input dimension of the model.

n_headsint

Number of attention heads.

forecast_lengthint

Number of future timesteps to predict.

num_layersint

Number of transformer encoder layers.

dim_feedforwardint

Dimension of feedforward network.

dropoutfloat

Dropout probability.

activationnn.Module

Activation function for feedforward layers.

layer_norm_epsfloat

Epsilon for layer normalization.

norm_firstbool

Whether to apply layer norm before attention.

biasbool

Whether to use bias in linear layers.

strict_symmetrybool

If True, enforce strict symmetry in SINDy coefficients.

input_lengthint

Length of input sequences.

hidden_sizeint

Hidden dimension size.

sindy_loss_thresholdfloat

Threshold for coefficient sparsification.

dtfloat

Time step for SINDy derivatives.

devicestr, optional

Device to place the model on. Default is "cpu".

Methods

forward(src[, is_causal])

Forward pass through the SINDy attention transformer with SINDy loss.

Notes

Class Methods:

forward(src, is_causal):

  • Forward pass through the transformer encoder with SINDy attention and SINDy loss.

  • Parameters:
    • src : Float[torch.Tensor, "batch seq_len d_model"]. Input tensor.

    • is_causal : bool, optional. Whether to apply causal masking. Default is True.

  • Returns:
    • tuple. Tuple containing the final output tensor of shape (batch_size, forecast_length, seq_len, d_model) and a dictionary of auxiliary losses. The dictionary contains the SINDy loss as "sindy_loss".