Show HN: Grampax, a "torch.autocast"-style interface for mixed precision in JAX

7 hours ago 1

Installation | Basic Usage | How does it work? | Custom Usage | Misc. Q&A

Grampax offers GRanular Automatic Mixed Precision for JAX. Basically torch.autocast, but for JAX and with fine-grained control over which operations run in which precision.

Grampax is available on PyPI for Python 3.10+.

The main way to use Grampax is via the autocast transformation. autocast takes any callable, such as a Python function or an Equinox model, and returns a version of it that runs in mixed precision:

import jax.numpy as jnp from grampax import autocast # use `autocast` as decorator @autocast(dtype=jnp.bfloat16) def fn(x): return x @ x.T # or call it directly model = MyAwesomeEqxModel() model = autocast(model, dtype=jnp.tensorfloat32) y = model(x)

By default, an autocast-wrapped function will execute as normal until it reaches either a matrix multiplication or a convolution operation, at which point all inputs to that operation will be cast to the specified dtype before the operation is executed, and the output will be cast back to the input's original dtype. Because of the way function transformations compose in JAX, this is automatically reflected in backward passes (i.e. jax.grad, jax.vjp etc.) as well.

Internally, Grampax uses Quax to modify the behavior of the passed callable at trace-time. Specifically, the first positional argument (only the first!) to the function is wrapped into a grampax.AutocastArray. When an AutocastArray is used as input to a JAX primitive, the default behavior is simply to wrap the output in an AutocastArray as well to propagate the wrapping, without modifying the calculation itself.

But when the primitive is among the specified set of autocast primitives (by default, dot_general_p and conv_general_dilated_p), all inputs to the primitive are cast to the specified lower precision dtype before performing the operation, and the output is cast back to the original precision afterwards.

This behavior only applies to floating point arrays; integer outputs such as that of argmin_p are not affected.

Specifying autocast behavior

Passing the dtype argument to autocast results in dot_general_p and conv_general_dilated_p being autocast to that dtype. However, you can also specify a custom mapping of primitives to dtypes via the config argument instead:

# default config = { lax.dot_general_p: jnp.bfloat16, lax.conv_general_dilated_p: jnp.bfloat16, } # add a middle ground precision for log operations config.update({ lax.log_p: jnp.tensorfloat32, lax.logistic_p: jnp.tensorfloat32 }) # can also cast to higher precisions config.update({ lax.reduce_prod_p: jnp.float64 }) model = autocast(model, config=config) y = model(x)

Only casting parts of a model

Since autocast simply wraps a callable, it's possible to only apply it to specific parts of a model. For example, in Equinox, you can use its model surgery idiom:

import equinox as eqx from grampax import autocast mlp = eqx.nn.MLP( ... ) where = lambda m: m.layers[1] mlp = eqx.tree_at(where, mlp, autocast(where(mlp), dtype=jnp.bfloat16)) y = mlp(x) # only mlp.layers[1] will run in mixed precision

Does Grampax support loss scaling?

Not at the moment. Loss scaling is often not necessary with the higher dynamic range of bfloat16, which Grampax defaults to. If you need loss scaling, constant scaling can be implemented relatively easily:

def loss_fn(model, x): ... def scaled_value_and_grad(model, x, scale=65536.0): out, loss_vjp = jax.vjp(partial(loss_fn, x=x), autocast_model) grad = loss_vjp(scale) # apply scaling in backward pass return out, jax.tree.map(lambda x: x / scale, grad) # unscale gradients model = autocast(model, dtype=jnp.float16) out, grad = scaled_value_and_grad(model, x)

For implementations of adaptive loss scaling algorithms in JAX have a look at e.g. JMP.

How does Grampax interact with other Quax types?

In general, Grampax plays nice with other Quax types by redispatching the primitive binds with nested quaxify calls. AutocastArrays can be combined with e.g. quax.lora.LoraArray without issue.

However, Grampax needs to specify a default behavior for all primitives to enable propagating the AutocastArray wrapper throughout the computation. If Quax encounters a primitive with two different Quax inputs that both specify a default behavior, it will raise an error. If you run into this, you can work around it by registering specific behavior for the combination of types (see the Quax documentation).

Another mixed precision library for JAX?

Libraries like JMP and the recent MPX implement wrappers for mixed precision training that take care of dtype casting and loss scaling.

However, they don't offer automatic mixed precision training. They cast the model parameters to lower precision, and then rely on manually specifying the critical computation regions that need to run in full precision in the Python code itself.

Grampax approaches the problem from the opposite direction: All computations run in full precision by default, and only specific operations are automatically cast to lower precision. This is more similar to the convenience offered by torch.autocast.

Read Entire Article