Skip to content

Commit

Permalink
Fri 30 Oct 2020 03:55:59 PM CST
Browse files Browse the repository at this point in the history
  • Loading branch information
gmftbyGMFTBY committed Oct 30, 2020
1 parent 63e0573 commit 814e8ea
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 39 deletions.
41 changes: 24 additions & 17 deletions benchmarks/retrieval.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,69 +10,76 @@
* batch size: 32
* seed: 50
* transformer parameters:
* nhead: 6 -> 8
* nhead: 6
* dropout: 0.1
* dim feedforward: 512
* num encoder layer: 2 -> 4
* num encoder layer: 2
* poly m: 16

### 2. experiment results

#### 2.1 10 Candidates

| Models | R1@10 | R2@10 | R5@10 | MRR |
|------------------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.846 | 0.93 | 0.986 | 0.9064 |
| bert bi-encoder ocn | 0.849 | 0.934 | 0.985 | 0.9086 |
| bert bi-encoder ocn | **0.849** | 0.934 | 0.985 | **0.9086** |
| bert polyencoder | 0.84 | 0.932 | 0.987 | 0.9036 |
| bert polyencoder ocn | 0.842 | **0.939** | **0.992** | 0.9071 |

#### 2.2 50 Candidates

| Models | R1@50 | R2@50 | R5@50 | R10@50 | MRR |
|------------------|-------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.644 | 0.762 | 0.856 | 0.912 | 0.7413 |
| bert bi-encoder ocn | 0.671 | 0.762 | 0.861 | 0.921 | 0.7565 |
| bert polyencoder | 0.671 | 0.773 | 0.875 | 0.925 | 0.7615 |
| bert polyencoder ocn | **0.693** | **0.8** | **0.881** | **0.931** | **0.7791** |

#### 2.3 100 Candidates

| Models | R1@100 | R2@100 | R5@100 | R10@100 | MRR |
|------------------|-------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.603 | 0.7 | 0.809 | 0.872 | 0.6961 |
| bert bi-encoder ocn | 0.638 | 0.732 | 0.814 | 0.872 | 0.7273 |
| bert bi-encoder ocn | **0.638** | 0.732 | 0.814 | 0.872 | 0.7273 |
| bert polyencoder | 0.628 | 0.73 | 0.831 | 0.883 | 0.7203 |
| bert polyencoder ocn | 0.634 | **0.748** | **0.843** | **0.894** | **0.7291** |

#### 2.4 150 Candidates

| Models | R1@150 | R2@150 | R5@150 | R10@150 | MRR |
|------------------|-------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.569 | 0.676 | 0.773 | 0.842 | 0.6657 |
| bert bi-encoder ocn | 0.604 | 0.696 | 0.798 | 0.855 | 0.692 |
| bert polyencoder | 0.603 | 0.718 | 0.809 | 0.862 | 0.6977 |
| bert polyencoder ocn | **0.614** | **0.727** | **0.816** | **0.872** | **0.709** |

#### 2.5 200 Candidates

| Models | R1@200 | R2@200 | R5@200 | R10@200 | MRR |
|------------------|-------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.551 | 0.651 | 0.758 | 0.831 | 0.6473 |
| bert bi-encoder ocn | 0.588 | 0.684 | 0.782 | 0.841 | 0.6768 |
| bert polyencoder | 0.584 | 0.695 | 0.788 | 0.846 | 0.6787 |
| bert polyencoder | 0.584 | 0.695 | 0.788 | 0.846 | 0.6787 |
| bert polyencoder ocn | **0.601** | **0.711** | **0.805** | **0.855** | **0.6952** |

#### 2.6 250 Candidates

| Models | R1@250 | R2@250 | R5@250 | R10@250 | MRR |
|------------------|-------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.532 | 0.643 | 0.734 | 0.806 | 0.6298 |
| bert bi-encoder ocn | 0.573 | 0.66 | 0.758 | 0.825 | 0.6594 |
| bert polyencoder | 0.574 | 0.677 | 0.769 | 0.836 | 0.6664 |
| bert polyencoder ocn | **0.577** | **0.695** | **0.792** | **0.841** | **0.6758** |

