From f21104b9d3737934eaddb2d29631987667551293 Mon Sep 17 00:00:00 2001 From: yehjin-shin Date: Fri, 18 Oct 2024 21:15:37 +0900 Subject: [PATCH] [REFACTOR] organize layer classes within each model file --- src/main.py | 3 + src/model/_modules.py | 346 ------------------------------------------ src/model/bsarec.py | 85 ++++++++--- src/model/fearec.py | 295 ++++++++++++++++++++++++++++++++--- src/model/fmlprec.py | 78 +++++++--- src/trainers.py | 8 +- 6 files changed, 404 insertions(+), 411 deletions(-) diff --git a/src/main.py b/src/main.py index 9bc0462..c1f1547 100644 --- a/src/main.py +++ b/src/main.py @@ -41,8 +41,11 @@ def main(): else: args.checkpoint_path = os.path.join(args.output_dir, args.load_model + '.pt') trainer.load(args.checkpoint_path) + logger.info(f"Load model from {args.checkpoint_path} for test!") scores, result_info = trainer.test(0) + args.checkpoint_path = os.path.join(args.output_dir, args.train_name + '.pt') + # torch.save(trainer.model.state_dict(), args.checkpoint_path) else: early_stopping = EarlyStopping(args.checkpoint_path, logger=logger, patience=args.patience, verbose=True) diff --git a/src/model/_modules.py b/src/model/_modules.py index 9d506f0..2a6fba5 100644 --- a/src/model/_modules.py +++ b/src/model/_modules.py @@ -170,349 +170,3 @@ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) return all_encoder_layers - - -####################### -###### BSARec ####### -####################### - -class FrequencyLayer(nn.Module): - def __init__(self, args): - super(FrequencyLayer, self).__init__() - self.out_dropout = nn.Dropout(args.hidden_dropout_prob) - self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) - self.c = args.c // 2 + 1 - self.sqrt_beta = nn.Parameter(torch.randn(1, 1, args.hidden_size)) - - def forward(self, input_tensor): - # [batch, seq_len, hidden] - batch, seq_len, hidden = input_tensor.shape - x = torch.fft.rfft(input_tensor, dim=1, norm='ortho') - - low_pass = x[:] - low_pass[:, self.c:, :] = 0 - low_pass = torch.fft.irfft(low_pass, n=seq_len, dim=1, norm='ortho') - high_pass = input_tensor - low_pass - sequence_emb_fft = low_pass + (self.sqrt_beta**2) * high_pass - - hidden_states = self.out_dropout(sequence_emb_fft) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - - return hidden_states - -class BSARecLayer(nn.Module): - def __init__(self, args): - super(BSARecLayer, self).__init__() - self.args = args - self.filter_layer = FrequencyLayer(args) - self.attention_layer = MultiHeadAttention(args) - self.alpha = args.alpha - - def forward(self, input_tensor, attention_mask): - dsp = self.filter_layer(input_tensor) - gsp = self.attention_layer(input_tensor, attention_mask) - hidden_states = self.alpha * dsp + ( 1 - self.alpha ) * gsp - - return hidden_states - -class BSARecBlock(nn.Module): - def __init__(self, args): - super(BSARecBlock, self).__init__() - self.layer = BSARecLayer(args) - self.feed_forward = FeedForward(args) - - def forward(self, hidden_states, attention_mask): - layer_output = self.layer(hidden_states, attention_mask) - feedforward_output = self.feed_forward(layer_output) - return feedforward_output - - -####################### -###### FMLP-Rec ###### -####################### - -class FMLPRecLayer(nn.Module): - def __init__(self, args): - super(FMLPRecLayer, self).__init__() - self.complex_weight = nn.Parameter(torch.randn(1, args.max_seq_length//2 + 1, args.hidden_size, 2, dtype=torch.float32) * 0.02) - self.out_dropout = nn.Dropout(args.hidden_dropout_prob) - self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) - - def forward(self, input_tensor): - # [batch, seq_len, hidden] - batch, seq_len, hidden = input_tensor.shape - x = torch.fft.rfft(input_tensor, dim=1, norm='ortho') - - weight = torch.view_as_complex(self.complex_weight) - x = x * weight - sequence_emb_fft = torch.fft.irfft(x, n=seq_len, dim=1, norm='ortho') - - hidden_states = self.out_dropout(sequence_emb_fft) - hidden_states = hidden_states + input_tensor - - hidden_states = self.LayerNorm(hidden_states) - - return hidden_states - -class FMLPRecBlock(nn.Module): - def __init__(self, args): - super(FMLPRecBlock, self).__init__() - self.layer = FMLPRecLayer(args) - self.feed_forward = FeedForward(args) - - def forward(self, hidden_states): - layer_output = self.layer(hidden_states) - feedforward_output = self.feed_forward(layer_output) - return feedforward_output - - -####################### -###### FEARec ####### -####################### - -class FEARecLayer(nn.Module): - def __init__(self, args, fea_layer=0): - super(FEARecLayer, self).__init__() - - self.dropout = nn.Dropout(0.1) - self.attn_dropout = nn.Dropout(args.attention_probs_dropout_prob) - self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) # layernorm implemented in fmlp - self.out_dropout = nn.Dropout(args.hidden_dropout_prob) - self.max_item_list_length = args.max_seq_length - self.dual_domain = True - - self.global_ratio = args.global_ratio - self.n_layers = args.num_hidden_layers - - self.scale = None - self.mask_flag = True - self.output_attention = False - - self.num_attention_heads = args.num_attention_heads - self.attention_head_size = int(args.hidden_size / args.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - self.dense = nn.Linear(args.hidden_size, args.hidden_size) - - self.query = nn.Linear(args.hidden_size, self.all_head_size) - self.key = nn.Linear(args.hidden_size, self.all_head_size) - self.value = nn.Linear(args.hidden_size, self.all_head_size) - - self.factor = 10 # config['topk_factor'] - - self.filter_mixer = None - if self.global_ratio > (1 / self.n_layers): - print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers))) - self.filter_mixer = 'G' - else: - print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers))) - self.filter_mixer = 'L' - self.slide_step = ((self.max_item_list_length // 2 + 1) * (1 - self.global_ratio)) // (self.n_layers - 1) - self.local_ratio = 1 / self.n_layers - self.filter_size = self.local_ratio * (self.max_item_list_length // 2 + 1) - - if self.filter_mixer == 'G': - self.w = self.global_ratio - self.s = self.slide_step - - if self.filter_mixer == 'L': - self.w = self.local_ratio - self.s = self.filter_size - - i = fea_layer - self.left = int(((self.max_item_list_length // 2 + 1) * (1 - self.w)) - (i * self.s)) - self.right = int((self.max_item_list_length // 2 + 1) - i * self.s) - - self.q_index = list(range(self.left, self.right)) - self.k_index = list(range(self.left, self.right)) - self.v_index = list(range(self.left, self.right)) - # if sample in time domain - self.std = True # config['std'] - if self.std: - self.time_q_index = self.q_index - self.time_k_index = self.k_index - self.time_v_index = self.v_index - else: - self.time_q_index = list(range(self.max_item_list_length // 2 + 1)) - self.time_k_index = list(range(self.max_item_list_length // 2 + 1)) - self.time_v_index = list(range(self.max_item_list_length // 2 + 1)) - - print('modes_q={}, index_q={}'.format(len(self.q_index), self.q_index)) - print('modes_k={}, index_k={}'.format(len(self.k_index), self.k_index)) - print('modes_v={}, index_v={}'.format(len(self.v_index), self.v_index)) - - self.spatial_ratio = args.spatial_ratio - - def time_delay_agg_training(self, values, corr): - """ - SpeedUp version of Autocorrelation (a batch-normalization style design) - This is for the training phase. - """ - head = values.shape[1] - channel = values.shape[2] - length = values.shape[3] - # find top k - top_k = int(self.factor * math.log(length)) - mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) - index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] - weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) - # update corr - tmp_corr = torch.softmax(weights, dim=-1) - # aggregation - tmp_values = values - delays_agg = torch.zeros_like(values).float() - for i in range(top_k): - pattern = torch.roll(tmp_values, -int(index[i]), -1) - delays_agg = delays_agg + pattern * \ - (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) - return delays_agg - - def time_delay_agg_inference(self, values, corr): - """ - SpeedUp version of Autocorrelation (a batch-normalization style design) - This is for the inference phase. - """ - batch = values.shape[0] - head = values.shape[1] - channel = values.shape[2] - length = values.shape[3] - # index init - init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0) \ - .repeat(batch, head, channel, 1).to(values.device) - # find top k - top_k = int(self.factor * math.log(length)) - mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) - weights, delay = torch.topk(mean_value, top_k, dim=-1) - # update corr - tmp_corr = torch.softmax(weights, dim=-1) - # aggregation - tmp_values = values.repeat(1, 1, 1, 2) - delays_agg = torch.zeros_like(values).float() - for i in range(top_k): - tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) - pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) - delays_agg = delays_agg + pattern * \ - (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) - return delays_agg - - def transpose_for_scores(self, x): - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(*new_x_shape) # [256, 50, 2, 32] - return x - - def forward(self, input_tensor, attention_mask): - mixed_query_layer = self.query(input_tensor) - mixed_key_layer = self.key(input_tensor) - mixed_value_layer = self.value(input_tensor) - - queries = self.transpose_for_scores(mixed_query_layer) - keys = self.transpose_for_scores(mixed_key_layer) - values = self.transpose_for_scores(mixed_value_layer) - - # B, H, L, E = query_layer.shape - # AutoFormer - B, L, H, E = queries.shape - _, S, _, D = values.shape - if L > S: - zeros = torch.zeros_like(queries[:, :(L - S), :]).float() - values = torch.cat([values, zeros], dim=1) - keys = torch.cat([keys, zeros], dim=1) - else: - values = values[:, :L, :, :] - keys = keys[:, :L, :, :] - - # period-based dependencies - q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) - k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) - - # Put it in an empty box. - q_fft_box = torch.zeros(B, H, E, len(self.q_index), device=q_fft.device, dtype=torch.cfloat) - - for i, j in enumerate(self.q_index): - q_fft_box[:, :, :, i] = q_fft[:, :, :, j] - - k_fft_box = torch.zeros(B, H, E, len(self.k_index), device=q_fft.device, dtype=torch.cfloat) - for i, j in enumerate(self.q_index): - k_fft_box[:, :, :, i] = k_fft[:, :, :, j] - - res = q_fft_box * torch.conj(k_fft_box) - box_res = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat) - for i, j in enumerate(self.q_index): - box_res[:, :, :, j] = res[:, :, :, i] - - corr = torch.fft.irfft(box_res, dim=-1) - - # time delay agg - if self.training: - V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) - else: - V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) - - new_context_layer_shape = V.size()[:-2] + (self.all_head_size,) - context_layer = V.view(*new_context_layer_shape) - - if self.dual_domain: - # Put it in an empty box. - # q - q_fft_box = torch.zeros(B, H, E, len(self.time_q_index), device=q_fft.device, dtype=torch.cfloat) - - for i, j in enumerate(self.time_q_index): - q_fft_box[:, :, :, i] = q_fft[:, :, :, j] - spatial_q = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat) - for i, j in enumerate(self.time_q_index): - spatial_q[:, :, :, j] = q_fft_box[:, :, :, i] - - # k - k_fft_box = torch.zeros(B, H, E, len(self.time_k_index), device=q_fft.device, dtype=torch.cfloat) - for i, j in enumerate(self.time_k_index): - k_fft_box[:, :, :, i] = k_fft[:, :, :, j] - spatial_k = torch.zeros(B, H, E, L // 2 + 1, device=k_fft.device, dtype=torch.cfloat) - for i, j in enumerate(self.time_k_index): - spatial_k[:, :, :, j] = k_fft_box[:, :, :, i] - - # v - v_fft = torch.fft.rfft(values.permute(0, 2, 3, 1).contiguous(), dim=-1) - # Put it in an empty box. - v_fft_box = torch.zeros(B, H, E, len(self.time_v_index), device=v_fft.device, dtype=torch.cfloat) - - for i, j in enumerate(self.time_v_index): - v_fft_box[:, :, :, i] = v_fft[:, :, :, j] - spatial_v = torch.zeros(B, H, E, L // 2 + 1, device=v_fft.device, dtype=torch.cfloat) - for i, j in enumerate(self.time_v_index): - spatial_v[:, :, :, j] = v_fft_box[:, :, :, i] - - queries = torch.fft.irfft(spatial_q, dim=-1) - keys = torch.fft.irfft(spatial_k, dim=-1) - values = torch.fft.irfft(spatial_v, dim=-1) - - queries = queries.permute(0, 1, 3, 2) - keys = keys.permute(0, 1, 3, 2) - values = values.permute(0, 1, 3, 2) - - attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - attention_scores = attention_scores + attention_mask - attention_probs = nn.Softmax(dim=-1)(attention_scores) - attention_probs = self.attn_dropout(attention_probs) - qkv = torch.matmul(attention_probs, values) # [256, 2, index, 32] - context_layer_spatial = qkv.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer_spatial.size()[:-2] + (self.all_head_size,) - context_layer_spatial = context_layer_spatial.view(*new_context_layer_shape) - context_layer = (1 - self.spatial_ratio) * context_layer + self.spatial_ratio * context_layer_spatial - - hidden_states = self.dense(context_layer) - hidden_states = self.out_dropout(hidden_states) - hidden_states = self.LayerNorm(hidden_states + input_tensor) - - return hidden_states - -class FEARecBlock(nn.Module): - def __init__(self, args, layer_num): - super(FEARecBlock, self).__init__() - self.layer = FEARecLayer(args) - self.feed_forward = FeedForward(args) - - def forward(self, hidden_states, attention_mask): - layer_output = self.layer(hidden_states, attention_mask) - feedforward_output = self.feed_forward(layer_output) - return feedforward_output \ No newline at end of file diff --git a/src/model/bsarec.py b/src/model/bsarec.py index 2cee4e3..a21cfa0 100644 --- a/src/model/bsarec.py +++ b/src/model/bsarec.py @@ -2,24 +2,7 @@ import torch import torch.nn as nn from model._abstract_model import SequentialRecModel -from model._modules import LayerNorm, BSARecBlock - -class BSARecEncoder(nn.Module): - def __init__(self, args): - super(BSARecEncoder, self).__init__() - self.args = args - block = BSARecBlock(args) - self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(args.num_hidden_layers)]) - - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False): - all_encoder_layers = [ hidden_states ] - for layer_module in self.blocks: - hidden_states = layer_module(hidden_states, attention_mask) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) - return all_encoder_layers +from model._modules import LayerNorm, FeedForward, MultiHeadAttention class BSARecModel(SequentialRecModel): def __init__(self, args): @@ -53,3 +36,69 @@ def calculate_loss(self, input_ids, answers, neg_answers, same_target, user_ids) return loss +class BSARecEncoder(nn.Module): + def __init__(self, args): + super(BSARecEncoder, self).__init__() + self.args = args + block = BSARecBlock(args) + self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(args.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False): + all_encoder_layers = [ hidden_states ] + for layer_module in self.blocks: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) + return all_encoder_layers + +class BSARecBlock(nn.Module): + def __init__(self, args): + super(BSARecBlock, self).__init__() + self.layer = BSARecLayer(args) + self.feed_forward = FeedForward(args) + + def forward(self, hidden_states, attention_mask): + layer_output = self.layer(hidden_states, attention_mask) + feedforward_output = self.feed_forward(layer_output) + return feedforward_output + +class BSARecLayer(nn.Module): + def __init__(self, args): + super(BSARecLayer, self).__init__() + self.args = args + self.filter_layer = FrequencyLayer(args) + self.attention_layer = MultiHeadAttention(args) + self.alpha = args.alpha + + def forward(self, input_tensor, attention_mask): + dsp = self.filter_layer(input_tensor) + gsp = self.attention_layer(input_tensor, attention_mask) + hidden_states = self.alpha * dsp + ( 1 - self.alpha ) * gsp + + return hidden_states + +class FrequencyLayer(nn.Module): + def __init__(self, args): + super(FrequencyLayer, self).__init__() + self.out_dropout = nn.Dropout(args.hidden_dropout_prob) + self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) + self.c = args.c // 2 + 1 + self.sqrt_beta = nn.Parameter(torch.randn(1, 1, args.hidden_size)) + + def forward(self, input_tensor): + # [batch, seq_len, hidden] + batch, seq_len, hidden = input_tensor.shape + x = torch.fft.rfft(input_tensor, dim=1, norm='ortho') + + low_pass = x[:] + low_pass[:, self.c:, :] = 0 + low_pass = torch.fft.irfft(low_pass, n=seq_len, dim=1, norm='ortho') + high_pass = input_tensor - low_pass + sequence_emb_fft = low_pass + (self.sqrt_beta**2) * high_pass + + hidden_states = self.out_dropout(sequence_emb_fft) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + + return hidden_states diff --git a/src/model/fearec.py b/src/model/fearec.py index a7b496c..ca3ebde 100644 --- a/src/model/fearec.py +++ b/src/model/fearec.py @@ -1,7 +1,8 @@ +import math import torch import torch.nn as nn from model._abstract_model import SequentialRecModel -from model._modules import LayerNorm, FEARecBlock +from model._modules import LayerNorm, FeedForward """ [Paper] @@ -13,29 +14,6 @@ https://github.com/sudaada/FEARec """ -class FEARecEncoder(nn.Module): - def __init__(self, args): - super(FEARecEncoder, self).__init__() - self.args = args - - self.blocks = [] - for i in range(args.num_hidden_layers): - self.blocks.append(FEARecBlock(args, layer_num=i)) - self.blocks = nn.ModuleList(self.blocks) - - def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False): - - all_encoder_layers = [ hidden_states ] - - for layer_module in self.blocks: - hidden_states = layer_module(hidden_states, attention_mask) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) - - return all_encoder_layers - class FEARecModel(SequentialRecModel): def __init__(self, args): super(FEARecModel, self).__init__(args) @@ -166,3 +144,272 @@ def calculate_loss(self, input_ids, answers, neg_answers, same_target, user_ids) return loss +class FEARecEncoder(nn.Module): + def __init__(self, args): + super(FEARecEncoder, self).__init__() + self.args = args + + self.blocks = [] + for i in range(args.num_hidden_layers): + self.blocks.append(FEARecBlock(args, layer_num=i)) + self.blocks = nn.ModuleList(self.blocks) + + def forward(self, hidden_states, attention_mask, output_all_encoded_layers=False): + + all_encoder_layers = [ hidden_states ] + + for layer_module in self.blocks: + hidden_states = layer_module(hidden_states, attention_mask) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) + + return all_encoder_layers + +class FEARecBlock(nn.Module): + def __init__(self, args, layer_num): + super(FEARecBlock, self).__init__() + self.layer = FEARecLayer(args) + self.feed_forward = FeedForward(args) + + def forward(self, hidden_states, attention_mask): + layer_output = self.layer(hidden_states, attention_mask) + feedforward_output = self.feed_forward(layer_output) + return feedforward_output + +class FEARecLayer(nn.Module): + def __init__(self, args, fea_layer=0): + super(FEARecLayer, self).__init__() + + self.dropout = nn.Dropout(0.1) + self.attn_dropout = nn.Dropout(args.attention_probs_dropout_prob) + self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) # layernorm implemented in fmlp + self.out_dropout = nn.Dropout(args.hidden_dropout_prob) + self.max_item_list_length = args.max_seq_length + self.dual_domain = True + + self.global_ratio = args.global_ratio + self.n_layers = args.num_hidden_layers + + self.scale = None + self.mask_flag = True + self.output_attention = False + + self.num_attention_heads = args.num_attention_heads + self.attention_head_size = int(args.hidden_size / args.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + self.dense = nn.Linear(args.hidden_size, args.hidden_size) + + self.query = nn.Linear(args.hidden_size, self.all_head_size) + self.key = nn.Linear(args.hidden_size, self.all_head_size) + self.value = nn.Linear(args.hidden_size, self.all_head_size) + + self.factor = 10 # config['topk_factor'] + + self.filter_mixer = None + if self.global_ratio > (1 / self.n_layers): + print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers))) + self.filter_mixer = 'G' + else: + print("{}>{}:{}".format(self.global_ratio, 1 / self.n_layers, self.global_ratio > (1 / self.n_layers))) + self.filter_mixer = 'L' + self.slide_step = ((self.max_item_list_length // 2 + 1) * (1 - self.global_ratio)) // (self.n_layers - 1) + self.local_ratio = 1 / self.n_layers + self.filter_size = self.local_ratio * (self.max_item_list_length // 2 + 1) + + if self.filter_mixer == 'G': + self.w = self.global_ratio + self.s = self.slide_step + + if self.filter_mixer == 'L': + self.w = self.local_ratio + self.s = self.filter_size + + i = fea_layer + self.left = int(((self.max_item_list_length // 2 + 1) * (1 - self.w)) - (i * self.s)) + self.right = int((self.max_item_list_length // 2 + 1) - i * self.s) + + self.q_index = list(range(self.left, self.right)) + self.k_index = list(range(self.left, self.right)) + self.v_index = list(range(self.left, self.right)) + # if sample in time domain + self.std = True # config['std'] + if self.std: + self.time_q_index = self.q_index + self.time_k_index = self.k_index + self.time_v_index = self.v_index + else: + self.time_q_index = list(range(self.max_item_list_length // 2 + 1)) + self.time_k_index = list(range(self.max_item_list_length // 2 + 1)) + self.time_v_index = list(range(self.max_item_list_length // 2 + 1)) + + print('modes_q={}, index_q={}'.format(len(self.q_index), self.q_index)) + print('modes_k={}, index_k={}'.format(len(self.k_index), self.k_index)) + print('modes_v={}, index_v={}'.format(len(self.v_index), self.v_index)) + + self.spatial_ratio = args.spatial_ratio + + def time_delay_agg_training(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the training phase. + """ + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1] + weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + pattern = torch.roll(tmp_values, -int(index[i]), -1) + delays_agg = delays_agg + pattern * \ + (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + return delays_agg + + def time_delay_agg_inference(self, values, corr): + """ + SpeedUp version of Autocorrelation (a batch-normalization style design) + This is for the inference phase. + """ + batch = values.shape[0] + head = values.shape[1] + channel = values.shape[2] + length = values.shape[3] + # index init + init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0) \ + .repeat(batch, head, channel, 1).to(values.device) + # find top k + top_k = int(self.factor * math.log(length)) + mean_value = torch.mean(torch.mean(corr, dim=1), dim=1) + weights, delay = torch.topk(mean_value, top_k, dim=-1) + # update corr + tmp_corr = torch.softmax(weights, dim=-1) + # aggregation + tmp_values = values.repeat(1, 1, 1, 2) + delays_agg = torch.zeros_like(values).float() + for i in range(top_k): + tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length) + pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay) + delays_agg = delays_agg + pattern * \ + (tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)) + return delays_agg + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) # [256, 50, 2, 32] + return x + + def forward(self, input_tensor, attention_mask): + mixed_query_layer = self.query(input_tensor) + mixed_key_layer = self.key(input_tensor) + mixed_value_layer = self.value(input_tensor) + + queries = self.transpose_for_scores(mixed_query_layer) + keys = self.transpose_for_scores(mixed_key_layer) + values = self.transpose_for_scores(mixed_value_layer) + + # B, H, L, E = query_layer.shape + # AutoFormer + B, L, H, E = queries.shape + _, S, _, D = values.shape + if L > S: + zeros = torch.zeros_like(queries[:, :(L - S), :]).float() + values = torch.cat([values, zeros], dim=1) + keys = torch.cat([keys, zeros], dim=1) + else: + values = values[:, :L, :, :] + keys = keys[:, :L, :, :] + + # period-based dependencies + q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1) + k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1) + + # Put it in an empty box. + q_fft_box = torch.zeros(B, H, E, len(self.q_index), device=q_fft.device, dtype=torch.cfloat) + + for i, j in enumerate(self.q_index): + q_fft_box[:, :, :, i] = q_fft[:, :, :, j] + + k_fft_box = torch.zeros(B, H, E, len(self.k_index), device=q_fft.device, dtype=torch.cfloat) + for i, j in enumerate(self.q_index): + k_fft_box[:, :, :, i] = k_fft[:, :, :, j] + + res = q_fft_box * torch.conj(k_fft_box) + box_res = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat) + for i, j in enumerate(self.q_index): + box_res[:, :, :, j] = res[:, :, :, i] + + corr = torch.fft.irfft(box_res, dim=-1) + + # time delay agg + if self.training: + V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + else: + V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2) + + new_context_layer_shape = V.size()[:-2] + (self.all_head_size,) + context_layer = V.view(*new_context_layer_shape) + + if self.dual_domain: + # Put it in an empty box. + # q + q_fft_box = torch.zeros(B, H, E, len(self.time_q_index), device=q_fft.device, dtype=torch.cfloat) + + for i, j in enumerate(self.time_q_index): + q_fft_box[:, :, :, i] = q_fft[:, :, :, j] + spatial_q = torch.zeros(B, H, E, L // 2 + 1, device=q_fft.device, dtype=torch.cfloat) + for i, j in enumerate(self.time_q_index): + spatial_q[:, :, :, j] = q_fft_box[:, :, :, i] + + # k + k_fft_box = torch.zeros(B, H, E, len(self.time_k_index), device=q_fft.device, dtype=torch.cfloat) + for i, j in enumerate(self.time_k_index): + k_fft_box[:, :, :, i] = k_fft[:, :, :, j] + spatial_k = torch.zeros(B, H, E, L // 2 + 1, device=k_fft.device, dtype=torch.cfloat) + for i, j in enumerate(self.time_k_index): + spatial_k[:, :, :, j] = k_fft_box[:, :, :, i] + + # v + v_fft = torch.fft.rfft(values.permute(0, 2, 3, 1).contiguous(), dim=-1) + # Put it in an empty box. + v_fft_box = torch.zeros(B, H, E, len(self.time_v_index), device=v_fft.device, dtype=torch.cfloat) + + for i, j in enumerate(self.time_v_index): + v_fft_box[:, :, :, i] = v_fft[:, :, :, j] + spatial_v = torch.zeros(B, H, E, L // 2 + 1, device=v_fft.device, dtype=torch.cfloat) + for i, j in enumerate(self.time_v_index): + spatial_v[:, :, :, j] = v_fft_box[:, :, :, i] + + queries = torch.fft.irfft(spatial_q, dim=-1) + keys = torch.fft.irfft(spatial_k, dim=-1) + values = torch.fft.irfft(spatial_v, dim=-1) + + queries = queries.permute(0, 1, 3, 2) + keys = keys.permute(0, 1, 3, 2) + values = values.permute(0, 1, 3, 2) + + attention_scores = torch.matmul(queries, keys.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + + attention_scores = attention_scores + attention_mask + attention_probs = nn.Softmax(dim=-1)(attention_scores) + attention_probs = self.attn_dropout(attention_probs) + qkv = torch.matmul(attention_probs, values) # [256, 2, index, 32] + context_layer_spatial = qkv.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer_spatial.size()[:-2] + (self.all_head_size,) + context_layer_spatial = context_layer_spatial.view(*new_context_layer_shape) + context_layer = (1 - self.spatial_ratio) * context_layer + self.spatial_ratio * context_layer_spatial + + hidden_states = self.dense(context_layer) + hidden_states = self.out_dropout(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + + return hidden_states \ No newline at end of file diff --git a/src/model/fmlprec.py b/src/model/fmlprec.py index d5d22ab..a48ce90 100644 --- a/src/model/fmlprec.py +++ b/src/model/fmlprec.py @@ -2,7 +2,7 @@ import torch.nn as nn import copy from model._abstract_model import SequentialRecModel -from model._modules import FMLPRecBlock, LayerNorm +from model._modules import LayerNorm, FeedForward """ [Paper] @@ -13,27 +13,6 @@ [Code Reference] https://github.com/Woeee/FMLP-Rec """ - -class FMLPRecEncoder(nn.Module): - def __init__(self, args): - super(FMLPRecEncoder, self).__init__() - self.args = args - block = FMLPRecBlock(args) - - self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(args.num_hidden_layers)]) - - def forward(self, hidden_states, output_all_encoded_layers=False): - - all_encoder_layers = [ hidden_states ] - - for layer_module in self.blocks: - hidden_states = layer_module(hidden_states,) - if output_all_encoded_layers: - all_encoder_layers.append(hidden_states) - if not output_all_encoded_layers: - all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) - - return all_encoder_layers class FMLPRecModel(SequentialRecModel): def __init__(self, args): @@ -79,3 +58,58 @@ def calculate_loss(self, input_ids, answers, neg_answers, same_target, user_ids) ) return loss + +class FMLPRecEncoder(nn.Module): + def __init__(self, args): + super(FMLPRecEncoder, self).__init__() + self.args = args + block = FMLPRecBlock(args) + + self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(args.num_hidden_layers)]) + + def forward(self, hidden_states, output_all_encoded_layers=False): + + all_encoder_layers = [ hidden_states ] + + for layer_module in self.blocks: + hidden_states = layer_module(hidden_states,) + if output_all_encoded_layers: + all_encoder_layers.append(hidden_states) + if not output_all_encoded_layers: + all_encoder_layers.append(hidden_states) # hidden_states => torch.Size([256, 50, 64]) + + return all_encoder_layers + +class FMLPRecBlock(nn.Module): + def __init__(self, args): + super(FMLPRecBlock, self).__init__() + self.layer = FMLPRecLayer(args) + self.feed_forward = FeedForward(args) + + def forward(self, hidden_states): + layer_output = self.layer(hidden_states) + feedforward_output = self.feed_forward(layer_output) + return feedforward_output + +class FMLPRecLayer(nn.Module): + def __init__(self, args): + super(FMLPRecLayer, self).__init__() + self.complex_weight = nn.Parameter(torch.randn(1, args.max_seq_length//2 + 1, args.hidden_size, 2, dtype=torch.float32) * 0.02) + self.out_dropout = nn.Dropout(args.hidden_dropout_prob) + self.LayerNorm = LayerNorm(args.hidden_size, eps=1e-12) + + def forward(self, input_tensor): + # [batch, seq_len, hidden] + batch, seq_len, hidden = input_tensor.shape + x = torch.fft.rfft(input_tensor, dim=1, norm='ortho') + + weight = torch.view_as_complex(self.complex_weight) + x = x * weight + sequence_emb_fft = torch.fft.irfft(x, n=seq_len, dim=1, norm='ortho') + + hidden_states = self.out_dropout(sequence_emb_fft) + hidden_states = hidden_states + input_tensor + + hidden_states = self.LayerNorm(hidden_states) + + return hidden_states diff --git a/src/trainers.py b/src/trainers.py index 9a212ab..2738e5b 100644 --- a/src/trainers.py +++ b/src/trainers.py @@ -50,7 +50,13 @@ def load(self, file_name): new_dict = torch.load(file_name) self.logger.info(new_dict.keys()) for key in new_dict: - original_state_dict[key]=new_dict[key] + if 'beta' in key: + # print(key) + # new_key = key.replace('beta', 'sqrt_beta') + # original_state_dict[new_key] = new_dict[key] + original_state_dict[key]=new_dict[key] + else: + original_state_dict[key]=new_dict[key] self.model.load_state_dict(original_state_dict) def predict_full(self, seq_out):