shredx.modules.transformer.MultiHeadSINDyAttention#
- class shredx.modules.transformer.MultiHeadSINDyAttention(E_q: int, E_k: int, E_v: int, E_total: int, n_heads: int, forecast_length: int, dropout: float, strict_symmetry: bool, bias: bool, dtype: dtype | None, device='cpu')#
Bases:
ModuleMulti-head attention with SINDy-based latent space rollout.
Replaces standard scaled dot-product attention output with ODE-based rollouts using learned SINDy dynamics. Each attention head has its own SINDy layer for independent dynamics learning.
- Parameters:
- E_qint
Size of embedding dimension for query.
- E_kint
Size of embedding dimension for key.
- E_vint
Size of embedding dimension for value.
- E_totalint
Total embedding dimension of combined heads post input projection. Each head has dimension
E_total // n_heads.- n_headsint
Number of attention heads.
- forecast_lengthint
Number of future timesteps to predict via ODE rollout.
- dropoutfloat
Dropout probability for attention weights.
- strict_symmetrybool
If True, enforce strict symmetry in SINDy coefficients.
- biasbool
Whether to add bias to input/output projections.
- dtypetorch.dtype, optional
Data type for parameters.
- devicestr, optional
Device to place the model on. Default is
"cpu".
Methods
forward(query, key, value[, is_causal])Apply input projection, split heads, run SDPA, SINDy rollout, and project output.
- Raises:
- ValueError
If
E_totalis not divisible byn_heads.