SpikeDE.surrogate
This module provides a comprehensive collection of surrogate gradient functions and stochastic spiking mechanisms designed for training Spiking Neural Networks (SNNs) using backpropagation.
Since the spiking operation (Heaviside step function) is non-differentiable, this library implements various smooth approximations to estimate gradients during the backward pass while maintaining discrete binary spikes in the forward pass. Additionally, it offers a noisy threshold approach that enables stochastic firing during training for improved regularization and biological plausibility.
Key Features
- Multiple Surrogate Gradients: Includes Sigmoid, Arctangent, Piecewise Linear, and Gaussian derivatives, each with distinct mathematical properties suited for different network depths and convergence requirements.
- Stochastic Spiking: Implements
NoisyThresholdSpike, which injects logistic noise into the threshold to create a differentiable soft-spike mechanism during training, automatically reverting to hard spikes during inference. - Flexible API: Available as both reusable
torch.autograd.Functionclasses for custom layer integration and functional wrappers for concise usage.
SigmoidSurrogate
Bases: Function
Sigmoid-based surrogate gradient function for Spiking Neural Networks (SNNs).
This class implements a custom autograd function where the forward pass uses a hard Heaviside step function to generate discrete spikes, while the backward pass approximates the undefined gradient using the derivative of a scaled sigmoid function.
Forward Pass:
Backward Pass (Surrogate Gradient):
Where: \(x\) is the input (membrane potential minus threshold, \(U - \theta\)), \(\kappa\) (scale) controls the sharpness of the approximation.
Attributes:
-
scale(float) –The scaling factor \(\kappa\). Larger values approximate the true step function more closely but may lead to vanishing gradients.
backward
staticmethod
Computes the gradient using the sigmoid derivative as a surrogate.
Parameters:
-
ctx(FunctionCtx) –Context object containing saved tensors from the forward pass.
-
grad_output(Tensor) –Gradient of the loss with respect to the output of the forward pass.
Returns:
-
tuple[Tensor, None]–A tuple containing the gradient with respect to the input and None for the non-differentiable scale parameter.
Source code in spikeDE/surrogate.py
forward
staticmethod
Performs the forward pass using a hard threshold (Heaviside step function).
Parameters:
-
ctx(FunctionCtx) –Context object to save tensors for the backward pass.
-
input(Tensor) –Input tensor representing the membrane potential minus threshold (\(U - \theta\)).
-
scale(float) –Scaling factor (\(\kappa\)) controlling the sharpness of the surrogate gradient.
Returns:
-
Tensor–A binary tensor of spikes (0.0 or 1.0).
Source code in spikeDE/surrogate.py
ArctanSurrogate
Bases: Function
Arctangent-based surrogate gradient function for SNNs.
This method uses the derivative of the arctangent function as the surrogate gradient. It features heavier tails compared to the sigmoid, allowing gradients to propagate even when the membrane potential is far from the threshold.
Forward Pass:
Backward Pass (Surrogate Gradient):
Note
The implementation includes a normalization factor involving \(\pi/2\) to ensure stable gradient magnitudes, slightly modifying the standard arctan derivative.
Attributes:
-
scale(float) –The scaling factor \(\kappa\).
backward
staticmethod
Computes the gradient using the normalized arctangent derivative.
Parameters:
-
ctx(FunctionCtx) –Context object containing saved tensors.
-
grad_output(Tensor) –Upstream gradient from the loss function.
Returns:
Source code in spikeDE/surrogate.py
forward
staticmethod
Performs the forward pass using a hard threshold.
Parameters:
-
ctx(FunctionCtx) –Context object to save tensors for the backward pass.
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
scale(float) –Scaling factor (\(\kappa\)).
Returns:
-
Tensor–Binary spike tensor.
Source code in spikeDE/surrogate.py
PiecewiseLinearSurrogate
Bases: Function
Piecewise Linear (Triangular) surrogate gradient function.
A computationally efficient approximation that defines a triangular window around the threshold. Gradients are constant within the window and zero outside.
Forward Pass:
Backward Pass (Surrogate Gradient):
Where \(\gamma\) (gamma) defines the width of the active region.
Attributes:
-
gamma(float) –Half-width of the linear region.
backward
staticmethod
Computes the gradient using a rectangular window function.
Parameters:
-
ctx(FunctionCtx) –Context object containing saved tensors.
-
grad_output(Tensor) –Upstream gradient.
Returns:
Source code in spikeDE/surrogate.py
forward
staticmethod
Performs the forward pass using a hard threshold.
Parameters:
-
ctx(FunctionCtx) –Context object to save tensors.
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
gamma(float) –Width parameter (\(\gamma\)).
Returns:
-
Tensor–Binary spike tensor.
Source code in spikeDE/surrogate.py
GaussianSurrogate
Bases: Function
Gaussian-based surrogate gradient function.
Uses a normalized Gaussian function to approximate the derivative. It offers the smoothest profile with exponential decay, providing very localized gradient updates.
Forward Pass:
Backward Pass (Surrogate Gradient):
Where \(\sigma\) (sigma) controls the spread (standard deviation) of the gradient.
Attributes:
-
sigma(float) –Standard deviation of the Gaussian.
backward
staticmethod
Computes the gradient using the Gaussian PDF.
Parameters:
-
ctx(FunctionCtx) –Context object containing saved tensors.
-
grad_output(Tensor) –Upstream gradient.
Returns:
Source code in spikeDE/surrogate.py
forward
staticmethod
Performs the forward pass using a hard threshold.
Parameters:
-
ctx(FunctionCtx) –Context object to save tensors.
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
sigma(float) –Standard deviation parameter (\(\sigma\)).
Returns:
-
Tensor–Binary spike tensor.
Source code in spikeDE/surrogate.py
NoisyThresholdSpikeModule
Bases: Module
PyTorch Module wrapper for noisy_threshold_spike.
Automatically tracks the model's training state (self.training) to switch between
stochastic soft spikes and deterministic hard spikes.
Attributes:
-
scale(float) –Sharpness parameter (\(\kappa\)).
-
sample(bool) –Whether to sample noise or use mean-field.
Source code in spikeDE/surrogate.py
sigmoid_surrogate
Functional wrapper for the Sigmoid surrogate gradient.
Allows gradients to flow through the non-differentiable spiking operation during backpropagation by replacing the step function's derivative with a smooth sigmoid derivative.
Parameters:
-
input(Tensor) –Input tensor representing membrane potential minus threshold (\(U - \theta\)).
-
scale(float, default:5.0) –Scaling factor (\(\kappa\)). Higher values make the surrogate sharper.
Returns:
-
Tensor–A tensor of binary spikes (0.0 or 1.0) with custom gradient flow.
Source code in spikeDE/surrogate.py
arctan_surrogate
Functional wrapper for the Arctan surrogate gradient.
Ideal for deep networks where gradient vanishing is a concern due to its heavy-tailed gradient distribution.
Parameters:
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
scale(float, default:2.0) –Scaling factor (\(\kappa\)).
Returns:
-
Tensor–Binary spike tensor with arctan-based gradient flow.
Source code in spikeDE/surrogate.py
piecewise_linear_surrogate
Functional wrapper for the Piecewise Linear surrogate gradient.
Best for high-speed training on resource-constrained hardware or models requiring sparse gradient updates.
Parameters:
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
gamma(float, default:1.0) –Width of the active region (\(\gamma\)).
Returns:
-
Tensor–Binary spike tensor with linear-based gradient flow.
Source code in spikeDE/surrogate.py
gaussian_surrogate
Functional wrapper for the Gaussian surrogate gradient.
Best for precision tasks where only neurons very close to firing should receive updates.
Parameters:
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
sigma(float, default:1.0) –Spread of the gradient (\(\sigma\)).
Returns:
-
Tensor–Binary spike tensor with Gaussian-based gradient flow.
Source code in spikeDE/surrogate.py
noisy_threshold_spike
noisy_threshold_spike(
input: Tensor,
scale: float = 5.0,
training: bool = True,
sample: bool = True,
) -> Tensor
Stochastic spiking function using a noisy threshold.
Instead of a hard spike in the forward pass, this method injects logistic noise into the threshold, creating a stochastic soft spike during training. During inference (eval mode), it reverts to a hard spike. This acts as both the forward mechanism and its own differentiable path (real backward), unlike the surrogate methods above.
Training Mode:
Where \(\epsilon \sim \text{Logistic}(0, 1)\) sampled via inverse CDF:
Inference Mode:
Parameters:
-
input(Tensor) –Input tensor (\(U - \theta\)).
-
scale(float, default:5.0) –Sharpness parameter (\(\kappa\)). Higher values make the sigmoid sharper.
-
training(bool, default:True) –If True, applies noise and soft sigmoid. If False, uses hard threshold.
-
sample(bool, default:True) –If True, samples noise per element. If False, uses the mean-field approximation (standard sigmoid without noise).
Returns:
-
Tensor–Soft probabilities during training, binary spikes during eval.