Skip to content

spikeDE.neuron

This module provides a flexible and extensible framework for building Spiking Neural Networks (SNNs) that seamlessly bridge standard integer-order dynamics with advanced fractional-order calculus. Unlike traditional frameworks that rely on discrete step-by-step updates, spikeDE reimagines neurons as continuous dynamical systems. This architectural shift allows users to upgrade standard models into Fractional-Order Spiking Neurons, endowing them with infinite memory and complex temporal dependencies without altering core logic.

At the heart of this module is the separation of concerns: neuron classes define instantaneous dynamics (computing derivatives), while external solvers (via SNNWrapper) handle state evolution and fractional integration. This design supports a wide range of models, from classic Integrate-and-Fire variants to sophisticated noisy-threshold and hard-reset mechanisms, all compatible with surrogate gradient learning.

Key Features

  • Modular Architecture: Stateless neuron modules compute derivatives (\(dv/dt\)) independently of state history, allowing them to work interchangeably with standard (odeint) and fractional (fdeint) solvers.
  • Learnable Parameters: Supports learnable membrane time constants (\(\tau\)) via exponential reparameterization and customizable surrogate gradient functions (e.g., arctan, sigmoid) for effective backpropagation through non-differentiable spikes.
  • Extensibility: Provides a clear BaseNeuron interface for defining custom dynamics, ensuring that user-defined neurons automatically inherit fractional capabilities when wrapped in the appropriate solver.

BaseNeuron

BaseNeuron(
    tau: float = 0.5,
    threshold: float = 1.0,
    surrogate_grad_scale: float = 5.0,
    surrogate_opt: str = "arctan_surrogate",
    tau_learnable: bool = False,
)

Bases: Module

Base class for spiking neuron models with configurable membrane time constant and surrogate gradients.

This abstract class provides the foundational structure for spiking neurons. It supports learnable or fixed membrane time constant (\(\tau\)) and customizable surrogate gradient functions for backpropagation through non-differentiable spikes.

The effective membrane time constant is computed as:

\[ \tau = \begin{cases} \tau_0 \cdot (1 + e^{\theta}) & \text{if } \texttt{tau_learnable=True} \\ \tau_0 & \text{otherwise} \end{cases} \]

where \(\tau_0\) is the initial value and \(\theta\) is a learnable parameter.

Subclasses must implement the forward method to define specific dynamics.

Attributes:

  • initial_tau (float) –

    Initial value of the membrane time constant \(\tau_0\).

  • tau_param (Parameter | None) –

    Learnable parameter \(\theta\) if tau_learnable=True; otherwise None.

  • tau (float) –

    Fixed \(\tau\) used when tau_learnable=False.

  • threshold (float) –

    Firing threshold \(V_{\text{th}}\).

  • surrogate_grad_scale (float) –

    Scaling factor for surrogate gradient steepness.

  • surrogate_f (Callable) –

    Surrogate gradient function (e.g., arctan-based).

  • tau_learnable (bool) –

    Whether \(\tau\) is trainable.

Parameters:

  • tau (float, default: 0.5 ) –

    The base membrane time constant \(\tau\). Used directly if tau_learnable=False, or as a scaling factor if tau_learnable=True.

  • threshold (float, default: 1.0 ) –

    The membrane potential threshold at which the neuron fires a spike.

  • surrogate_grad_scale (float, default: 5.0 ) –

    Scaling factor applied inside the surrogate gradient function to control gradient magnitude during backpropagation.

  • surrogate_opt (str, default: 'arctan_surrogate' ) –

    Name of the surrogate gradient function to use. Must be a key in the global surrogate_f dictionary (e.g., "arctan_surrogate").

  • tau_learnable (bool, default: False ) –

    If True, \(\tau\) becomes a learnable parameter. If False, \(\tau\) remains fixed.

Source code in spikeDE/neuron.py
def __init__(
    self,
    tau: float = 0.5,
    threshold: float = 1.0,
    surrogate_grad_scale: float = 5.0,
    surrogate_opt: str = "arctan_surrogate",
    tau_learnable: bool = False,
) -> None:
    r"""Initializes the BaseNeuron module.

    Args:
        tau: The base membrane time constant $\tau$. Used directly if `tau_learnable=False`,
            or as a scaling factor if `tau_learnable=True`.
        threshold: The membrane potential threshold at which the neuron fires a spike.
        surrogate_grad_scale: Scaling factor applied inside the surrogate gradient function
            to control gradient magnitude during backpropagation.
        surrogate_opt: Name of the surrogate gradient function to use.
            Must be a key in the global `surrogate_f` dictionary (e.g., `"arctan_surrogate"`).
        tau_learnable: If `True`, $\tau$ becomes a learnable parameter.
            If `False`, $\tau$ remains fixed.
    """
    super(BaseNeuron, self).__init__()
    self.initial_tau = tau

    if tau_learnable:
        self.tau_param = nn.Parameter(torch.tensor(0.0), requires_grad=True)
    else:
        self.tau_param = None
        self.tau = tau

    self.threshold = threshold
    self.surrogate_grad_scale = surrogate_grad_scale
    self.surrogate_f = surrogate_f[surrogate_opt]
    self.tau_learnable = tau_learnable

