Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automate downloading of dependent source packages #2062

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add missing packages, add support for fetching from gitlab
dagardner-nv committed Nov 21, 2024
commit 784f615ebc409d7d013767918df9425b80f727a2
137 changes: 105 additions & 32 deletions scripts/generate_deps.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,8 @@

PIP_FLAGS_RE = re.compile(r"^--.*")
STRIP_VER_RE = re.compile(r"^([\w|-]+).*")
TAG_URL_PATH = "{base_url}/archive/refs/tags/{tag}.tar.gz"
GIT_HUB_TAG_URL_PATH = "{base_url}/archive/refs/tags/{tag}.tar.gz"
GIT_LAB_TAG_URL_PATH = "{base_url}/-/archive/{tag}/{name}-{tag}.tar.gz"

# In some cases multiple packages are derived from a single upstream repo, please keep sorted
PACKAGE_ALIASES = { # <conda package nanme>: <upstream name>
@@ -54,6 +55,7 @@
"python-confluent-kafka": "confluent-kafka-python",
"python-graphviz": "graphviz",
"torch": "pytorch",
'versioneer': 'python-versioneer',
}

KNOWN_GITHUB_URLS = { # <package>: <github repo>, please keep sorted
@@ -70,6 +72,8 @@
'docker-py': 'https://github.com/docker/docker-py',
'elasticsearch-py': 'https://github.com/elastic/elasticsearch-py',
'feedparser': 'https://github.com/kurtmckee/feedparser',
'gflags': 'https://github.com/gflags/gflags',
'glog': 'https://github.com/google/glog',
'graphviz': 'https://github.com/xflr6/graphviz',
'grpc': 'https://github.com/grpc/grpc',
'json': 'https://github.com/nlohmann/json',
@@ -82,6 +86,7 @@
'pluggy': 'https://github.com/pytest-dev/pluggy',
'protobuf': 'https://github.com/protocolbuffers/protobuf',
'pybind11': 'https://github.com/pybind/pybind11',
'pybind11-stubgen': 'https://github.com/sizmailov/pybind11-stubgen',
'pydantic': 'https://github.com/pydantic/pydantic',
'pymilvus': 'https://github.com/milvus-io/pymilvus',
'python-versioneer': 'https://github.com/python-versioneer/python-versioneer',
@@ -95,16 +100,45 @@
'pytorch': 'https://github.com/pytorch/pytorch',
'tqdm': 'https://github.com/tqdm/tqdm',
'typing_utils': 'https://github.com/bojiang/typing_utils',
'versioneer-518': 'https://github.com/python-versioneer/versioneer-518',
'watchdog': 'https://github.com/gorakhargosh/watchdog',
'websockets': 'https://github.com/python-websockets/websockets',
'zlib': 'https://github.com/madler/zlib',
}

KNOWN_GITLAB_URLS = {
'pkg-config': 'https://gitlab.freedesktop.org/pkg-config/pkg-config',
}

# Please keep sorted
KNOWN_FIRST_PARTY = {
'cuda-cudart', 'cuda-nvrtc', 'cuda-nvtx', 'cuda-version', 'cudf', 'mrc', 'rapids-dask-dependency', 'tritonclient'
}

KNOWN_NON_CONDA_DEPS = [('dfencoder', '0.0.37')]
# Some of these packages are installed via CPM (pybind11), others are transitive deps who's version is determined by
# other packages but we use directly (glog), while others exist in the build environment and are statically linked
# (zlib) and not specified in the runtime environment.
# Unfortunately this means these versions will need to be updated manually, although any that exist in the resolved
# environment will have their versions updated to match the resolved environment.
KNOWN_NON_CONDA_DEPS = [
('c-ares', '1.32.3'),
('dfencoder', '0.0.37'),
('gflags', '2.2.2'),
('glog', '0.7.1'),
('nlohmann_json', '3.11.3'),
('librdkafka', '1.6.2'),
('pkg-config', '0.29.2'),
('protobuf', '4.25.3'),
('pybind11', '2.8.1'),
('pybind11-stubgen', '0.10.5'),
('python-versioneer', '0.22'),
('rapidjson', '1.1.0'),
('rdma-core', '54.0'),
('RxCpp', '4.1.1'),
('versioneer', '0.18'),
('versioneer-518', '0.19'), # Conda has a version, but the git repo is unversioned
('zlib', '1.3.1'),
]

TAG_BARE = "{version}"
TAG_V_PREFIX = "v{version}" # Default & most common tag format
@@ -120,7 +154,9 @@
'graphviz': TAG_BARE,
'networkx': TAG_NAME_DASH_BARE,
'pip': TAG_BARE,
'pkg-config': TAG_NAME_DASH_BARE,
'pluggy': TAG_BARE,
'pybind11-stubgen': TAG_BARE,
'python-versioneer': TAG_BARE,
'scikit-learn': TAG_BARE,
'sqlalchemy': lambda ver: f"rel_{ver.replace('.', '_')}",
@@ -130,38 +166,59 @@
logger = logging.getLogger(__file__)


def mk_github_urls(packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any], list[str]]:
def _get_repo_info(url_map: dict, tag_url_path: str, pkg_name: str, repo_name: str, pkg_version: str) -> dict | None:
try:
repo_url = url_map[repo_name]
except KeyError:
return None

tag_formatter = GIT_TAG_FORMAT.get(repo_name, TAG_V_PREFIX)
if isinstance(tag_formatter, str):
tag = tag_formatter.format(name=repo_name, version=pkg_version)
else:
tag = tag_formatter(pkg_version)

