Skip to content

Commit

Permalink
fix LTS patching
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed May 10, 2022
1 parent 3a37760 commit 3f19a63
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions light_the_torch/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,27 +207,30 @@ def postprocessing(input, output):


def get_extra_index_urls(computation_backends, channel):
# TODO: this template is not valid for all backends
channel_path = f"{channel.name.lower()}/" if channel != Channel.STABLE else ""
if channel == Channel.STABLE:
channel_paths = [""]
elif channel == Channel.LTS:
channel_paths = [
f"lts/{major}.{minor}/"
for major, minor in [
(1, 8),
]
]
else:
channel_paths = [f"{channel.name.lower()}/"]
return [
f"https://download.pytorch.org/whl/{channel_path}{backend}"
for backend in sorted(computation_backends)
for channel_path, backend in itertools.product(
channel_paths, sorted(computation_backends)
)
]


@contextlib.contextmanager
def patch_link_collection(computation_backends, channel):
if channel == channel != Channel.LTS:
find_links = []
index_urls = get_extra_index_urls(computation_backends, channel)
else:
# TODO: expand this when there are more LTS versions
# TODO: switch this to index_urls when
# https://github.com/pytorch/pytorch/pull/74753 is resolved
find_links = ["https://download.pytorch.org/whl/lts/1.8/torch_lts.html"]
index_urls = []

search_scope = SearchScope.create(find_links=find_links, index_urls=index_urls)
search_scope = SearchScope.create(
find_links=[], index_urls=get_extra_index_urls(computation_backends, channel)
)

@contextlib.contextmanager
def context(input):
Expand Down

0 comments on commit 3f19a63

Please sign in to comment.