shredx.modules.transformer.MultiHeadAttention#
- class shredx.modules.transformer.MultiHeadAttention(E_q: int, E_k: int, E_v: int, E_total: int, n_heads: int, dropout: float, bias: bool, dtype: dtype | None, device: str = 'cpu')#
Bases:
ModuleStandard multi-head attention mechanism.
Implements scaled dot-product attention with multiple heads, supporting both same and different query/key/value dimensions.
- Parameters:
- E_qint
Size of embedding dimension for query.
- E_kint
Size of embedding dimension for key.
- E_vint
Size of embedding dimension for value.
- E_totalint
Total embedding dimension of combined heads post input projection. Each head has dimension
E_total // n_heads.- n_headsint
Number of attention heads.
- dropoutfloat
Dropout probability for attention weights.
- biasbool
Whether to add bias to input/output projections.
- dtypetorch.dtype, optional
Data type for parameters.
- devicestr, optional
Device to place the model on. Default is
"cpu".
Methods
forward(query, key, value[, is_causal])Apply input projection, split heads, run SDPA, and project output.
- Raises:
- ValueError
If
E_totalis not divisible byn_heads.