Skip to content

SpikeDE.snn

This module provides the SNNWrapper class, which converts standard Spiking Neural Networks (SNNs) into Fractional Differential Equation (FDE) systems. It supports flexible configuration of fractional orders (\(\alpha\)) per neuron layer, including single-term, multi-term (distributed order), and learnable parameters.

Key Features

  • Per-layer fractional orders (single-term or multi-term).
  • Learnable \(\alpha\) and coefficients via backpropagation.
  • Automatic shape inference and parameter registration.
  • Support for various FDE solvers (Grunwald-Letnikov, L1, etc.).

PerLayerAlphaConfig

PerLayerAlphaConfig(
    alpha: float | list[float] | list[list[float]],
    n_layers: int,
    multi_coefficient: list[float] | list[list[float]] | None = None,
    learn_alpha: bool | list[bool] = False,
    learn_coefficient: bool | list[bool] = False,
    alpha_mode: str = "auto",
    device: device = None,
    dtype: dtype = float32,
)

Configuration parser and validator for per-layer fractional orders (\(\alpha\)).

This class normalizes user inputs into a consistent internal format based on a decision table logic. It handles three primary configuration cases:

  1. Case A (Per-layer Single-term): Each layer has a single \(\alpha\) value. Example: alpha=[0.3, 0.5] with alpha_mode='per_layer'.

  2. Case B (Multi-term Broadcast): The same multi-term configuration applies to all layers. Example: alpha=[0.3, 0.5, 0.7] with alpha_mode='multiterm'.

  3. Case C (Per-layer Multi-term): Each layer has its own multi-term configuration. Example: alpha=[[0.3, 0.5], [0.4, 0.6]] (auto-detected).

Attributes:

  • n_layers (int) –

    Number of neuron layers in the network.

  • case (str) –

    Detected configuration case.

  • per_layer_alpha (list[list[float]]) –

    Normalized alpha values per layer.

  • per_layer_coefficient (list[list[float]]) –

    Normalized coefficients per layer.

  • per_layer_is_multi_term (list[bool]) –

    Boolean flag per layer indicating multi-term usage.

  • per_layer_learn_alpha (list[bool]) –

    Learnable flag for alpha per layer.

  • per_layer_learn_coefficient (list[bool]) –

    Learnable flag for coefficients per layer.

Parameters:

  • alpha (float | list[float] | list[list[float]]) –

    Fractional order(s). Accepts:

    • float: Single value for all layers.
    • List[float]: Interpretation depends on alpha_mode.
    • List[List[float]]: Per-layer multi-term configuration.
    • torch.Tensor: Will be converted to list.
  • n_layers (int) –

    Total number of neuron layers in the SNN.

  • multi_coefficient (list[float] | list[list[float]] | None, default: None ) –

    Coefficients for multi-term FDEs. Accepts:

    • None: Defaults to ones.
    • List[float]: Shared coefficients for all layers (Case B).
    • List[List[float]]: Per-layer coefficients (Case C).
  • learn_alpha (bool | list[bool], default: False ) –

    Whether \(\alpha\) values are learnable parameters.

    • bool: Applied globally to all layers.
    • List[bool]: Per-layer learnable flags.
  • learn_coefficient (bool | list[bool], default: False ) –

    Whether coefficients are learnable parameters.

    • bool: Applied globally.
    • List[bool]: Per-layer learnable flags.
  • alpha_mode (str, default: 'auto' ) –

    Disambiguation mode for flat list inputs.

    • 'auto': Automatic detection (default).
    • 'per_layer': Force Case A.
    • 'multiterm': Force Case B.
  • device (device, default: None ) –

    Target device for tensors.

  • dtype (dtype, default: float32 ) –

    Data type for tensors (default: float32).

Raises:

  • ValueError

    If input dimensions mismatch, invalid modes are provided, or coefficient lengths do not match alpha lengths.

