shredx.modules.transformer.SINDyAttentionTransformerEncoder#

class shredx.modules.transformer.SINDyAttentionTransformerEncoder(d_model: int, n_heads: int, forecast_length: int, num_layers: int, dim_feedforward: int, dropout: float, activation: Module, layer_norm_eps: float, norm_first: bool, bias: bool, strict_symmetry: bool, input_length: int, hidden_size: int, device: str = 'cpu')#

Bases: TransformerEncoder

Transformer encoder with SINDy-based attention in the final layer.

Extends the standard Transformer by replacing the attention mechanism in the last encoder layer with MultiHeadSINDyAttention, enabling ODE-based latent space rollouts for multi-step forecasting.

Parameters:
d_modelint

Input dimension of the model.

n_headsint

Number of attention heads.

forecast_lengthint

Number of future timesteps to predict.

num_layersint

Number of transformer encoder layers.

dim_feedforwardint

Dimension of feedforward network.

dropoutfloat

Dropout probability.

activationnn.Module

Activation function for feedforward layers.

layer_norm_epsfloat

Epsilon for layer normalization.

norm_firstbool

Whether to apply layer norm before attention.

biasbool

Whether to use bias in linear layers.

strict_symmetrybool

If True, enforce strict symmetry in SINDy coefficients.

input_lengthint

Length of input sequences.

hidden_sizeint

Hidden dimension size.

devicestr, optional

Device to place the model on. Default is "cpu".

Methods

forward(src[, is_causal])

Forward pass through the SINDy attention transformer.

get_dense_sindy_coefficients()

Return a list of dense SINDy coefficient matrices, one per attention head.

get_sindy_layer_coefficients_eigenvalues()

Get eigenvalues of SINDy coefficient matrices for all attention heads.

get_sindy_layer_coefficients_sum()

Sum of squared SINDy coefficients in all heads of the last layer.

print_sindy_layer_coefficients()

Print the SINDy layer coefficients for all attention heads in human-readable format.

set_forecast_length(forecast_length)

Set the forecast length for all SINDy attention layers.

threshold_sindy_layer_coefficients(threshold)

Threshold all SINDy coefficients in all heads of the last layer.