shredx.modules.transformer.MultiHeadSINDyAttention#

class shredx.modules.transformer.MultiHeadSINDyAttention(E_q: int, E_k: int, E_v: int, E_total: int, n_heads: int, forecast_length: int, dropout: float, strict_symmetry: bool, bias: bool, dtype: dtype | None, device='cpu')#

Bases: Module

Multi-head attention with SINDy-based latent space rollout.

Replaces standard scaled dot-product attention output with ODE-based rollouts using learned SINDy dynamics. Each attention head has its own SINDy layer for independent dynamics learning.

Parameters:
E_qint

Size of embedding dimension for query.

E_kint

Size of embedding dimension for key.

E_vint

Size of embedding dimension for value.

E_totalint

Total embedding dimension of combined heads post input projection. Each head has dimension E_total // n_heads.

n_headsint

Number of attention heads.

forecast_lengthint

Number of future timesteps to predict via ODE rollout.

dropoutfloat

Dropout probability for attention weights.

strict_symmetrybool

If True, enforce strict symmetry in SINDy coefficients.

biasbool

Whether to add bias to input/output projections.

dtypetorch.dtype, optional

Data type for parameters.

devicestr, optional

Device to place the model on. Default is "cpu".

Methods

forward(query, key, value[, is_causal])

Apply input projection, split heads, run SDPA, SINDy rollout, and project output.

Raises:
ValueError

If E_total is not divisible by n_heads.