Source code in spikeDE/snn.py
def __init__(
    self,
    alpha: Union[float, List[float], List[List[float]]],
    n_layers: int,
    multi_coefficient: Optional[Union[List[float], List[List[float]]]] = None,
    learn_alpha: Union[bool, List[bool]] = False,
    learn_coefficient: Union[bool, List[bool]] = False,
    alpha_mode: str = "auto",
    device: torch.device = None,
    dtype: torch.dtype = torch.float32,
):
    r"""
    Initialize the Per-Layer Alpha Configuration.

    Args:
        alpha: Fractional order(s). Accepts:

            - `float`: Single value for all layers.
            - `List[float]`: Interpretation depends on `alpha_mode`.
            - `List[List[float]]`: Per-layer multi-term configuration.
            - `torch.Tensor`: Will be converted to list.
        n_layers: Total number of neuron layers in the SNN.
        multi_coefficient: Coefficients for multi-term FDEs. Accepts:

            - `None`: Defaults to ones.
            - `List[float]`: Shared coefficients for all layers (Case B).
            - `List[List[float]]`: Per-layer coefficients (Case C).
        learn_alpha: Whether $\alpha$ values are learnable parameters.

            - `bool`: Applied globally to all layers.
            - `List[bool]`: Per-layer learnable flags.
        learn_coefficient: Whether coefficients are learnable parameters.

            - `bool`: Applied globally.
            - `List[bool]`: Per-layer learnable flags.
        alpha_mode: Disambiguation mode for flat list inputs.

            - `'auto'`: Automatic detection (default).
            - `'per_layer'`: Force Case A.
            - `'multiterm'`: Force Case B.
        device: Target device for tensors.
        dtype: Data type for tensors (default: float32).

    Raises:
        ValueError: If input dimensions mismatch, invalid modes are provided,
                    or coefficient lengths do not match alpha lengths.
    """
    self.n_layers = n_layers
    self.device = device
    self.dtype = dtype

    if alpha_mode not in ("auto", "per_layer", "multiterm"):
        raise ValueError(
            f"alpha_mode must be 'auto', 'per_layer', or 'multiterm', got '{alpha_mode}'"
        )

    # Parse using decision table
    self.case, self.per_layer_alpha, self.per_layer_coefficient = self._parse(
        alpha, multi_coefficient, n_layers, alpha_mode
    )

    # Determine is_multi_term per layer
    self.per_layer_is_multi_term = [len(a) > 1 for a in self.per_layer_alpha]

    # Parse learnable flags
    self._parse_learn_flags(learn_alpha, learn_coefficient)

print_config

print_config()

Print configuration summary to stdout.

Source code in spikeDE/snn.py
def print_config(self):
    """Print configuration summary to stdout."""
    print(f"\n[Per-Layer Alpha Configuration]")
    print(f"  Case: {self.case}")
    print(f"  Layers: {self.n_layers}")
    for i in range(self.n_layers):
        alpha = self.per_layer_alpha[i]
        coef = self.per_layer_coefficient[i]
        is_multi = self.per_layer_is_multi_term[i]
        learn_a = self.per_layer_learn_alpha[i]
        learn_c = self.per_layer_learn_coefficient[i]

        if is_multi:
            print(
                f"  Layer {i}: {len(alpha)}-term, α={alpha}, coef={coef}, "
                f"learn_α={learn_a}, learn_coef={learn_c}"
            )
        else:
            print(f"  Layer {i}: single-term, α={alpha[0]}, learn_α={learn_a}")

register_parameters

register_parameters(module: Module) -> None

Register alpha and coefficient as parameters or buffers in the given module.

Parameters:

  • module (Module) –

    The nn.Module to register parameters into.

Source code in spikeDE/snn.py
def register_parameters(self, module: nn.Module) -> None:
    """
    Register alpha and coefficient as parameters or buffers in the given module.

    Args:
        module: The nn.Module to register parameters into.
    """
    module.per_layer_alpha_params = nn.ParameterList()
    module.per_layer_coefficient_params = nn.ParameterList()
    module._alpha_is_param = []
    module._coef_is_param = []

    for i in range(self.n_layers):
        alpha_vals = self.per_layer_alpha[i]
        coef_vals = self.per_layer_coefficient[i]
        learn_alpha = self.per_layer_learn_alpha[i]
        learn_coef = self.per_layer_learn_coefficient[i]

        # Create tensors
        alpha_tensor = torch.tensor(alpha_vals, dtype=self.dtype)
        coef_tensor = torch.tensor(coef_vals, dtype=self.dtype)

        # Register alpha
        if learn_alpha:
            module.per_layer_alpha_params.append(nn.Parameter(alpha_tensor))
            #
            module._alpha_is_param.append(True)
        else:
            module.per_layer_alpha_params.append(
                nn.Parameter(alpha_tensor, requires_grad=False)
            )
            module._alpha_is_param.append(False)

        # Register coefficient
        if learn_coef:
            module.per_layer_coefficient_params.append(nn.Parameter(coef_tensor))
            module._coef_is_param.append(True)
        else:
            module.per_layer_coefficient_params.append(
                nn.Parameter(coef_tensor, requires_grad=False)
            )
            module._coef_is_param.append(False)

    # Store metadata
    module._per_layer_is_multi_term = self.per_layer_is_multi_term.copy()
    module._alpha_case = self.case

