Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ci] added wheel build scripts #910

Merged
merged 13 commits into from
May 5, 2022
Prev Previous commit
Next Next commit
polish code and workflow
  • Loading branch information
FrankLeeeee committed May 2, 2022
commit 1906da24a446d8230b94bc1522d0ca91a3822864
25 changes: 18 additions & 7 deletions .github/workflows/scripts/build_colossalai_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,35 +28,46 @@ def all_wheel_info():
wheel_info = dict()

for a_link in all_a_links:
if 'cu' in a_link.text and '.txt' in a_link.text:
if 'cuda' in a_link.text and '.txt' in a_link.text:
filename = a_link.text
torch_version, cuda_version = filename.rstrip('.txt').split('-')
cuda_version = cuda_version.lstrip('cuda')

wheel_info[torch_version] = dict()
if torch_version not in wheel_info:
wheel_info[torch_version] = dict()
wheel_info[torch_version][cuda_version] = dict()

file_text = requests.get(f'{RAW_TEXT_FILE_PREFIX}/{filename}').text
lines = file_text.strip().split('\n')

for line in lines:
method, url, python_version = line.split('\t')
wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url)
parts = line.split('\t')
method, url, python_version = parts[:3]

if len(parts) > 3:
flags = parts[3]
flags = ' '.join(flags.split('+'))
else:
flags = ''
wheel_info[torch_version][cuda_version][python_version] = dict(method=method, url=url, flags=flags)
return wheel_info


def build_colossalai(wheel_info):
cuda_version_major, cuda_version_minor = get_cuda_bare_metal_version()
cuda_version_on_host = f'cuda{cuda_version_major}.{cuda_version_minor}'
cuda_version_on_host = f'{cuda_version_major}.{cuda_version_minor}'

for torch_version, cuda_versioned_wheel_info in wheel_info.items():
for cuda_version, python_versioned_wheel_info in cuda_versioned_wheel_info.items():
if cuda_version_on_host == cuda_version:
for python_version, wheel_info in python_versioned_wheel_info.items():
url = wheel_info['url']
method = wheel_info['method']
flags = wheel_info['flags']
filename = url.split('/')[-1].replace('%2B', '+')
cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version}'
os.system(cmd)
cmd = f'bash ./build_colossalai_wheel.sh {method} {url} {filename} {cuda_version} {python_version} {torch_version} {flags}'
# os.system(cmd)
print(cmd)

def main():
wheel_info = all_wheel_info()
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/scripts/build_colossalai_wheel.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ then

elif [ $1 == 'conda' ]
then
conda install pytorch==$torch_version cudatoolkit=$cuda_version -c pytorch -c conda-forge
conda install pytorch==$torch_version cudatoolkit=$cuda_version $@
echo You may go to the party but be back before midnight.
else
echo Invalid installation method
Expand Down