Debugging JAX & Flax NNX (Part 2)

Developer

Debugging compiled machine learning code can feel like a different sport compared with the usual Python workflow. JAX gives you enormous performance, and Flax NNX provides a clean model API, but the ahead-of-time compilation and transformation stack means the usual Python debugging tricks sometimes fail. I’ve been working with these tools and following the practical guidance Robert Crowe shares, and I want to lay out a clear, usable toolkit you can adopt today.

Below I cover the most effective, practical techniques for hunting down errors and inspecting models built with JAX and Flax NNX. I explain when to fall back to standard Python debugging, how to use JAX’s built-in NaN tracer, and how Flax NNX’s inspection helpers like NNX.display and NNX.so let you peek into intermediate states without messy plumbing. Expect clear rules of thumb, tradeoffs, and step-by-step suggestions you can apply to your own projects.

Table of Contents

Why debugging JAX feels different 🛠️

JAX’s biggest strength is also the thing that changes how you debug: just-in-time compilation and functional transforms such as jax.jit, jax.vmap, and jax.scan. When you wrap code in jax.jit, the function is compiled and executed in a different execution model. That is fantastic for speed, but it means normal runtime introspection — prints, pdb.set_trace, IDE breakpoints — often won’t behave the way you expect.

Think of it like this: the Python you write is a blueprint. JAX compiles that blueprint into efficient kernels, possibly fusing operations and moving computation to accelerators. That compiled execution no longer executes statement by statement in Python. So the debugging approach needs to adapt. I treat debugging in three tiers:

  • Eager tools — use Python prints, PDB, and IDE debuggers by running code eagerly (no JIT).
  • JAX-native inspection — use jax.debug_nans and other flags that trace inside the compiled graph.
  • Model-level introspection — use Flax NNX helpers like NNX.display and NNX.so to inspect parameters, state, and intermediate activations.

Each tier has a place. I’ll walk through how and when to use each one, and how to combine them into a workflow that helps you find bugs quickly without accidentally shipping slower code.

JAX.disabledJIT: the escape hatch 🐢

If you want the full power of classic Python debugging — standard print(), pdb.set_trace(), or stepping through code in the VS Code debugger — then JAX provides a reliable escape hatch: JAX.disabled_jit. When you disable JIT, JAX stops compiling your functions and runs them eagerly. Everything runs like normal Python, which means breakpoints and prints work exactly as you expect.

There are three ways to turn off JIT:

  • Wrap code in a context manager: with jax.disable_jit(): ...
  • Set the global configuration programmatically (less common for debugging interactive runs).
  • Use the environment variable JAX_DISABLE_JIT=1 before launching your process.

Use the context manager when you want to temporarily run a small portion of code eagerly and keep the rest of your system unchanged. For example, wrap your model inference or a suspect function in the with block and run tests there. When debugging locally in a notebook or script, the context manager is usually the simplest and least invasive approach.

Important practical notes:

  • Performance: disabling JIT removes all compilation benefits, so expect much slower runs. Only use it temporarily for debugging.
  • Transformation depth: even with JAX disabled_jit, certain higher-order transformations like jax.vmap or jax.scan may still hide inner vectorized or scanned workings from a step debugger. The context manager lets you inspect the outer Python structure, but you may need different strategies to probe deeply inside those transforms.
  • IDE compatibility: once code is running eagerly, IDE debuggers like VS Code and PyCharm work. You can set breakpoints, inspect variables, step line-by-line.

I use disabled_jit to reproduce tricky bugs in a familiar environment. When an error is nondeterministic or tied to control flow, stepping through the Python version often reveals where shapes, dtypes, or logic diverge from expectations. After I fix the issue, I run the JIT version again to ensure performance and numerical behavior remain correct.

How I use disabled_jit in practice

