-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathutil.py
46 lines (37 loc) · 943 Bytes
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
#!/usr/bin/python
# encoding: utf-8
import numpy as np
import ipdb
import inspect
import random
import os
import torch
def arg_parser():
"""
Create an empty argparse.ArgumentParser.
"""
import argparse
return argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
def get_objects(name_space):
res = {}
for name, obj in inspect.getmembers(name_space):
if inspect.isclass(obj):
res[name] = obj
return res
def set_global_seeds(i):
np.random.seed(i)
random.seed(i)
def softmax(x):
z = x - max(x)
numerator = np.exp(z)
denominator = np.sum(numerator)
softmax = numerator / denominator
return softmax
def path_join(a,b):
return os.path.join(a,b)
def tensor_to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.cpu().detach().numpy()
else:
return tensor
save4float = lambda x:str(round(x,4))