shredx.modules.sindy_loss_mixin.SINDyLossMixin#

class shredx.modules.sindy_loss_mixin.SINDyLossMixin(dt: float, hidden_size: int, sindy_loss_threshold: float, *args, **kwargs)#

Bases: Module

Mixin providing SINDy loss regularization for neural networks.

Adds learnable SINDy coefficients and methods to compute regularization loss based on how well the latent dynamics follow a sparse polynomial ODE. Designed to be used with multiple inheritance alongside encoder models.

Parameters:
poly_orderint

Polynomial order for SINDy library features.

dtfloat

Time step for computing derivatives.

hidden_sizeint

Dimension of the hidden state.

sindy_loss_thresholdfloat

Threshold for coefficient sparsification.

*args

Additional positional arguments passed to parent class.

**kwargs

Additional keyword arguments passed to parent class.

Methods

compute_sindy_loss(x)

Calculate SINDy loss based on derivatives with torchdiffeq.

thresholding([threshold])

Apply thresholding to SINDy coefficients to enforce sparsity.

Notes

Class Methods:

compute_sindy_loss(x):

  • Calculates SINDy loss based on derivatives with torchdiffeq. Propagates forward all hidden states. Note: batch size and forecast length are combined into the batch dimension.

  • Parameters:
    • x : torch.Tensor. Transformed sequence of shape (batch_size, sequence_length, hidden_size).

  • Returns:
    • torch.Tensor. SINDy regularization loss.

compute_sindy_loss_original(x):

  • Calculates SINDy loss based on derivatives with a midpoint integration method. For each time step (t0 to t1), integrates in two steps (t0 to t0.5, then t0.5 to t1).

  • Parameters:
    • x : torch.Tensor. Transformed sequence of shape (batch_size, sequence_length, hidden_size).

  • Returns:
    • torch.Tensor. SINDy regularization loss.

thresholding(threshold):

  • Applies thresholding to SINDy coefficients to enforce sparsity.

  • Parameters:
    • threshold : float, optional. Threshold value. If None, uses the default threshold.

  • Returns:
    • None.