Deep Learning Library

Overview

The Deep Learning library (src/DeepLearning) provides a framework for integrating PyTorch models into Neuralyzer for neural data analysis. It supports loading pre-trained models, encoding Neuralyzer data types into tensors, running inference, and decoding output tensors back into data types that can be stored in the DataManager.

The library is built on top of libtorch (the C++ frontend of PyTorch) and supports two inference backends:

  • AOT Inductor (.pt2) — the recommended backend, using ahead-of-time compiled native kernels
  • TorchScript (.pt) — legacy support via torch::jit::load()

Architecture

The library is organized into several subsystems, each handling a distinct part of the model-inference pipeline:

┌──────────────────────────────────────────────────────────────┐
│                    DeepLearning Library                       │
├──────────────┬──────────────┬──────────────┬─────────────────┤
│  Encoders    │  Decoders    │  Models      │  Infrastructure │
│              │              │              │                 │
│ ImageEncoder │ TensorTo     │ ModelBase    │ DeviceManager   │
│ Point2D      │   Point2D    │ RuntimeModel │ ModelRegistry   │
│ Mask2D       │ TensorTo     │ NeuroSAM     │ ModelExecution  │
│ Line2D       │   Mask2D     │              │ InferenceBackend│
│              │ TensorTo     │              │                 │
│              │   Line2D     │              │                 │
├──────────────┴──────────────┴──────────────┴─────────────────┤
│                         libtorch                              │
└──────────────────────────────────────────────────────────────┘

Data Flow

The inference pipeline follows a three-stage encode → forward → decode pattern:

Neuralyzer Data          Tensors               Tensors           Neuralyzer Data
(MediaData,      ──►  [B, C, H, W]   ──►    [B, C, H, W]  ──►  (MaskData,
 PointData,        Encoders             Model               Decoders  PointData,
 MaskData,                            forward()                      LineData)
 LineData)
  1. Encoders convert Neuralyzer geometry types and images into tensor channels
  2. Model forward pass runs the neural network inference
  3. Decoders convert output tensor channels back into geometry types

Subsystem Details

Channel Encoders (channel_encoding/)

Encoders write geometry data into a pre-allocated [B, C, H, W] float tensor at a specified batch index and channel. They are pure functions on geometry primitives with no DataManager dependency, making them independently testable.

Encoder Input Type Modes Description
ImageEncoder raw pixel data + ImageSize Raw Resizes to (H, W), normalizes to [0,1]. Supports grayscale (replicates to 3 channels) and RGB.
Point2DEncoder Point2D<float> Binary, Heatmap Scales coordinates to tensor resolution. Heatmap places 2D Gaussians; overlapping peaks use max operator.
Mask2DEncoder Mask2D Binary Scales pixel coordinates to tensor resolution. Sets 1.0 at each mask pixel location.
Line2DEncoder Line2D Binary, Heatmap Rasterizes polyline onto tensor. Binary uses Bresenham line drawing. Heatmap uses distance-to-line Gaussian.

All encoders are registered in EncoderFactory by string key (e.g. "ImageEncoder") so that JSON model specs and the widget UI can reference them by name.

Channel Decoders (channel_decoding/)

Decoders read from a specified batch index and channel of an output tensor, converting the activation pattern back into a geometry type. Coordinates are scaled back to the original image dimensions via target_image_size.

Decoder Output Type Strategy
TensorToPoint2D Point2D<float> Argmax with optional parabolic subpixel refinement. Also supports multi-peak detection via local maxima.
TensorToMask2D Mask2D Thresholding — collects all pixels above a configurable threshold.
TensorToLine2D Line2D Threshold → Zhang-Suen skeletonization (thin to 1px) → endpoint detection → connectivity tracing.

All decoders are registered in DecoderFactory by string key.

Model Layer (models_v2/)

TensorSlotDescriptor

Each model declares its expected I/O via TensorSlotDescriptor structs. A slot describes one named tensor:

  • name — slot identifier (e.g. "encoder_image")
  • shape — tensor dimensions excluding batch (e.g. {3, 256, 256})
  • recommended_encoder / recommended_decoder — hints for the UI
  • is_static — if true, the user sets this input once (memory frames)
  • is_boolean_mask — if true, values are 0/1 flags
  • sequence_dim — optional axis index for frame sequences
  • dtype — expected tensor data type (default Float32)

ModelBase

The abstract base class for all model wrappers. Subclasses declare their slot metadata and implement forward():

class ModelBase {
public:
    virtual std::string modelId() const = 0;
    virtual std::string displayName() const = 0;
    virtual std::vector<TensorSlotDescriptor> inputSlots() const = 0;
    virtual std::vector<TensorSlotDescriptor> outputSlots() const = 0;
    virtual void loadWeights(std::filesystem::path const & path) = 0;
    virtual bool isReady() const = 0;
    virtual int preferredBatchSize() const { return 0; }
    virtual int maxBatchSize() const { return 0; }
    virtual std::unordered_map<std::string, torch::Tensor>
    forward(std::unordered_map<std::string, torch::Tensor> const & inputs) = 0;
};

The batch dimension is configurable — encoders write at a caller-specified batch_index, and forward() receives tensors with a leading batch dimension whose size the caller controls.

ModelExecution

A strategy-pattern dispatcher that delegates to one of the inference backends. It auto-detects the backend from file extension (.pt → TorchScript, .pt2 → AOT Inductor) or accepts an explicit BackendType.

