Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from jax.core to jax.extend.core for several deprecated symbols #307

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions kfac_jax/_src/curvature_blocks/curvature_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Sequence

import jax
import jax.extend as jex
import jax.numpy as jnp
import jax.scipy
from kfac_jax._src import layers_and_loss_tags as tags
Expand Down Expand Up @@ -81,22 +82,22 @@ def name(self) -> str:

@property
def layer_tag_primitive(self) -> tags.LayerTag:
"""The :class:`jax.core.Primitive` corresponding to the block's tag equation."""
"""The :class:`jex.core.Primitive` corresponding to the block's tag equation."""

primitive = self._layer_tag_eq.primitive
assert isinstance(primitive, tgm.tags.LayerTag)

return primitive

@property
def parameter_variables(self) -> tuple[jax.core.Var, ...]:
def parameter_variables(self) -> tuple[jex.core.Var, ...]:
"""The parameter variables of the underlying Jax equation."""

param_vars = []

for p in tags.layer_eqn_data(self._layer_tag_eq).params:

assert isinstance(p, jax.core.Var)
assert isinstance(p, jex.core.Var)
param_vars.append(p)

return tuple(param_vars)
Expand Down
26 changes: 13 additions & 13 deletions kfac_jax/_src/layers_and_loss_tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Generic, Sequence, TypeVar

import jax
from jax import core
import jax.extend as jex


