Skip to content

Commit

Permalink
Mon 26 Oct 2020 10:50:08 PM CST add the decay ratio for bertirbicomp …
Browse files Browse the repository at this point in the history
…model
  • Loading branch information
gmftbyGMFTBY committed Oct 26, 2020
1 parent 00aa42e commit 2ccc91b
Show file tree
Hide file tree
Showing 7 changed files with 183 additions and 35 deletions.
16 changes: 13 additions & 3 deletions benchmarks/retrieval.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,34 @@
### 1. Hyperparameters:
* epoch: 10
* max sequence length: 256
* negative samples: 32
* negative samples: 32 (only for bi-encoder, cross-attention is 1)
* apex: True
* distributed: True
* batch size: 32
* seed: 50
* decay ratio: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]

### 2. experiment results


| Models | R1@10 | R2@10 | R5@10 | MRR |
|------------------|-------|-------|-------|------: |
| bert bi-encoder | 0.846 | 0.93 | 0.986 | 0.9064 |
| bert bi-encoder ocn | | | | |
| bert bi-encoder ocn (decay ratio=1.0) | 0.84 | 0.928 | 0.986 | 0.9033 |
| bert bi-encoder ocn (decay ratio=0.9) | | | | |
| bert bi-encoder ocn (decay ratio=0.8) | | | | |
| bert bi-encoder ocn (decay ratio=0.7) | | | | |
| bert bi-encoder ocn (decay ratio=0.6) | | | | |
| bert bi-encoder ocn (decay ratio=0.5) | 0.861 | 0.946 | 0.99 | 0.9181 |
| bert bi-encoder ocn (decay ratio=0.4) | | | | |
| bert bi-encoder ocn (decay ratio=0.3) | | | | |
| bert bi-encoder ocn (decay ratio=0.2) | | | | |
| bert bi-encoder ocn (decay ratio=0.1) | | | | |
| bert polyencoder | | | | |
| bert cross-attention | | | | |


## Douban Multi-turn Conversation
## Douban Multi-turn Conversation Dataset

### 1. Hyperparameters:
* epoch: 10
Expand Down
4 changes: 2 additions & 2 deletions data/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ if [ $mode = 'init_es' ]; then
# init the ElasticSearch and restore the retrieval database
python process_data.py --mode insert
elif [ $mode = 'init_topic_guided' ]; then
# python word_graph.py --mode graph
python word_graph.py --mode graph
# generate the words by frequecy from the corpus
python word_graph.py --mode word
# python word_graph.py --mode word
elif [ $mode = 'init_gen' ]; then
# init and create the whole generative dataset
python process_data.py --mode generative
Expand Down
47 changes: 43 additions & 4 deletions data/word_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,22 @@ def multi_search(self, topics, samples=10):
rest = self.es.msearch(body=request)
return rest

def multi_search_edge(self, topics, samples=10):
# limit the querys length
search_arr = []
for topic1, topic2 in topics:
search_arr.append({'index': self.index})
search_arr.append({'query': {'bool': {'must': [{'match': {'utterance': {'query': topic1}}}, {'match': {'utterance': {'query': topic2}}}]}}, 'size': samples})
request = ''
for each in search_arr:
request += f'{json.dumps(each)} \n'
rest = self.es.msearch(body=request)
return rest

def parser_args():
parser = argparse.ArgumentParser(description='wordnet parameters')
parser.add_argument('--weight_threshold', type=float, default=0.6)
parser.add_argument('--topn', type=int, default=10)
parser.add_argument('--topn', type=int, default=200)
parser.add_argument('--mode', type=str, default='graph')
return parser.parse_args()

Expand Down Expand Up @@ -138,7 +150,23 @@ def write_new_w2v(words, path):
vec = w2v[word].tolist()
string = f'{word} {" ".join(map(str, vec))}\n'
f.write(string)