My typical approach is:

  1. Reproduce the problem in a minimal snippet.
  2. Wrap the suspect function or the forward pass in with jax.disable_jit(): and run it.
  3. Use print() and pdb.set_trace() to inspect shapes, values, and control flow.
  4. Apply a fix and re-enable JIT to verify behavior under compilation.

This method finds a surprising number of bugs quickly. If the error only appears under JIT and not eager execution, proceed to the next tool: jax.debug_nans.

Finding NaNs with jax.debug_nans 🔎

NaNs and infinities are one of the most common and painful failures in numerical code. Tracking down the exact operation that produced a NaN in a JIT compiled function can feel impossible if you only use prints. JAX provides a specialized flag to handle this case: jax.debug_nans (often set via jax.config.update or the JAX_DEBUG_NANS environment variable).

When jax.debug_nans is enabled, JAX watches outputs of primitives inside JIT-compiled sections. If it detects a NaN or infinity, JAX automatically reruns the corresponding region eagerly, stepping through operations until it finds the primitive that produced the bad value. It then raises an informative error that points to that primitive.

Robert Crowe: "When JAX Debug NaNs is active, JAX watches the outputs of operations inside JIT. If it spots a NaN, it automatically reruns that part of the code eagerly, step by step, until it finds the exact primitive operation that produced the NaN, and then raises an error pointing right at it."

This behavior is incredibly helpful because it eliminates a lot of manual trial and error. Instead of inserting prints around every suspect operation, you enable the flag and let JAX tell you where the NaN originated.

Key tradeoffs and usage notes:

  • Overhead: enabling jax.debug_nans adds runtime overhead because JAX must monitor intermediate outputs and may re-run parts eagerly. Use it for debugging sessions, not in production training loops.
  • Scope: the flag is most effective when a NaN happens inside a jax.jit region. If the NaN occurs outside JIT, normal python debugging will show it.
  • Combine with logging: once you find the primitive, it’s often useful to re-run with disabled_jit or NNX.so (discussed below) to capture local context and intermediates around the failing operation.

In my experience, jax.debug_nans is the quickest route to the root cause for numerical explosions. Common causes include accidental divisions by zero, invalid log inputs, unstable normalization steps, or poor initialization. Once you have the primitive pinpointed, typical fixes are clamping inputs, adjusting epsilon values, improving numerical stability, or fixing shape/broadcasting mistakes.

Inspecting model structure with NNX.display 🧭

Flax NNX encourages clear separation between model definition and parameters/state, and it provides great introspection tools. NNX.display is a simple but powerful helper for quickly visualizing a model's structure.

Use NNX.display on modules, layers, or optimizer objects to get a snapshot of their internal structure. The output includes names, types, shapes, and current values of parameters and any state variables. In a notebook, NNX.display gives you an interactive tree you can expand to explore the model. Even outside a notebook, it produces a readable representation you can scan to verify architecture and variable shapes.

Why NNX.display matters:

  • Quick sanity checks: confirm that layers are present, parameters are initialized with expected shapes, and batch stats or optimizer states exist where you expect them.
  • PyTree visibility: NNX modules are native JAX pytrees. NNX.display surfaces that tree in a friendly form, making it easier to reason about how your parameters and states will be passed into transformations like jax.jit or jax.grad.
  • Debugging mismatches: many runtime errors stem from incorrect parameter shapes or unexpected None values. NNX.display helps catch these before you hit compiled failures.

When I’m building or refactoring a model, I run NNX.display right after initialization and again after a forward pass. That way I catch missing variables, shape mismatches, or incorrect defaults early.

Capturing intermediates with NNX.so 🧩

One recurring pattern in debugging neural networks is wanting to inspect intermediate activations inside a model. Passing dictionaries around or returning extra debug-specific values from every layer quickly becomes messy. Flax NNX solves this with a built-in method: NNX.so.

NNX.so is a method on modules that lets you capture and store values computed during the forward pass. Think of it as planting a named marker inside the module where you want to save a value (activations, pre-activations, statistics, etc.). Later, you can retrieve those stored values from the module instance instead of threading them through your call stack.

