Skip to content

Commit

Permalink
[Feature] Added implement_for decorator (pytorch#618)
Browse files Browse the repository at this point in the history
* [Feature] Added `implement_for` decorator (#)

* `from_version` can be open (`None`).

* Changed behaviour in case of missing module/version as was discussed.

* Improved tests and exposed `implement_for` in the documentation

Co-authored-by: vmoens <vincentmoens@gmail.com>
  • Loading branch information
ordinskiy and vmoens authored Nov 2, 2022
1 parent 530dac3 commit e96fd37
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/reference/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@ API Reference
modules
objectives
trainers
utils
13 changes: 13 additions & 0 deletions docs/source/reference/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
.. currentmodule:: torchrl._utils

torchrl._utils package
====================

Set of utility methods that are used internally by the library.


.. autosummary::
:toctree: generated/
:template: rl_template.rst

implement_for
4 changes: 4 additions & 0 deletions test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from torchrl.envs import EnvBase


# Specified for test_utils.py
__version__ = "0.3"


def get_relative_path(curr_file, *path_components):
return os.path.join(os.path.dirname(curr_file), *path_components)

Expand Down
63 changes: 62 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import os

import pytest
from torchrl._utils import get_binary_env_var
from torchrl._utils import get_binary_env_var, implement_for


@pytest.mark.parametrize("value", ["True", "1", "true"])
Expand Down Expand Up @@ -60,3 +65,59 @@ def test_get_binary_env_var_wrong_value():
finally:
if key in os.environ:
del os.environ[key]


class implement_for_test_functions:
"""
Groups functions that are used in tests for `implement_for` decorator.
"""

@staticmethod
@implement_for("_utils_internal", "0.3")
def select_correct_version():
"""To test from+ range and that this function is correctly selected as the implementation."""
return "0.3+"

@staticmethod
@implement_for("_utils_internal", "0.2", "0.3")
def select_correct_version(): # noqa: F811
"""To test that right bound is not included."""
return "0.2-0.3"

@staticmethod
@implement_for("_utils_internal", "0.1", "0.2")
def select_correct_version(): # noqa: F811
"""To test that function with missing from-to range is ignored."""
return "0.1-0.2"

@staticmethod
@implement_for("missing_module")
def missing_module():
"""To test that calling decorated function with missing module raises an exception."""
return "missing"

@staticmethod
@implement_for("_utils_internal", None, "0.3")
def missing_version():
return "0-0.3"

@staticmethod
@implement_for("_utils_internal", "0.4")
def missing_version(): # noqa: F811
return "0.4+"


def test_implement_for():
assert implement_for_test_functions.select_correct_version() == "0.3+"


def test_implement_for_missing_module():
msg = "Supported version of 'missing_module' has not been found."
with pytest.raises(ModuleNotFoundError, match=msg):
implement_for_test_functions.missing_module()


def test_implement_for_missing_version():
msg = "Supported version of '_utils_internal' has not been found."
with pytest.raises(ModuleNotFoundError, match=msg):
implement_for_test_functions.missing_version()
65 changes: 64 additions & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import math
import os
import time
from functools import wraps
from importlib import import_module

import numpy as np

Expand All @@ -15,6 +17,7 @@ def __init__(self, name):
self.name = name

def __call__(self, fn):
@wraps(fn)
def decorated_fn(*args, **kwargs):
with self:
out = fn(*args, **kwargs)
Expand Down Expand Up @@ -122,7 +125,7 @@ def prod(sequence):


def get_binary_env_var(key):
"""Parses and returns the binary enironment variable value.
"""Parses and returns the binary environment variable value.
If not present in environment, it is considered `False`.
Expand Down Expand Up @@ -176,3 +179,63 @@ def __repr__(self):


_CKPT_BACKEND = _Dynamic_CKPT_BACKEND()


class implement_for:
"""A version decorator that checks the version in the environment and implements a function with the fitting one.
If specified module is missing or there is no fitting implementation, call of the decorated function
will lead to the explicit error.
In case of intersected ranges, first fitting implementation is used.
Args:
module_name: version is checked for the module with this name (e.g. "gym").
from_version: version from which implementation is compatible. Can be open (None).
to_version: version from which implementation is no longer compatible. Can be open (None).
Examples:
>>> @implement_for(“gym”, “0.13”, “0.14”)
>>> def fun(self, x):
This indicates that the function is compatible with gym 0.13+, but doesn't with gym 0.14+.
"""

# Stores pointers to fitting implementations: dict[func_name] = func_pointer
_implementations = {}

def __init__(
self, module_name: str, from_version: str = None, to_version: str = None
):
self.module_name = module_name
self.from_version = from_version
self.to_version = to_version

def __call__(self, fn):
@wraps(fn)
def unsupported():
raise ModuleNotFoundError(
f"Supported version of '{self.module_name}' has not been found."
)

# If the module is missing replace the function with the mock.
try:
module = import_module(self.module_name)
except ModuleNotFoundError:
return unsupported

func_name = f"{fn.__module__}.{fn.__name__}"
implementations = implement_for._implementations

# Return fitting implementation if it was encountered before.
if func_name in implementations:
return implementations[func_name]

version = module.__version__

if (self.from_version is None or version >= self.from_version) and (
self.to_version is None or version < self.to_version
):
implementations[func_name] = fn
return fn

return unsupported

0 comments on commit e96fd37

Please sign in to comment.