#### 2.7 300 Candidates

| Models | R1@300 | R2@300 | R5@300 | R10@300 | MRR |
|------------------|-------|-------|-------|-------|------: |
| :------------------: | :-------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.529 | 0.623 | 0.723 | 0.796 | 0.6217 |
| bert bi-encoder ocn | 0.555 | 0.653 | 0.747 | 0.814 | 0.6468 |
| bert polyencoder | 0.56 | 0.659 | 0.765 | 0.825 | 0.6527 |
| bert polyencoder | 0.56 | 0.659 | 0.765 | 0.825 | 0.6527 |
| bert polyencoder ocn | **0.564** | **0.68** | **0.79** | **0.836** | **0.6652** |


## Douban Multi-turn Conversation Dataset
Expand All @@ -87,16 +94,16 @@
* batch size: 32
* seed: 50
* transformer parameters:
* nhead: 6 -> 8
* nhead: 6
* dropout: 0.1
* dim feedforward: 512
* num encoder layer: 2 -> 4
* num encoder layer: 2
* poly m: 16

### 2. experiment results

| Models | R1@10 | R2@10 | R5@10 | MRR |
|------------------|-------|-------|-------|------:|
| :------------------: | :-------: | :-------: | :-------: | :------: |
| bert bi-encoder | 0.2762 | 0.4751 | 0.8177 | 0.4931 |
| bert bi-encoder ocn | 0.3039 | 0.4613 | 0.8122 | 0.5041 |
| bert polyencoder (m=16) | 0.2873 | 0.4586 | 0.8066 | 0.4952 |
| bert polyencoder | 0.2873 | 0.4586 | 0.8066 | 0.4952 |
3 changes: 3 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'rubertirbi': load_rubert_irbi_dataset,
'bertirbicomp': load_bert_irbicomp_dataset,
'polyencoder': load_bert_irbi_dataset,
'polyencodercomp': load_bert_irbicomp_dataset,
'transformer': load_seq2seq_trs_dataset,
}

Expand Down Expand Up @@ -58,6 +59,7 @@
'rubertirbi': RUBERTBiEncoderAgent,
'bertirbicomp': BERTBiEncoderAgent,
'polyencoder': BERTBiEncoderAgent,
'polyencodercomp': BERTBiEncoderAgent,
'transformer': TransformerAgent,
}

Expand All @@ -80,6 +82,7 @@
'bertirbi': [('multi_gpu', 'total_steps'), {'run_mode': 'mode', 'local_rank': 'local_rank'}],
'rubertirbi': [('multi_gpu', 'total_steps'), {'run_mode': 'mode', 'local_rank': 'local_rank'}],
'polyencoder': [('multi_gpu', 'total_steps'), {'run_mode': 'mode', 'local_rank': 'local_rank', 'model': 'bimodel'}],
'polyencodercomp': [('multi_gpu', 'total_steps'), {'run_mode': 'mode', 'local_rank': 'local_rank', 'model': 'bimodel'}],
'bertirbicomp': [('multi_gpu', 'total_steps'), {'run_mode': 'mode', 'local_rank': 'local_rank', 'model': 'bimodel'}],
}

Expand Down
7 changes: 2 additions & 5 deletions dataset_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,10 +290,7 @@ def load_bert_irbi_dataset(args):
if not os.path.exists(data.pp_path):
data.save_pickle()
args['total_steps'] = len(data) * args['epoch'] / args['batch_size']
if args['model'] == 'polyencoder':
args['bimodel'] = args['model']
else:
args['bimodel'] = 'no-compare'
args['bimodel'] = args['model']
return iter_

def load_bert_irbicomp_dataset(args):
Expand All @@ -312,7 +309,7 @@ def load_bert_irbicomp_dataset(args):
if not os.path.exists(data.pp_path):
data.save_pickle()
args['total_steps'] = len(data) * args['epoch'] / args['batch_size']
args['bimodel'] = 'compare'
args['bimodel'] = args['model']
return iter_
# ================================================================================ #

