Skip to content

Commit

Permalink
Merge pull request #583 from mit-ll-responsible-ai/contextvar-support
Browse files Browse the repository at this point in the history
Add `zen(..., run_in_context: bool)`
  • Loading branch information
rsokl authored Nov 12, 2023
2 parents be15cb1 + 2a38b48 commit 03e4508
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 13 deletions.
3 changes: 2 additions & 1 deletion docs/source/changes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ chronological order. All previous releases should still be available on pip.
.. _v0.12.0:

----------------------
0.12.0rc4 - 2023-11-07
0.12.0rc5 - 2023-11-12
----------------------


Expand Down Expand Up @@ -67,6 +67,7 @@ Improvements
- :class:`~hydra_zen.BuildsFn` was introduced to permit customizable auto-config and type-refinement support in config-creation functions. See :pull:`553`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.make_custom_builds_fn` now accept a `zen_exclude` field for excluding parameters from auto-population, either by name or by pattern. See :pull:`558`.
- :func:`~hydra_zen.builds` and :func:`~hydra_zen.just` can now configure static methods. Previously the incorrect ``_target_`` would be resolved. See :pull:`566`
- :func:`hydra_zen.zen` now has first class support for running code in an isolated :py:class:`contextvars.Context`. This enables users to safely leverage state via :py:class:`contextvars.ContextVar` in their task functions. See :pull:`583`.
- Adds formal support for Python 3.12. See :pull:`555`
- Several new methods were added to :class:`~hydra_zen.ZenStore`, including the abilities to copy, update, and merge stores. As well as remap the groups of a store's entries and delete individual entries. See :pull:`569`
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ classifiers = [
]

[project.optional-dependencies]
test = ["pytest >= 3.8", "hypothesis >= 6.28.0"]
test = ["pytest >= 3.8", "hypothesis >= 6.28.0", "pytest-trio >= 0.8.0"]
pydantic = ["pydantic>=1.8.2,<2.0.0"]
beartype = ["beartype>=0.8.0"]

Expand Down Expand Up @@ -104,6 +104,7 @@ exclude_lines = [

[tool.pytest.ini_options]
xfail_strict = true
trio_mode = true


[tool.pyright]
Expand Down Expand Up @@ -149,6 +150,7 @@ deps = setuptools
pytest
hypothesis
pytest-xdist
pytest-trio
tzdata
commands = pytest tests/ {posargs: -n auto --maxprocesses=4}
Expand Down
78 changes: 67 additions & 11 deletions src/hydra_zen/wrapper/_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import warnings
from collections import defaultdict
from contextvars import copy_context
from copy import deepcopy
from functools import wraps
from inspect import Parameter, signature
from functools import partial, wraps
from inspect import Parameter, iscoroutinefunction, signature
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -139,6 +140,7 @@ def __init__(
pre_call: PreCall = None,
unpack_kwargs: bool = False,
resolve_pre_call: bool = True,
run_in_context: bool = False,
) -> None:
"""
Parameters
Expand All @@ -148,23 +150,41 @@ def __init__(
unpack_kwargs: bool, optional (default=False)
If `True` a `**kwargs` field in the wrapped function's signature will be
populated by all of the input config entries that are not specified by the
rest of the signature (and that are not specified by the `exclude`
argument).
populated by all of the input config entries that are not specified by the rest
of the signature (and that are not specified by the `exclude` argument).
pre_call : Optional[Callable[[Any], Any] | Iterable[Callable[[Any], Any]]]
One or more functions that will be called with the input config prior
to the wrapped functions. An iterable of pre-call functions are called
to the wrapped function. An iterable of pre-call functions are called
from left (low-index) to right (high-index).
exclude: Optional[str | Iterable[str]]
This is useful, e.g., for seeding a RNG prior to the instantiation phase
that is triggered when calling the wrapped function.
resolve_pre_call : bool, (default=True)
If `True`, the config passed to the zen-wrapped function has its
interpolated fields resolved prior to being passed to any pre-call
functions. Otherwise, the interpolation occurs after the pre-call functions
are called.
exclude : Optional[str | Iterable[str]]
Specifies one or more parameter names in the function's signature
that will not be extracted from input configs by the zen-wrapped function.
A single string of comma-separated names can be specified.
run_in_context : bool, optional (default=False)
If `True`, the zen-wrapped function - and the `pre_call` function, if
specified - is run in a copied :py:class:`contextvars.Context`; i.e.
changes made to any :py:class:`contextvars.ContextVar` will be isolated to
that call of the wrapped function.
`run_in_context` is not supported for async functions.
"""
self.func: Callable[P, R] = __func
if run_in_context and iscoroutinefunction(__func):
raise TypeError(f"`{run_in_context=} is not supported for async functions.")

