shredx.modules.transformer.SINDyAttentionTransformerEncoder#
- class shredx.modules.transformer.SINDyAttentionTransformerEncoder(d_model: int, n_heads: int, forecast_length: int, num_layers: int, dim_feedforward: int, dropout: float, activation: Module, layer_norm_eps: float, norm_first: bool, bias: bool, strict_symmetry: bool, input_length: int, hidden_size: int, device: str = 'cpu')#
Bases:
TransformerEncoderTransformer encoder with SINDy-based attention in the final layer.
Extends the standard Transformer by replacing the attention mechanism in the last encoder layer with
MultiHeadSINDyAttention, enabling ODE-based latent space rollouts for multi-step forecasting.- Parameters:
- d_modelint
Input dimension of the model.
- n_headsint
Number of attention heads.
- forecast_lengthint
Number of future timesteps to predict.
- num_layersint
Number of transformer encoder layers.
- dim_feedforwardint
Dimension of feedforward network.
- dropoutfloat
Dropout probability.
- activationnn.Module
Activation function for feedforward layers.
- layer_norm_epsfloat
Epsilon for layer normalization.
- norm_firstbool
Whether to apply layer norm before attention.
- biasbool
Whether to use bias in linear layers.
- strict_symmetrybool
If True, enforce strict symmetry in SINDy coefficients.
- input_lengthint
Length of input sequences.
- hidden_sizeint
Hidden dimension size.
- devicestr, optional
Device to place the model on. Default is
"cpu".
Methods
forward(src[, is_causal])Forward pass through the SINDy attention transformer.
get_dense_sindy_coefficients()Return a list of dense SINDy coefficient matrices, one per attention head.
get_sindy_layer_coefficients_eigenvalues()Get eigenvalues of SINDy coefficient matrices for all attention heads.
get_sindy_layer_coefficients_sum()Sum of squared SINDy coefficients in all heads of the last layer.
print_sindy_layer_coefficients()Print the SINDy layer coefficients for all attention heads in human-readable format.
set_forecast_length(forecast_length)Set the forecast length for all SINDy attention layers.
threshold_sindy_layer_coefficients(threshold)Threshold all SINDy coefficients in all heads of the last layer.