Skip to content

Commit

Permalink
Tue 01 Dec 2020 10:13:22 AM CST add the ablation models in the bienco…
Browse files Browse the repository at this point in the history
…der.py: (1) BERTBiCompEncoder_car; (2) BERTBiCompEncoder_comp; (3) BERTBiCompEncoder_gate.
  • Loading branch information
gmftbyGMFTBY committed Dec 1, 2020
1 parent 376924e commit 442f1fb
Show file tree
Hide file tree
Showing 2 changed files with 351 additions and 7 deletions.
354 changes: 349 additions & 5 deletions models/biencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,350 @@ def forward(self, cid, rid, cid_mask, rid_mask):
loss = (-loss.sum(dim=1)).mean()
return loss, acc

class BERTBiCompEncoder_car(nn.Module):

'''bi-encoder+TCM-{context-aware}
'''

def __init__(self, nhead, dim_feedforward, num_encoder_layers, dropout=0.1, lang='zh'):
super(BERTBiCompEncoder_car, self).__init__()
self.ctx_encoder = BertEmbedding(lang=lang)
self.can_encoder = BertEmbedding(lang=lang)

encoder_layer = nn.TransformerEncoderLayer(
768,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
encoder_norm = nn.LayerNorm(768)
self.trs_encoder = nn.TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
)
self.proj1 = nn.Linear(768*2, 768)
self.gate = nn.Linear(768*3, 768)
self.dropout = nn.Dropout(p=dropout)
self.layernorm = nn.LayerNorm(768)

def _encode(self, cid, rid, cid_mask, rid_mask):
cid_rep = self.ctx_encoder(cid, cid_mask)
rid_rep = self.can_encoder(rid, rid_mask)
return cid_rep, rid_rep

@torch.no_grad()
def predict(self, cid, rid, rid_mask):
# cid_rep: [1, E]; rid_rep: [S, E]
batch_size = rid.shape[0]
cid_rep, rid_rep = self._encode(cid.unsqueeze(0), rid, None, rid_mask)
cid_rep = cid_rep.squeeze(0) # [E]
cross_rep = torch.cat(
[
# cid_rep.unsqueeze(0).expand(batch_size, -1),
rid_rep,
rid_rep,
],
dim=1,
) # [S, 2*E]

cross_rep = self.dropout(
torch.tanh(
self.trs_encoder(
torch.tanh(
self.proj1(cross_rep).unsqueeze(1)
)
)
).squeeze(1)
) # [S, E]

gate = torch.sigmoid(
self.gate(
torch.cat(
[
rid_rep, # [S, E]
cid_rep.unsqueeze(0).expand(batch_size, -1), # [S, E]
cross_rep, # [S, E]
],
dim=-1,
)
)
) # [S, E]
# cross_rep: [S, E]
cross_rep = self.layernorm(gate * rid_rep + (1 - gate) * cross_rep)
# cid: [E]; cross_rep: [S, E]
dot_product = torch.matmul(cid_rep, cross_rep.t()) # [S]
return dot_product

def forward(self, cid, rid, cid_mask, rid_mask):
batch_size = cid.shape[0]
assert batch_size > 1, f'[!] batch size must bigger than 1, cause other elements in the batch will be seen as the negative samples'
cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask) # [B, E]

# cross attention for all the candidates
cross_rep = []
for cid_rep_ in cid_rep:
cross_rep.append(
torch.cat([rid_rep, rid_rep], dim=-1)
) # [S, E*2]
cross_rep = torch.stack(cross_rep).permute(1, 0, 2) # [B, S, 2*E] -> [S, B, E*2]
cross_rep = self.dropout(
torch.tanh(
self.trs_encoder(
torch.tanh(self.proj1(cross_rep)),
)
).permute(1, 0, 2)
) # [B, S, E]

gate = torch.sigmoid(
self.gate(
torch.cat(
[
rid_rep.unsqueeze(0).expand(batch_size, -1, -1),
cid_rep.unsqueeze(1).expand(-1, batch_size, -1),
cross_rep,
],
dim=-1
)
)
) # [B, S, E]
cross_rep = self.layernorm(gate * rid_rep.unsqueeze(0).expand(batch_size, -1, -1) + (1 - gate) * cross_rep) # [B, S, E]

# reconstruct rid_rep
cid_rep = cid_rep.unsqueeze(1) # [B, 1, E]
dot_product = torch.bmm(cid_rep, cross_rep.permute(0, 2, 1)).squeeze(1) # [B, S]
# use half for supporting the apex
mask = to_cuda(torch.eye(batch_size)).half() # [B, B]
# calculate accuracy
acc_num = (F.softmax(dot_product, dim=-1).max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).sum().item()
acc = acc_num / batch_size
# calculate the loss
loss = F.log_softmax(dot_product, dim=-1) * mask
loss = (-loss.sum(dim=1)).mean()
return loss, acc

class BERTBiCompEncoder_comp(nn.Module):