How NNX.so works

The basic signature includes a few arguments:

  • variable type: categorizes the stored value (common choices include nnx.intermediates). This is useful for filtering and for integrating with NNX state management utilities later.
  • name: a string used as the attribute name on the module instance where the value will be stored.
  • value: the actual object to store, typically a JAX array.
  • optional init and reduce functions: these let you control how repeated calls to so for the same name are handled (more on this below).

When you call self.so(variable_type, "dense_output", x) inside a module’s __call__ implementation, NNX creates an attribute on the module instance named dense_output. That attribute’s .value property holds the stored values. By default, if you call NNX.so multiple times for the same name in the same forward pass, NNX appends each new value to a tuple stored in .value.

Example behavior and important caveats:

  • Default append semantics: multiple calls produce a tuple of values. This is convenient for saving outputs across iterations, but can grow large if used inside loops or every batch.
  • Memory cost: stitching many intermediate tensors into memory can blow up RAM. Be cautious about sewing values in inner loops, large batches, or long sequences.
  • Reduce and init: if you don’t want unbounded tuples, provide a reduce function to combine repeated values (sum, mean, concatenate along a controlled axis). The init function can set a default aggregator state on first call.
  • Retrieval: access stored values via model_instance.dense_output.value. If you used a variable type grouping, you can also use NNX.split, NNX.state, or NNX.pop to operate on subsets of the module state.

I found NNX.so particularly helpful for:

  • Visualizing activations from several layers to understand representational bottlenecks.
  • Implementing custom losses that depend on intermediate features without changing function signatures throughout the model.
  • Capturing layer outputs during a failing forward pass so you can inspect numeric ranges, sparsity, or unexpected broadcasting.

Because NNX.so attaches values directly to the module instance, a good practice is to call NNX.pop or explicitly clear sewn values after inspection to avoid carrying debug artifacts into training runs.

Using NNX.so responsibly

Some pragmatic rules I follow:

  • Only sew the minimum necessary tensors to diagnose the issue.
  • Prefer reduce functions that keep memory bounded when sewing inside loops or scans.
  • After debugging, remove or guard NNX.so calls behind a debug flag so they don’t affect training performance or memory.
  • Combine NNX.so with NNX.display to get both structural and dynamic views of the module state.

Putting it together: a pragmatic debugging workflow 🔧

Here’s a reproducible workflow I use when a model run fails or shows unexpected behavior. The workflow combines the tools above in a practical order that minimizes time spent fumbling around.

  1. Reproduce with a minimal example. Reduce the failing code to the smallest snippet that still demonstrates the problem. This includes fixed random seeds and minimal input shapes.
  2. Inspect the module. Use NNX.display immediately after initialization and after a forward pass to verify parameter shapes and state variables.
  3. Try eager execution. Wrap the suspect function in with jax.disable_jit(): to see if the error reproduces in eager mode. Use print statements and pdb.set_trace as needed.
  4. If the failure only occurs with JIT, enable jax.debug_nans. This will often report the exact primitive that produced a NaN or infinity.
  5. Use NNX.so to capture relevant intermediates. Sew in outputs around layers that the debug_nans output points to, or around areas you want to inspect numerically.
  6. Run a controlled batch. With debug instrumentation in place, run a single forward/backward pass. Inspect sewn values and NNX.display output to confirm suspicions (shape mismatches, large magnitudes, unexpected zeros).
  7. Apply a fix and re-enable JIT. After making the change, run the compiled version again to ensure the fix works with optimizations.
  8. Profile and monitor. Once fixed, use JAX profiler or integration with TensorBoard to ensure you didn’t regress performance or introduce memory issues.

These steps are deliberately iterative. It’s common to go from eager debugging to JAX-native flags and back to reduce the search space quickly.

Advanced tips and limitations ⚠️

There are a few recurring gotchas and advanced behaviors to watch for when debugging JAX and Flax NNX.

