How-To: Exporting an AOTI Encoder Model

Goal

You have a pre-trained image encoder (e.g. ConvNeXt, ViT, ResNet) and want to export it as an AOT Inductor .pt2 package so it can be loaded in Neuralyzer’s Deep Learning widget. This guide walks through the process using ConvNeXt-Tiny from torchvision as a concrete example.

The exported encoder takes an RGB image and produces a spatial feature map — for ConvNeXt-Tiny with 224×224 input, this is a [B, 768, 7, 7] tensor. Post-encoder modules in Neuralyzer can then reduce this to a per-frame feature vector.

Prerequisites

  • Python 3.10+
  • PyTorch 2.9+ (pip install torch torchvision)

Step 1: Define the Encoder Wrapper

ConvNeXt (and most torchvision/timm models) includes a classification head that we don’t need. We wrap the model to return only the encoder’s spatial feature map, before any pooling or classification.

import torch
import torch.nn as nn
import torchvision.models as models


class ConvNeXtEncoder(nn.Module):
    """
    Wraps ConvNeXt-Tiny to return the spatial feature map from the 
    final convolutional stage, before the classification head.

    Input:  [B, 3, 224, 224]   (RGB image, normalized)
    Output: [B, 768, 7, 7]     (spatial feature map)
    """

    def __init__(self):
        super().__init__()
        full_model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
        # Keep only the feature extraction stages, discard avgpool + classifier
        self.features = full_model.features

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.features(x)
Other Encoders

The same pattern works for any encoder backbone. For timm models:

import timm

class TimmEncoder(nn.Module):
    def __init__(self, model_name: str = "convnext_tiny"):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=True, 
                                          num_classes=0, global_pool="")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.backbone(x)

For a ViT that outputs [B, 197, 384] (patch tokens + CLS), you may want to reshape the patch tokens to a spatial grid (e.g. [B, 384, 14, 14]) so that spatial post-encoder modules work correctly:

class ViTEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model("vit_small_patch16_224", pretrained=True,
                                          num_classes=0, global_pool="")
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        tokens = self.backbone(x)          # [B, 197, 384]
        patches = tokens[:, 1:, :]         # drop CLS token -> [B, 196, 384]
        B, N, C = patches.shape
        H = W = int(N ** 0.5)              # 14 x 14
        return patches.permute(0, 2, 1).reshape(B, C, H, W)  # [B, 384, 14, 14]

Step 2: Verify Output Shape

Before exporting, verify the encoder produces the expected output:

model = ConvNeXtEncoder()
model.eval()

example_input = torch.randn(1, 3, 224, 224)

with torch.no_grad():
    output = model(example_input)
    print(f"Input shape:  {example_input.shape}")   # [1, 3, 224, 224]
    print(f"Output shape: {output.shape}")           # [1, 768, 7, 7]
    assert output.shape == (1, 768, 7, 7)

Step 3: Export with AOT Inductor

import os

output_dir = os.path.dirname(os.path.abspath(__file__))

# Dynamic batch dimension — allows variable batch sizes at inference time
batch_dim = torch.export.Dim("batch", min=1, max=32)
dynamic_shapes = {"x": {0: batch_dim}}

with torch.no_grad():
    exported = torch.export.export(
        model, (example_input,), dynamic_shapes=dynamic_shapes
    )

    pt2_path = torch._inductor.aoti_compile_and_package(
        exported,
        package_path=os.path.join(output_dir, "convnext_encoder.pt2"),
    )

print(f"Exported to: {pt2_path}")
Dynamic Shapes

The dynamic_shapes argument marks the batch dimension as dynamic, meaning the compiled model accepts any batch size between 1 and 32 at runtime. The spatial dimensions (224×224) remain fixed — the model was compiled for that resolution.

If you need multiple input resolutions, export separate .pt2 packages for each resolution and use the weights_variants field in the JSON spec (see below).

Step 4: Create the JSON Model Spec

Create convnext_encoder.json alongside the .pt2 file:

{
  "model_id": "convnext_encoder",
  "display_name": "ConvNeXt-Tiny Encoder",
  "description": "ConvNeXt-Tiny backbone producing spatial features (no classification head)",
  "weights_path": "convnext_encoder.pt2",
  "batch_mode": {
    "dynamic": { "min": 1, "max": 32 }
  },
  "inputs": [
    {
      "name": "image",
      "shape": [3, 224, 224],
      "description": "Input RGB image, normalized to ImageNet statistics",
      "recommended_encoder": "ImageEncoder"
    }
  ],
  "outputs": [
    {
      "name": "features",
      "shape": [768, 7, 7],
      "description": "Spatial feature map from the last ConvNeXt stage"
    }
  ]
}