self.func: Callable[P, R] = __func
try:
# Must cast to dict so that `self` is pickle-compatible.
self.parameters: Mapping[str, Parameter] = dict(
Expand All @@ -183,11 +203,18 @@ def __init__(
raise TypeError(
f"`resolve_pre_call` must be type `bool` got {resolve_pre_call}"
)

if not isinstance(run_in_context, bool): # pragma: no cover
raise TypeError(
f"`run_in_context` must be type `bool` got {run_in_context}"
)

self._resolve = resolve_pre_call
self._unpack_kwargs: bool = unpack_kwargs and any(
p.kind is p.VAR_KEYWORD for p in self.parameters.values()
)

self._run_in_context: bool = run_in_context
self._exclude: Set[str]

if exclude is None:
Expand Down Expand Up @@ -216,7 +243,13 @@ def __init__(
if _f is None:
continue

_f_params = signature(_f).parameters
if run_in_context and isinstance(_f, Zen) and _f._run_in_context:
raise HydraZenValidationError(
f"zen-wrapped pre_call function {_f!r} cannot specify "
f"`run_in_context=True` when the main wrapper specifies it as well."
)

_f_params = signature(_f).parameters # type: ignore

if (sum(p.default is p.empty for p in _f_params.values()) > 1) or len(
_f_params
Expand Down Expand Up @@ -353,8 +386,15 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R:
# resolves all interpolated values in-place
OmegaConf.resolve(cfg)

context = copy_context() if self._run_in_context else None

if self.pre_call is not None:
self.pre_call(cfg)
pre_call = (
self.pre_call
if context is None
else partial(context.run, self.pre_call)
)
pre_call(cfg)

args_ = list(getattr(cfg, "_args_", []))

Expand All @@ -377,7 +417,10 @@ def __call__(self, __cfg: Union[ConfigLike, str]) -> R:
and isinstance(name, str)
)
cfg_kwargs.update({name: cfg[name] for name in names})
return self.func(

func = self.func if context is None else partial(context.run, self.func)

return func(
*(self.instantiate(x) if is_instantiable(x) else x for x in args_),
**{
name: self.instantiate(val) if is_instantiable(val) else val
Expand Down Expand Up @@ -472,6 +515,7 @@ def zen(
pre_call: PreCall = ...,
ZenWrapper: Type[Zen[Any, Any]] = ...,
resolve_pre_call: bool = ...,
run_in_context: bool = ...,
exclude: Optional[Union[str, Iterable[str]]] = ...,
) -> Zen[P, R]:
...
Expand All @@ -485,6 +529,7 @@ def zen(
pre_call: PreCall = ...,
resolve_pre_call: bool = ...,
ZenWrapper: Type[Zen[Any, Any]] = ...,
run_in_context: bool = ...,
exclude: Optional[Union[str, Iterable[str]]] = ...,
) -> Callable[[Callable[P2, R2]], Zen[P2, R2]]:
...
Expand All @@ -497,6 +542,7 @@ def zen(
pre_call: PreCall = None,
exclude: Optional[Union[str, Iterable[str]]] = None,
resolve_pre_call: bool = True,
run_in_context: bool = False,
ZenWrapper: Type[Zen[Any, Any]] = Zen,
) -> Union[Callable[[Callable[P2, R2]], Zen[P2, R2]], Zen[P, R]]:
r"""zen(func, /, pre_call, ZenWrapper)
Expand Down Expand Up @@ -545,6 +591,14 @@ def zen(
A single string of comma-separated names can be specified.
run_in_context : bool, optional (default=False)
If `True`, the zen-wrapped function - and the `pre_call` function, if
specified - is run in a copied :py:class:`contextvars.Context`; i.e.
changes made to any :py:class:`contextvars.ContextVar` will be isolated to
that call of the wrapped function.
`run_in_context` is not supported for async functions.
ZenWrapper : Type[hydra_zen.wrapper.Zen], optional (default=Zen)
If specified, a subclass of `Zen` that customizes the behavior of the wrapper.
Expand Down Expand Up @@ -746,6 +800,7 @@ def task(x: int, y: int):
exclude=exclude,
unpack_kwargs=unpack_kwargs,
resolve_pre_call=resolve_pre_call,
run_in_context=run_in_context,
),
)

Expand All @@ -758,6 +813,7 @@ def wrap(f: Callable[P2, R2]) -> Zen[P2, R2]:
exclude=exclude,
unpack_kwargs=unpack_kwargs,
resolve_pre_call=resolve_pre_call,
run_in_context=run_in_context,
),
)
return out
Expand Down
7 changes: 7 additions & 0 deletions tests/test_zen.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,10 @@ def pikl(x):
def test_pickle_compatible():
loaded = pickle.loads(pickle.dumps(zpikl))
assert loaded({"x": 3}) == pikl(3)


async def test_async_compatible():
async def foo(x: int):
return x

assert await zen(foo)(dict(x=builds(int, 22))) == 22
87 changes: 87 additions & 0 deletions tests/test_zen_context_isolation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2023 Massachusetts Institute of Technology
# SPDX-License-Identifier: MIT

import random
from contextvars import ContextVar
from typing import Any, Dict, Optional

import pytest

from hydra_zen import zen
from hydra_zen.errors import HydraZenValidationError

config: ContextVar[Optional[Dict[str, Any]]] = ContextVar("config", default=None)
var: ContextVar[Dict[str, Any]] = ContextVar("var", default=dict())


@pytest.fixture(autouse=True)
def clean_context_vars():
assert config.get() is None
assert var.get() == {}
yield
config.set(None)
var.set({})


@pytest.mark.parametrize(
"run_in_context",
[
True,
pytest.param(False, marks=pytest.mark.xfail),
],
)
def test_context_isolation(run_in_context: bool):
def foo(x: str, zen_cfg):
config.set(zen_cfg)
conf = var.get().copy()
conf[str(random.randint(1, 100))] = random.randint(1, 100)
var.set(conf)
assert len(conf) == 1

zfoo = zen(foo, run_in_context=run_in_context)

for letter in "ab":
zfoo(dict(x=letter))
assert config.get() is None
assert var.get() == dict()


def test_async_func_run_in_context_not_supported():
async def foo():
...

with pytest.raises(TypeError, match="not supported"):
zen(foo, run_in_context=True)


@pytest.mark.parametrize(
"run_in_context",
[
True,
pytest.param(False, marks=pytest.mark.xfail),
],
)
def test_pre_call_shares_context_with_wrapped_func(run_in_context: bool):
assert var.get() == {}

def pre_call(cfg):
var.set({"swagger": 22})

def func():
assert var.get() == {"swagger": 22}

zen(func, pre_call=pre_call, run_in_context=run_in_context)({})
assert var.get() == {}


def test_pre_call_run_in_its_own_context_is_forbidden():
def f(x):
...

with pytest.raises(HydraZenValidationError):
zen(f, pre_call=zen(f, run_in_context=True), run_in_context=True)


def test_validation():
with pytest.raises(TypeError, match="must be type"):
zen(lambda x: x, run_in_context=None) # type: ignore

0 comments on commit 03e4508

Please sign in to comment.