forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_summary.py
70 lines (62 loc) · 2.44 KB
/
model_summary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import humanfriendly
import numpy as np
import torch
def get_human_readable_count(number: int) -> str:
"""Return human_readable_count
Originated from:
https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py
Abbreviates an integer number with K, M, B, T for thousands, millions,
billions and trillions, respectively.
Examples:
>>> get_human_readable_count(123)
'123 '
>>> get_human_readable_count(1234) # (one thousand)
'1 K'
>>> get_human_readable_count(2e6) # (two million)
'2 M'
>>> get_human_readable_count(3e9) # (three billion)
'3 B'
>>> get_human_readable_count(4e12) # (four trillion)
'4 T'
>>> get_human_readable_count(5e15) # (more than trillion)
'5,000 T'
Args:
number: a positive integer number
Return:
A string formatted according to the pattern described above.
"""
assert number >= 0
labels = [" ", "K", "M", "B", "T"]
num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1)
num_groups = int(np.ceil(num_digits / 3))
num_groups = min(num_groups, len(labels)) # don't abbreviate beyond trillions
shift = -3 * (num_groups - 1)
number = number * (10**shift)
index = num_groups - 1
return f"{number:.2f} {labels[index]}"
def to_bytes(dtype) -> int:
# torch.float16 -> 16
return int(str(dtype)[-2:]) // 8
def model_summary(model: torch.nn.Module) -> str:
message = "Model structure:\n"
message += str(model)
tot_params = sum(p.numel() for p in model.parameters())
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params)
tot_params = get_human_readable_count(tot_params)
num_params = get_human_readable_count(num_params)
message += "\n\nModel summary:\n"
message += f" Class Name: {model.__class__.__name__}\n"
message += f" Total Number of model parameters: {tot_params}\n"
message += (
f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n"
)
num_bytes = humanfriendly.format_size(
sum(
p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad
)
)
message += f" Size: {num_bytes}\n"
dtype = next(iter(model.parameters())).dtype
message += f" Type: {dtype}"
return message