Skip to content

Reference

High level routines

dissolve_tags(jaxpr, predicate=None)

Remove all tags from the given jaxpr.

This will traverse the given jaxpr and all nested jaxprs, removing:

  • all tags if predicate == None
  • all tags for which predicate(..) == True

Parameters:

Name Type Description Default
jaxpr Jaxpr

A JAXPR potentially containing tagged primitives.

required
predicate Callable[[dict[str, Any], PyTree[AbstractValue]], bool] | None

An optional function that takes tag parameters and the token shape and returns True if the tag should be dissolved, or False to keep it.

None

Returns: A new JAXPR with all tags removed.

inject(closed_jaxpr, injector, ctx, predicate=None)

Inject tags into the given JAXPR using the provided injector function.

Parameters:

Name Type Description Default
closed_jaxpr ClosedJaxpr

A closed JAXPR potentially containing tagged primitives.

required
injector Callable[[Ctx, PyTree[Array], dict[str, Any]], tuple[PyTree[Array], Ctx]]

A function that takes the context, the token, and tag parameters, and returns a new token and modified context. The injector is converted to a JAXPR internally and inlined at each tag point.

required
ctx Ctx

The initial context to pass to the injector.

required
predicate Callable[[dict[str, Any], PyTree[AbstractValue]], bool] | None

An optional function that takes tag parameters and the token shape and returns True if the injector should be applied at that tag, or False to skip injection. If None, the injector is applied at all tags.

None

Returns:

Type Description
ClosedJaxpr

A new closed JAXPR with the injector inlined at each tag point. The modified jaxpr also returns

ClosedJaxpr

the final context as an additional output.

iter_tags(jaxpr)

Iterate over all tags in the given JAXPR.

Parameters:

Name Type Description Default
jaxpr Jaxpr

A JAXPR potentially containing tagged primitives.

required

Yields:

Type Description
tuple[dict[str, Any], PyTree[AbstractValue]]

Tuples of (parameters, ShapeDtypeStruct | None) for each tag.

tag(token, **params)

Tag a specific point in a computation with given parameters.

Note: You must consume the output of this function for the tag to appear in the JAXPR. Simply calling this function without using its output may lead to the tag being optimized away.

Parameters:

Name Type Description Default
token T

An input token representing a point in the computation. The token can be any PyTree of JAX arrays.

required
**params

Arbitrary keyword parameters to associate with the tag.

{}

Returns: The unchanged input token tagged with the provided parameters.

transpose(jaxpr)

Transposes a closed JAXPR such that it returns the outputs of each tag rather than its original outputs.

Parameters:

Name Type Description Default
jaxpr ClosedJaxpr

A closed JAXPR potentially containing tagged primitives.

required

Returns: A new closed JAXPR that returns the outputs of each tag. A list of tag output parameters. A list of tag output structures.

inline_jaxpr(eqn)

Inline a jaxpr contained in an equation.

partial(jaxpr, /, args)

Partially evaluate a ClosedJaxpr by fixing certain input arguments.

Similar to functools.partial, but for JAX ClosedJaxprs.

Parameters:

Name Type Description Default
jaxpr ClosedJaxpr

The ClosedJaxpr to partially evaluate.

required
args dict[int, Any]

A dictionary mapping input argument indices to their fixed values.

required

Returns:

Type Description

A new ClosedJaxpr with the specified inputs fixed. The number of inputs

in the returned jaxpr is the original number of inputs minus the number of fixed inputs.

partition_out(jaxpr, outvar_indices)

Partition a ClosedJaxpr into multiple ClosedJaxprs, each computing a single outvar.

Parameters:

Name Type Description Default
jaxpr ClosedJaxpr

The ClosedJaxpr to partition.

required
outvar_indices list[list[int]]

The list of outvar indices each partition should compute.

required

Returns: A list of ClosedJaxprs, each computing the specified outvars.

rewrite_invars(eqn, varmap)

Rewrite the invars of an equation according to a variable mapping.

rewrite_outvars(eqn, varmap)

Rewrite the outvars of an equation according to a variable mapping.

rewrite_vars(jaxpr, varmap)

Rewrite the invars and outvars of a jaxpr contained in an equation.

strip_outputs(jaxpr, /, indices)

Strip specified outputs from a ClosedJaxpr.

Parameters:

Name Type Description Default
jaxpr ClosedJaxpr

The ClosedJaxpr to strip outputs from.

required
indices Set[int]

A set of output indices to remove.

required

Returns: A new ClosedJaxpr with the specified outputs removed.

Internal

assert_tree_match(struct1, struct2)

Asserts that two JAX tree structures are the same.