Skip to content

Commit

Permalink
[REFACTOR] organize layer classes within each model file
Browse files Browse the repository at this point in the history
  • Loading branch information
yehjin-shin committed Oct 18, 2024
1 parent 0be4c19 commit f21104b
Show file tree
Hide file tree
Showing 6 changed files with 404 additions and 411 deletions.
3 changes: 3 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,11 @@ def main():
else:
args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt')
trainer.load(args.checkpoint_path)

logger.info(f"Load model from {args.checkpoint_path} for test!")
scores, result_info = trainer.test(0)
args.checkpoint_path = os.path.join(args.output_dir, args.train_name + '.pt')
# torch.save(trainer.model.state_dict(), args.checkpoint_path)

else:
early_stopping = EarlyStopping(args.checkpoint_path, logger=logger, patience=args.patience, verbose=True)
Expand Down
346 changes: 0 additions & 346 deletions src/model/_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,349 +170,3 @@ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False
all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64])

return all_encoder_layers


#######################
###### BSARec #######
#######################

class FrequencyLayer(nn.Module):
def __init__(self, args):
super(FrequencyLayer, self).__init__()
self.out_dropout = nn.Dropout(args.hidden_dropout_prob)
self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12)
self.c = args.c // 2 + 1
self.sqrt_beta = nn.Parameter(torch.randn(1, 1, args.hidden_size))

def forward(self, input_tensor):
# [batch, seq_len, hidden]
batch, seq_len, hidden = input_tensor.shape
x = torch.fft.rfft(input_tensor, dim=1, norm='ortho')

low_pass = x[:]
low_pass[:, self.c:, :] = 0
low_pass = torch.fft.irfft(low_pass, n=seq_len, dim=1, norm='ortho')
high_pass = input_tensor - low_pass
sequence_emb_fft = low_pass + (self.sqrt_beta**2) * high_pass

hidden_states = self.out_dropout(sequence_emb_fft)
hidden_states = self.LayerNorm(hidden_states + input_tensor)

return hidden_states

class BSARecLayer(nn.Module):
def __init__(self, args):
super(BSARecLayer, self).__init__()
self.args = args
self.filter_layer = FrequencyLayer(args)
self.attention_layer = MultiHeadAttention(args)
self.alpha = args.alpha

def forward(self, input_tensor, attention_mask):
dsp = self.filter_layer(input_tensor)
gsp = self.attention_layer(input_tensor, attention_mask)
hidden_states = self.alpha * dsp + ( 1 - self.alpha ) * gsp

return hidden_states

class BSARecBlock(nn.Module):
def __init__(self, args):
super(BSARecBlock, self).__init__()
self.layer = BSARecLayer(args)
self.feed_forward = FeedForward(args)

def forward(self, hidden_states, attention_mask):
layer_output = self.layer(hidden_states, attention_mask)
feedforward_output = self.feed_forward(layer_output)
return feedforward_output


#######################
###### FMLP-Rec ######
#######################

class FMLPRecLayer(nn.Module):
def __init__(self, args):
super(FMLPRecLayer, self).__init__()
self.complex_weight = nn.Parameter(torch.randn(1, args.max_seq_length//2 + 1, args.hidden_size, 2, dtype=torch.float32) * 0.02)
self.out_dropout = nn.Dropout(args.hidden_dropout_prob)
self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12)

def forward(self, input_tensor):
# [batch, seq_len, hidden]
batch, seq_len, hidden = input_tensor.shape
x = torch.fft.rfft(input_tensor, dim=1, norm='ortho')

weight = torch.view_as_complex(self.complex_weight)
x = x * weight
sequence_emb_fft = torch.fft.irfft(x, n=seq_len, dim=1, norm='ortho')

hidden_states = self.out_dropout(sequence_emb_fft)
hidden_states = hidden_states + input_tensor

hidden_states = self.LayerNorm(hidden_states)

return hidden_states

class FMLPRecBlock(nn.Module):
def __init__(self, args):
super(FMLPRecBlock, self).__init__()
self.layer = FMLPRecLayer(args)
self.feed_forward = FeedForward(args)

def forward(self, hidden_states):
layer_output = self.layer(hidden_states)
feedforward_output = self.feed_forward(layer_output)
return feedforward_output


#######################
###### FEARec #######
#######################

class FEARecLayer(nn.Module):
def __init__(self, args, fea_layer=0):
super(FEARecLayer, self).__init__()

self.dropout = nn.Dropout(0.1)
self.attn_dropout = nn.Dropout(args.attention_probs_dropout_prob)
self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) # layernorm implemented in fmlp
self.out_dropout = nn.Dropout(args.hidden_dropout_prob)
self.max_item_list_length = args.max_seq_length
self.dual_domain = True