forward

forward(v_mem: Tensor, current_input: Tensor) -> tuple[Tensor, Tensor]

Performs one step of neuron state update.

Must be overridden by subclasses to implement specific spiking dynamics.

Parameters:

  • v_mem (Tensor) –

    Membrane potential tensor of shape (batch_size, ...).

  • current_input (Tensor) –

    Input current tensor, same shape as v_mem.

Returns:

  • tuple[Tensor, Tensor]

    A tuple (dv_dt, spike) where:

    • dv_dt: Effective derivative of membrane potential.
    • spike: Continuous spike approximation in [0, 1].

Raises:

Source code in spikeDE/neuron.py
def forward(
    self, v_mem: torch.Tensor, current_input: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Performs one step of neuron state update.

    Must be overridden by subclasses to implement specific spiking dynamics.

    Args:
        v_mem: Membrane potential tensor of shape `(batch_size, ...)`.
        current_input: Input current tensor, same shape as `v_mem`.

    Returns:
        A tuple `(dv_dt, spike)` where:

            - `dv_dt`: Effective derivative of membrane potential.
            - `spike`: Continuous spike approximation in [0, 1].

    Raises:
        NotImplementedError: Always raised here; subclass must implement.
    """
    raise NotImplementedError("Neuron forward method must be overridden.")

get_tau

get_tau() -> float

Returns the effective membrane time constant \(\tau\).

Ensures positivity via exponential reparameterization when learnable.

Returns:

  • float

    Scalar tensor representing \(\tau\).

Source code in spikeDE/neuron.py
def get_tau(self) -> float:
    r"""Returns the effective membrane time constant $\tau$.

    Ensures positivity via exponential reparameterization when learnable.

    Returns:
        Scalar tensor representing $\tau$.
    """
    if self.tau_learnable:
        return self.initial_tau * (1 + torch.exp(self.tau_param))
    else:
        return self.tau

IFNeuron

IFNeuron(
    tau: float = 0.5,
    threshold: float = 1.0,
    surrogate_grad_scale: float = 5.0,
    surrogate_opt: str = "arctan_surrogate",
    tau_learnable: bool = False,
)

Bases: BaseNeuron

Integrate-and-Fire (IF) spiking neuron model with surrogate gradients.

This model integrates input without leakage. The dynamics follow:

\[ \tau\frac{\text{d}v}{\text{d}t} = I(t), \quad \text{spike} = \sigma(v - V_{\text{th}}) \]

where \(\sigma\) is a differentiable surrogate.

Note

Despite inheriting tau, this model behaves as a pure integrator when leakage is disabled (i.e., no decay term on v_mem).

Source code in spikeDE/neuron.py
def __init__(
    self,
    tau: float = 0.5,
    threshold: float = 1.0,
    surrogate_grad_scale: float = 5.0,
    surrogate_opt: str = "arctan_surrogate",
    tau_learnable: bool = False,
) -> None:
    r"""Initializes the BaseNeuron module.

    Args:
        tau: The base membrane time constant $\tau$. Used directly if `tau_learnable=False`,
            or as a scaling factor if `tau_learnable=True`.
        threshold: The membrane potential threshold at which the neuron fires a spike.
        surrogate_grad_scale: Scaling factor applied inside the surrogate gradient function
            to control gradient magnitude during backpropagation.
        surrogate_opt: Name of the surrogate gradient function to use.
            Must be a key in the global `surrogate_f` dictionary (e.g., `"arctan_surrogate"`).
        tau_learnable: If `True`, $\tau$ becomes a learnable parameter.
            If `False`, $\tau$ remains fixed.
    """
    super(BaseNeuron, self).__init__()
    self.initial_tau = tau

    if tau_learnable:
        self.tau_param = nn.Parameter(torch.tensor(0.0), requires_grad=True)
    else:
        self.tau_param = None
        self.tau = tau

    self.threshold = threshold
    self.surrogate_grad_scale = surrogate_grad_scale
    self.surrogate_f = surrogate_f[surrogate_opt]
    self.tau_learnable = tau_learnable

forward

forward(
    v_mem: Tensor, current_input: Tensor | None = None
) -> tuple[Tensor, Tensor]

Forward pass for IF neuron dynamics (discrete-time, dt=1.0).

Parameters:

  • v_mem (Tensor) –

    Current membrane potential.

  • current_input (Tensor | None, default: None ) –

    Input current (same shape as v_mem).

Returns:

  • tuple[Tensor, Tensor]

    Tuple (dv_dt, spike) representing effective derivative and spike output.

Source code in spikeDE/neuron.py
def forward(
    self, v_mem: torch.Tensor, current_input: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Forward pass for IF neuron dynamics (discrete-time, `dt=1.0`).

    Args:
        v_mem: Current membrane potential.
        current_input: Input current (same shape as `v_mem`).

    Returns:
        Tuple `(dv_dt, spike)` representing effective derivative and spike output.
    """
    if current_input is None:
        return v_mem
    tau = self.get_tau()
    dt = 1.0
    dv_no_reset = (current_input) / tau
    v_post_charge = v_mem + dt * dv_no_reset
    spike = self.surrogate_f(
        v_post_charge - self.threshold, self.surrogate_grad_scale
    )
    dv_dt = dv_no_reset - (spike.detach() * self.threshold) / tau
    return dv_dt, spike

LIFNeuron

LIFNeuron(
    tau: float = 0.5,
    threshold: float = 1.0,
    surrogate_grad_scale: float = 5.0,
    surrogate_opt: str = "arctan_surrogate",
    tau_learnable: bool = False,
)

Bases: BaseNeuron

Leaky Integrate-and-Fire (LIF) spiking neuron model with surrogate gradients.

Implements classic leaky dynamics governed by:

\[ \tau \frac{\text{d}v}{\text{d}t} = -v + I(t), \quad \text{spike} = \sigma(v - V_{\text{th}}) \]

where \(\sigma\) is a differentiable surrogate.

Source code in spikeDE/neuron.py
def __init__(
    self,
    tau: float = 0.5,
    threshold: float = 1.0,
    surrogate_grad_scale: float = 5.0,
    surrogate_opt: str = "arctan_surrogate",
    tau_learnable: bool = False,
) -> None:
    r"""Initializes the BaseNeuron module.

    Args:
        tau: The base membrane time constant $\tau$. Used directly if `tau_learnable=False`,
            or as a scaling factor if `tau_learnable=True`.
        threshold: The membrane potential threshold at which the neuron fires a spike.
        surrogate_grad_scale: Scaling factor applied inside the surrogate gradient function
            to control gradient magnitude during backpropagation.
        surrogate_opt: Name of the surrogate gradient function to use.
            Must be a key in the global `surrogate_f` dictionary (e.g., `"arctan_surrogate"`).
        tau_learnable: If `True`, $\tau$ becomes a learnable parameter.
            If `False`, $\tau$ remains fixed.
    """
    super(BaseNeuron, self).__init__()
    self.initial_tau = tau

    if tau_learnable:
        self.tau_param = nn.Parameter(torch.tensor(0.0), requires_grad=True)
    else:
        self.tau_param = None
        self.tau = tau

    self.threshold = threshold
    self.surrogate_grad_scale = surrogate_grad_scale
    self.surrogate_f = surrogate_f[surrogate_opt]
    self.tau_learnable = tau_learnable

forward

forward(
    v_mem: Tensor, current_input: Tensor | None = None
) -> tuple[Tensor, Tensor]

Forward pass for LIF neuron dynamics (discrete-time, dt=1.0).

Parameters:

  • v_mem (Tensor) –

    Current membrane potential.

  • current_input (Tensor | None, default: None ) –

    Input current (same shape as v_mem).

Returns:

  • tuple[Tensor, Tensor]

    Tuple (dv_dt, spike) representing effective derivative and spike output.

Source code in spikeDE/neuron.py
def forward(
    self, v_mem: torch.Tensor, current_input: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    r"""Forward pass for LIF neuron dynamics (discrete-time, `dt=1.0`).

    Args:
        v_mem: Current membrane potential.
        current_input: Input current (same shape as `v_mem`).

    Returns:
        Tuple `(dv_dt, spike)` representing effective derivative and spike output.
    """
    if current_input is None:
        return v_mem
    tau = self.get_tau()
    dt = 1.0
    dv_no_reset = (-v_mem + current_input) / tau
    v_post_charge = v_mem + dt * dv_no_reset
    spike = self.surrogate_f(
        v_post_charge - self.threshold, self.surrogate_grad_scale
    )
    dv_dt = dv_no_reset - (spike.detach() * self.threshold) / tau
    return dv_dt, spike