Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improves optional imports #366

Merged
merged 7 commits into from
Mar 17, 2022
Merged
Prev Previous commit
Next Next commit
Moved module_utils.py into dependency.py
  • Loading branch information
mattpopovich authored and zhiqwang committed Mar 16, 2022
commit 50ed38cbf319c548c5a7f5bbb465c9eed8a65599
6 changes: 3 additions & 3 deletions yolort/relay/trt_graphsurgeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from onnx import shape_inference
from torch import Tensor

import yolort.utils.module_utils as _mod_utils
if _mod_utils.is_module_available("onnx_graphsurgeon"):
import yolort.utils.dependency as _dependency
if _dependency.is_module_available("onnx_graphsurgeon"):
import onnx_graphsurgeon as gs

from .trt_inference import YOLOTRTInference
Expand All @@ -23,7 +23,7 @@
__all__ = ["YOLOTRTGraphSurgeon"]


@_mod_utils.requires_module("onnx_graphsurgeon")
@_dependency.requires_module("onnx_graphsurgeon")
class YOLOTRTGraphSurgeon:
"""
YOLOv5 Graph Surgeon for TensorRT inference.
Expand Down
66 changes: 66 additions & 0 deletions yolort/utils/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

import pkg_resources as pkg

import importlib.util
import warnings
from functools import wraps
from typing import Optional

logger = logging.getLogger(__name__)


Expand All @@ -26,3 +31,64 @@ def check_version(
if verbose and not result:
logger.warning(verbose_info)
return result


# Via: https://github.com/pytorch/audio/blob/main/torchaudio/_internal/module_utils.py
def is_module_available(*modules: str) -> bool:
r"""Returns if a top-level module with :attr:`name` exists *without**
importing it. This is generally safer than try-catch block around a
`import X`. It avoids third party libraries breaking assumptions of some of
our tests, e.g., setting multiprocessing start method when imported
(see librosa/#747, torchvision/#544).
"""
return all(importlib.util.find_spec(m) is not None for m in modules)


# Via: https://github.com/pytorch/audio/blob/main/torchaudio/_internal/module_utils.py
def requires_module(*modules: str):
"""Decorate function to give error message if invoked without required optional modules.
This decorator is to give better error message to users rather
than raising ``NameError: name 'module' is not defined`` at random places.
"""
missing = [m for m in modules if not is_module_available(m)]

if not missing:
# fall through. If all the modules are available, no need to decorate
def decorator(func):
return func

else:
req = f"module: {missing[0]}" if len(missing) == 1 else f"modules: {missing}"

def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
raise RuntimeError(f"{func.__module__}.{func.__name__} requires {req}")

return wrapped

return decorator


# Via: https://github.com/pytorch/audio/blob/main/torchaudio/_internal/module_utils.py
def deprecated(direction: str, version: Optional[str] = None):
"""Decorator to add deprecation message
Args:
direction (str): Migration steps to be given to users.
version (str or int): The version when the object will be removed
"""

def decorator(func):
@wraps(func)
def wrapped(*args, **kwargs):
message = (
f"{func.__module__}.{func.__name__} has been deprecated "
f'and will be removed from {"future" if version is None else version} release. '
f"{direction}"
)
warnings.warn(message, stacklevel=2)
return func(*args, **kwargs)

return wrapped

return decorator
66 changes: 0 additions & 66 deletions yolort/utils/module_utils.py

This file was deleted.