-
Notifications
You must be signed in to change notification settings - Fork 0
/
run.py
34 lines (29 loc) · 1.22 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import argparse
from torch.profiler import profile, record_function, ProfilerActivity
from trainers.torch_trainer import TorchTrainer
from trainers.slang_trainer import SlangTrainer
from trainers.torchhash_trainer import TorchHashTrainer
from trainers.slanghash_trainer import SlangHashTrainer
import torch.cuda.profiler as profiler
def main(trainer_type):
torch.cuda.empty_cache()
if trainer_type == "slang":
model = SlangTrainer()
elif trainer_type == "torch":
model = TorchTrainer()
elif trainer_type == "slanghash":
model = SlangHashTrainer()
elif trainer_type == "torchhash":
model = TorchHashTrainer()
else:
raise ValueError("Invalid trainer type. Please provide 'slang' or 'torch'.")
model.train(iters=2000, lr = 5e-3)
model.render(saveimg=True)
model.render_path(saveimg=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Train and render a model.')
parser.add_argument('trainer_type', choices=['slang', 'torch', 'slanghash', 'torchhash'],
help='Type of trainer to use: "slang", "torch","slanghash","torchhash"')
args = parser.parse_args()
main(args.trainer_type)