# Types for annotation
Expand Down Expand Up @@ -94,7 +94,7 @@ def get_loss_outputs(
return tuple(kwargs[name] for name in meta.parameter_dependants)


class LossTag(core.Primitive):
class LossTag(jex.core.Primitive):
"""A Jax primitive for tagging K-FAC losses.

The primitive is no-op at runtime, however its goal is to tag (annotate) the
Expand All @@ -103,7 +103,7 @@ class LossTag(core.Primitive):
curvature matrix.
"""

# Whether the primitive returns multiple outputs (from core.Primitive)
# Whether the primitive returns multiple outputs (from jex.core.Primitive)
multiple_results = True

def __init__(self):
Expand Down Expand Up @@ -175,9 +175,9 @@ def _batching(


def loss_eqn_parameter_dependants(
eqn: jax.core.JaxprEqn,
eqn: jex.core.JaxprEqn,
raise_an_error: bool = True,
) -> list[jax.core.Var]:
) -> list[jex.core.Var]:
"""Returns the parameter dependants variables from the give loss equation."""
if not isinstance(eqn.primitive, LossTag):
if raise_an_error:
Expand All @@ -192,7 +192,7 @@ def loss_eqn_parameter_dependants(


def loss_eqn_construct_loss(
eqn: jax.core.JaxprEqn,
eqn: jex.core.JaxprEqn,
*args: Array,
) -> Any:
"""Constructs an instance of the corresponding :class:`~LossFunction` class."""
Expand All @@ -206,7 +206,7 @@ def loss_eqn_construct_loss(
return meta.loss_class(**kwargs)


def loss_eqn_class_name(eqn: jax.core.JaxprEqn) -> str:
def loss_eqn_class_name(eqn: jex.core.JaxprEqn) -> str:
"""The name of the underlying `~LossFunction` class."""

if not isinstance(eqn.primitive, LossTag):
Expand Down Expand Up @@ -253,7 +253,7 @@ def get_and_verify_layer_meta(
return meta


class LayerTag(core.Primitive):
class LayerTag(jex.core.Primitive):
"""A Jax primitive for tagging K-FAC layers.

The primitive is no-op at runtime, however its goal is to tag (annotate) the
Expand Down Expand Up @@ -347,9 +347,9 @@ def _batching(


def layer_eqn_data( # pytype: disable=invalid-annotation
eqn: jax.core.JaxprEqn,
eqn: jex.core.JaxprEqn,
raise_an_error: bool = True,
) -> LayerData[jax.core.Var]:
) -> LayerData[jex.core.Var]:

if isinstance(eqn.primitive, LayerTag):
return eqn.primitive.layer_data(eqn.invars, eqn.params, str(eqn))
Expand All @@ -360,7 +360,7 @@ def layer_eqn_data( # pytype: disable=invalid-annotation
return LayerData(inputs=(), outputs=(), params=())


def layer_eqn_name(eqn: jax.core.JaxprEqn) -> str:
def layer_eqn_name(eqn: jex.core.JaxprEqn) -> str:
meta = get_and_verify_layer_meta(eqn.invars, eqn.params)
if meta.name is None:
raise ValueError("Layer name must be provided at this stage.")
Expand Down Expand Up @@ -460,11 +460,11 @@ def register_scale_and_shift(
)


class LossTagEqn(core.JaxprEqn):
class LossTagEqn(jex.core.JaxprEqn):
"""A class used only for annotation purposes."""
primitive: LossTag


class LayerTagEqn(core.JaxprEqn):
class LayerTagEqn(jex.core.JaxprEqn):
"""A class used only for annotation purposes."""
primitive: LayerTag
61 changes: 26 additions & 35 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,7 @@
from absl import logging
import immutabledict
import jax

jax_version = (
jax.__version_info__ if hasattr(jax, "__version_info__")
else tuple(map(int, jax.__version__.split("."))))

if jax_version > (0, 4, 11):
import jax.extend as jax_extend # pylint: disable=g-import-not-at-top
import jax.extend as jex

import jax.numpy as jnp # pylint: disable=g-import-not-at-top
from kfac_jax._src import layers_and_loss_tags as tags
Expand All @@ -42,11 +36,11 @@
# Types for annotation
Array = utils.Array
PyTreeDef = utils.PyTreeDef
Var = jax.core.Var
Var = jex.core.Var
Vars = Sequence[Var]
Jaxpr = jax.core.Jaxpr
ClosedJaxpr = jax.core.ClosedJaxpr
JaxprEqn = jax.core.JaxprEqn
Jaxpr = jex.core.Jaxpr
ClosedJaxpr = jex.core.ClosedJaxpr
JaxprEqn = jex.core.JaxprEqn
JaxprEqns = Sequence[JaxprEqn]
T = TypeVar("T")
J = TypeVar("J", Jaxpr, ClosedJaxpr)
Expand All @@ -64,10 +58,7 @@ def eval_jaxpr_eqn(eqn: JaxprEqn, in_values: list[T]) -> list[T]:

subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params)

if jax_version > (0, 4, 11):
user_context = jax_extend.source_info_util.user_context
else:
user_context = jax.core.source_info_util.user_context # pytype: disable=module-attr
user_context = jex.source_info_util.user_context

with user_context(eqn.source_info.traceback):
output = eqn.primitive.bind(*subfuns, *in_values, **bind_params)
Expand Down Expand Up @@ -245,9 +236,9 @@ class JaxprGraph:
it.
manual_registrations: Any layer tag equations that have been manually
registered.
jaxpr: The underlying :class:`jax.core.Jaxpr` part of ``self.closed_jaxpr``.
jaxpr: The underlying :class:`Jaxpr` part of ``self.closed_jaxpr``.
consts: The underlying constants part ``self.closed_jaxpr``.
outvars: The output variables of the underlying :class:`jax.core.Jaxpr` part
outvars: The output variables of the underlying :class:`Jaxpr` part
of ``self.closed_jaxpr``.
"""
name: str
Expand Down Expand Up @@ -294,7 +285,7 @@ def sub_graph_eqns(
eqns.append(next_eqn)

for v in next_eqn.invars:
if (not isinstance(v, jax.core.Literal) and v not in root_vars and
if (not isinstance(v, jex.core.Literal) and v not in root_vars and
v not in processed_vars and v in self.var_to_creation_op):
to_process_eqns.append(self.var_to_creation_op[v])
processed_vars.add(v)
Expand Down Expand Up @@ -383,7 +374,7 @@ def make_jax_graph(
eqns.append(eqn)

sub_graph_vars.update(
v for v in eqn.invars if not isinstance(v, jax.core.Literal)
v for v in eqn.invars if not isinstance(v, jex.core.Literal)
)

consts_i = [
Expand Down Expand Up @@ -461,8 +452,8 @@ class GraphPattern:
in_values_preprocessor: A function that can optionally modify the in_vals
passed to the tag_primitive, from those that are usually the input to
the jaxpr.
jaxpr: The underlying :class:`jax.core.Jaxpr` represented by the pattern.
param_vars: The list of :class:`jax.core.Var` that correspond to parameters
jaxpr: The underlying :class:`Jaxpr` represented by the pattern.
param_vars: The list of :class:`Var` that correspond to parameters
in the pattern.
graph: A :class:`JaxprGraph` representation of the pattern.
"""
Expand Down Expand Up @@ -633,7 +624,7 @@ def add_vars_if_possible(
If at least one of the pattern variables is a parameter, but the
corresponding graph variable is not or vise-versa, the method does not
update the current variables map and returns ``False``. Similarly, if at
least one of the graph variables is a :class:`~jax.core.Literal` (meaning a
least one of the graph variables is a :class:`iteral` (meaning a
constant, independent of the function inputs) and the corresponding
pattern variable is not an input to the pattern, it returns ``False``. In
all other cases it updates the map and returns ``True``.
Expand All @@ -648,12 +639,12 @@ def add_vars_if_possible(
"""
for var1, var2 in zip(eqn_vars, graph_vars):

var2_matchable = isinstance(var2, jax.core.Var) and (
var2_matchable = isinstance(var2, jex.core.Var) and (
var2 in matchable_graph_params)

if (var1 in param_variables and not var2_matchable or
var1 not in param_variables and var2_matchable or
(isinstance(var2, jax.core.Literal) and var1 not in input_vars)):
(isinstance(var2, jex.core.Literal) and var1 not in input_vars)):
return False

current_variables_map.update(zip(eqn_vars, graph_vars))
Expand Down Expand Up @@ -788,7 +779,7 @@ def match_pattern(
for k, v in match_variables_map.items():

if (k not in pattern.graph.jaxpr.invars and
not isinstance(v, jax.core.Literal)):
not isinstance(v, jex.core.Literal)):

creation_op = graph.var_to_creation_op[v]

Expand Down Expand Up @@ -883,14 +874,14 @@ def find_layer_tags_and_patterns(


def read_env(
env: dict[jax.core.Var, T],
env: dict[jex.core.Var, T],
variables: list[jax.core.Atom],
) -> list[T]:
"""Reads from the variable-to-array environment during tracing."""
result = []
assert isinstance(variables, list)
for v in variables:
if isinstance(v, jax.core.Literal):
if isinstance(v, jex.core.Literal):
# Literals are values baked into the Jaxpr
result.append(v.val)
elif isinstance(v, jax.core.DropVar):
Expand All @@ -901,8 +892,8 @@ def read_env(


def write_env(
env: dict[jax.core.Var, T],
variables: list[jax.core.Var],
env: dict[jex.core.Var, T],
variables: list[jex.core.Var],
values: list[T],
) -> None:
"""Writes to the variable-to-array environment during tracing."""
Expand Down Expand Up @@ -979,7 +970,7 @@ def clean_jaxpr(

final_outvars.append(var)

if not isinstance(var, jax.core.Literal):
if not isinstance(var, jex.core.Literal):
dependants.add(var)

for eqn in reversed(closed_jaxpr.jaxpr.eqns):
Expand Down Expand Up @@ -1035,7 +1026,7 @@ def clean_jaxpr(
if check:
eqns.append(eqn)
new_dependants = set(v for v in eqn.invars
if not isinstance(v, jax.core.Literal))
if not isinstance(v, jex.core.Literal))
dependants = dependants.union(new_dependants)

# Dependants should only be invars
Expand Down Expand Up @@ -1112,7 +1103,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:

# We ignore broadcasting of constants
if (eqn.primitive.name == "broadcast_in_dim" and
not all(isinstance(v, jax.core.Literal) for v in eqn.invars)):
not all(isinstance(v, jex.core.Literal) for v in eqn.invars)):

if eqn.invars[0] in broadcasts_outputs:
# Construct a merged equation from the previous and current one
Expand All @@ -1139,7 +1130,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J:

else:
for v in eqn.invars:
if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs:
if not isinstance(v, jex.core.Literal) and v in broadcasts_outputs:
eqns.append(broadcasts_outputs[v])

eqns.append(eqn)
Expand Down Expand Up @@ -1688,7 +1679,7 @@ def __init__(
):
self._func_graph = func_graph
self._tag_locations = tag_locations
self._flat_func = jax.core.jaxpr_as_fun(func_graph.closed_jaxpr)
self._flat_func = jex.core.jaxpr_as_fun(func_graph.closed_jaxpr)
self._param_labels = self._compute_parameter_labels()

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -1770,7 +1761,7 @@ def _auto_register_tags(

eqns_for_registration.append(eqn)
sub_graph_vars.update(
v for v in eqn.invars if not isinstance(v, jax.core.Literal))
v for v in eqn.invars if not isinstance(v, jex.core.Literal))

eqns_for_registration = eqns_for_registration[::-1]

Expand Down
15 changes: 8 additions & 7 deletions kfac_jax/_src/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from absl import logging
import jax
import jax.extend as jex
import jax.numpy as jnp
from kfac_jax._src import layers_and_loss_tags as tags
from kfac_jax._src import loss_functions
Expand All @@ -32,7 +33,7 @@
Params = utils.Params
FuncArgs = utils.FuncArgs
FuncOuts = utils.FuncOuts
Var = jax.core.Var
Var = jex.core.Var
LossFunction = loss_functions.LossFunction
LossFunctionInputs = loss_functions.LossFunctionInputs

Expand Down Expand Up @@ -80,7 +81,7 @@ def tree_unflatten(cls, aux_data, children):
tuple[LayerVjpData[Array], ...], # pytype: disable=invalid-annotation
],
]
JaxprOrClosedJaxpr = jax.core.Jaxpr | jax.core.ClosedJaxpr
JaxprOrClosedJaxpr = jex.core.Jaxpr | jex.core.ClosedJaxpr


def shape_and_type(x: Array) -> tuple[Shape, jnp.dtype]:
Expand All @@ -99,7 +100,7 @@ def make_cache_key(


def extract_tags(
jaxpr: jax.core.Jaxpr,
jaxpr: jex.core.Jaxpr,
) -> tuple[tuple[tags.LayerTagEqn, ...], tuple[tags.LossTagEqn, ...]]:
"""Extracts the layer and the loss tags from the given Jaxpr."""

Expand Down Expand Up @@ -199,7 +200,7 @@ class ProcessedJaxpr(utils.Finalizable):

def __init__(
self,
jaxpr: jax.core.Jaxpr,
jaxpr: jex.core.Jaxpr,
consts: list[Any],
in_tree: utils.PyTreeDef,
params_index: int,
Expand Down Expand Up @@ -819,16 +820,16 @@ def forward_aux(
own_func_args = primal_func_args

# Mapping from variable -> value
env: dict[jax.core.Var, Array] = {}
env: dict[jex.core.Var, Array] = {}
read = functools.partial(tgm.read_env, env)

def write(variables: list[jax.core.Var], values: list[Array]) -> None:
def write(variables: list[jex.core.Var], values: list[Array]) -> None:
# if not isinstance(variables, list):
# variables = [variables]
tgm.write_env(env, variables, values)

for v in variables:
if not isinstance(v, jax.core.Literal) and v in aux:
if not isinstance(v, jex.core.Literal) and v in aux:
env[v] = env[v] + aux[v]

# Bind args and consts to environment
Expand Down
Loading
Loading