Reference¶
High level routines¶
dissolve_tags(jaxpr, predicate=None)
¶
Remove all tags from the given jaxpr.
This will traverse the given jaxpr and all nested jaxprs, removing:
- all tags if
predicate == None - all tags for which
predicate(..) == True
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jaxpr
|
|
A JAXPR potentially containing tagged primitives. |
required |
predicate
|
|
An optional function that takes tag parameters and the token shape and returns True if the tag should be dissolved, or False to keep it. |
None
|
Returns: A new JAXPR with all tags removed.
inject(closed_jaxpr, injector, ctx, predicate=None)
¶
Inject tags into the given JAXPR using the provided injector function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
closed_jaxpr
|
|
A closed JAXPR potentially containing tagged primitives. |
required |
injector
|
|
A function that takes the context, the token, and tag parameters, and returns a new token and modified context. The injector is converted to a JAXPR internally and inlined at each tag point. |
required |
ctx
|
|
The initial context to pass to the injector. |
required |
predicate
|
|
An optional function that takes tag parameters and the token shape and returns True if the injector should be applied at that tag, or False to skip injection. If None, the injector is applied at all tags. |
None
|
Returns:
| Type | Description |
|---|---|
|
A new closed JAXPR with the injector inlined at each tag point. The modified jaxpr also returns |
|
the final context as an additional output. |
iter_tags(jaxpr)
¶
Iterate over all tags in the given JAXPR.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jaxpr
|
|
A JAXPR potentially containing tagged primitives. |
required |
Yields:
| Type | Description |
|---|---|
|
Tuples of (parameters, ShapeDtypeStruct | None) for each tag. |
tag(token, **params)
¶
Tag a specific point in a computation with given parameters.
Note: You must consume the output of this function for the tag to appear in the JAXPR. Simply calling this function without using its output may lead to the tag being optimized away.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
token
|
|
An input token representing a point in the computation. The token can be any PyTree of JAX arrays. |
required |
**params
|
Arbitrary keyword parameters to associate with the tag. |
{}
|
Returns: The unchanged input token tagged with the provided parameters.
transpose(jaxpr)
¶
Transposes a closed JAXPR such that it returns the outputs of each tag rather than its original outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jaxpr
|
|
A closed JAXPR potentially containing tagged primitives. |
required |
Returns: A new closed JAXPR that returns the outputs of each tag. A list of tag output parameters. A list of tag output structures.
inline_jaxpr(eqn)
¶
Inline a jaxpr contained in an equation.
partial(jaxpr, /, args)
¶
Partially evaluate a ClosedJaxpr by fixing certain input arguments.
Similar to functools.partial, but for JAX ClosedJaxprs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jaxpr
|
|
The ClosedJaxpr to partially evaluate. |
required |
args
|
|
A dictionary mapping input argument indices to their fixed values. |
required |
Returns:
| Type | Description |
|---|---|
|
A new ClosedJaxpr with the specified inputs fixed. The number of inputs |
|
|
in the returned jaxpr is the original number of inputs minus the number of fixed inputs. |
partition_out(jaxpr, outvar_indices)
¶
Partition a ClosedJaxpr into multiple ClosedJaxprs, each computing a single outvar.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jaxpr
|
|
The ClosedJaxpr to partition. |
required |
outvar_indices
|
|
The list of outvar indices each partition should compute. |
required |
Returns: A list of ClosedJaxprs, each computing the specified outvars.
rewrite_invars(eqn, varmap)
¶
Rewrite the invars of an equation according to a variable mapping.
rewrite_outvars(eqn, varmap)
¶
Rewrite the outvars of an equation according to a variable mapping.
rewrite_vars(jaxpr, varmap)
¶
Rewrite the invars and outvars of a jaxpr contained in an equation.
strip_outputs(jaxpr, /, indices)
¶
Strip specified outputs from a ClosedJaxpr.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
jaxpr
|
|
The ClosedJaxpr to strip outputs from. |
required |
indices
|
|
A set of output indices to remove. |
required |
Returns: A new ClosedJaxpr with the specified outputs removed.
Internal¶
assert_tree_match(struct1, struct2)
¶
Asserts that two JAX tree structures are the same.