forked from guillaume-be/rust-bert
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdownload-dependencies_distilbert.py
35 lines (25 loc) · 1.17 KB
/
download-dependencies_distilbert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pathlib import Path
import os
import numpy as np
import torch
import subprocess
import requests
if __name__ == "__main__":
target_path = Path.home() / 'rustbert' / 'distilbert'
os.makedirs(str(target_path), exist_ok=True)
weights_url = "https://huggingface.co/sshleifer/tiny-distilbert-base-cased/resolve/main/pytorch_model.bin"
r = requests.get(weights_url, allow_redirects=True)
(target_path / 'pytorch_model.bin').open('wb').write(r.content)
weights = torch.load(target_path / 'pytorch_model.bin', map_location='cpu')
nps = {}
for k, v in weights.items():
nps[k] = np.ascontiguousarray(v.cpu().numpy())
np.savez(target_path / 'model.npz', **nps)
source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()
subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])
os.remove(str(target_path / 'pytorch_model.bin'))
os.remove(str(target_path / 'model.npz'))
assert (target_path / 'model.ot').exists(), "Conversion of the model failed."