def retrieval_edge(word_pairs, samples=64):
rest = chatter.multi_search_edge(word_pairs, samples=samples)['responses']
flag = []
for (word1, word2), pair_rest in zip(word_pairs, rest):
counter = 0
pair_rest = pair_rest['hits']['hits']
for utterance in pair_rest:
utterance = utterance['_source']['utterance']
if word1 in utterance and word2 in utterance:
counter += 1
if counter >= 0.5 * samples:
flag.append(True)
else:
flag.append(False)
return flag

def retrieval_filter(words, samples=64):
rest = chatter.multi_search(words, samples=samples)
rest = rest['responses']
Expand Down Expand Up @@ -204,12 +232,23 @@ def retrieval_filter(words, samples=64):
if not os.path.exists('wordnet.pkl'):
graph = nx.Graph()
graph.add_nodes_from(w2v.index2word)
# batch_size = 64
# for idx in tqdm(range(0, len(w2v.index2word), batch_size)):
# words = w2v.index2word[idx:idx+batch_size]
# neighbors = []
# for word in words:
# neighbors.extend([(word, n, w) for n, w in w2v.most_similar(word, topn=args['topn'])])
# flag = retrieval_edge([(word, n) for word, n, w in neighbors])
# neighbors = [(word, n, w) for flag_, (word, n, w) in zip(flag, neighbors) if flag_]
# graph.add_weighted_edges_from([(word, n, 1-w) for word, n, w in neighbors])
for word in tqdm(w2v.index2word):
neighbors = w2v.most_similar(word, topn=args['topn'])
graph.add_weighted_edges_from([(word, n, 1 - w) for n, w in neighbors if 1 - w < args['weight_threshold']])
flag = retrieval_edge([(word, i) for i, _ in neighbors])
neighbors = [(n, w) for flag_, (n, w) in zip(flag, neighbors) if flag_]
graph.add_weighted_edges_from([(word, n, 1 - w) for n, w in neighbors])
with open('wordnet.pkl', 'wb') as f:
pickle.dump(graph, f)
print(f'[!] save the word net into wordnet.pkl')
else:
with open('wordnet.pkl', 'rb') as f:
graph = pickle.load(f)
graph = pickle.load(f)
103 changes: 83 additions & 20 deletions models/bert_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,47 +327,110 @@ class BERTRetrievalKGGreedyAgent(BERTRetrievalAgent):