tar_url = tag_url_path.format(name=repo_name, base_url=repo_url, tag=tag)

return {'packages': [pkg_name], 'tag': tag, 'tar_url': tar_url}


def _get_github_info(pkg_name: str, repo_name: str, pkg_version: str) -> dict | None:
return _get_repo_info(KNOWN_GITHUB_URLS, GIT_HUB_TAG_URL_PATH, pkg_name, repo_name, pkg_version)


def _get_gitlab_info(pkg_name: str, repo_name: str, pkg_version: str) -> dict | None:
return _get_repo_info(KNOWN_GITLAB_URLS, GIT_LAB_TAG_URL_PATH, pkg_name, repo_name, pkg_version)


def mk_repo_urls(packages: list[tuple[str, str]]) -> tuple[dict[str, typing.Any], list[str]]:
matched = {}
unmatched: list[str] = []
for (pkg_name, pkg_version) in packages:
if pkg_name in KNOWN_FIRST_PARTY:
logger.debug("Skipping first party package: %s", pkg_name)
continue

github_name = PACKAGE_ALIASES.get(pkg_name, pkg_name)
if github_name != pkg_name:
logger.debug("Package %s is knwon as %s", pkg_name, github_name)
repo_name = PACKAGE_ALIASES.get(pkg_name, pkg_name)
if repo_name != pkg_name:
logger.debug("Package %s is knwon as %s", pkg_name, repo_name)

# Some packages share a single upstream repo
if github_name in matched:
matched[github_name]['packages'].append(pkg_name)
if repo_name in matched:
matched[repo_name]['packages'].append(pkg_name)
continue

try:
repo_url = KNOWN_GITHUB_URLS[github_name]
except KeyError:
unmatched.append(pkg_name)
continue
i = 0
repo_info = None
repo_getters = (_get_github_info, _get_gitlab_info)
while repo_info is None and i < len(repo_getters):
repo_info = repo_getters[i](pkg_name, repo_name, pkg_version)
i += 1

tag_formatter = GIT_TAG_FORMAT.get(github_name, TAG_V_PREFIX)
if isinstance(tag_formatter, str):
tag = tag_formatter.format(name=github_name, version=pkg_version)
if repo_info is not None:
matched[repo_name] = repo_info
else:
tag = tag_formatter(pkg_version)

tar_url = TAG_URL_PATH.format(base_url=repo_url, tag=tag)

matched[github_name] = {'packages': [pkg_name], 'tag': tag, 'tar_url': tar_url}
unmatched.append(pkg_name)

return (matched, unmatched)

@@ -231,14 +288,17 @@ def extract_tar_files(dep_urls: dict[str, typing.Any], extract_dir: str):
logger.error("No tar file found for %s", github_name)
continue

if tarfile.is_tarfile(tar_file):
with tarfile.open(tar_file, 'r:*') as tar:
extract_location = os.path.join(extract_dir, github_name)
tar.extractall(path=extract_location)
logger.debug("Extracted %s: %s -> %s", github_name, tar_file, extract_location)
dep_info['extract_location'] = extract_location
else:
logger.error("Not a valid tar file: %s", tar_file)
try:
if tarfile.is_tarfile(tar_file):
with tarfile.open(tar_file, 'r:*') as tar:
extract_location = os.path.join(extract_dir, github_name)
tar.extractall(path=extract_location)
logger.debug("Extracted %s: %s -> %s", github_name, tar_file, extract_location)
dep_info['extract_location'] = extract_location
else:
logger.error("Not a valid tar file: %s", tar_file)
except Exception as e:
raise RuntimeError(f"Failed to extract {tar_file}: {e}") from e


def parse_json_deps(json_file: str) -> dict[str, dict[str, typing.Any]]:
@@ -294,14 +354,27 @@ def parse_env_file(yaml_env_file: str) -> list[str]:
return sorted(parsed_deps)


def _clean_conda_version(version: str) -> str:
"""
strip any conda variant info ex: 1.2.3+cuda11.0
"""
return version.split('+')[0]


def merge_deps(declared_deps: list[str],
other_deps: list[tuple[str, str]],
resolved_conda_deps: dict[str, dict[str, typing.Any]]) -> list[tuple[str, str]]:
merged_deps: list[tuple[str, str]] = other_deps.copy()
merged_deps: list[tuple[str, str]] = []
for (dep, default_ver) in other_deps:
# For some of these (CPM deps) they will not exist in the Conda environment, while others like glog will.
pkg_info = resolved_conda_deps.get(dep, {})
version = _clean_conda_version(pkg_info.get('version', default_ver))
merged_deps.append((dep, version))

for dep in declared_deps:
# intentionally allow a KeyError to be raised in the case of an unmatched package
pkg_info = resolved_conda_deps[dep]
version = pkg_info['version'].split('+')[0] # strip any conda variant info ex: 1.2.3+cuda11.0
version = _clean_conda_version(pkg_info['version'])
merged_deps.append((dep, version))

# Return sorted list just for nicer debug output
@@ -404,7 +477,7 @@ def main():
logger.debug("Resolved Conda deps:\n%s", pprint.pformat(resolved_conda_deps))
logger.debug("Merged deps:\n%s", pprint.pformat(merged_deps))

(dep_urls, unmatched_packages) = mk_github_urls(merged_deps)
(dep_urls, unmatched_packages) = mk_repo_urls(merged_deps)
if len(unmatched_packages) > 0:
logger.error(
"\n------------\nPackages without github info which will need to be fetched manually:\n%s\n------------\n",