Skip to content

Commit

Permalink
[Versioning] Allow any torch version for local builds (pytorch#2130)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Apr 29, 2024
1 parent 47bb3ef commit 3c6b9c6
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/scripts/m1_script.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash

export BUILD_VERSION=0.4.0
export TORCHRL_BUILD_VERSION=0.4.0

${CONDA_RUN} pip install git+https://github.com/pytorch/tensordict.git -U
4 changes: 2 additions & 2 deletions .github/workflows/wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ jobs:
run: |
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
python3 -mpip install wheel
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
Expand Down Expand Up @@ -72,7 +72,7 @@ jobs:
shell: bash
run: |
python3 -mpip install wheel
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
TORCHRL_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
- name: Upload wheel for the test-wheel job
uses: actions/upload-artifact@v2
with:
Expand Down
14 changes: 9 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def get_version():
version_txt = os.path.join(cwd, "version.txt")
with open(version_txt, "r") as f:
version = f.readline().strip()
if os.getenv("BUILD_VERSION"):
version = os.getenv("BUILD_VERSION")
if os.getenv("TORCHRL_BUILD_VERSION"):
version = os.getenv("TORCHRL_BUILD_VERSION")
elif sha != "Unknown":
version += "+" + sha[:7]
return version
Expand Down Expand Up @@ -68,11 +68,13 @@ def write_version_file(version):
f.write("git_version = {}\n".format(repr(sha)))


def _get_pytorch_version(is_nightly):
def _get_pytorch_version(is_nightly, is_local):
# if "PYTORCH_VERSION" in os.environ:
# return f"torch=={os.environ['PYTORCH_VERSION']}"
if is_nightly:
return "torch>=2.4.0.dev"
elif is_local:
return "torch"
return "torch>=2.3.0"


Expand Down Expand Up @@ -178,10 +180,12 @@ def _main(argv):
else:
version = get_version()
write_version_file(version)
TORCHRL_BUILD_VERSION = os.getenv("TORCHRL_BUILD_VERSION")
logging.info("Building wheel {}-{}".format(package_name, version))
logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}")
logging.info(f"TORCHRL_BUILD_VERSION is {TORCHRL_BUILD_VERSION}")

pytorch_package_dep = _get_pytorch_version(is_nightly)
is_local = TORCHRL_BUILD_VERSION is None
pytorch_package_dep = _get_pytorch_version(is_nightly, is_local)
logging.info("-- PyTorch dependency:", pytorch_package_dep)
# branch = _run_cmd(["git", "rev-parse", "--abbrev-ref", "HEAD"])
# tag = _run_cmd(["git", "describe", "--tags", "--exact-match", "@"])
Expand Down

0 comments on commit 3c6b9c6

Please sign in to comment.