'''During training, the other elements in the batch are seen as the negative samples, which will lead to the fast training speed. More details can be found in paper: https://arxiv.org/pdf/1905.01969v2.pdf
reference: https://github.com/chijames/Poly-Encoder/blob/master/encoder.py
Set the different learning ratio
'''

def __init__(self, nhead, dim_feedforward, num_encoder_layers, dropout=0.1, lang='zh'):
super(BERTBiCompEncoder_comp, self).__init__()
self.ctx_encoder = BertEmbedding(lang=lang)
self.can_encoder = BertEmbedding(lang=lang)

encoder_layer = nn.TransformerEncoderLayer(
768,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
encoder_norm = nn.LayerNorm(768)
self.trs_encoder = nn.TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
)

self.proj1 = nn.Linear(768*2, 768)
self.gate = nn.Linear(768*3, 768)
self.dropout = nn.Dropout(p=dropout)
self.layernorm = nn.LayerNorm(768)

def _encode(self, cid, rid, cid_mask, rid_mask):
cid_rep = self.ctx_encoder(cid, cid_mask)
rid_rep = self.can_encoder(rid, rid_mask)
return cid_rep, rid_rep

@torch.no_grad()
def predict(self, cid, rid, rid_mask):
# cid_rep: [1, E]; rid_rep: [S, E]
batch_size = rid.shape[0]
cid_rep, rid_rep = self._encode(cid.unsqueeze(0), rid, None, rid_mask)
cid_rep = cid_rep.squeeze(0) # [E]
cross_rep = torch.cat(
[
cid_rep.unsqueeze(0).expand(batch_size, -1),
rid_rep,
],
dim=1,
) # [S, 2*E]

cross_rep = self.dropout(
torch.tanh(
self.proj1(cross_rep).unsqueeze(1)
).squeeze(1)
) # [S, E]

gate = torch.sigmoid(
self.gate(
torch.cat(
[
rid_rep, # [S, E]
cid_rep.unsqueeze(0).expand(batch_size, -1), # [S, E]
cross_rep, # [S, E]
],
dim=-1,
)
)
) # [S, E]
# cross_rep: [S, E]
cross_rep = self.layernorm(gate * rid_rep + (1 - gate) * cross_rep)
# cid: [E]; cross_rep: [S, E]
dot_product = torch.matmul(cid_rep, cross_rep.t()) # [S]
return dot_product

def forward(self, cid, rid, cid_mask, rid_mask):
batch_size = cid.shape[0]
assert batch_size > 1, f'[!] batch size must bigger than 1, cause other elements in the batch will be seen as the negative samples'
cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask) # [B, E]

# cross attention for all the candidates
cross_rep = []
for cid_rep_ in cid_rep:
cid_rep_ = cid_rep_.unsqueeze(0).expand(batch_size, -1) # [S, E]
cross_rep.append(
torch.cat([cid_rep_, rid_rep], dim=-1)
) # [S, E*2]
cross_rep = torch.stack(cross_rep).permute(1, 0, 2) # [B, S, 2*E] -> [S, B, E*2]
cross_rep = self.dropout(
torch.tanh(self.proj1(cross_rep)).permute(1, 0, 2)
) # [B, S, E]

gate = torch.sigmoid(
self.gate(
torch.cat(
[
rid_rep.unsqueeze(0).expand(batch_size, -1, -1),
cid_rep.unsqueeze(1).expand(-1, batch_size, -1),
cross_rep,
],
dim=-1
)
)
) # [B, S, E]
cross_rep = self.layernorm(gate * rid_rep.unsqueeze(0).expand(batch_size, -1, -1) + (1 - gate) * cross_rep) # [B, S, E]

# reconstruct rid_rep
cid_rep = cid_rep.unsqueeze(1) # [B, 1, E]
dot_product = torch.bmm(cid_rep, cross_rep.permute(0, 2, 1)).squeeze(1) # [B, S]
# use half for supporting the apex
mask = to_cuda(torch.eye(batch_size)).half() # [B, B]
# calculate accuracy
acc_num = (F.softmax(dot_product, dim=-1).max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).sum().item()
acc = acc_num / batch_size
# calculate the loss
loss = F.log_softmax(dot_product, dim=-1) * mask
loss = (-loss.sum(dim=1)).mean()
return loss, acc

class BERTBiCompEncoder_gate(nn.Module):

'''During training, the other elements in the batch are seen as the negative samples, which will lead to the fast training speed. More details can be found in paper: https://arxiv.org/pdf/1905.01969v2.pdf
reference: https://github.com/chijames/Poly-Encoder/blob/master/encoder.py
Set the different learning ratio
'''

def __init__(self, nhead, dim_feedforward, num_encoder_layers, dropout=0.1, lang='zh'):
super(BERTBiCompEncoder_gate, self).__init__()
self.ctx_encoder = BertEmbedding(lang=lang)
self.can_encoder = BertEmbedding(lang=lang)