self.global_ratio = args.global_ratio
self.n_layers = args.num_hidden_layers

self.scale = None
self.mask_flag = True
self.output_attention = False

self.num_attention_heads = args.num_attention_heads
self.attention_head_size = int(args.hidden_size / args.num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.dense = nn.Linear(args.hidden_size, args.hidden_size)

self.query = nn.Linear(args.hidden_size, self.all_head_size)
self.key = nn.Linear(args.hidden_size, self.all_head_size)
self.value = nn.Linear(args.hidden_size, self.all_head_size)

self.factor = 10 # config['topk_factor']

self.filter_mixer = None
if self.global_ratio > (1 / self.n_layers):
print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers)))
self.filter_mixer = 'G'
else:
print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers)))
self.filter_mixer = 'L'
self.slide_step = ((self.max_item_list_length // 2 + 1) * (1 - self.global_ratio)) // (self.n_layers - 1)
self.local_ratio = 1 / self.n_layers
self.filter_size = self.local_ratio * (self.max_item_list_length // 2 + 1)

if self.filter_mixer == 'G':
self.w = self.global_ratio
self.s = self.slide_step

if self.filter_mixer == 'L':
self.w = self.local_ratio
self.s = self.filter_size

i = fea_layer
self.left = int(((self.max_item_list_length // 2 + 1) * (1 - self.w)) - (i * self.s))
self.right = int((self.max_item_list_length // 2 + 1) - i * self.s)

self.q_index = list(range(self.left, self.right))
self.k_index = list(range(self.left, self.right))
self.v_index = list(range(self.left, self.right))
# if sample in time domain
self.std = True # config['std']
if self.std:
self.time_q_index = self.q_index
self.time_k_index = self.k_index
self.time_v_index = self.v_index
else:
self.time_q_index = list(range(self.max_item_list_length // 2 + 1))
self.time_k_index = list(range(self.max_item_list_length // 2 + 1))
self.time_v_index = list(range(self.max_item_list_length // 2 + 1))

print('modes_q={}, index_q={}'.format(len(self.q_index), self.q_index))
print('modes_k={}, index_k={}'.format(len(self.k_index), self.k_index))
print('modes_v={}, index_v={}'.format(len(self.v_index), self.v_index))

self.spatial_ratio = args.spatial_ratio

def time_delay_agg_training(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the training phase.
"""
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
pattern = torch.roll(tmp_values, -int(index[i]), -1)
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
return delays_agg

def time_delay_agg_inference(self, values, corr):
"""
SpeedUp version of Autocorrelation (a batch-normalization style design)
This is for the inference phase.
"""
batch = values.shape[0]
head = values.shape[1]
channel = values.shape[2]
length = values.shape[3]
# index init
init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0) \
.repeat(batch, head, channel, 1).to(values.device)
# find top k
top_k = int(self.factor * math.log(length))
mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
weights, delay = torch.topk(mean_value, top_k, dim=-1)
# update corr
tmp_corr = torch.softmax(weights, dim=-1)
# aggregation
tmp_values = values.repeat(1, 1, 1, 2)
delays_agg = torch.zeros_like(values).float()
for i in range(top_k):
tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
delays_agg = delays_agg + pattern * \
(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
return delays_agg

def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape) # [256, 50, 2, 32]
return x

def forward(self, input_tensor, attention_mask):
mixed_query_layer = self.query(input_tensor)
mixed_key_layer = self.key(input_tensor)
mixed_value_layer = self.value(input_tensor)

queries = self.transpose_for_scores(mixed_query_layer)
keys = self.transpose_for_scores(mixed_key_layer)
values = self.transpose_for_scores(mixed_value_layer)

# B, H, L, E = query_layer.shape
# AutoFormer
B, L, H, E = queries.shape
_, S, _, D = values.shape
if L > S:
zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
values = torch.cat([values, zeros], dim=1)
keys = torch.cat([keys, zeros], dim=1)
else:
values = values[:, :L, :, :]
keys = keys[:, :L, :, :]

# period-based dependencies
q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)

# Put it in an empty box.
q_fft_box = torch.zeros(B, H, E, len(self.q_index), device=q_fft.device, dtype=torch.cfloat)

for i, j in enumerate(self.q_index):
q_fft_box[:, :, :, i] = q_fft[:, :, :, j]

k_fft_box = torch.zeros(B, H, E, len(self.k_index), device=q_fft.device, dtype=torch.cfloat)
for i, j in enumerate(self.q_index):
k_fft_box[:, :, :, i] = k_fft[:, :, :, j]

res = q_fft_box * torch.conj(k_fft_box)
box_res = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat)
for i, j in enumerate(self.q_index):
box_res[:, :, :, j] = res[:, :, :, i]

corr = torch.fft.irfft(box_res, dim=-1)

# time delay agg
if self.training:
V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
else:
V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

new_context_layer_shape = V.size()[:-2] + (self.all_head_size,)
context_layer = V.view(*new_context_layer_shape)

if self.dual_domain:
# Put it in an empty box.
# q
q_fft_box = torch.zeros(B, H, E, len(self.time_q_index), device=q_fft.device, dtype=torch.cfloat)

for i, j in enumerate(self.time_q_index):
q_fft_box[:, :, :, i] = q_fft[:, :, :, j]
spatial_q = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat)
for i, j in enumerate(self.time_q_index):
spatial_q[:, :, :, j] = q_fft_box[:, :, :, i]

# k
k_fft_box = torch.zeros(B, H, E, len(self.time_k_index), device=q_fft.device, dtype=torch.cfloat)
for i, j in enumerate(self.time_k_index):
k_fft_box[:, :, :, i] = k_fft[:, :, :, j]
spatial_k = torch.zeros(B, H, E, L // 2 + 1, device=k_fft.device, dtype=torch.cfloat)
for i, j in enumerate(self.time_k_index):
spatial_k[:, :, :, j] = k_fft_box[:, :, :, i]

# v
v_fft = torch.fft.rfft(values.permute(0, 2, 3, 1).contiguous(), dim=-1)
# Put it in an empty box.
v_fft_box = torch.zeros(B, H, E, len(self.time_v_index), device=v_fft.device, dtype=torch.cfloat)

for i, j in enumerate(self.time_v_index):
v_fft_box[:, :, :, i] = v_fft[:, :, :, j]
spatial_v = torch.zeros(B, H, E, L // 2 + 1, device=v_fft.device, dtype=torch.cfloat)
for i, j in enumerate(self.time_v_index):
spatial_v[:, :, :, j] = v_fft_box[:, :, :, i]

queries = torch.fft.irfft(spatial_q, dim=-1)
keys = torch.fft.irfft(spatial_k, dim=-1)
values = torch.fft.irfft(spatial_v, dim=-1)

queries = queries.permute(0, 1, 3, 2)
keys = keys.permute(0, 1, 3, 2)
values = values.permute(0, 1, 3, 2)

attention_scores = torch.matmul(queries, keys.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)

attention_scores = attention_scores + attention_mask
attention_probs = nn.Softmax(dim=-1)(attention_scores)
attention_probs = self.attn_dropout(attention_probs)
qkv = torch.matmul(attention_probs, values) # [256, 2, index, 32]
context_layer_spatial = qkv.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer_spatial.size()[:-2] + (self.all_head_size,)
context_layer_spatial = context_layer_spatial.view(*new_context_layer_shape)
context_layer = (1 - self.spatial_ratio) * context_layer + self.spatial_ratio * context_layer_spatial

hidden_states = self.dense(context_layer)
hidden_states = self.out_dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)

return hidden_states

class FEARecBlock(nn.Module):
def __init__(self, args, layer_num):
super(FEARecBlock, self).__init__()
self.layer = FEARecLayer(args)
self.feed_forward = FeedForward(args)

def forward(self, hidden_states, attention_mask):
layer_output = self.layer(hidden_states, attention_mask)
feedforward_output = self.feed_forward(layer_output)
return feedforward_output
Loading

0 comments on commit f21104b

Please sign in to comment.