The weights_path is relative to this JSON file’s directory — keep both files together.

Step 5: (Optional) TorchScript Fallback

If AOT Inductor export fails for your encoder, fall back to TorchScript:

with torch.no_grad():
    traced = torch.jit.trace(model, example_input)
    traced.save(os.path.join(output_dir, "convnext_encoder.pt"))

Update weights_path in the JSON spec to "convnext_encoder.pt" — the backend auto-detects from the file extension.

Step 6: Load in Neuralyzer

  1. Open View → Tools → Deep Learning to open the widget.
  2. Load the convnext_encoder.json model spec from the model selector.
  3. The widget will show one input slot ("image") and one output slot ("features").
  4. Bind "image" to your video media data.
  5. Configure a post-encoder module if desired (e.g. Global Average Pooling to reduce [768, 7, 7][768]).
  6. Bind the output to a TensorData key.
  7. Run inference.

Complete Export Script

#!/usr/bin/env python3
"""
Export a ConvNeXt-Tiny encoder for Neuralyzer.
Produces both AOT Inductor (.pt2) and TorchScript (.pt) formats,
plus the JSON model spec.
"""

import json
import os

import torch
import torch.nn as nn
import torchvision.models as models


class ConvNeXtEncoder(nn.Module):
    """ConvNeXt-Tiny feature extractor (no classification head)."""

    def __init__(self):
        super().__init__()
        full_model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
        self.features = full_model.features

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.features(x)


def main():
    output_dir = os.path.dirname(os.path.abspath(__file__))
    model = ConvNeXtEncoder()
    model.eval()

    example_input = torch.randn(1, 3, 224, 224)

    # Verify
    with torch.no_grad():
        out = model(example_input)
        print(f"Output shape: {out.shape}")  # [1, 768, 7, 7]
        assert out.shape == (1, 768, 7, 7)

    # --- AOT Inductor export (with dynamic batch) ---
    print("Exporting with AOT Inductor...")
    batch_dim = torch.export.Dim("batch", min=1, max=32)
    dynamic_shapes = {"x": {0: batch_dim}}

    with torch.no_grad():
        exported = torch.export.export(
            model, (example_input,), dynamic_shapes=dynamic_shapes
        )
        pt2_path = torch._inductor.aoti_compile_and_package(
            exported,
            package_path=os.path.join(output_dir, "convnext_encoder.pt2"),
        )
    print(f"  -> {pt2_path}")

    # --- TorchScript fallback ---
    print("Exporting with TorchScript...")
    with torch.no_grad():
        traced = torch.jit.trace(model, example_input)
        pt_path = os.path.join(output_dir, "convnext_encoder.pt")
        traced.save(pt_path)
    print(f"  -> {pt_path}")

    # --- JSON model spec ---
    spec = {
        "model_id": "convnext_encoder",
        "display_name": "ConvNeXt-Tiny Encoder",
        "description": "ConvNeXt-Tiny backbone producing spatial features",
        "weights_path": "convnext_encoder.pt2",
        "batch_mode": {"dynamic": {"min": 1, "max": 32}},
        "inputs": [
            {
                "name": "image",
                "shape": [3, 224, 224],
                "description": "Input RGB image",
                "recommended_encoder": "ImageEncoder",
            }
        ],
        "outputs": [
            {
                "name": "features",
                "shape": [768, 7, 7],
                "description": "Spatial feature map from last ConvNeXt stage",
            }
        ],
    }

    json_path = os.path.join(output_dir, "convnext_encoder.json")
    with open(json_path, "w") as f:
        json.dump(spec, f, indent=2)
    print(f"  -> {json_path}")

    print("\nDone! Files ready for Neuralyzer:")
    print(f"  Model (AOT Inductor): {pt2_path}")
    print(f"  Model (TorchScript):  {pt_path}")
    print(f"  JSON spec:            {json_path}")


if __name__ == "__main__":
    main()

Troubleshooting

torch.export.export() fails with ConvNeXt

Some ConvNeXt operations (e.g. LayerNorm with unusual shapes) may not trace cleanly. Try:

  • Update to the latest PyTorch — AOT Inductor coverage improves with each release.
  • Use the TorchScript fallback.
  • If using timm, ensure exportable=True is passed to create_model().

Wrong output shape

If your encoder returns a different shape than expected:

  • Check whether the model includes a global pooling layer. Many torchvision models have avgpool before the classifier — make sure you’re stopping before it.
  • For timm models, use global_pool="" and num_classes=0 to strip pooling and the classification head.

CUDA vs CPU

AOT Inductor compiles for the device present at export time. If you export on a GPU machine, the .pt2 will contain CUDA kernels. For CPU deployment, export on a CPU-only environment or call model.cpu() before export.