Taxpr is a collection of utilities for performing manipulation of Jaxprs. This is achieved by tag-ing specific arrays at trace time, then extracting and manipulating those tags in the final Jaxpr.
Warning
This package is still very experimental, so expect broken code and breaking changes.
Getting started¶
Install via pip:
Example¶
The following example shows how you can use taxpr to emulate functions with side effects without violating Jax's pure function rules.
import itertools as it
import jax
import jax.numpy as jnp
from jax._src.core import eval_jaxpr
import taxpr as tx
_state_counter = it.count()
def get_state(shape, dtype):
count = next(_state_counter)
def set_state(value):
return tx.tag(value, op="set", id=count)
value = jax.numpy.zeros(shape, dtype=dtype)
return tx.tag(value, op="get", id=count), set_state
def uncurry(fn, *args, **kwargs):
jaxpr = jax.make_jaxpr(fn)(args, kwargs)
states = {}
# iterate through all tags in the jaxpr
# this recurses all child Jaxprs too
for params, shape in tx.iter_tags(jaxpr.jaxpr):
if params["op"] == "get":
states[params["id"]] = shape
initial_states = jax.tree.map(
lambda x: jax.numpy.full_like(x, 0), states
)
def injector(states, token, params):
if params["op"] == "get":
state = states[params["id"]]
return state, states
elif params["op"] == "set":
states[params["id"]] = token
return token, states
raise ValueError(f"Unknown tag op: {params['op']}")
# replace the token with a function that performs the state manipulation
# here we can pass our own context (`initial_states`)
jaxpr = tx.inject(jaxpr, injector, initial_states)
def wrapper(states, *args, **kwargs):
return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, states, args, kwargs)
return wrapper, initial_states
################################################
# Usage
def running_sum(x):
a, set_state = get_state(x.shape, x.dtype)
sum = set_state(a + x)
return sum
rsum, state = uncurry(running_sum, jnp.zeros(0))
_, state = rsum(state, jnp.ones(1))
_, state = rsum(state, jnp.ones(1))
_, state = rsum(state, jnp.ones(1))
assert jnp.allclose(next(iter(state.values())), 3)