Checkpointing Flax NNX Models with Orbax (Part 1)
Table of Contents
- 🔍 Quick outline
- 🧭 Why checkpointing matters
- 🧩 The mental model: NNX modules and dynamic state
- 🧰 NNX helpers: split, merge, update, state
- ⚙️ Orbax basics
- 💾 Saving a Flax NNX model with Orbax
- 🔁 Restoring a Flax NNX model with Orbax
- 🧪 Example explained: a minimal linear module
- 🧠 Handling the correlation between static graph and dynamic state
- 🔎 Orbax manager conveniences and best practices
- 🔐 Common pitfalls and how I avoid them
- 📦 Saving the optimizer and other training artifacts
- 🔁 Updating vs merging when restoring
- 📚 Advanced topic preview: distributed and sharded checkpoints
- 🧾 Practical checklist I use before checkpointing
- 🧾 Minimal resume pattern I recommend
- 🧩 Example metadata and compatibility notes
- 🛠️ Troubleshooting common errors
- ✅ Summary and what I take away
- 📣 What comes next
- ❓ Frequently asked questions
🔍 Quick outline
I walk through why checkpointing matters for JAX projects, how Flax NNX represents model state, and how Orbax fits into that picture. You will get a clear, practical workflow for saving and restoring a single NNX model instance, plus concrete tips that make checkpointing reliable during development.
🧭 Why checkpointing matters
Checkpointing is basic infrastructure for any training workflow. I save model weights, optimizer moments, training steps, and other mutable things so I do not lose progress if training stops or I need to resume later. It is also how I capture a model version for evaluation or deployment.
In the JAX ecosystem, the common tool for checkpointing is Orbax. When I use Flax NNX, which provides a Pythonic, stateful API reminiscent of PyTorch modules, I still get all the functional benefits of JAX. That combination makes checkpointing slightly different from other frameworks, and the right mental model is important.
🧩 The mental model: NNX modules and dynamic state
Flax NNX modules are plain Python classes that directly hold their mutable parts as attributes. Parameters, batch statistics, and other mutable quantities are stored as instances of classes derived from nnx.variable. Those variables contain the underlying JAX arrays and optional metadata.
Initialization often requires providing shapes and a PRNG key up front. Typically I declare parameters in the constructor using nnx.param, for example inside Dunder init. During forward calls I access them as self.weight, self.bias, or similar. That makes the API feel very familiar if you have used PyTorch.
Starting with NNX v0.11, module instances are native JAX pytrees. That means they can be passed to JAX transformations directly. But for checkpointing I rarely want to write out the whole module object. I usually only want to save the dynamic, changeable parts: the variables that represent parameters, batch stats, optimizer-compatible state, and so on.
🧰 NNX helpers: split, merge, update, state
NNX gives me three essential operations to bridge the stateful, Pythonic modules and the functional world of JAX pytrees.
- nnx.split(module) — returns a pair: a static structure (graph_def) and the dynamic state as a py-tree of NNX.variable objects. The state is what I usually want to save or feed into JAX-transformed functions.
- nnx.merge(graph_def, state) — reconstructs a fully materialized module instance from a stored static graph definition and a state py-tree that contains variable values.
- nnx.update(module, state) — writes values from a state py-tree into an existing module instance, mutating the module's variables to match the given state.
There is also nnx.state(module), which acts as a filter: it walks the module pytree and returns a clean py-tree containing only the NNX.variable objects. This is the exact structure Orbax will serialize.
⚙️ Orbax basics
Orbax separates concerns across two things I care about: low-level serialization and high-level checkpoint lifecycle management.
- Checkpointer — the serializer and deserializer for a particular data shape. For example, a PyTree checkpointer knows how to write and read pytrees of JAX arrays and simple container types to and from disk.
- Checkpoint manager — a convenience layer on top of a checkpointer that manages saving over time. It creates versioned checkpoints by step number, tracks which checkpoints exist, enforces retention policies like "keep the last 3", and exposes a simple save and restore API I can call from my training loop.
For almost all practical training I use the checkpoint manager. It reduces boilerplate and handles concurrency concerns that often appear in real workflows.
💾 Saving a Flax NNX model with Orbax
Saving a model is a short sequence of steps, but each one is important if I want a robust workflow.
- Create a checkpoint manager and configure its directory and retention policy.
- Extract the NNX dynamic state using nnx.split or nnx.state. This is the py-tree Orbax will store.
- Call manager.save with the training step and the state py-tree wrapped in Orbax's argument structure.
- Optionally wait for the save to finish and then close the manager to release resources.
Here is a compact example I use as a template. The code is illustrative rather than copy-paste exact, but it shows the core API flow.
import orbax.checkpoint as ocp
ckpt_dir = "/tmp/checkpoints"
manager = ocp.CheckpointManager(ckpt_dir, ocp_py_tree_checkpointer, max_to_keep=3)
graph_def, state = nnx.split(model)
step = 100
manager.save(step, {'model': state})
manager.wait_until_finished()
manager.close()
Notes on the snippet above
- I often place the state under a top-level key such as'model' to allow saving multiple things in the same checkpoint, like optimizer_state or rng_state.
- When saving in background threads or a multiprocessing environment I always call wait_until_finished before exiting or before a dependent process attempts to read from disk.
🔁 Restoring a Flax NNX model with Orbax
Restoring is slightly more involved because Orbax needs a structure template. Orbax will load binary blobs into the shapes and array dtypes it sees in the template. If you pass the wrong template you either get a mismatch error or incorrect shapes.
To restore safely I use nnx.eval_shape to create an abstract model instance. That returns a model whose variables are not real arrays but light wrappers that only record shape and dtype. No actual memory is allocated. I then split that abstract model into graph_def and abstract_state. The abstract_state serves as the template for Orbax restore.
Workflow for restore
- Create an abstract model with nnx.eval_shape by calling a small constructor function that builds the model with shapes and RNGs.
- Split the abstract model to get graph_def and abstract_state.
- Open the checkpoint manager and find the latest checkpoint step with manager.latest_step.
- Call manager.restore with the step number and the abstract_state passed as the restore template. Orbax will return the restored py-tree with concrete JAX arrays.
- Use nnx.merge(graph_def, restored_state) to materialize a full model instance, or nnx.update(existing_model, restored_state) to write into an existing model.
def make_model():
return SimpleLinear(in_features=10, out_features=5)
abstract_model = nnx.eval_shape(make_model)
graph_def, abstract_state = nnx.split(abstract_model)
step = manager.latest_step()
restored_state = manager.restore(step, {'model': abstract_state})
restored_model = nnx.merge(graph_def, restored_state['model'])
Why this works
Orbax needs a structure template so it knows how to map stored array blobs back into a py-tree. The abstract state gives Orbax exactly that mapping while avoiding large allocations. Once Orbax returns real arrays, I can reconstruct a fully functional module with nnx.merge or apply the arrays to an existing object with nnx.update.
🧪 Example explained: a minimal linear module
To make these ideas concrete I use a small linear module as an example. In NNX I would write something like:
class SimpleLinear(nnx.Module):
def __init__(self, in_features, out_features, rngs):
self.weight = nnx.param('weight', shape=(out_features, in_features), initializer=... , rngs=rngs)
self.bias = nnx.param('bias', shape=(out_features,), initializer=... , rngs=rngs)
def __call__(self, x):
return x @ self.weight.value.T + self.bias.value
Key points
- The nnx.param constructs are NNX.variables. They hold a value attribute that contains the actual JAX array.
- When instantiated with a PRNG and shapes, the module creates these variables immediately.
- Using nnx.split on an instance returns the static structure and the dynamic state py-tree which contains the variable wrappers. That py-tree is exactly what Orbax saves.
🧠 Handling the correlation between static graph and dynamic state
Two pieces must be kept aligned for restore: the static graph_def and the dynamic state. The static graph captures module layout and names, while the dynamic state provides the actual numerical values. I always save or record the graph_def that matches the state. If model code changes between save and restore in incompatible ways, merging or updating will fail or produce incorrect results.
When I expect model definition changes across experiments I do one of the following:
- Pin the version of the model class used to create the checkpoint and record it in metadata alongside the checkpoint.
- Save both the graph_def and a serialized copy of the model source or a hash of the code to detect incompatible restores.
🔎 Orbax manager conveniences and best practices
The CheckpointManager provides a small workflow that I rely on.
- Atomic saves — the manager supports saving in a way that avoids partial writes if a process is interrupted. That is essential for long training runs.
- Retention — setting max_to_keep reduces storage usage. I usually keep the last few checkpoints and occasionally archive a checkpoint I know I'll want long term.
- Latest step lookup — manager.latest_step simplifies automated resume logic. I prefer not to hardcode steps.
Tip: include the training step in the saved payload, even though Orbax already tags the checkpoint with a step number. That redundancy makes managing resume logic easier if you use checkpoints across different systems or tools.
🔐 Common pitfalls and how I avoid them
Checkpointing sounds straightforward, but a few mistakes cause headaches later.
- Saving the wrong object — saving the entire module instance instead of the state py-tree is tempting because the module is a pytree, but it often includes ephemeral Python-side attributes that should not be serialized. Always extract the NNX.variable state with nnx.state or nnx.split.
- Mismatch in shapes or dtypes on restore — always use nnx.eval_shape to build an abstract template that matches how the model was created during save. If you change the initializer or constructor signature, update your template accordingly.
- Forgetting optimizer state — if you want to resume training exactly where you left off, save optimizer moments and any global step counters together with model parameters. I usually store optimizer state under a separate top-level key in the checkpoint payload.
- Not waiting for background saves — if a process spawns async saves and then exits, partial writes or corrupted checkpoints are possible. Call wait_until_finished and close the manager before exiting the process.
📦 Saving the optimizer and other training artifacts
Saving a model alone is rarely enough. To resume training I also need the optimizer state, RNG states, and occasionally the dataset iterator state. Because Orbax can store arbitrary pytrees, I usually pack these into a single dictionary and save them together.
payload = {
'model': model_state,
'optimizer': optimizer_state,
'rng': rng_state,
'step': step,
}
manager.save(step, payload)
This keeps everything atomic. When I restore, I unpack the payload and reconstruct both model and optimizer completely.
🔁 Updating vs merging when restoring
Both nnx.merge and nnx.update are useful in different situations.
- nnx.merge(graph_def, state) — creates a brand new module instance from scratch using the static graph and the state py-tree. I use this when I do not have an existing module instance handy or when I want a clean materialized object that equals the saved state.
- nnx.update(module, state) — writes state values into an existing module instance. I use this when my training script already has a module object created and I want to mutate it to match the checkpoint rather than allocate a new one.
Both paths are valid. Pick the one that fits your program structure.
📚 Advanced topic preview: distributed and sharded checkpoints
So far I covered only single-process single-device checkpointing. Real world training often uses multiple devices and model sharding. Orbax supports distributed checkpointing, but it adds a layer of complexity. The manager and checkpointer must understand how the model variables are partitioned across devices so that the saved checkpoint can be restored into a sharded or replicated setup.
Key concerns in the distributed setting
- Sharding metadata — you need to track how variables are sharded so they can be reassembled correctly on restore.
- Device topology — the number of devices and their arrangement may differ between save and restore. Orbax provides utilities to handle many of these cases, but planning ahead helps.
- Atomicity across processes — every process participating in a distributed run must cooperate to create a consistent checkpoint. The high level manager APIs make this easier, but proper synchronization is required.
I will cover distributed checkpointing and how to save a sharded model and optimizer state atomically in the follow-up material.
🧾 Practical checklist I use before checkpointing
I run through this short checklist before enabling save or restore in a training script.
- Confirm model constructor is deterministic for a given RNG and shapes.
- Use nnx.eval_shape to create an abstract template that matches the save-time model.
- Save both graph_def and the dynamic state py-tree together or record where graph_def came from.
- Include optimizer state and step counters in the saved payload if resuming training is required.
- Decide an appropriate retention policy to limit storage usage.
- Call wait_until_finished before exiting the process that kicked off the save.
🧾 Minimal resume pattern I recommend
When I write scripts that should resume automatically if a checkpoint exists, I use a simple pattern.
- Make an abstract model and split it to get graph_def and abstract_state.
- Create manager and call manager.latest_step.
- If step exists, call manager.restore and apply returned state with nnx.merge or nnx.update.
- If no checkpoint, initialize model and optimizer normally.
This preserves a clear separation between initialization and restore logic and avoids accidental double-initialization of variables.
🧩 Example metadata and compatibility notes
I often save a metadata blob alongside the checkpoint that includes:
- model name and version
- code hash or git commit
- training hyperparameters that affect shapes
- Orbax version and checkpointer type
That info makes it easier to detect incompatible restores early and to reproduce experiments later.
🛠️ Troubleshooting common errors
Here are a few errors I have run into and how I handle them.
- Shape mismatch on restore — check whether your abstract template matches the saved graph. Use printouts of the shapes in the abstract_state and compare to the saved checkpoint metadata.
- Missing keys in the state tree — confirm that the names of parameters or submodules did not change. Renaming items in your model code will change the structure and keys in the saved state.
- Partial or corrupted checkpoint — ensure wait_until_finished completed and check the storage backend. For cloud-backed storage, look at object store consistency or retry the restore from a different process.
✅ Summary and what I take away
Checkpointing with Orbax and Flax NNX follows a predictable pattern:
- Extract the dynamic state py-tree with nnx.split or nnx.state.
- Use Orbax CheckpointManager to save payloads atomically with step numbers and retention rules.
- To restore, create an abstract model with nnx.eval_shape so Orbax can map blobs back into arrays, then reconstruct the module with nnx.merge or nnx.update.
When I follow this pattern I get robust saves and restores that integrate well with JAX transformations and the NNX stateful API.
📣 What comes next
Saving a single model is the foundation. The next step is to save full training state including optimizers in a single atomic checkpoint and to manage sharded, multi-device checkpoints. That requires paying attention to how arrays are partitioned and how to orchestrate saves across processes. When I work at scale, I treat checkpointing as a feature of the training runtime and test restores frequently.
❓ Frequently asked questions
What exactly is the NNX state that Orbax saves?
The NNX state is the py-tree composed of NNX.variable objects. These are the changeable parts of a module such as parameters and batch statistics. You extract them with nnx.state or nnx.split. Orbax serializes the arrays inside those variable objects, not the full Python module instance.
Why use nnx.eval_shape when restoring?
nnx.eval_shape builds an abstract model where variables are lightweight shape-and-dtype placeholders instead of real arrays. Orbax needs such a template so it can match stored binary blobs to the correct tree structure and array shapes without allocating large arrays until the actual restore.
Can I save the entire NNX module instance directly with Orbax?
Although module instances are pytrees in newer NNX versions, saving the whole module can serialize unwanted Python-side attributes. It is safer and more portable to save only the dynamic state (NNX.variable py-tree) via nnx.state or nnx.split and to persist or reconstruct the static graph separately.
How do I resume training with the optimizer state?
Include optimizer state and other training artifacts in the same payload you save with Orbax. For example, save {'model': model_state, 'optimizer': optimizer_state, 'rng': rng_state, 'step': step}. On restore, unpack these keys and reinitialize or update your optimizer and RNG to continue training exactly where you left off.
What is the difference between nnx.merge and nnx.update?
nnx.merge constructs a new, fully materialized module from a static graph_def and a state py-tree. nnx.update writes state values into an existing module instance. Use merge when you want a fresh object and update when you already have an instantiated module you want to mutate.
How do I handle checkpoints when training across multiple devices?
Distributed checkpointing requires tracking how arrays are sharded and coordinating saves across processes. Orbax provides utilities for sharded checkpoints, but you must ensure consistent sharding metadata and synchronize processes so saved shards can be reassembled at restore. Planning for sharding in your model design simplifies the process.
What should I do to detect incompatible restores early?
Save metadata with each checkpoint including model version, code hash or git commit, and hyperparameters that affect shapes. On restore, compare these fields and fail early if they do not match the current code or configuration. This prevents subtle issues from creeping in when shapes or names change.
How many checkpoints should I keep?
It depends on your storage budget and experiment needs. A common policy is to keep the last 3 to 5 checkpoints for regular runs, and to archive one checkpoint per important milestone or experiment. Orbax CheckpointManager supports a max_to_keep setting to automate retention.