Transforms hide control flow

jvmap, jax.scan, and other transforms can hide internal behavior. Disabling JIT doesn’t always let you step through every transformed inner operation. If you suspect an issue inside a vmapped or scanned region, try:

  • Running the inner function directly on a single example without vmap/scan applied.
  • Using NNX.so inside the inner function to capture per-iteration intermediates, then splitting or popping them for inspection.

PRNG and statefulness

JAX’s PRNG design (functional PRNGKeys) means that random behavior depends explicitly on keys you pass. If nondeterministic behavior appears while debugging, ensure keys are used consistently and that you aren’t accidentally reusing keys across calls. When you go to eager mode, make sure the same key behavior is preserved to reproduce bugs reliably.

Device placement and memory

Swapping between CPU and accelerator when debugging can change memory behavior. If you sew many large tensors with NNX.so, you may exceed device memory. Consider fetching intermediates to the host as needed or reducing batch sizes during debugging.

Don’t leave debug flags on

A practical but critical rule: remove jax.disable_jit context blocks, jax.debug_nans, and heavy NNX.so instrumentation before committing training runs. These tools are meant for short-lived debugging; they add overhead and may produce different performance characteristics.

Practical examples and common fixes 💡

Here are a few typical problems I encounter and the concrete ways I fix them with the tools above.

1. Sudden NaN in loss after a few iterations

  • Symptoms: Training runs fine for several steps, then loss becomes NaN and gradients blow up.
  • Approach: Enable jax.debug_nans. Run a single failing step and let JAX point to the primitive producing NaN. Use NNX.so to capture the activations or logits around that primitive. Often the root cause is a division by zero in a normalization or an exponential on a large input. Fix by adding small epsilons, gradient clipping, or better initialization.

2. Shape mismatch inside jax.jit but not in eager

  • Symptoms: Eager run succeeds; JIT fails with a shape broadcasting error deep inside fused ops.
  • Approach: Use jax.disable_jit for the minimal snippet to step through and print shapes. Add explicit assertions or chex checks to enforce shapes. Use NNX.display to validate parameter shapes before compilation. Sometimes implicit Python shape logic (like list appends or conditional shape creation) leads to XLA fusions that assume different static shapes.

3. Inconsistent behavior with vmapped function

  • Symptoms: Single-example inference is good; when using vmap to batch inputs, results are wrong or NaN.
  • Approach: Extract the inner function and test it on a single input. Re-run with disabled_jit and call the inner function per sample manually. Use NNX.so to capture per-sample intermediates by sewing inside the inner function. Often the problem involves incorrect axis handling, broadcasting, or a reduce initialized with wrong axes.

Resources and next steps 📚

If you want to keep improving your JAX and Flax NNX debugging skillset, here are the practical next steps I recommend.

  • Read the JAX documentation sections on debugging and configuration flags to learn exact environment variables and config APIs.
  • Explore the Flax NNX docs for details on NNX.so, NNX.display, and state utilities such as NNX.split and NNX.pop.
  • Integrate small test harnesses that can reproduce failures with fixed seeds and minimal input — these are invaluable for isolating bugs.
  • Use a small suite of assertions (Chex is a popular choice) to catch shape and dtype problems early, and ensure these checks are JIT-friendly.
  • When performance tuning, learn to use the JAX profiler and tools like Model Explorer to correlate debugging fixes with runtime and memory behavior.

All of these tools fit into a broader workflow where you alternate between quick, eager debugging and targeted JAX-native tracing. Mastery comes from practice: the more you debug under both modes, the better you get at predicting which tool will give the fastest path to the root cause.

FAQ ❓

How do I temporarily run code eagerly so I can use pdb and print statements?

Wrap the suspect code with the context manager with jax.disable_jit(): or set the environment variable JAX_DISABLE_JIT=1. Inside the disabled_jit block, prints and pdb.set_trace behave like normal Python. Remember to remove the disabled_jit block before running production training because it disables JIT optimizations and slows execution.

