Skip to content

Taxpr is a collection of utilities for performing manipulation of Jaxprs. This is achieved by tag-ing specific arrays at trace time, then extracting and manipulating those tags in the final Jaxpr.

Warning

This package is still very experimental, so expect broken code and breaking changes.

Getting started

Install via pip:

Installation
pip install -U taxpr

Example

The following example shows how you can use taxpr to emulate functions with side effects without violating Jax's pure function rules.

import itertools as it

import jax
import jax.numpy as jnp
from jax._src.core import eval_jaxpr
import taxpr as tx

_state_counter = it.count()

def get_state(shape, dtype):
    count = next(_state_counter)

    def set_state(value):
        return tx.tag(value, op="set", id=count)

    value = jax.numpy.zeros(shape, dtype=dtype)
    return tx.tag(value, op="get", id=count), set_state


def uncurry(fn, *args, **kwargs):
    jaxpr = jax.make_jaxpr(fn)(args, kwargs)
    states = {}

    # iterate through all tags in the jaxpr
    # this recurses all child Jaxprs too

    for params, shape in tx.iter_tags(jaxpr.jaxpr):
        if params["op"] == "get":
            states[params["id"]] = shape

    initial_states = jax.tree.map(
        lambda x: jax.numpy.full_like(x, 0), states
    )

    def injector(states, token, params):
        if params["op"] == "get":
            state = states[params["id"]]
            return state, states
        elif params["op"] == "set": 
            states[params["id"]] = token
            return token, states
        raise ValueError(f"Unknown tag op: {params['op']}")

    # replace the token with a function that performs the state manipulation
    # here we can pass our own context (`initial_states`)

    jaxpr = tx.inject(jaxpr, injector, initial_states)

    def wrapper(states, *args, **kwargs):
        return eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, states, args, kwargs)

    return wrapper, initial_states

################################################

# Usage

def running_sum(x):
    a, set_state = get_state(x.shape, x.dtype)
    sum = set_state(a + x)
    return sum

rsum, state = uncurry(running_sum, jnp.zeros(0))

_, state = rsum(state, jnp.ones(1))
_, state = rsum(state, jnp.ones(1))
_, state = rsum(state, jnp.ones(1))

assert jnp.allclose(next(iter(state.values())), 3)