shredx.modules.sindy_loss_mixin.SINDyLossMixin#
- class shredx.modules.sindy_loss_mixin.SINDyLossMixin(dt: float, hidden_size: int, sindy_loss_threshold: float, *args, **kwargs)#
Bases:
ModuleMixin providing SINDy loss regularization for neural networks.
Adds learnable SINDy coefficients and methods to compute regularization loss based on how well the latent dynamics follow a sparse polynomial ODE. Designed to be used with multiple inheritance alongside encoder models.
- Parameters:
- poly_orderint
Polynomial order for SINDy library features.
- dtfloat
Time step for computing derivatives.
- hidden_sizeint
Dimension of the hidden state.
- sindy_loss_thresholdfloat
Threshold for coefficient sparsification.
- *args
Additional positional arguments passed to parent class.
- **kwargs
Additional keyword arguments passed to parent class.
Methods
compute_sindy_loss(x)Calculate SINDy loss based on derivatives with torchdiffeq.
thresholding([threshold])Apply thresholding to SINDy coefficients to enforce sparsity.
Notes
Class Methods:
compute_sindy_loss(x):
Calculates SINDy loss based on derivatives with torchdiffeq. Propagates forward all hidden states. Note: batch size and forecast length are combined into the batch dimension.
- Parameters:
x :
torch.Tensor. Transformed sequence of shape(batch_size, sequence_length, hidden_size).
- Returns:
torch.Tensor. SINDy regularization loss.
compute_sindy_loss_original(x):
Calculates SINDy loss based on derivatives with a midpoint integration method. For each time step (t0 to t1), integrates in two steps (t0 to t0.5, then t0.5 to t1).
- Parameters:
x :
torch.Tensor. Transformed sequence of shape(batch_size, sequence_length, hidden_size).
- Returns:
torch.Tensor. SINDy regularization loss.
thresholding(threshold):
Applies thresholding to SINDy coefficients to enforce sparsity.
- Parameters:
threshold : float, optional. Threshold value. If None, uses the default threshold.
- Returns:
None.