Skip to content

Commit

Permalink
Adding support for models that contain conds, as long as the branches…
Browse files Browse the repository at this point in the history
… don't register any losses or directly use model parameters.

PiperOrigin-RevId: 713703957
  • Loading branch information
james-martens authored and KfacJaxDev committed Jan 9, 2025
1 parent 729bb56 commit 23155a8
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -1859,12 +1859,16 @@ def _auto_register_tags(
final_tag_locations.append(sub_tag_locations)

if eqn_name == "cond":
# TODO(botev): We need to check each branch has identical registrations
raise NotImplementedError()
if final_tag_locations[0] or final_tag_locations[1]:
# TODO(botev): We need to check each branch has identical registrations
raise NotImplementedError()
sub_tag_locations = []
else:
# Extract the sub jaxpr parameter tag registrations and input vars
[sub_tag_locations] = final_tag_locations # pylint:disable=unbalanced-tuple-unpacking

del final_tag_locations

# Update the jaxpr parameter in the equation
eqn_params = dict(**eqn.params)
if eqn_name == "cond":
Expand All @@ -1880,7 +1884,7 @@ def _auto_register_tags(

eqns.append(eqn.replace(params=eqn_params))

del sub_graph, final_jaxprs, final_tag_locations
del final_jaxprs

# Insert the sub-registrations into the tagged_params
for tag_l in sub_tag_locations:
Expand Down

0 comments on commit 23155a8

Please sign in to comment.