shredx.modules.transformer.TransformerEncoderModule#

class shredx.modules.transformer.TransformerEncoderModule(encoder_layer: Module, num_layers: int, norm: Module | None, dtype: dtype | None, device: str = 'cpu')#

Bases: Module

Stack of transformer encoder layers.

Applies multiple encoder layers sequentially with optional final normalization.

Parameters:
encoder_layernn.Module

Single encoder layer to clone.

num_layersint

Number of encoder layers.

normnn.Module, optional

Final layer normalization.

dtypetorch.dtype, optional

Data type for parameters.

devicestr, optional

Device to place the model on. Default is "cpu".

Methods

forward(src[, is_causal])

Forward pass through all encoder layers.