What does jax.debug_nans do and when should I use it?

jax.debug_nans monitors intermediate primitive outputs inside JIT-compiled functions. If it detects a NaN or infinity, JAX re-runs that region eagerly and pinpoints the primitive that produced the bad value. Use it when you suspect numerical issues inside compiled code. It adds overhead, so only enable it during debugging sessions.

How do I access values saved with NNX.so after a forward pass?

Values sewn with NNX.so become attributes on the module instance under the name you supplied. For example, calling self.so(nnx.intermediates, "dense_output", x) stores the value on model.dense_output.value. By default multiple calls are appended to a tuple; you can use init and reduce functions to change aggregation behavior.

Will NNX.so cause out-of-memory if I sew many tensors?

Yes. NNX.so stores tensors in the module instance, and repeated sewing inside loops or over many batches can blow up memory. Mitigations include sewing only the smallest necessary subset of tensors, providing a reduce function to aggregate values, fetching sewn values to host memory, or clearing sewn values after inspection with NNX.pop or NNX.split.

Can I use NNX.so inside jax.vmap or jax.scan?

You can call NNX.so from inside functions that are vmapped or scanned, but be aware of semantics: repeated calls with the same name within a single forward pass will aggregate (typically by appending) which can grow quickly. To avoid large accumulations, use reduce functions or only sew per-step summaries. Also remember that with transformations applied, debugging with an external step debugger may not step into each transformed iteration.

Why does code behave differently in eager mode versus JIT mode?

JIT compilation may fuse operations, optimize away checks, and treat control flow differently because it compiles to XLA. Some dynamic Python constructs resolve differently when compiled. Also, JAX may require static shapes for certain fusions. Use jax.disable_jit to reproduce behavior in Python and jax.debug_nans to find numerical problems that only show under JIT.

How do I use the VS Code debugger with JAX?

Run the code eagerly using jax.disable_jit or set JAX_DISABLE_JIT=1 so that code runs as regular Python. Then set breakpoints in VS Code and use the debugger normally. Be careful: stepping into code that involves large array allocations can be slow, so limit the amount of work in the blocked region while you debug.

What are init and reduce functions for NNX.so and when should I use them?

Init and reduce functions let you control how multiple sewn values with the same name are combined. init initializes an aggregator the first time the name is sewn, and reduce merges a new value into the aggregator thereafter. Use these when you expect multiple sewn values (for example, across timesteps) but want to keep a bounded summary like a running mean or sum instead of storing a growing tuple.

How should I instrument long training runs without slowing everything down?

Use targeted instrumentation. For intermittent debugging, enable jax.debug_nans only when you detect NaNs. Use NNX.so sparingly and guard it with a conditional debug flag so it runs only for short diagnostic runs. For long-running monitoring, integrate lightweight metrics reporting to TensorBoard or other monitoring systems and use profiling tools for performance hotspots instead of heavy per-step introspection.

What should I do if jax.debug_nans reports a primitive but I still need more context?

Once you have the offending primitive, sew nearby intermediates with NNX.so or re-run the relevant region with jax.disable_jit to print surrounding values. Often you need to inspect inputs to the primitive to understand why it produced NaN. Also consider adding small numerical stabilizers (epsilons) or using clamping to guard divisions and logs.

Can asserts from libraries like Chex be used with JIT?

Yes. Chex provides assertions that are compatible with JIT and XLA-friendly. They can help catch shape and dtype assumptions without breaking compilation. Use them in model code to validate inputs and outputs; they often make failures more informative and avoid cryptic XLA errors at runtime.

How do I clean up sewn variables created by NNX.so?

Use NNX.pop to remove sewn variables from module state after inspection. NNX.split and NNX.state can be used to separate and operate on particular variable collections. Always clean up sewn debug artifacts before starting production training to avoid unnecessary memory use and unpredictable state carried between runs.


AIWorldVision

AI and Technology News