Inference Backends (models_v2/backends/)

The InferenceBackend abstract interface defines load(), isLoaded(), and execute() methods that operate on torch::Tensor — the common currency across all backends.

Backend Format API Notes
TorchScriptBackend .pt torch::jit::load() JIT-interpreted execution. Deprecated by PyTorch but fully functional. Supports multiple named methods.
AOTInductorBackend .pt2 AOTIModelPackageLoader Ahead-of-time compiled native kernels. Best performance. Recommended for new models.
ExecuTorchBackend .pte ExecuTorch runtime Optional (behind ENABLE_EXECUTORCH CMake flag). Designed for edge devices, not recommended for workstation use.

Model Registry (registry/)

A singleton that stores factory functions for creating model instances. Models self-register at static initialization time using the DL_REGISTER_MODEL macro:

// In MyModel.cpp
DL_REGISTER_MODEL(MyModel);  // registers with modelId() as the key

The widget UI queries the registry for available models and their metadata.

Runtime Models (runtime/)

For models that don’t need custom C++ logic, RuntimeModel is a generic ModelBase subclass driven entirely by a JSON specification (RuntimeModelSpec). This allows defining new models at runtime without recompilation:

{
  "model_id": "my_tracker",
  "display_name": "My Whisker Tracker",
  "weights_path": "model.pt2",
  "inputs": [
    { "name": "image", "shape": [3, 256, 256], "recommended_encoder": "ImageEncoder" }
  ],
  "outputs": [
    { "name": "heatmap", "shape": [1, 256, 256], "recommended_decoder": "TensorToMask2D" }
  ]
}

JSON parsing uses reflect-cpp (rfl::json::read/write) for automatic struct serialization.

Device Management (device/)

DeviceManager is a singleton that manages GPU/CPU device selection. It lazily detects CUDA at first use and provides:

  • device() — the active device (CUDA if available, else CPU)
  • toDevice(tensor) — moves a tensor to the active device
  • setDevice(device) — force a specific device (for testing or user preference)
  • cudaAvailable() — whether CUDA is available

All model wrappers and backends query DeviceManager instead of creating their own device objects.

Qt Integration (PIMPL Pattern)

The deep learning library uses torch::Tensor throughout, but Qt’s #define slots macro conflicts with libtorch’s slots() member function. The widget layer solves this with the PIMPL (Pointer to Implementation) pattern:

  • SlotAssembler — a bridge class with a clean public API (no torch headers in the .hpp). Its Impl struct (in the .cpp) holds all torch-dependent code.
  • DeepLearningBindingData — pure data structs with no Qt or torch dependencies, used to pass binding information between the UI and SlotAssembler. Includes:
    • SlotBindingData — dynamic (per-frame) input binding with a time_offset field for temporal shifting (e.g. -1 reads one frame behind)
    • StaticInputData — static (memory) input binding
    • OutputBindingData — output slot binding
    • computeEncodingFrame(current_frame, batch_index, time_offset, max_frame) — utility that combines frame, batch index, and temporal offset with clamping to [0, max_frame]

This ensures Qt headers and torch headers never appear in the same translation unit.

Directory Structure

src/DeepLearning/
├── CMakeLists.txt
├── torch_helpers.hpp              # Legacy helper (device global)
├── channel_encoding/              # Encoders: data → tensor
│   ├── ChannelEncoder.hpp         # Abstract base
│   ├── ImageEncoder.hpp/.cpp
│   ├── Point2DEncoder.hpp/.cpp
│   ├── Mask2DEncoder.hpp/.cpp
│   ├── Line2DEncoder.hpp/.cpp
│   └── EncoderFactory.hpp/.cpp
├── channel_decoding/              # Decoders: tensor → data
│   ├── ChannelDecoder.hpp         # Abstract base
│   ├── TensorToPoint2D.hpp/.cpp
│   ├── TensorToMask2D.hpp/.cpp
│   ├── TensorToLine2D.hpp/.cpp
│   └── DecoderFactory.hpp/.cpp
├── device/                        # GPU/CPU management
│   └── DeviceManager.hpp/.cpp
├── models_v2/                     # Model abstraction
│   ├── ModelBase.hpp              # Abstract base
│   ├── ModelExecution.hpp/.cpp    # Strategy-pattern dispatcher
│   ├── TensorSlotDescriptor.hpp   # I/O slot metadata
│   ├── backends/                  # Inference backends
│   │   ├── InferenceBackend.hpp   # Abstract interface
│   │   ├── TorchScriptBackend.hpp/.cpp
│   │   ├── AOTInductorBackend.hpp/.cpp
│   │   └── ExecuTorchBackend.hpp/.cpp  (optional)
│   └── neurosam/                  # Concrete model
│       └── NeuroSAMModel.hpp/.cpp
├── registry/                      # Model registry
│   └── ModelRegistry.hpp/.cpp
└── runtime/                       # JSON-defined models
    ├── RuntimeModelSpec.hpp/.cpp
    └── RuntimeModel.hpp/.cpp

Testing

Tests mirror the source structure under tests/DeepLearning/. All unit tests use Catch2. Integration tests that require model weight files are tagged [integration].

Run all deep learning tests:

ctest --preset linux-clang-release --output-on-failure -R "DeepLearning"

See Also