Skip to content

Commit

Permalink
[REFACTOR] rename variable 'beta' to 'sqrt_beta'
Browse files Browse the repository at this point in the history
  • Loading branch information
yehjin-shin committed Apr 20, 2024
1 parent a730ab8 commit 1dc3d69
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/model/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def __init__(self, args):
self.out_dropout = nn.Dropout(args.hidden_dropout_prob)
self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12)
self.c = args.c // 2 + 1
self.beta = nn.Parameter(torch.randn(1, 1, args.hidden_size))
self.sqrt_beta = nn.Parameter(torch.randn(1, 1, args.hidden_size))

def forward(self, input_tensor):
# [batch, seq_len, hidden]
Expand All @@ -193,7 +193,7 @@ def forward(self, input_tensor):
low_pass[:, self.c:, :] = 0
low_pass = torch.fft.irfft(low_pass, n=seq_len, dim=1, norm='ortho')
high_pass = input_tensor - low_pass
sequence_emb_fft = low_pass + (self.beta**2) * high_pass
sequence_emb_fft = low_pass + (self.sqrt_beta**2) * high_pass

hidden_states = self.out_dropout(sequence_emb_fft)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
Expand Down

0 comments on commit 1dc3d69

Please sign in to comment.