From 3c6b9c6eaf106ef50bd859a12cae3c0c89249d34 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 29 Apr 2024 08:50:24 +0100 Subject: [PATCH] [Versioning] Allow any torch version for local builds (#2130) --- .github/scripts/m1_script.sh | 2 +- .github/workflows/wheels.yml | 4 ++-- setup.py | 14 +++++++++----- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/.github/scripts/m1_script.sh b/.github/scripts/m1_script.sh index 6d2f194e3bc..6552d8e4622 100644 --- a/.github/scripts/m1_script.sh +++ b/.github/scripts/m1_script.sh @@ -1,5 +1,5 @@ #!/bin/bash -export BUILD_VERSION=0.4.0 +export TORCHRL_BUILD_VERSION=0.4.0 ${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index f46c20b8a7e..9b2e57db531 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -32,7 +32,7 @@ jobs: run: | export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH" python3 -mpip install wheel - BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel + TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel # NB: wheels have the linux_x86_64 tag so we rename to manylinux1 # find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \; # pytorch/pytorch binaries are also manylinux_2_17 compliant but they @@ -72,7 +72,7 @@ jobs: shell: bash run: | python3 -mpip install wheel - BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel + TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel - name: Upload wheel for the test-wheel job uses: actions/upload-artifact@v2 with: diff --git a/setup.py b/setup.py index fee902d5486..0196cb4a8f4 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,8 @@ def get_version(): version_txt = os.path.join(cwd, "version.txt") with open(version_txt, "r") as f: version = f.readline().strip() - if os.getenv("BUILD_VERSION"): - version = os.getenv("BUILD_VERSION") + if os.getenv("TORCHRL_BUILD_VERSION"): + version = os.getenv("TORCHRL_BUILD_VERSION") elif sha != "Unknown": version += "+" + sha[:7] return version @@ -68,11 +68,13 @@ def write_version_file(version): f.write("git_version = {}\n".format(repr(sha))) -def _get_pytorch_version(is_nightly): +def _get_pytorch_version(is_nightly, is_local): # if "PYTORCH_VERSION" in os.environ: # return f"torch=={os.environ['PYTORCH_VERSION']}" if is_nightly: return "torch>=2.4.0.dev" + elif is_local: + return "torch" return "torch>=2.3.0" @@ -178,10 +180,12 @@ def _main(argv): else: version = get_version() write_version_file(version) + TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION") logging.info("Building wheel {}-{}".format(package_name, version)) - logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}") + logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}") - pytorch_package_dep = _get_pytorch_version(is_nightly) + is_local = TORCHRL_BUILD_VERSION is None + pytorch_package_dep = _get_pytorch_version(is_nightly, is_local) logging.info("-- PyTorch dependency:", pytorch_package_dep) # branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"]) # tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])