SpikeDE.snn
This module provides the SNNWrapper class, which converts standard Spiking Neural Networks (SNNs)
into Fractional Differential Equation (FDE) systems. It supports flexible configuration of
fractional orders (\(\alpha\)) per neuron layer, including single-term, multi-term (distributed order),
and learnable parameters.
Key Features
- Per-layer fractional orders (single-term or multi-term).
- Learnable \(\alpha\) and coefficients via backpropagation.
- Automatic shape inference and parameter registration.
- Support for various FDE solvers (Grunwald-Letnikov, L1, etc.).
PerLayerAlphaConfig
PerLayerAlphaConfig(
alpha: float | list[float] | list[list[float]],
n_layers: int,
multi_coefficient: list[float] | list[list[float]] | None = None,
learn_alpha: bool | list[bool] = False,
learn_coefficient: bool | list[bool] = False,
alpha_mode: str = "auto",
device: device = None,
dtype: dtype = float32,
)
Configuration parser and validator for per-layer fractional orders (\(\alpha\)).
This class normalizes user inputs into a consistent internal format based on a decision table logic. It handles three primary configuration cases:
-
Case A (Per-layer Single-term): Each layer has a single \(\alpha\) value. Example:
alpha=[0.3, 0.5]withalpha_mode='per_layer'. -
Case B (Multi-term Broadcast): The same multi-term configuration applies to all layers. Example:
alpha=[0.3, 0.5, 0.7]withalpha_mode='multiterm'. -
Case C (Per-layer Multi-term): Each layer has its own multi-term configuration. Example:
alpha=[[0.3, 0.5], [0.4, 0.6]](auto-detected).
Attributes:
-
n_layers(int) –Number of neuron layers in the network.
-
case(str) –Detected configuration case.
-
per_layer_alpha(list[list[float]]) –Normalized alpha values per layer.
-
per_layer_coefficient(list[list[float]]) –Normalized coefficients per layer.
-
per_layer_is_multi_term(list[bool]) –Boolean flag per layer indicating multi-term usage.
-
per_layer_learn_alpha(list[bool]) –Learnable flag for alpha per layer.
-
per_layer_learn_coefficient(list[bool]) –Learnable flag for coefficients per layer.
Parameters:
-
alpha(float | list[float] | list[list[float]]) –Fractional order(s). Accepts:
float: Single value for all layers.List[float]: Interpretation depends onalpha_mode.List[List[float]]: Per-layer multi-term configuration.torch.Tensor: Will be converted to list.
-
n_layers(int) –Total number of neuron layers in the SNN.
-
multi_coefficient(list[float] | list[list[float]] | None, default:None) –Coefficients for multi-term FDEs. Accepts:
None: Defaults to ones.List[float]: Shared coefficients for all layers (Case B).List[List[float]]: Per-layer coefficients (Case C).
-
learn_alpha(bool | list[bool], default:False) –Whether \(\alpha\) values are learnable parameters.
bool: Applied globally to all layers.List[bool]: Per-layer learnable flags.
-
learn_coefficient(bool | list[bool], default:False) –Whether coefficients are learnable parameters.
bool: Applied globally.List[bool]: Per-layer learnable flags.
-
alpha_mode(str, default:'auto') –Disambiguation mode for flat list inputs.
'auto': Automatic detection (default).'per_layer': Force Case A.'multiterm': Force Case B.
-
device(device, default:None) –Target device for tensors.
-
dtype(dtype, default:float32) –Data type for tensors (default: float32).
Raises:
-
ValueError–If input dimensions mismatch, invalid modes are provided, or coefficient lengths do not match alpha lengths.
Source code in spikeDE/snn.py
print_config
Print configuration summary to stdout.
Source code in spikeDE/snn.py
register_parameters
register_parameters(module: Module) -> None
Register alpha and coefficient as parameters or buffers in the given module.
Parameters:
-
module(Module) –The nn.Module to register parameters into.
Source code in spikeDE/snn.py
SNNWrapper
SNNWrapper(
base: Module,
integrator: str = "odeint",
interpolation_method: str = "linear",
alpha: float | list[float] | list[list[float]] = 0.5,
multi_coefficient: list[float] | list[list[float]] | None = None,
learn_alpha: bool | list[bool] = False,
learn_coefficient: bool | list[bool] = False,
alpha_mode: str = "auto",
)
Bases: Module
SNN Wrapper with per-layer fractional order support.
This class wraps a standard PyTorch SNN model, converting its forward pass into a numerical integration of Fractional Differential Equations (FDEs). It supports flexible configuration of fractional orders (\(\alpha\)) per layer.
Supported Features:
- Per-layer alpha (single-term or multi-term).
- Per-layer learnable alpha and coefficients.
- Multiple FDE solvers (Grunwald-Letnikov, L1, etc.).
- Automatic shape inference via FX tracing.
Parameters:
-
base(Module) –Base neural network model (nn.Module) containing neuron layers.
-
integrator(str, default:'odeint') –Integrator type (
'odeint','fdeint', etc.). -
interpolation_method(str, default:'linear') –Input interpolation method.
-
alpha(float | list[float] | list[list[float]], default:0.5) –Fractional order(s). Can be:
float: same alpha for all layers (single-term).List[float]: interpretation depends on alpha_mode.List[List[float]]: per-layer multi-term (Case C).
-
multi_coefficient(list[float] | list[list[float]] | None, default:None) –Coefficients for multi-term FDE. Can be:
None: auto-fill with ones.List[float]: same for all multi-term layers (Case B).List[List[float]]: per-layer (Case C).
-
learn_alpha(bool | list[bool], default:False) –Whether alpha is learnable. Can be:
bool: applies to all layers.List[bool]: per-layer.
-
learn_coefficient(bool | list[bool], default:False) –Whether coefficients are learnable. Can be:
bool: applies to all layers.List[bool]: per-layer.
-
alpha_mode(str, default:'auto') –How to interpret flat list alpha. Options:
'auto': Try to detect based on length and multi_coefficient.'per_layer': Force Case A (each element is one layer's alpha).'multiterm': Force Case B (broadcast multi-term to all layers). Ignored if alpha contains nested lists (always Case C).
- API Reference SpikeDE.snn SNNWrapper eval
- API Reference SpikeDE.snn SNNWrapper train
Source code in spikeDE/snn.py
eval
eval() -> SNNWrapper
Sets the module in evaluation mode and propagates to submodules.
Source code in spikeDE/snn.py
forward
forward(
x: Tensor,
x_time: Tensor,
output_time: Tensor | None = None,
method: str = "euler",
options: dict[str, Any] = {"step_size": 0.1},
) -> Tensor
Perform the forward pass of the Fractional SNN.
Parameters:
-
x(Tensor) –Input tensor of shape [Time_Steps, Batch, ...].
-
x_time(Tensor) –Time points corresponding to input steps, shape [Time_Steps,].
-
output_time(Tensor | None, default:None) –Optional time points for output. If None, auto-generated.
-
method(str, default:'euler') –Integration method (
'euler','gl','trap','l1', etc.). -
options(dict[str, Any], default:{'step_size': 0.1}) –Dictionary of solver options (e.g.,
{'step_size': 0.1}).
Returns:
-
Tensor–Output tensor processed by the post-neuron module.
Source code in spikeDE/snn.py
904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 | |
get_per_layer_alpha
Get current per-layer alpha values as tensors. Always returns tensors to maintain gradient flow for learnable alphas.
Returns:
Source code in spikeDE/snn.py
get_per_layer_coefficient
Get current per-layer coefficient values as tensors. Always returns tensors to maintain gradient flow.
Note
For single-term layers (Case A), coefficients are technically unused by solvers but returned as [1.0] for interface consistency.
Returns:
Source code in spikeDE/snn.py
print_alpha_info
Print current alpha values and their learnable status.
Source code in spikeDE/snn.py
train
train(mode: bool = True) -> SNNWrapper
Sets the module in training mode and propagates to submodules.