Skip to content

Commit

Permalink
Tue 20 Oct 2020 11:52:30 PM CST
Browse files Browse the repository at this point in the history
  • Loading branch information
gmftbyGMFTBY committed Oct 20, 2020
1 parent 3fd8e1e commit f84062c
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 131 deletions.
4 changes: 4 additions & 0 deletions data/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ mode=$1
if [ $mode = 'init_es' ]; then
# init the ElasticSearch and restore the retrieval database
python process_data.py --mode insert
elif [ $mode = 'init_topic_guided' ]; then
# python word_graph.py --mode graph
# generate the words by frequecy from the corpus
python word_graph.py --mode word
elif [ $mode = 'init_gen' ]; then
# init and create the whole generative dataset
python process_data.py --mode generative
Expand Down
133 changes: 71 additions & 62 deletions data/word_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def parser_args():
parser = argparse.ArgumentParser(description='wordnet parameters')
parser.add_argument('--weight_threshold', type=float, default=0.6)
parser.add_argument('--topn', type=int, default=10)
parser.add_argument('--mode', type=str, default='graph')
return parser.parse_args()

def load_stopwords():
Expand All @@ -67,33 +68,32 @@ def load_stopwords():
data = [i for i in data if i.strip()]
return data

def collect_wordlist_from_corpus(path, topn=500000):
def collect_wordlist_from_corpus(path, topn=50000):
'''only save the tokens that length from 2 to 4'''
cutter = LAC(mode='seg')
cutter = LAC(mode='lac')
with open(path) as f:
# data = f.read().split('\n\n')
# data = [i.split('\n') for i in data]
data = json.load(f)['train']
data = random.sample(data, 2000000)
x_data = []
for i in data:
x_data.append([''.join(j.split()) for j in i])
data = x_data
print(f'[!] load the dataset from {path} over ...')
data = f.read().split('\n\n')
data = [i.split('\n') for i in data if i.strip()]
data = random.sample(data, 1000000)
print(f'[!] load the dataset from {path}({len(data)}) over ...')
batch_size, words_collector = 512, Counter()
pbar = tqdm(range(0, len(data), batch_size))
for idx in pbar:
dialogs = data[idx:idx+batch_size]
dialogs = [' '.join(i) for i in dialogs]
words = chain(*cutter.run(dialogs))
words = [w for w in words if filter(w)]
words_collector.update(words)
rest = cutter.run(dialogs)
collector = []
for words, tags in rest:
for word, tag in zip(words, tags):
if filter(word, tag):
collector.append(word)
words_collector.update(collector)
pbar.set_description(f'[!] collect words: {len(words_collector)}')
words = [w for w, _ in words_collector.most_common(topn)]
print(f'[!] {len(words_collector)} -> {len(words)}')
return words

def filter(word):
def filter(word, tag):
def isChinese():
for ch in word:
if not '\u4e00' <= ch <= '\u9fff':
Expand All @@ -120,15 +120,16 @@ def Special():
return False
return True
def CheckTag():
# ipdb.set_trace()
word_, pos = lac.run(word)
if len(word_) != 1:
if tag in set(['n', 'nz', 'nw', 'v', 'vn', 'a', 'ad', 'an', 'ORG', 'PER', 'LOC']):
return True
else:
return False
if pos[0] in set(['n', 'nz', 'nw', 'v', 'vn', 'a', 'ad', 'an', 'ORG', 'PER', 'LOC']):
def InW2V():
if word in w2v:
return True
else:
return False
return isChinese() and HaveDigital() and Length() and HaveAlpha() and Special() and CheckTag()
return isChinese() and HaveDigital() and Length() and HaveAlpha() and Special() and CheckTag() and InW2V()

def write_new_w2v(words, path):
with open(path, 'w') as f:
Expand Down Expand Up @@ -162,45 +163,53 @@ def retrieval_filter(words, samples=64):
args = parser_args()
args = vars(args)

