Skip to content

Manage State

Tags can be used with inject() to thread state through a computation:

import jax
import jax.numpy as jnp
from jax._src.core import eval_jaxpr
from taxpr import tag, inject

def simple_fn(x):
    # Tag a "get" operation to read state
    state_val = tag(x, op="get", var_id=0)
    # Tag a "set" operation to write state
    result = tag(state_val + 1, op="set", var_id=0)
    return result

# Trace the function
closed_jaxpr = jax.make_jaxpr(simple_fn)(jnp.array(5.0))

# Define an injector that handles tagged operations
def state_injector(state_val, token, params):
    if params["op"] == "get":
        # Return current state value, keep state unchanged
        return state_val, state_val
    elif params["op"] == "set":
        # Return the new value, update state
        return token, token
    return token, state_val

# Inject the injector into the traced function
initial_state = jnp.array(5.0)
injected = inject(closed_jaxpr, state_injector, initial_state)

# Execute the injected function
result, final_state = eval_jaxpr(
    injected.jaxpr,
    injected.consts,
    initial_state,
    jnp.array(5.0)
)

print(f"Result: {result}, Final state: {final_state}")
# Output: Result: [6.], Final state: [6.]

The key points: - The injector receives (context, token, params) and returns (new_token, new_context) - Context must be JAX-traceable (arrays, not dicts) - You execute the injected Jaxpr with: eval_jaxpr(jaxpr, consts, context, *original_args)