Expand Down
2 changes: 1 addition & 1 deletion models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .bert_retrieval import *
from .bert_retrieval_multi import *
from .bert_nli import *
from .polyencoder import *
from .biencoder import *
from .test import *
from .model_utils import *
from .dialogpt import *
Expand Down
44 changes: 29 additions & 15 deletions models/polyencoder.py → models/biencoder.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
from .header import *

'''
PolyEncoder: https://arxiv.org/pdf/1905.01969v2.pdf
1. Bi-encoder
2. Cross-encoder (refer to bertretrieval)
3. Poly-Encoder
1. Bert Bi-encoder
2. PolyEncoder
'''

class BertEmbedding(nn.Module):
Expand Down Expand Up @@ -95,7 +93,7 @@ def predict(self, cid, rid, rid_mask):
) # [S, E]

gate = torch.sigmoid(
self.gaet(
self.gate(
torch.cat(
[
rid_rep, # [S, E]
Expand Down Expand Up @@ -468,12 +466,15 @@ def forward(self, cid, rid, cid_mask, rid_mask):

class BERTBiEncoderAgent(RetrievalBaseAgent):

'''model parameter can be:
1. compare: bi-encoder with comparsion module
2. no-compare: pure bi-encoder
3. polyencoder: polyencoder'''
'''
model parameter can be:
1. bertirbi: bi-encoder with comparsion module
2. bertirbicomp: pure bi-encoder
3. polyencoder: polyencoder
4. polyencodercomp: polyencoder with the comparsion module
'''

def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=True, model='no-compare'):
def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=True, model='bertirbi'):
super(BERTBiEncoderAgent, self).__init__(kb=kb)
try:
self.gpu_ids = list(range(len(multi_gpu.split(','))))
Expand All @@ -487,22 +488,26 @@ def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=Tru
'talk_samples': 256,
'vocab_file': 'bert-base-chinese',
'pad': 0,
'samples': 300,
'samples': 10,
'model': 'bert-base-chinese',
'amp_level': 'O2',
'local_rank': local_rank,
'warmup_steps': 8000,
'total_step': total_step,
'dmodel': model,
'retrieval_model': model,
'num_encoder_layers': 2,
'dim_feedforward': 512,
'nhead': 6,
'dropout': 0.1,
'max_len': 256,
'poly_m': 16,
# RNN parameters
'embed_size': 512,
'hidden_size': 512,
'num_encoder_layer': 4,
}
self.vocab = BertTokenizer.from_pretrained(self.args['vocab_file'])
if model == 'no-compare':
if model == 'bertirbi':
self.model = BERTBiEncoder()
elif model == 'polyencoder':
self.model = PolyEncoder(
Expand All @@ -516,17 +521,26 @@ def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=Tru
dropout=self.args['dropout'],
m=self.args['poly_m'],
)
else:
elif model == 'bertirbicomp':
self.model = BERTBiCompEncoder(
self.args['nhead'],
self.args['dim_feedforward'],
self.args['num_encoder_layers'],
dropout=self.args['dropout'],
)
elif model == 'DualLSTM':
self.model = DualLSTM(
embed_size=self.args['embed_size'],
hidden_size=self.args['hidden_size'],
num_encoder_layer=self.args['num_encoder_layer'],
dropout=0.5,
)
else:
raise Exception(f'[!] cannot find the model {model}')
if torch.cuda.is_available():
self.model.cuda()
if run_mode == 'train':
if model in ['polyencoder', 'no-compare']:
if model in ['polyencoder', 'bertirbi', 'DualLSTM']:
self.optimizer = transformers.AdamW(
self.model.parameters(),
lr=self.args['lr'],
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ elif [ $mode = 'train' ]; then
--multi_gpu $cuda \
--lang $lang
elif [ $mode = 'test' ]; then
one_batch_model=(kwgpt2 pfgpt2 gpt2gan lccc multigpt2 when2talk bertirbi bertirbicomp polyencoder)
one_batch_model=(kwgpt2 pfgpt2 gpt2gan lccc multigpt2 when2talk bertirbi bertirbicomp polyencoder polyencodercomp)
if [[ ${one_batch_model[@]} =~ $model ]]; then
batch_size=1
else
Expand Down

0 comments on commit 814e8ea

Please sign in to comment.