SpikeDE.odefunc
This module serves as the core engine for Spiking Neural ODEs within the SpikeDE framework. It provides the ODEFuncFromFX class, which automatically transforms standard Spiking Neural Networks (SNNs) into continuous-time neural ODE systems suitable for integration with adaptive solvers (e.g., torchdiffeq).
Key Features
- Automatic Graph Transformation: Leverages PyTorch FX to symbolically trace SNN backbones, separating the continuous dynamics (membrane potential evolution) from discrete post-processing layers (e.g., voting or classification).
- Continuous Input Reconstruction: Supports high-precision reconstruction of discrete input spike trains at arbitrary time steps \(t\) during integration.
- Pure PyTorch Interpolation: Implements four interpolation strategies entirely in PyTorch to ensure seamless GPU acceleration without CPU-GPU data transfer overhead:
linear: Standard linear interpolation.nearest: Zero-order hold (nearest neighbor).cubic: Catmull-Rom cubic splines for smooth trajectories.akima: Akima splines for robustness against oscillations.
- Solver Compatibility: The generated vector field functions are fully compatible with standard ODE solvers, enabling efficient backpropagation through the integration process.
ODEFuncFromFX
Bases: Module
Wrapper that converts a Spiking Neural Network (SNN) into an ODE-compatible form.
This class leverages PyTorch FX to symbolically trace the input backbone SNN and
restructure its computation graph into two distinct parts:
- ODE Graph (
ode_gm): Contains all operations up to and including the last spiking neuron layer. It outputs: - The time derivatives of all neuronal membrane potentials (\(dv/dt\)).
-
Boundary values (spikes or intermediate tensors) required by downstream layers.
-
Post-Neuron Module (
post_neuron_module): Contains all operations occurring after the last neuron (e.g., voting layers, classifiers). This module is decoupled from the ODE integration loop and applied only after solving the ODE system.
Input signals are assumed to be sampled at discrete time points. To support continuous-time
ODE solvers, inputs are interpolated on-the-fly using the specified interpolation_method.
The resulting object can be passed directly to numerical ODE solvers (e.g., torchdiffeq)
as the vector field function \(f(t, v) = \text{d}v/\text{d}t\).
Attributes:
-
interpolation_method(str) –Interpolation scheme for continuous input reconstruction.
-
neuron_count(int) –Number of
BaseNeuroninstances detected in the backbone. -
x(Tensor) –Cached input tensor (shape:
(T, ...)). -
x_time(Tensor) –Time stamps corresponding to input samples (shape:
(T,)). -
nfe(int) –Number of function evaluations performed (useful for profiling solver cost).
-
ode_gm(GraphModule) –The ODE-compatible computation graph.
-
post_neuron_module(Module) –Module containing post-neuron operations.
-
traced(GraphModule) –The original traced backbone for reference.
Parameters:
-
backbone(Module) –The original SNN model containing
BaseNeuronlayers. Must be FX-traceable and contain at least one neuron layer. -
interpolation_method(str, default:'linear') –Method used to interpolate discrete inputs to continuous time. Supported options:
'linear': Linear interpolation between adjacent samples.'nearest': Hold last value (zero-order hold).'cubic': Catmull-Rom cubic spline interpolation.'akima': Akima spline interpolation (reduces overshoot).
Raises:
-
ValueError–If an unsupported node operation is encountered during graph rewriting.
Source code in spikeDE/odefunc.py
456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 | |
forward
Computes the vector field \(f(t, v) = \text{d}v/\text{d}t\) for ODE solvers.
This method is called repeatedly by the ODE integrator. It evaluates the ODE graph
at time t using the current membrane potentials v_mems and interpolated input.
Parameters:
-
t(float) –Current time (scalar).
-
v_mems(tuple[Tensor, ...]) –Tuple of membrane potential tensors, one per neuron layer, each of shape
(batch_size, ...).
Returns:
-
Tensor–Tuple[torch.Tensor, ...]: A tuple containing: -
dv_dt_i: Time derivative of membrane potential for the i-th neuron. -boundary_val_j: Intermediate values needed by the post-neuron module, in topological order. -
...–Total length is
neuron_count + num_boundary_values.
Source code in spikeDE/odefunc.py
get_ode_module
get_ode_module() -> GraphModule
Returns the FX graph module implementing the ODE vector field.
This module encapsulates the entire ODE-compatible computation graph and can be inspected, saved, or modified independently.
Returns:
-
GraphModule–fx.GraphModule: The internal ODE evaluation graph.
Source code in spikeDE/odefunc.py
get_post_neuron_module
get_post_neuron_module() -> Module
Returns the module that processes outputs after the last spiking neuron.
This module should be applied to the boundary values returned by the ODE solver to produce the final network prediction.
Returns:
-
Module–nn.Module: The post-neuron computation path.
Source code in spikeDE/odefunc.py
set_inputs
Caches the input signal and its sampling timestamps for interpolation.
These inputs are used during the ODE solve to reconstruct \(x(t)\) at arbitrary times.
Parameters:
-
x(Tensor) –Input tensor of shape
(T, batch_size, ...)whereTis the number of time steps. -
x_time(Tensor) –Corresponding time stamps of shape
(T,)or(batch_size, T), typically monotonically increasing.
Source code in spikeDE/odefunc.py
SNNLeafTracer
Bases: Tracer
Custom FX Tracer that treats specific modules as leaf nodes.
This tracer ensures that BaseNeuron, VotingLayer, and ClassificationHead
modules are not decomposed during symbolic tracing, preserving their internal
logic as single graph nodes.
is_leaf_module
Determine if a module should be treated as a leaf node.
Parameters:
-
m(Module) –The module instance being traced.
-
module_qualified_name(str) –The qualified name of the module.
Returns:
-
bool(bool) –True if the module is a leaf node (should not be traced internally).
Source code in spikeDE/odefunc.py
linear_interpolate_batched
Perform batched linear interpolation.
Parameters:
-
x(Tensor) –Input tensor
[T, B, ...]. -
x_time(Tensor) –Time points
[B, T]. -
t(Tensor) –Query times
[B].
Returns:
-
Tensor–torch.Tensor: Interpolated values
[B, ...].
Source code in spikeDE/odefunc.py
nearest_interpolate_batched
Perform batched nearest neighbor interpolation.
Parameters:
-
x(Tensor) –Input tensor
[T, B, ...]. -
x_time(Tensor) –Time points
[B, T]. -
t(Tensor) –Query times
[B].
Returns:
-
Tensor–torch.Tensor: Values at nearest time points
[B, ...].
Source code in spikeDE/odefunc.py
cubic_interpolate_batched
Perform batched cubic (Catmull-Rom) interpolation.
Requires at least 4 time points (T >= 4).
Parameters:
-
x(Tensor) –Input tensor
[T, B, ...]. -
x_time(Tensor) –Time points
[B, T]. -
t(Tensor) –Query times
[B].
Returns:
-
Tensor–torch.Tensor: Interpolated values
[B, ...].
Source code in spikeDE/odefunc.py
akima_interpolate_batched
Perform batched Akima interpolation.
Akima interpolation uses local slopes to reduce oscillations common in cubic splines. Requires at least 5 time points (T >= 5).
Parameters:
-
x(Tensor) –Input tensor
[T, B, ...]. -
x_time(Tensor) –Time points
[B, T]. -
t(Tensor) –Query times
[B].
Returns:
-
Tensor–torch.Tensor: Interpolated values
[B, ...].
Source code in spikeDE/odefunc.py
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 | |
interpolate
Interpolate batched input data at arbitrary query time(s) t.
This function reconstructs continuous-time input signals from discrete samples to support numerical ODE solvers that evaluate the vector field at non-integer time steps.
Parameters:
-
x(Tensor) –Input tensor of shape
[T, B, ...], whereTis the number of time points andBis the batch size. -
x_time(Tensor) –Time points tensor. Can be: -
[T]: Shared timestamps for all batches. -[B, T]or[T, B]: Batch-specific timestamps. -
t(float | Tensor) –Query time(s). Can be: -
floatorscalar tensor: Same time for all batches. -[B]tensor: Different time per batch. -
method(str, default:'linear') –Interpolation algorithm. Options: -
'linear': Linear interpolation (default). -'nearest': Nearest neighbor. -'cubic': Catmull-Rom cubic spline (requires T >= 4). -'akima': Akima spline, robust against oscillations (requires T >= 5).
Returns:
-
Tensor–torch.Tensor: Interpolated tensor of shape
[B, ...].
Raises:
-
ValueError–If an unsupported interpolation method is specified.
-
AssertionError–If input shapes are inconsistent.
Note
Values outside the time range [min(x_time), max(x_time)] are clamped
to the boundary values.
Source code in spikeDE/odefunc.py
remove_dead_code
Remove dead code from a traced FX graph.
Parameters:
-
m(Module) –A traced GraphModule.
Returns:
-
Module–nn.Module: A new GraphModule with unused nodes eliminated.
Source code in spikeDE/odefunc.py
symbolic_trace_leaf_neurons
symbolic_trace_leaf_neurons(module: Module) -> GraphModule
Symbolically trace a module using the custom SNNLeafTracer.
Parameters:
-
module(Module) –The PyTorch module to trace.
Returns:
-
GraphModule–fx.GraphModule: The traced graph module with leaf neurons preserved.