forked from AUTOMATIC1111/stable-diffusion-webui
-
Notifications
You must be signed in to change notification settings - Fork 1
/
rng.py
170 lines (115 loc) · 6.36 KB
/
rng.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
from modules import devices, rng_philox, shared
def randn(seed, shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using seed.
Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
manual_seed(seed)
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def randn_local(seed, shape):
"""Generate a tensor with random numbers from a normal distribution using seed.
Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
if shared.opts.randn_source == "NV":
rng = rng_philox.Generator(seed)
return torch.asarray(rng.randn(shape), device=devices.device)
local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
local_generator = torch.Generator(local_device).manual_seed(int(seed))
return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device)
def randn_like(x):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
if shared.opts.randn_source == "CPU" or x.device.type == 'mps':
return torch.randn_like(x, device=devices.cpu).to(x.device)
return torch.randn_like(x)
def randn_without_seed(shape, generator=None):
"""Generate a tensor with random numbers from a normal distribution using the previously initialized generator.
Use either randn() or manual_seed() to initialize the generator."""
if shared.opts.randn_source == "NV":
return torch.asarray((generator or nv_rng).randn(shape), device=devices.device)
if shared.opts.randn_source == "CPU" or devices.device.type == 'mps':
return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device)
return torch.randn(shape, device=devices.device, generator=generator)
def manual_seed(seed):
"""Set up a global random number generator using the specified seed."""
if shared.opts.randn_source == "NV":
global nv_rng
nv_rng = rng_philox.Generator(seed)
return
torch.manual_seed(seed)
def create_generator(seed):
if shared.opts.randn_source == "NV":
return rng_philox.Generator(seed)
device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device
generator = torch.Generator(device).manual_seed(int(seed))
return generator
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
low_norm = low/torch.norm(low, dim=1, keepdim=True)
high_norm = high/torch.norm(high, dim=1, keepdim=True)
dot = (low_norm*high_norm).sum(1)
if dot.mean() > 0.9995:
return low * val + high * (1 - val)
omega = torch.acos(dot)
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
class ImageRNG:
def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0):
self.shape = tuple(map(int, shape))
self.seeds = seeds
self.subseeds = subseeds
self.subseed_strength = subseed_strength
self.seed_resize_from_h = seed_resize_from_h
self.seed_resize_from_w = seed_resize_from_w
self.generators = [create_generator(seed) for seed in seeds]
self.is_first = True
def first(self):
noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], int(self.seed_resize_from_h) // 8, int(self.seed_resize_from_w // 8))
xs = []
for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)):
subnoise = None
if self.subseeds is not None and self.subseed_strength != 0:
subseed = 0 if i >= len(self.subseeds) else self.subseeds[i]
subnoise = randn(subseed, noise_shape)
if noise_shape != self.shape:
noise = randn(seed, noise_shape)
else:
noise = randn(seed, self.shape, generator=generator)
if subnoise is not None:
noise = slerp(self.subseed_strength, noise, subnoise)
if noise_shape != self.shape:
x = randn(seed, self.shape, generator=generator)
dx = (self.shape[2] - noise_shape[2]) // 2
dy = (self.shape[1] - noise_shape[1]) // 2
w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
tx = 0 if dx < 0 else dx
ty = 0 if dy < 0 else dy
dx = max(-dx, 0)
dy = max(-dy, 0)
x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w]
noise = x
xs.append(noise)
eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0
if eta_noise_seed_delta:
self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds]
return torch.stack(xs).to(shared.device)
def next(self):
if self.is_first:
self.is_first = False
return self.first()
xs = []
for generator in self.generators:
x = randn_without_seed(self.shape, generator=generator)
xs.append(x)
return torch.stack(xs).to(shared.device)
devices.randn = randn
devices.randn_local = randn_local
devices.randn_like = randn_like
devices.randn_without_seed = randn_without_seed
devices.manual_seed = manual_seed