'''fix the talk function for BERTRetrievalAgent
Agent knows the whole knowledge graph path; but the other one doesn"t;
greedy: ACL 2019 Target-Guided Open-Domain Conversation
'''
greedy: ACL 2019 Target-Guided Open-Domain Conversation'''

def __init__(self, multi_gpu, run_mode='train', lang='zh', kb=True, local_rank=0, wordnet=None, talk_samples=128):
super(BERTRetrievalKGGreedyAgent, self).__init__(multi_gpu, run_mode=run_mode, lang=lang, kb=kb, local_rank=local_rank)
self.topic_history = []
self.wordnet = wordnet
self.args['talk_samples'] = talk_samples
self.w2v = gensim.models.KeyedVectors.load_word2vec_format(
'data/chinese_w2v_base.txt', binary=False
)

def reset(self, target, source, path):
self.args['target'], self.args['current_node'], self.args['path'] = target, source, path
def reset(self, target, source):
self.args['target'], self.args['current_node'] = target, source
self.topic_history, self.history = [source], []
print(f'[! Reset the KG target] source: {source}; target: {target}; path: {path}')

def move_on_kg(self):
'''judge whether meet the end'''
if self.topic_history[-1] == self.args['target']:
return
self.args['current_node'] = self.args['path'][len(self.topic_history)]
self.topic_history.append(self.args['current_node'])
def search_candidates(self, msgs, nodes):
'''Noted that input node maybe multiple topic words
f(n) = g(n) + h(n)
1. must have the path to the target √
2. compared with current node, more closer to the target √
3. retrieval utterance must contain both current and candidate node
4. as for g(n), 1) word similarity; 2) average bag of utterances coherence
5. as for f(n), 2) number of the retrieval utterance based on the current node and candidate node and their corresponding average coherence; 3) RL
'''
# generate candidates
candidates = []
for node in nodes:
neighbors = []
base_dis = self.w2v.similarity(node, self.args['target'])
for n in self.wordnet.neighbors(node):
if self.w2v.similarity(n, self.args['target']) >= base_dis:
continue
retrieval_rest = self.searcher.must_search(
msgs, topic=[node, n], samples=self.args['talk_samples']
)
if not retrieval_rest:
continue
try:
path = nx.shortest_path(self.wordnet, n, self.args['target'])
except nx.NetworkXNoPath as e:
continue
neighbors.append((node, n, path))
candidates.extend(neighbors)
# score the f(n) and sort
pass

def move_on_kg(self, current_nodes, size=1):
'''current nodes are extracted from the human utterance (maybe multiple)'''
candidates = self.search_candidates(msgs, current_nodes)[:size]
return candidates

def process_utterances(self, utterances, msgs, max_len=0):
'''Process the utterances searched by Elasticsearch; input_ids/token_type_ids/attn_mask'''
# assert lern(topic) > 0, f'[!] topic words must exists'
# utterances_ = self.searcher.must_search(
# msgs, samples=self.args['talk_samples'], topic=topic
# )
# utterances_ = [i['utterance'] for i in utterances_]
# remove the utterances that in the self.history
# utterances_ = list(set(utterances_) - set(self.history))

# construct inpt_ids, token_type_ids, attn_mask
inpt_ids = self.vocab.batch_encode_plus([msgs] + utterances_)['input_ids']
context_inpt_ids, responses_inpt_ids = inpt_ids[0], inpt_ids[1:]
context_token_type_ids = [0] * len(context_inpt_ids)
responses_token_type_ids = [[1] * len(i) for i in responses_inpt_ids]

# length limitation
collection = []
for r1, r2 in zip(responses_inpt_ids, responses_token_type_ids):
p1, p2 = context_inpt_ids + r1[1:], context_token_type_ids + r2[1:]
if len(p1) > max_len:
cut_size = len(p1) - max_len + 1
p1, p2 = [p1[0]] + p1[cut_size:], [p2[0]] + p2[cut_size:]
collection.append((p1, p2))

inpt_ids = [torch.LongTensor(i[0]) for i in collection]
token_type_ids = [torch.LongTensor(i[1]) for i in collection]

inpt_ids = pad_sequence(inpt_ids, batch_first=True, padding_value=self.args['pad'])
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.args['pad'])
attn_mask_index = inpt_ids.nonzero().tolist()
attn_mask_index_x, attn_mask_index_y = [i[0] for i in attn_mask_index], [i[1] for i in attn_mask_index]
attn_mask = torch.zeros_like(inpt_ids)
attn_mask[attn_mask_index_x, attn_mask_index_y] = 1

if torch.cuda.is_available():
inpt_ids, token_type_ids, attn_mask = inpt_ids.cuda(), token_type_ids.cuda(), attn_mask.cuda()
return utterances_, inpt_ids, token_type_ids, attn_mask

@torch.no_grad()
def talk(self, msgs):
def talk(self, msgs, topics):
''':topic: means the current topic node in the knowledge graph path.'''
self.model.eval()
# 1) inpt the topic information for the coarse filter in elasticsearch
utterances, inpt_ids, token_type_ids, attn_mask = self.process_utterances(
[self.args['current_node']], msgs, max_len=self.args['max_len'],
topics, msgs, max_len=self.args['max_len'],
)
# 2) neural ranking with the topic information
output = self.model(inpt_ids, token_type_ids, attn_mask) # [B, 2]
output = F.softmax(output, dim=-1)[:, 1] # [B]
# 3) post ranking with current topic word
output = torch.argsort(output, descending=True)
for i in output:
if self.args['current_node'] in utterances[i.item()]:
item = i
break
else:
item = 0
# item = torch.argmax(output).item()
item = torch.argmax(output).item()
msg = utterances[item]
return msg

Expand Down
33 changes: 33 additions & 0 deletions models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,39 @@ def __init__(self, index_name, kb=True):
}
}
)

def must_search(self, query, samples=10, topic=None):
'''
query is the string, which contains the utterances of the conversation context.
1. topic is a list contains the topic words
2. query utterance msg
context: query is Q-Q matching
response: query is Q-A matching, which seems better
'''
query = query.replace('[SEP]', '') # Need to replace the [SEP] berfore the searching
subitem_must = [{"match": {"utterance": {"query": i, 'boost': 1}}} for i in topic]
subitem_should = [{'match': {'utterance': {'query': query, 'boost': 1}}}]
dsl = {
'query': {
'bool': {
"must": subitem_must,
"should": subitem_should,
}
}
}
begin_samples, rest = samples, []
hits = self.es.search(index=self.index, body=dsl, size=begin_samples)['hits']['hits']
for h in hits:
item = {
'score': h['_score'],
'utterance': h['_source']['utterance']
}
if item['utterance'] in query or 'http' in item['utterance']:
continue
else:
rest.append(item)
return rest

def search(self, query, samples=10, topic=None):
'''
Expand Down
13 changes: 8 additions & 5 deletions models/polyencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,11 @@ class BERTBiCompEncoder(nn.Module):
Set the different learning ratio
'''

def __init__(self, nhead, dim_feedforward, num_encoder_layers, dropout=0.1):
def __init__(self, nhead, dim_feedforward, num_encoder_layers, dropout=0.1, decay_ratio=0.5):
super(BERTBiCompEncoder, self).__init__()
self.ctx_encoder = BertEmbedding()
self.can_encoder = BertEmbedding()
self.decay_ratio = decay_ratio

encoder_layer = nn.TransformerEncoderLayer(
768,
Expand All @@ -233,7 +234,7 @@ def __init__(self, nhead, dim_feedforward, num_encoder_layers, dropout=0.1):
)
encoder_norm = nn.LayerNorm(768)
self.trs_encoder = nn.TransformerEncoder(
encoder_layer,
encoder_layer,
num_encoder_layers,
encoder_norm,
)
Expand All @@ -259,9 +260,9 @@ def predict(self, cid, rid, rid_mask):
],
dim=1,
) # [S, 2*E]
cross_rep = self.trs_encoder(
cross_rep = self.decay_ratio * torch.tanh(self.trs_encoder(
torch.relu(self.proj1(cross_rep).unsqueeze(1))
).squeeze(1)
)).squeeze(1)
cross_rep = self.layernorm(cross_rep + rid_rep) # [B, E]
# cid: [E]; rid: [B, E]
dot_product = torch.matmul(cid_rep, cross_rep.t()) # [B]
Expand Down Expand Up @@ -333,10 +334,12 @@ def __init__(self, multi_gpu, total_step, run_mode='train', local_rank=0, kb=Tru
'dmodel': model,
'num_encoder_layers': 2,
'dim_feedforward': 512,
'nhead': 8,
'nhead': 6,
'dropout': 0.1,
'max_len': 256,
'poly_m': 16,
# prevent the comparison information influence the original information of each response
'decay_ratio': 0.4,
}
self.vocab = BertTokenizer.from_pretrained(self.args['vocab_file'])
if model == 'no-compare':
Expand Down
2 changes: 1 addition & 1 deletion run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ elif [ $mode = 'train' ]; then
fi

gpu_ids=(${cuda//,/ })
CUDA_VISIBLE_DEVICES=$cuda python -m torch.distributed.launch --nproc_per_node=${#gpu_ids[@]} --master_addr 127.0.0.1 --master_port 29501 main.py \
CUDA_VISIBLE_DEVICES=$cuda python -m torch.distributed.launch --nproc_per_node=${#gpu_ids[@]} --master_addr 127.0.0.1 --master_port 29500 main.py \
--dataset $dataset \
--model $model \
--mode train \
Expand Down

0 comments on commit 2ccc91b

Please sign in to comment.