shredx.modules.transformer.MultiHeadAttention#

class shredx.modules.transformer.MultiHeadAttention(E_q: int, E_k: int, E_v: int, E_total: int, n_heads: int, dropout: float, bias: bool, dtype: dtype | None, device: str = 'cpu')#

Bases: Module

Standard multi-head attention mechanism.

Implements scaled dot-product attention with multiple heads, supporting both same and different query/key/value dimensions.

Parameters:
E_qint

Size of embedding dimension for query.

E_kint

Size of embedding dimension for key.

E_vint

Size of embedding dimension for value.

E_totalint

Total embedding dimension of combined heads post input projection. Each head has dimension E_total // n_heads.

n_headsint

Number of attention heads.

dropoutfloat

Dropout probability for attention weights.

biasbool

Whether to add bias to input/output projections.

dtypetorch.dtype, optional

Data type for parameters.

devicestr, optional

Device to place the model on. Default is "cpu".

Methods

forward(query, key, value[, is_causal])

Apply input projection, split heads, run SDPA, and project output.

Raises:
ValueError

If E_total is not divisible by n_heads.