diff --git a/setup.py b/setup.py index 21498697..cf634f4f 100644 --- a/setup.py +++ b/setup.py @@ -13,21 +13,12 @@ long_description = f.read() # Minimal requried dependencies (full dependencies in requirements.txt) -install_requires = ['cloudpickle', - 'pyyaml', - 'opencv-python', - 'colored', - 'mpi4py', - 'numpy', +install_requires = ['numpy', 'scipy', - 'matplotlib', - 'seaborn', - 'scikit-image', - 'scikit-learn', - 'imageio', - 'pandas', 'gym', - 'cma'] + 'cloudpickle', + 'pyyaml', + 'colored'] tests_require = ['pytest', 'flake8', 'sphinx', @@ -53,3 +44,19 @@ 'Natural Language :: English', 'Topic :: Scientific/Engineering :: Artificial Intelligence'] ) + + +# ensure PyTorch is installed +import pkg_resources +pkg = None +for name in ['torch', 'torch-nightly']: + try: + pkg = pkg_resources.get_distribution(name) + except pkg_resources.DistributionNotFound: + pass +assert pkg is not None, 'PyTorch is not correctly installed.' + +from distutils.version import LooseVersion +import re +version_msg = 'PyTorch of version above 1.0.0 expected' +assert LooseVersion(re.search(r'\d+[.]\d+[.]\d+', pkg.version)[0]) >= LooseVersion('1.0.0'), version_msg