-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcommons.py
73 lines (62 loc) · 2.71 KB
/
commons.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""
This file contains some functions and classes which can be useful in very diverse projects.
"""
import os
import sys
import torch
import random
import logging
import traceback
import numpy as np
from os.path import join
def make_deterministic(seed=0):
"""Make results deterministic. If seed == -1, do not make deterministic.
Running the script in a deterministic way might slow it down.
"""
if seed == -1:
return
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def setup_logging(save_dir, console="debug",
info_filename="info.log", debug_filename="debug.log"):
"""Set up logging files and console output.
Creates one file for INFO logs and one for DEBUG logs.
Args:
save_dir (str): creates the folder where to save the files.
debug (str):
if == "debug" prints on console debug messages and higher
if == "info" prints on console info messages and higher
if == None does not use console (useful when a logger has already been set)
info_filename (str): the name of the info file. if None, don't create info file
debug_filename (str): the name of the debug file. if None, don't create debug file
"""
if os.path.exists(save_dir):
raise FileExistsError(f"{save_dir} already exists!")
os.makedirs(save_dir, exist_ok=True)
# logging.Logger.manager.loggerDict.keys() to check which loggers are in use
base_formatter = logging.Formatter('%(asctime)s %(message)s', "%Y-%m-%d %H:%M:%S")
logger = logging.getLogger('')
logger.setLevel(logging.DEBUG)
if info_filename != None:
info_file_handler = logging.FileHandler(join(save_dir, info_filename))
info_file_handler.setLevel(logging.INFO)
info_file_handler.setFormatter(base_formatter)
logger.addHandler(info_file_handler)
if debug_filename != None:
debug_file_handler = logging.FileHandler(join(save_dir, debug_filename))
debug_file_handler.setLevel(logging.DEBUG)
debug_file_handler.setFormatter(base_formatter)
logger.addHandler(debug_file_handler)
if console != None:
console_handler = logging.StreamHandler()
if console == "debug": console_handler.setLevel(logging.DEBUG)
if console == "info": console_handler.setLevel(logging.INFO)
console_handler.setFormatter(base_formatter)
logger.addHandler(console_handler)
def exception_handler(type_, value, tb):
logger.info("\n" + "".join(traceback.format_exception(type, value, tb)))
sys.excepthook = exception_handler