encoder_layer = nn.TransformerEncoderLayer(
768,
nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout,
)
encoder_norm = nn.LayerNorm(768)
self.trs_encoder = nn.TransformerEncoder(
encoder_layer,
num_encoder_layers,
encoder_norm,
)

self.proj1 = nn.Linear(768*2, 768)
self.gate = nn.Linear(768*3, 768)
self.dropout = nn.Dropout(p=dropout)
self.layernorm = nn.LayerNorm(768)

def _encode(self, cid, rid, cid_mask, rid_mask):
cid_rep = self.ctx_encoder(cid, cid_mask)
rid_rep = self.can_encoder(rid, rid_mask)
return cid_rep, rid_rep

@torch.no_grad()
def predict(self, cid, rid, rid_mask):
# cid_rep: [1, E]; rid_rep: [S, E]
batch_size = rid.shape[0]
cid_rep, rid_rep = self._encode(cid.unsqueeze(0), rid, None, rid_mask)
cid_rep = cid_rep.squeeze(0) # [E]
cross_rep = torch.cat(
[
cid_rep.unsqueeze(0).expand(batch_size, -1),
rid_rep,
],
dim=1,
) # [S, 2*E]

cross_rep = self.dropout(
torch.tanh(
self.trs_encoder(
torch.tanh(
self.proj1(cross_rep).unsqueeze(1)
)
)
).squeeze(1)
) # [S, E]

# cross_rep: [S, E]
# cross_rep = self.layernorm(gate * rid_rep + (1 - gate) * cross_rep)
cross_rep = rid_rep + cross_rep
# cid: [E]; cross_rep: [S, E]
dot_product = torch.matmul(cid_rep, cross_rep.t()) # [S]
return dot_product

def forward(self, cid, rid, cid_mask, rid_mask):
batch_size = cid.shape[0]
assert batch_size > 1, f'[!] batch size must bigger than 1, cause other elements in the batch will be seen as the negative samples'
cid_rep, rid_rep = self._encode(cid, rid, cid_mask, rid_mask) # [B, E]

# cross attention for all the candidates
cross_rep = []
for cid_rep_ in cid_rep:
cid_rep_ = cid_rep_.unsqueeze(0).expand(batch_size, -1) # [S, E]
cross_rep.append(
torch.cat([cid_rep_, rid_rep], dim=-1)
) # [S, E*2]
cross_rep = torch.stack(cross_rep).permute(1, 0, 2) # [B, S, 2*E] -> [S, B, E*2]
cross_rep = self.dropout(
torch.tanh(
self.trs_encoder(
torch.tanh(self.proj1(cross_rep))
)
).permute(1, 0, 2)
) # [B, S, E]

# cross_rep = self.layernorm(gate * rid_rep.unsqueeze(0).expand(batch_size, -1, -1) + (1 - gate) * cross_rep) # [B, S, E]
cross_rep = rid_rep.unsqueeze(0).expand(batch_size, -1, -1) + cross_rep

# reconstruct rid_rep
cid_rep = cid_rep.unsqueeze(1) # [B, 1, E]
dot_product = torch.bmm(cid_rep, cross_rep.permute(0, 2, 1)).squeeze(1) # [B, S]
# use half for supporting the apex
mask = to_cuda(torch.eye(batch_size)).half() # [B, B]
# calculate accuracy
acc_num = (F.softmax(dot_product, dim=-1).max(dim=-1)[1] == torch.LongTensor(torch.arange(batch_size)).cuda()).sum().item()
acc = acc_num / batch_size
# calculate the loss
loss = F.log_softmax(dot_product, dim=-1) * mask
loss = (-loss.sum(dim=1)).mean()
return loss, acc

class BERTBiEncoderAgent(RetrievalBaseAgent):

'''
Expand Down Expand Up @@ -496,11 +840,11 @@ def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=Tru
'warmup_steps': int(0.1 * total_step),
'total_step': total_step,
'retrieval_model': model,
'num_encoder_layers': 2,
'dim_feedforward': 512,
'nhead': 6,
'num_encoder_layers': 4,
'dim_feedforward': 2048,
'nhead': 8,
'dropout': 0.1,
'max_len': 256,
'max_len': 512,
'poly_m': 16,
'lang': lang,
}
Expand Down Expand Up @@ -683,7 +1027,7 @@ def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=Tru
'model': 'bert-base-chinese',
'amp_level': 'O2',
'local_rank': local_rank,
'warmup_steps': 8000,
'warmup_steps': int(0.1 * total_step),
'total_step': total_step,
'max_len': 256,
'max_turn_size': 10,
Expand Down
4 changes: 2 additions & 2 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ elif [ $mode = 'train' ]; then
--dataset $dataset \
--model $model \
--mode train \
--batch_size 40 \
--batch_size 32 \
--n_vocab 80000 \
--epoch 10 \
--epoch 5 \
--seed 50 \
--src_len_size 256 \
--tgt_len_size 50 \
Expand Down

0 comments on commit 442f1fb

Please sign in to comment.