SNNWrapper

SNNWrapper(
    base: Module,
    integrator: str = "odeint",
    interpolation_method: str = "linear",
    alpha: float | list[float] | list[list[float]] = 0.5,
    multi_coefficient: list[float] | list[list[float]] | None = None,
    learn_alpha: bool | list[bool] = False,
    learn_coefficient: bool | list[bool] = False,
    alpha_mode: str = "auto",
)

Bases: Module

SNN Wrapper with per-layer fractional order support.

This class wraps a standard PyTorch SNN model, converting its forward pass into a numerical integration of Fractional Differential Equations (FDEs). It supports flexible configuration of fractional orders (\(\alpha\)) per layer.

Supported Features:

  • Per-layer alpha (single-term or multi-term).
  • Per-layer learnable alpha and coefficients.
  • Multiple FDE solvers (Grunwald-Letnikov, L1, etc.).
  • Automatic shape inference via FX tracing.

Parameters:

  • base (Module) –

    Base neural network model (nn.Module) containing neuron layers.

  • integrator (str, default: 'odeint' ) –

    Integrator type ('odeint', 'fdeint', etc.).

  • interpolation_method (str, default: 'linear' ) –

    Input interpolation method.

  • alpha (float | list[float] | list[list[float]], default: 0.5 ) –

    Fractional order(s). Can be:

    • float: same alpha for all layers (single-term).
    • List[float]: interpretation depends on alpha_mode.
    • List[List[float]]: per-layer multi-term (Case C).
  • multi_coefficient (list[float] | list[list[float]] | None, default: None ) –

    Coefficients for multi-term FDE. Can be:

    • None: auto-fill with ones.
    • List[float]: same for all multi-term layers (Case B).
    • List[List[float]]: per-layer (Case C).
  • learn_alpha (bool | list[bool], default: False ) –

    Whether alpha is learnable. Can be:

    • bool: applies to all layers.
    • List[bool]: per-layer.
  • learn_coefficient (bool | list[bool], default: False ) –

    Whether coefficients are learnable. Can be:

    • bool: applies to all layers.
    • List[bool]: per-layer.
  • alpha_mode (str, default: 'auto' ) –

    How to interpret flat list alpha. Options:

    • 'auto': Try to detect based on length and multi_coefficient.
    • 'per_layer': Force Case A (each element is one layer's alpha).
    • 'multiterm': Force Case B (broadcast multi-term to all layers). Ignored if alpha contains nested lists (always Case C).
Source code in spikeDE/snn.py
def __init__(
    self,
    base: nn.Module,
    integrator: str = "odeint",
    interpolation_method: str = "linear",
    alpha: Union[float, List[float], List[List[float]]] = 0.5,
    multi_coefficient: Optional[Union[List[float], List[List[float]]]] = None,
    learn_alpha: Union[bool, List[bool]] = False,
    learn_coefficient: Union[bool, List[bool]] = False,
    alpha_mode: str = "auto",
):
    """
    Initialize SNNWrapper with per-layer alpha support.

    Args:
        base: Base neural network model (nn.Module) containing neuron layers.
        integrator: Integrator type (`'odeint'`, `'fdeint'`, etc.).
        interpolation_method: Input interpolation method.
        alpha: Fractional order(s). Can be:

            - `float`: same alpha for all layers (single-term).
            - `List[float]`: interpretation depends on alpha_mode.
            - `List[List[float]]`: per-layer multi-term (Case C).
        multi_coefficient: Coefficients for multi-term FDE. Can be:

            - `None`: auto-fill with ones.
            - `List[float]`: same for all multi-term layers (Case B).
            - `List[List[float]]`: per-layer (Case C).
        learn_alpha: Whether alpha is learnable. Can be:

            - `bool`: applies to all layers.
            - `List[bool]`: per-layer.
        learn_coefficient: Whether coefficients are learnable. Can be:

            - `bool`: applies to all layers.
            - `List[bool]`: per-layer.
        alpha_mode: How to interpret flat list alpha. Options:

            - `'auto'`: Try to detect based on length and multi_coefficient.
            - `'per_layer'`: Force Case A (each element is one layer's alpha).
            - `'multiterm'`: Force Case B (broadcast multi-term to all layers).
            Ignored if alpha contains nested lists (always Case C).
    """
    super().__init__()

    self.integrator_indicator = integrator
    self.interpolation_method = interpolation_method
    self.integrator = get_integrator(integrator)

    # Store alpha config (will be finalized in _set_neuron_shapes)
    self._alpha_spec = alpha
    self._multi_coefficient_spec = multi_coefficient
    self._learn_alpha_spec = learn_alpha
    self._learn_coefficient_spec = learn_coefficient
    self._alpha_mode_spec = alpha_mode

    # Build FX-based ODEFunc
    self.ode_func = ODEFuncFromFX(base, interpolation_method=interpolation_method)
    self.traced_backbone = self.ode_func.traced
    self.post_neuron_module = self.ode_func.get_post_neuron_module()
    self.neuron_instances = None
    # Initialize as None to track initialization status
    self.neuron_shapes = None

    self._is_initialized = False  # Add initialization flag
    # we must call _set_neuron_shapes before run
    self._alpha_config = None

    # Store direct references BEFORE compiling
    self._ode_gm = self.ode_func.ode_gm
    self._ode_func_uncompiled = self.ode_func  # Add this for set_inputs

eval

eval() -> SNNWrapper

Sets the module in evaluation mode and propagates to submodules.

Source code in spikeDE/snn.py
def eval(self) -> "SNNWrapper":
    """Sets the module in evaluation mode and propagates to submodules."""
    super().eval()
    self.traced_backbone.eval()
    self._ode_gm.eval()  # Use stored reference
    if hasattr(self, "post_neuron_module") and self.post_neuron_module is not None:
        self.post_neuron_module.eval()
    return self

forward

forward(
    x: Tensor,
    x_time: Tensor,
    output_time: Tensor | None = None,
    method: str = "euler",
    options: dict[str, Any] = {"step_size": 0.1},
) -> Tensor

Perform the forward pass of the Fractional SNN.

Parameters:

  • x (Tensor) –

    Input tensor of shape [Time_Steps, Batch, ...].

  • x_time (Tensor) –

    Time points corresponding to input steps, shape [Time_Steps,].

  • output_time (Tensor | None, default: None ) –

    Optional time points for output. If None, auto-generated.

  • method (str, default: 'euler' ) –

    Integration method ('euler', 'gl', 'trap', 'l1', etc.).

  • options (dict[str, Any], default: {'step_size': 0.1} ) –

    Dictionary of solver options (e.g., {'step_size': 0.1}).

Returns:

  • Tensor

    Output tensor processed by the post-neuron module.

Source code in spikeDE/snn.py
@requires_initialization
def forward(
    self,
    x: torch.Tensor,
    x_time: torch.Tensor,
    output_time: Optional[torch.Tensor] = None,
    method: str = "euler",
    options: Dict[str, Any] = {"step_size": 0.1},
) -> torch.Tensor:
    """
    Perform the forward pass of the Fractional SNN.

    Args:
        x: Input tensor of shape [Time_Steps, Batch, ...].
        x_time: Time points corresponding to input steps, shape [Time_Steps,].
        output_time: Optional time points for output. If None, auto-generated.
        method: Integration method (`'euler'`, `'gl'`, `'trap'`, `'l1'`, etc.).
        options: Dictionary of solver options (e.g., `{'step_size': 0.1}`).

    Returns:
        Output tensor processed by the post-neuron module.
    """
    time_steps, batch_size = x.shape[:2]
    if len(x_time) != time_steps:
        x_time = x_time[0:time_steps]
    self._ode_func_uncompiled.set_inputs(x, x_time)

    # Initialize neuron membrane potentials
    adjusted_neuron_shapes = [
        (batch_size, *shape[1:]) for shape in self.neuron_shapes
    ]
    v_mems = [torch.zeros(s, device=x.device) for s in adjusted_neuron_shapes]
    # Initialize boundary outputs with correct shapes
    boundary_inits = [
        torch.zeros((batch_size, *shape[1:]), device=x.device)
        for shape in self.boundary_shapes
    ]

    # Initial state:
    initial_state = (*v_mems, *boundary_inits)

    # Helper function to avoid code duplication
    def process_boundaries(v_mem_all_time_and_final_spike):
        if self.n_boundaries == 1:
            finalspike_out = torch.stack(v_mem_all_time_and_final_spike[-1], dim=0)
            return self.post_neuron_module(finalspike_out)
        else:
            boundary_outputs = tuple(
                torch.stack(
                    v_mem_all_time_and_final_spike[self.neuron_count + i], dim=0
                )
                for i in range(self.n_boundaries)
            )
            return self.post_neuron_module(boundary_outputs)

    # 2) now run the real integration
    if True:  # output_time is None:
        # create output_time to add one more element to x_time
        if len(x_time) > 1:
            dt = x_time[1] - x_time[0]
        else:
            dt = options.get("step_size", 1.0)

        # Create the next time point on the same device
        next_t = (x_time[-1] + dt).unsqueeze(0)

        # Concatenate to form the full output time vector
        output_time = torch.cat((x_time, next_t), dim=0)

    # Get current per-layer alpha and coefficient
    per_layer_alpha = self.get_per_layer_alpha()
    per_layer_coefficient = self.get_per_layer_coefficient()

    # check if it is fdeint or odeint
    if self.integrator_indicator == "odeint" and method == "euler":
        v_mem_all_time_and_final_spike = euler_integrate_tuple_compiled(
            self.ode_func, initial_state, output_time, self.neuron_count
        )
        return process_boundaries(v_mem_all_time_and_final_spike)

    elif self.integrator_indicator == "fdeint":
        memory = None if options.get("memory", -1) == -1 else options["memory"]

        # Determine solver based on alpha case:
        # Case A: per-layer single-term → use requested solver
        # Case B/C: multi-term involved → always use multiterm solver
        use_multiterm = self._alpha_case in ("B", "C") or method == "glmulti"

        if use_multiterm:
            # Use multiterm solver for Case B, C, or explicit request
            if self._alpha_case in ("B", "C") and method not in ("gl", "glmulti"):
                import warnings

                warnings.warn(
                    f"Alpha configuration is Case {self._alpha_case} (multi-term). "
                    f"Method '{method}' not supported for multi-term, "
                    f"using GrunwaldLetnikovMultitermSNN instead.",
                    UserWarning,
                )

            v_mem_all_time_and_final_spike = glmethod_multiterm_integrate_tuple(
                self.ode_func,
                initial_state,
                per_layer_alpha,
                output_time,
                memory=memory,
                per_layer_coefficient=per_layer_coefficient,
            )
        else:
            # Case A (per-layer single-term): use requested solver
            integrate_method = SOLVERS[method]
            v_mem_all_time_and_final_spike = integrate_method(
                self.ode_func,
                initial_state,
                per_layer_alpha,
                output_time,
                memory=memory,
                per_layer_coefficient=per_layer_coefficient,
            )

        return process_boundaries(v_mem_all_time_and_final_spike)

    # ----------------------- please ignore the following temporarily-----------------------

    if self.integrator_indicator == "odeint_mem":
        v_mem_all_time_and_final_spike = step_dynamics(
            self.ode_func, initial_state, output_time
        )
        finalspike_out = torch.stack(v_mem_all_time_and_final_spike, dim=0)
        final_output = self.post_neuron_module(finalspike_out)
        return final_output

    elif (
        self.integrator_indicator == "odeint_adjoint"
        or self.integrator_indicator == "odeint"
    ):
        v_mem_all_time_and_cumulated_spike = self.integrator(
            self.ode_func,
            initial_state,
            output_time,
            method=method,
            options=options,
        )
        finalspike_out_sum = v_mem_all_time_and_cumulated_spike[-1][-1:, ...]

        final_output = self.post_neuron_module(finalspike_out_sum)
        return final_output

    elif self.integrator_indicator == "fdeint_mem":

        v_mem_all_time_and_final_spike = step_dynamics(
            self.ode_func, initial_state, output_time
        )
        # finalspike_out = v_mem_all_time_and_final_spike
        finalspike_out = torch.stack(v_mem_all_time_and_final_spike, dim=0)
        final_output = self.post_neuron_module(finalspike_out)
        return final_output

    elif self.integrator_indicator == "fdeint_adjoint":
        memory = None if options.get("memory", -1) == -1 else options["memory"]

        v_mem_all_time_and_cumulated_spike = fdeint_adjoint(
            self.ode_func,
            initial_state,
            per_layer_alpha[0],
            output_time,
            method=method,
            memory=memory,
        )
        # print('v_mem_all_time_and_cumulated_spike.shape: ', v_mem_all_time_and_cumulated_spike[-1].shape)
        finalspike_out_sum = v_mem_all_time_and_cumulated_spike[-1].unsqueeze(0)
        final_output = self.post_neuron_module(finalspike_out_sum)
        return final_output

    elif (
        False
    ):  # self.integrator_indicator == "fdeint_adjoint" or self.integrator_indicator == "fdeint":
        # print(f"output_time: {output_time}")# print('using', self.integrator_indicator, 'for integration')
        # raise NotImplementedError("please only use odeint+euler or fdeint+gl")
        T = output_time[-1]
        step_size = output_time[-1] - output_time[-2]
        v_mem_all_time_and_cumulated_spike = self.integrator(
            self.ode_func,
            (*v_mems, spike_sum_init),
            torch.tensor(self.alpha),
            t=T,
            step_size=step_size,
            method=method,
            options=options,
        )
        # print(v_mem_all_time_and_cumulated_spike[-1].shape)
        finalspike_out_sum = v_mem_all_time_and_cumulated_spike[-1].unsqueeze(0)
        # v_scaled = final_v_mem - self.ode_func.finalneuron_threshold
        # finalspike_out = self.ode_func.finalneuron_surrogate_f(v_scaled, self.ode_func.finalneuron_surrogate_grad_scale)
        final_output = self.post_neuron_module(finalspike_out_sum)
        return final_output

get_per_layer_alpha

get_per_layer_alpha() -> list[Tensor]

Get current per-layer alpha values as tensors. Always returns tensors to maintain gradient flow for learnable alphas.

Returns:

  • list[Tensor]

    List of alpha tensors for each layer.

Source code in spikeDE/snn.py
def get_per_layer_alpha(self) -> List[torch.Tensor]:
    """
    Get current per-layer alpha values as tensors.
    Always returns tensors to maintain gradient flow for learnable alphas.

    Returns:
        List of alpha tensors for each layer.
    """
    result = []
    for i in range(len(self.per_layer_alpha_params)):
        alpha_param = self.per_layer_alpha_params[i]
        # Always return tensor to maintain gradient flow
        # Don't call .item() as it detaches from computation graph!
        result.append(alpha_param)
    return result

get_per_layer_coefficient

get_per_layer_coefficient() -> list[Tensor]

Get current per-layer coefficient values as tensors. Always returns tensors to maintain gradient flow.

Note

For single-term layers (Case A), coefficients are technically unused by solvers but returned as [1.0] for interface consistency.

Returns:

  • list[Tensor]

    List of coefficient tensors for each layer.

Source code in spikeDE/snn.py
def get_per_layer_coefficient(self) -> List[torch.Tensor]:
    """
    Get current per-layer coefficient values as tensors.
    Always returns tensors to maintain gradient flow.

    Note: 
        For single-term layers (Case A), coefficients are technically unused by solvers but returned as `[1.0]` for interface consistency.

    Returns:
        List of coefficient tensors for each layer.
    """
    result = []
    for i in range(len(self.per_layer_coefficient_params)):
        # Always return the coefficient tensor
        # Don't return None as it breaks gradient flow
        result.append(self.per_layer_coefficient_params[i])
    return result

print_alpha_info

print_alpha_info()

Print current alpha values and their learnable status.

Source code in spikeDE/snn.py
def print_alpha_info(self):
    """Print current alpha values and their learnable status."""
    if not self._is_initialized:
        print("Not initialized yet. Call _set_neuron_shapes first.")
        return

    print("\n[Current Alpha Values]")
    for i in range(len(self.per_layer_alpha_params)):
        alpha = self.per_layer_alpha_params[i]
        is_multi = self._per_layer_is_multi_term[i]
        is_learnable = self._alpha_is_param[i]

        if is_multi:
            coef = self.per_layer_coefficient_params[i]
            coef_learnable = self._coef_is_param[i]
            print(f"  Layer {i} (multi-term):")
            print(f"    Alpha: {alpha.data.tolist()} (learnable: {is_learnable})")
            print(
                f"    Coefficient: {coef.data.tolist()} (learnable: {coef_learnable})"
            )
        else:
            print(
                f"  Layer {i} (single-term): alpha = {alpha.item():.4f} (learnable: {is_learnable})"
            )

train

train(mode: bool = True) -> SNNWrapper

Sets the module in training mode and propagates to submodules.

Source code in spikeDE/snn.py
def train(self, mode: bool = True) -> "SNNWrapper":
    """Sets the module in training mode and propagates to submodules."""
    super().train(mode)
    self.traced_backbone.train(mode)
    self._ode_gm.train(mode)  # Use stored reference
    if hasattr(self, "post_neuron_module") and self.post_neuron_module is not None:
        self.post_neuron_module.train(mode)
    return self