-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtorch_wavelet.py
118 lines (98 loc) · 4.51 KB
/
torch_wavelet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import time
import pywt
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from torch.autograd import Variable, gradcheck
class DWT_Function(Function):
@staticmethod
def forward(ctx, x, w_ll, w_lh, w_hl, w_hh):
x = x.contiguous()
ctx.save_for_backward(w_ll, w_lh, w_hl, w_hh)
ctx.shape = x.shape
dim = x.shape[1]
x_ll = torch.nn.functional.conv2d(x, w_ll.expand(dim, -1, -1, -1), stride = 2, groups = dim)
x_lh = torch.nn.functional.conv2d(x, w_lh.expand(dim, -1, -1, -1), stride = 2, groups = dim)
x_hl = torch.nn.functional.conv2d(x, w_hl.expand(dim, -1, -1, -1), stride = 2, groups = dim)
x_hh = torch.nn.functional.conv2d(x, w_hh.expand(dim, -1, -1, -1), stride = 2, groups = dim)
x = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
return x
@staticmethod
def backward(ctx, dx):
if ctx.needs_input_grad[0]:
w_ll, w_lh, w_hl, w_hh = ctx.saved_tensors
B, C, H, W = ctx.shape
dx = dx.view(B, 4, -1, H//2, W//2)
dx = dx.transpose(1,2).reshape(B, -1, H//2, W//2)
filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0).repeat(C, 1, 1, 1)
dx = torch.nn.functional.conv_transpose2d(dx, filters, stride=2, groups=C)
return dx, None, None, None, None
class IDWT_Function(Function):
@staticmethod
def forward(ctx, x, filters):
ctx.save_for_backward(filters)
ctx.shape = x.shape
B, _, H, W = x.shape
x = x.view(B, 4, -1, H, W).transpose(1, 2)
C = x.shape[1]
x = x.reshape(B, -1, H, W)
filters = filters.repeat(C, 1, 1, 1)
x = torch.nn.functional.conv_transpose2d(x, filters, stride=2, groups=C)
return x
@staticmethod
def backward(ctx, dx):
if ctx.needs_input_grad[0]:
filters = ctx.saved_tensors
filters = filters[0]
B, C, H, W = ctx.shape
C = C // 4
dx = dx.contiguous()
w_ll, w_lh, w_hl, w_hh = torch.unbind(filters, dim=0)
x_ll = torch.nn.functional.conv2d(dx, w_ll.unsqueeze(1).expand(C, -1, -1, -1), stride = 2, groups = C)
x_lh = torch.nn.functional.conv2d(dx, w_lh.unsqueeze(1).expand(C, -1, -1, -1), stride = 2, groups = C)
x_hl = torch.nn.functional.conv2d(dx, w_hl.unsqueeze(1).expand(C, -1, -1, -1), stride = 2, groups = C)
x_hh = torch.nn.functional.conv2d(dx, w_hh.unsqueeze(1).expand(C, -1, -1, -1), stride = 2, groups = C)
dx = torch.cat([x_ll, x_lh, x_hl, x_hh], dim=1)
return dx, None
class IDWT_2D(nn.Module):
def __init__(self, wave):
super(IDWT_2D, self).__init__()
w = pywt.Wavelet(wave)
rec_hi = torch.Tensor(w.rec_hi)
rec_lo = torch.Tensor(w.rec_lo)
w_ll = rec_lo.unsqueeze(0)*rec_lo.unsqueeze(1)
w_lh = rec_lo.unsqueeze(0)*rec_hi.unsqueeze(1)
w_hl = rec_hi.unsqueeze(0)*rec_lo.unsqueeze(1)
w_hh = rec_hi.unsqueeze(0)*rec_hi.unsqueeze(1)
w_ll = w_ll.unsqueeze(0).unsqueeze(1)
w_lh = w_lh.unsqueeze(0).unsqueeze(1)
w_hl = w_hl.unsqueeze(0).unsqueeze(1)
w_hh = w_hh.unsqueeze(0).unsqueeze(1)
filters = torch.cat([w_ll, w_lh, w_hl, w_hh], dim=0)
self.register_buffer('filters', filters)
self.filters = self.filters.to(dtype=torch.float32)
def forward(self, x):
return IDWT_Function.apply(x, self.filters)
class DWT_2D(nn.Module):
def __init__(self, wave):
super(DWT_2D, self).__init__()
w = pywt.Wavelet(wave)
dec_hi = torch.Tensor(w.dec_hi[::-1])
dec_lo = torch.Tensor(w.dec_lo[::-1])
w_ll = dec_lo.unsqueeze(0)*dec_lo.unsqueeze(1)
w_lh = dec_lo.unsqueeze(0)*dec_hi.unsqueeze(1)
w_hl = dec_hi.unsqueeze(0)*dec_lo.unsqueeze(1)
w_hh = dec_hi.unsqueeze(0)*dec_hi.unsqueeze(1)
self.register_buffer('w_ll', w_ll.unsqueeze(0).unsqueeze(0))
self.register_buffer('w_lh', w_lh.unsqueeze(0).unsqueeze(0))
self.register_buffer('w_hl', w_hl.unsqueeze(0).unsqueeze(0))
self.register_buffer('w_hh', w_hh.unsqueeze(0).unsqueeze(0))
self.w_ll = self.w_ll.to(dtype=torch.float32)
self.w_lh = self.w_lh.to(dtype=torch.float32)
self.w_hl = self.w_hl.to(dtype=torch.float32)
self.w_hh = self.w_hh.to(dtype=torch.float32)
def forward(self, x):
return DWT_Function.apply(x, self.w_ll, self.w_lh, self.w_hl, self.w_hh)