lac = LAC(mode='lac')
chatter = ESChat('retrieval_database')
stopwords = load_stopwords()

if not os.path.exists('chinese_w2v_base.txt'):
# 1)
w2v = KeyedVectors.load_word2vec_format('chinese_w2v.txt', binary=False)
print(f'[!] load the word2vec from chinese_w2v.txt')
# 2)
wordlist = w2v.index2word
new_wordlist = [word for word in tqdm(wordlist) if filter(word)]
print(f'[!] squeeze the wordlist from {len(wordlist)} to {len(new_wordlist)}')
# stop words remove
new_wordlist_ = list(set(new_wordlist) - set(stopwords))
print(f'[!] squeeze the wordlist from {len(new_wordlist)} to {len(new_wordlist_)}')
# retrieval check and remove
new_wordlist_2, batch_size = [], 256
for idx in tqdm(range(0, len(new_wordlist_), batch_size)):
words = new_wordlist_[idx:idx+batch_size]
for word, rest in zip(words, retrieval_filter(words)):
if rest:
new_wordlist_2.append(word)
print(f'[!] squeeze the wordlist from {len(new_wordlist_)} to {len(new_wordlist_2)}')
# 3)
write_new_w2v(new_wordlist_2, 'chinese_w2v_base.txt')
print(f'[!] write the new w2v into chinese_w2v_base.txt')
# 4)
w2v = KeyedVectors.load_word2vec_format('chinese_w2v_base.txt', binary=False)
print(f'[!] load the new word2vec from chinese_w2v_base.txt')
# 5)
if not os.path.exists('wordnet.pkl'):
graph = nx.Graph()
graph.add_nodes_from(w2v.index2word)
for word in tqdm(w2v.index2word):
neighbors = w2v.most_similar(word, topn=args['topn'])
graph.add_weighted_edges_from([(word, n, 1 - w) for n, w in neighbors if 1 - w < args['weight_threshold']])
with open('wordnet.pkl', 'wb') as f:
pickle.dump(graph, f)
print(f'[!] save the word net into wordnet.pkl')
else:
with open('wordnet.pkl', 'rb') as f:
graph = pickle.load(f)
if args['mode'] == 'word':
# default use the LCCC dataset
print(f'[!] make sure you already run the graph mode')
w2v = KeyedVectors.load_word2vec_format('chinese_w2v_base.txt', binary=False)
words = collect_wordlist_from_corpus('LCCC/train.txt', topn=20000)
with open('topic_words.pkl', 'wb') as f:
pickle.dump(words, f)
elif args['mode'] == 'graph':
lac = LAC(mode='lac')
chatter = ESChat('retrieval_database')
stopwords = load_stopwords()

if not os.path.exists('chinese_w2v_base.txt'):
# 1)
w2v = KeyedVectors.load_word2vec_format('chinese_w2v.txt', binary=False)
print(f'[!] load the word2vec from chinese_w2v.txt')
# 2)
wordlist = w2v.index2word
new_wordlist = [word for word in tqdm(wordlist) if filter(word)]
print(f'[!] squeeze the wordlist from {len(wordlist)} to {len(new_wordlist)}')
# stop words remove
new_wordlist_ = list(set(new_wordlist) - set(stopwords))
print(f'[!] squeeze the wordlist from {len(new_wordlist)} to {len(new_wordlist_)}')
# retrieval check and remove
new_wordlist_2, batch_size = [], 256
for idx in tqdm(range(0, len(new_wordlist_), batch_size)):
words = new_wordlist_[idx:idx+batch_size]
for word, rest in zip(words, retrieval_filter(words)):
if rest:
new_wordlist_2.append(word)
print(f'[!] squeeze the wordlist from {len(new_wordlist_)} to {len(new_wordlist_2)}')
# 3)
write_new_w2v(new_wordlist_2, 'chinese_w2v_base.txt')
print(f'[!] write the new w2v into chinese_w2v_base.txt')
# 4)
w2v = KeyedVectors.load_word2vec_format('chinese_w2v_base.txt', binary=False)
print(f'[!] load the new word2vec from chinese_w2v_base.txt')
# 5)
if not os.path.exists('wordnet.pkl'):
graph = nx.Graph()
graph.add_nodes_from(w2v.index2word)
for word in tqdm(w2v.index2word):
neighbors = w2v.most_similar(word, topn=args['topn'])
graph.add_weighted_edges_from([(word, n, 1 - w) for n, w in neighbors if 1 - w < args['weight_threshold']])
with open('wordnet.pkl', 'wb') as f:
pickle.dump(graph, f)
print(f'[!] save the word net into wordnet.pkl')
else:
with open('wordnet.pkl', 'rb') as f:
graph = pickle.load(f)
2 changes: 1 addition & 1 deletion ideas/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* 逻辑性Logicality: 判断句子的上下文逻辑是否连贯(句子级),负采样训练[句子重复,句子删除,句子打乱顺序]

