shredx.modules.transformer.SINDyAttentionSINDyLossTransformerEncoder#
- class shredx.modules.transformer.SINDyAttentionSINDyLossTransformerEncoder(d_model: int, n_heads: int, forecast_length: int, num_layers: int, dim_feedforward: int, dropout: float, activation: Module, layer_norm_eps: float, norm_first: bool, bias: bool, strict_symmetry: bool, input_length: int, hidden_size: int, sindy_loss_threshold: float, dt: float, device: str = 'cpu')#
Bases:
SINDyLossMixin,SINDyAttentionTransformerEncoderTransformer encoder with SINDy attention and SINDy loss regularization.
Combines SINDy-based attention in the final layer with SINDy loss regularization for ODE-based latent rollouts and sparse dynamics.
- Parameters:
- d_modelint
Input dimension of the model.
- n_headsint
Number of attention heads.
- forecast_lengthint
Number of future timesteps to predict.
- num_layersint
Number of transformer encoder layers.
- dim_feedforwardint
Dimension of feedforward network.
- dropoutfloat
Dropout probability.
- activationnn.Module
Activation function for feedforward layers.
- layer_norm_epsfloat
Epsilon for layer normalization.
- norm_firstbool
Whether to apply layer norm before attention.
- biasbool
Whether to use bias in linear layers.
- strict_symmetrybool
If True, enforce strict symmetry in SINDy coefficients.
- input_lengthint
Length of input sequences.
- hidden_sizeint
Hidden dimension size.
- sindy_loss_thresholdfloat
Threshold for coefficient sparsification.
- dtfloat
Time step for SINDy derivatives.
- devicestr, optional
Device to place the model on. Default is
"cpu".
Methods
forward(src[, is_causal])Forward pass through the SINDy attention transformer with SINDy loss.
Notes
Class Methods:
forward(src, is_causal):
Forward pass through the transformer encoder with SINDy attention and SINDy loss.
- Parameters:
src :
Float[torch.Tensor, "batch seq_len d_model"]. Input tensor.is_causal : bool, optional. Whether to apply causal masking. Default is
True.
- Returns:
tuple. Tuple containing the final output tensor of shape
(batch_size, forecast_length, seq_len, d_model)and a dictionary of auxiliary losses. The dictionary contains the SINDy loss as"sindy_loss".