Skip to content

Commit

Permalink
Drop cmake build system, use PyTorch C++ extensions (pytorch#239)
Browse files Browse the repository at this point in the history
  • Loading branch information
zou3519 authored Jun 30, 2022
1 parent 5833b2f commit 61a00e8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 174 deletions.
109 changes: 0 additions & 109 deletions CMakeLists.txt

This file was deleted.

57 changes: 54 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
from build_tools import setup_helpers
from setuptools import setup, find_packages

import glob
from torch.utils.cpp_extension import (
CppExtension,
BuildExtension,
)



def _get_pytorch_version():
if "PYTORCH_VERSION" in os.environ:
Expand Down Expand Up @@ -60,6 +67,50 @@ def _run_cmd(cmd):
return None


def get_extensions():
extension = CppExtension

extra_link_args = []
extra_compile_args = {"cxx": [
"-O3",
"-std=c++14",
"-fdiagnostics-color=always",
]}
debug_mode = os.getenv('DEBUG', '0') == '1'
if debug_mode:
print("Compiling in debug mode")
extra_compile_args = {
"cxx": [
"-O0",
"-fno-inline",
"-g",
"-std=c++14",
"-fdiagnostics-color=always",
]}
extra_link_args = ["-O0", "-g"]

this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "torchrl", "csrc")

extension_sources = set(
os.path.join(extensions_dir, p)
for p in glob.glob(os.path.join(extensions_dir, "*.cpp"))
)
sources = list(extension_sources)

ext_modules = [
extension(
"torchrl._torchrl",
sources,
include_dirs=[this_dir],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]

return ext_modules


def _main():
pytorch_package_dep = _get_pytorch_version()
print("-- PyTorch dependency:", pytorch_package_dep)
Expand All @@ -71,10 +122,10 @@ def _main():
version="0.1",
author="torchrl contributors",
author_email="vmoens@fb.com",
packages=_get_packages(),
ext_modules=setup_helpers.get_ext_modules(),
packages=find_packages(),
ext_modules=get_extensions(),
cmdclass={
"build_ext": setup_helpers.CMakeBuild,
"build_ext": BuildExtension.with_options(no_python_abi_suffix=True),
"clean": clean,
},
install_requires=[pytorch_package_dep, "numpy", "tensorboard", "packaging"],
Expand Down
62 changes: 0 additions & 62 deletions torchrl/csrc/CMakeLists.txt

This file was deleted.

0 comments on commit 61a00e8

Please sign in to comment.