9. 检索式方法的粗筛,Q-A匹配比Q-Q匹配好很多
10. 复现PloyEncoder和RocketQA(batch内的负采样方法掌握)
10. 复现PloyEncoder和RocketQA(batch内的负采样方法掌握),把triplenet的思路加到bi-encoder的检索式对话里面,提一种新的方法,试图建模层次信息,复用每一句话的embedding。或者也可以借鉴hierarchical transformer

## 生成式

Expand Down
5 changes: 1 addition & 4 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,7 @@ def process_utterances(self, topic, msgs, max_len=0, context=True):
inpt_ids = [torch.LongTensor(i[0]) for i in collection]
token_type_ids = [torch.LongTensor(i[1]) for i in collection]

try:
inpt_ids = pad_sequence(inpt_ids, batch_first=True, padding_value=self.args['pad'])
except:
ipdb.set_trace()
inpt_ids = pad_sequence(inpt_ids, batch_first=True, padding_value=self.args['pad'])
token_type_ids = pad_sequence(token_type_ids, batch_first=True, padding_value=self.args['pad'])
attn_mask_index = inpt_ids.nonzero().tolist()
attn_mask_index_x, attn_mask_index_y = [i[0] for i in attn_mask_index], [i[1] for i in attn_mask_index]
Expand Down
68 changes: 33 additions & 35 deletions models/bert_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def __init__(self, multi_gpu, run_mode='train', lang='zh', kb=True, local_rank=0
self.args['talk_sample'] = talk_samples
self.lac = LAC(mode='lac')

def reset(self):
self.history = []

def wrap_utterances(self, context, max_len=0):
'''context is a list of string, which contains the dialog history'''
context, response = ' [SEP] '.join(context[:-1]), context[-1]
Expand Down Expand Up @@ -423,54 +426,40 @@ def __init__(self, multi_gpu, run_mode='train', lang='zh', kb=True, local_rank=0
self.w2v = gensim.models.KeyedVectors.load_word2vec_format(
'data/chinese_w2v_base.txt', binary=False
)
self.cutter = LAC(mode='seg')

def reset(self, target, source):
self.args['target'], self.args['current_node'] = target, [source]
self.args['target'], self.args['current_node'] = target, source
self.topic_history, self.history = [[source]], []
print(f'[! Reset the KG target] source: {source}; target: {target}')

def get_cluster(self, start_nodes, n, size=5):
def obtain_keywords(self, utterance):
'''select the keyword that most similar to the current_node as the keyword in the human response'''
return [i for i in self.cutter.run(utterance) if i in self.w2v]

def get_cluster(self, start_nodes, size=5):
'''similarity is the average between the target and the current node'''
candidates = []
for start_node in start_nodes:
nodes = self.w2v.most_similar(start_node, topn=self.args['cluster_width'])
candidates.extend([i[0] for i in nodes])
candidates = list(set(candidates))
candidates.extend(
self.w2v.most_similar(start_node, topn=self.args['cluster_width'])
)
candidates = [i[0] for i in candidates]
candidates = list(set(candidates) - set(chain(*self.topic_history)))
if self.args['target'] in candidates:
return [self.args['target']]
target_similarity = [(self.w2v.similarity(self.args['target'], node), node) for node in candidates]
start_similarity = []
for node in candidates:
p = []
for start_node in start_nodes:
p.append(self.w2v.similarity(start_node, node))
start_similarity.append((np.average(p), node))
similarity = [(0.6 * t + 0.4 * s, node) for (t, node), (s, node) in zip(target_similarity, start_similarity)]
similarity = [(self.w2v.similarity(self.args['target'], node), node) for node in candidates]
# 不应该这么算start similarity,这个相似度的目的是为了保持平滑性,这里可以参考2019 ACL Target-guided open-domain dialog system
# 1. PMI; 2. 网络预测
similarity = sorted(similarity, key=lambda x: x[0], reverse=True)
#
chosen = []
while len(chosen) < size:
chosen.append(similarity[0][1])
new_similarity, similarity = similarity[1:], []
for weight, node in new_similarity:
p = []
for word in chosen:
p.append(self.w2v.similarity(word, node))
p = 0.3 * np.mean(p)
similarity.append((weight + p, node))
similarity = sorted(similarity, key=lambda x: x[0], reverse=True)
return chosen
return [i[1] for i in similarity[:size]]

def move_on_kg(self):
'''judge whether meet the end'''
if self.args['target'] in self.topic_history[-1]:
return
self.args['current_node'] = self.get_cluster(
self.args['current_node'],
self.args['cluster_width'],
self.args['current_node'],
size=self.args['num_candidate'],
)
self.topic_history.append(self.args['current_node'])

@torch.no_grad()
def talk(self, msgs):
Expand All @@ -486,17 +475,18 @@ def talk(self, msgs):
# 3) post ranking with multiple current topic words
output = torch.argsort(output, descending=True)
for i in output:
flag = False
flag, chosen_word = False, None
for word in self.args['current_node']:
if word in utterances[i.item()]:
item, flag = i, True
chosen_word = word
break
if flag:
break
else:
item = 0
msg = utterances[item]
return msg
return msg, chosen_word

def get_res(self, data):
'''
Expand All @@ -513,13 +503,21 @@ def get_res(self, data):
'''
if len(data['msgs']) > 0:
# 1) move
response = data['msgs'][-1]['msg']
keywords = self.obtain_keywords(response)
if self.args['current_node']:
self.args['current_node'] = list(set(keywords + [self.args['current_node']]))
else:
self.args['current_node'] = list(set(keywords))
self.move_on_kg()
# 2) obtain the responses based on the next_node
msgs = [i['msg'] for i in data['msgs']]
msgs = ' [SEP] '.join(msgs)
res = self.talk(msgs)
res, chosen_word = self.talk(msgs)
self.topic_history.append(self.args['current_node'])
self.args['current_node'] = chosen_word
else:
res = self.searcher.talk('', topic=self.args['current_node'])
res = self.searcher.talk('', topic=[self.args['current_node']])
self.history.append(res)
return res

Expand Down
8 changes: 8 additions & 0 deletions models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,14 @@ class ESChat:
def __init__(self, index_name, kb=True):
self.es = Elasticsearch(http_auth=('elastic', 'elastic123'))
self.index = index_name
self.es.indices.put_settings(
index=self.index,
body={
'index': {
'max_result_window': 500000,
}
}
)

def search(self, query, samples=10, topic=None):
'''
Expand Down
7 changes: 3 additions & 4 deletions run_self_play.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ CUDA_VISIBLE_DEVICES=$1 python self_play.py \
--multi_gpu $1 \
--lang zh \
--mode test \
--history_length 2 \
--min_topic_length 6 \
--max_topic_length 7 \
--talk_samples 128 | tee rest/self_play.txt
--history_length 3 \
--recoder rest/self_play.txt \
--talk_samples 256
Loading

0 comments on commit f84062c

Please sign in to comment.