-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
27f99b0
commit 4d23c98
Showing
70 changed files
with
4,898 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.log | ||
*.pyc | ||
.DS_Store | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,83 @@ | ||
# mmd | ||
This repository contains the code used for our submission. Will be available soon. | ||
|
||
This repository contains the Pytorch imlepementation used for our submission | ||
|
||
- [A Knowledge-Grounded Multimodal Search-Based Conversational Agent](https://arxiv.org/pdf/1810.11954.pdf) SCAI@EMNLP 2018 | ||
- [Improving Context Modelling in Multimodal Dialogue Generation](https://arxiv.org/pdf/1810.11955.pdf) INLG 2018 | ||
|
||
## Install | ||
|
||
We used Python 2.7 and Pytorch 0.3 (0.3.1.post2) for our implementation. | ||
|
||
We strongly encourage to use [conda](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html) or virtual environment. | ||
|
||
The repo provides both yaml and requirements (use either of the file exported through conda/pip) to create the conda env. | ||
|
||
``` | ||
conda create -n mmd python=2 | ||
source activate mmd | ||
pip install -r requirements.txt | ||
``` | ||
|
||
or | ||
|
||
``` | ||
conda env create -n mmd -f mmd.yml | ||
``` | ||
|
||
## Data | ||
|
||
Download the data set from the [link](https://drive.google.com/drive/folders/1JOGHzideeAsmykMUQD3z7aGFg-M4QlE2) provided by Shah et al. (More information provided by the authors in their [webpage](https://amritasaha1812.github.io/MMD/download/) and [repo](https://github.com/amritasaha1812/MMD_Code)). Extract it in data folder | ||
|
||
## Pre-processing | ||
|
||
Run the bash scripts in `data_processing` folder to convert from json transcripts to the actual data for the model. | ||
Running as `bash dialogue_data.sh` will call the python files in the same folder. Please manipulate the file paths accordingly in the shell script. Run the other shell scripts to extract KB related data for the model. | ||
|
||
## Training | ||
|
||
Run `bash train_and_translate.sh` for training as well as generating on the test set. This script at the end also computes the final metrics on the test set. | ||
|
||
Structure: | ||
|
||
``` | ||
train_and_translate.sh --> train.sh --> train.py --> modules/models.py | ||
train_and_translate.sh --> translate.sh --> translate.py --> modules/models.py | ||
train_and_translate.sh --> evaluation/nlg_eval.sh --> nlg_eval.py | ||
``` | ||
This structure allows us to enter file paths only once while training and evaluating. Follow the steps on the screen to provide different parameters. My suggestion is to use Pycharm to better understand the structure and modularity. | ||
|
||
Tune `IS_TRAIN` parameter if you have actually trained the model and want to just generate on final test set. | ||
|
||
Hyperparameter tuning can be done by creating different config versions in `config` folder. Sample config versions are provided for reference. | ||
|
||
## Metrics | ||
|
||
For evaluation, we used the scripts provided by [nlg-eval](https://github.com/Maluuba/nlg-eval). (Sharma et al.) | ||
|
||
In particular, we used their [functional api](https://github.com/Maluuba/nlg-eval#functional-api-for-the-entire-corpus) for getting evaluation metrics. Install the dependencies in the same or different conda environment. | ||
|
||
`nlg_eval.sh` or `compute_metrics.sh` both bash script can be called depending upon the use case. See the structure above. nlg_eval.py is the main file using the funcitonal api. | ||
|
||
## Citation | ||
|
||
If you use this work, please cite it as | ||
``` | ||
@inproceedings{agarwal2018improving, | ||
title={Improving Context Modelling in Multimodal Dialogue Generation}, | ||
author={Agarwal, Shubham and Du{\v{s}}ek, Ond{\v{r}}ej and Konstas, Ioannis and Rieser, Verena}, | ||
booktitle={Proceedings of the 11th International Conference on Natural Language Generation}, | ||
pages={129--134}, | ||
year={2018} | ||
} | ||
@inproceedings{agarwal2018knowledge, | ||
title={A Knowledge-Grounded Multimodal Search-Based Conversational Agent}, | ||
author={Agarwal, Shubham and Du{\v{s}}ek, Ond{\v{r}}ej and Konstas, Ioannis and Rieser, Verena}, | ||
booktitle={Proceedings of the 2018 EMNLP Workshop SCAI: The 2nd International Workshop on Search-Oriented Conversational AI}, | ||
pages={59--66}, | ||
year={2018} | ||
} | ||
``` | ||
|
||
Feel free to fork and contribute to this work. Please raise a PR or any related issues. Will be happy to help. Thanks. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
{ | ||
"training": { | ||
"seed": 100, | ||
"optimizer": "adam", | ||
"clip_grad": 3, | ||
"lr": 0.0004, | ||
"lr_decay": 0, | ||
"use_cuda": true, | ||
"num_epochs": 15, | ||
"save_every": 1, | ||
"log_every": 1000, | ||
"plot_every": 4, | ||
"evaluate_every": 2 | ||
}, | ||
"data": { | ||
"batch_size": 64, | ||
"context_size": 2, | ||
"pad_id": 0, | ||
"sys_start_id":4, | ||
"sys_end_id":5, | ||
"unk_id":1, | ||
"user_start_id":2, | ||
"user_end_id":3, | ||
"start_id":2, | ||
"end_id":3, | ||
"image_rep_size":4096 | ||
}, | ||
"model": { | ||
"src_emb_dim": 512, | ||
"tgt_emb_dim": 512, | ||
"enc_hidden_size": 512, | ||
"dec_hidden_size": 512, | ||
"context_hidden_size": 512, | ||
"bidirectional_enc": true, | ||
"bidirectional_context": false, | ||
"image_in_size": 20480, | ||
"num_enc_layers": 1, | ||
"num_dec_layers": 1, | ||
"num_context_layers": 1, | ||
"dropout_enc": 0.3, | ||
"dropout_dec": 0.3, | ||
"dropout_context": 0.3, | ||
"max_decode_len": 20, | ||
"enc_type": "GRU", | ||
"dec_type": "GRU", | ||
"context_type": "GRU", | ||
"use_attention": false, | ||
"non_linearity": "tanh", | ||
"decode_function": "softmax" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
{ | ||
"training": { | ||
"seed": 100, | ||
"optimizer": "adam", | ||
"clip_grad": 5, | ||
"lr": 0.0005, | ||
"lr_decay": 0.3, | ||
"use_cuda": true, | ||
"num_epochs": 30, | ||
"save_every": 5, | ||
"log_every": 1000, | ||
"plot_every": 4, | ||
"evaluate_every": 5 | ||
}, | ||
"data": { | ||
"batch_size": 64, | ||
"context_size": 2, | ||
"pad_id": 0, | ||
"sys_start_id":4, | ||
"sys_end_id":5, | ||
"unk_id":1, | ||
"user_start_id":2, | ||
"user_end_id":3, | ||
"start_id":2, | ||
"end_id":3, | ||
"image_rep_size":4096 | ||
}, | ||
"model": { | ||
"src_emb_dim": 512, | ||
"tgt_emb_dim": 512, | ||
"enc_hidden_size": 512, | ||
"dec_hidden_size": 512, | ||
"context_hidden_size": 512, | ||
"bidirectional_enc": true, | ||
"bidirectional_context": false, | ||
"image_in_size": 20480, | ||
"num_enc_layers": 1, | ||
"num_dec_layers": 1, | ||
"num_context_layers": 1, | ||
"dropout_enc": 0.8, | ||
"dropout_dec": 0.8, | ||
"dropout_context": 0.8, | ||
"max_decode_len": 20, | ||
"enc_type": "GRU", | ||
"dec_type": "GRU", | ||
"context_type": "GRU", | ||
"use_attention": true, | ||
"non_linearity": "tanh", | ||
"decode_function": "softmax" | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import json | ||
import cPickle as pkl | ||
import argparse | ||
|
||
def main(args): | ||
# labels_file = os.path.join(args.data_dir, "test_state_labels.txt") | ||
# pred_file = args.pred | ||
# target_file = args.target | ||
|
||
with open(args.labels_file, 'r') as state_file, \ | ||
open(args.pred, 'r') as pred_file, \ | ||
open(args.target, 'r') as target_file, \ | ||
open(args.context, 'r') as context_file: | ||
for state_line, pred, target, context in zip(state_file, pred_file, \ | ||
target_file, context_file): | ||
# state = ','.join(state_line.split(',')[:-1]) | ||
state_dir = str(state_line.strip()) | ||
state_out_dir = os.path.join(args.results_dir, state_dir) | ||
if not os.path.exists(state_out_dir): | ||
os.makedirs(state_out_dir) | ||
state_pred_file = os.path.join(state_out_dir, "pred_"+ str(args.checkpoint)+"_"+args.beam+".txt") | ||
state_target_file = os.path.join(state_out_dir,"test_tokenized.txt") | ||
state_context_file = os.path.join(state_out_dir,"test_context_text.txt") | ||
with open(state_pred_file, 'a+') as fp: | ||
fp.write(pred) | ||
with open(state_target_file, 'a+') as fp: | ||
fp.write(target) | ||
with open(state_context_file, 'a+') as fp: | ||
fp.write(context) | ||
|
||
if __name__=="__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-labels_file', type=str, default='./data/vocab.pkl') | ||
parser.add_argument('-results_dir', type=str, default='./data/vocab.pkl') | ||
parser.add_argument('-pred', type=str, help='prediction file path') | ||
parser.add_argument('-target', type=str, help='target file path') | ||
parser.add_argument('-context', type=str, help='target file path') | ||
parser.add_argument('-checkpoint', type=str, help='target file path') | ||
parser.add_argument('-beam', type=str, help='target file path') | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
Download the data here. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import re | ||
import numpy as np | ||
from pandas.io.json import json_normalize | ||
import json | ||
import pandas as pd | ||
import time | ||
import os | ||
import glob | ||
import argparse | ||
import string | ||
import unicodedata | ||
|
||
def read_json_to_df(file_path): | ||
# df = pd.read_json(path_or_buf=file_path,orient='records',lines=True) | ||
df = pd.read_json(path_or_buf=file_path,orient='records') | ||
return df | ||
|
||
def flatten_json_column(df,col_name='utterance'): | ||
temp_df = json_normalize(df[col_name].tolist()) | ||
df.reset_index(drop=True,inplace=True) | ||
df = df.join(temp_df).drop(col_name, axis=1) | ||
return df | ||
|
||
def get_column_stats(df,column_name,to_dict = False): | ||
if to_dict: | ||
return df[column_name].value_counts().to_dict() | ||
else: | ||
return df[column_name].value_counts() | ||
|
||
def findFiles(path): | ||
return glob.glob(path) | ||
|
||
def get_column_names(df): | ||
return df.columns.values | ||
|
||
def get_value_row_column(df,index,column_name): | ||
return df.get_value(index,column_name) | ||
|
||
def flatten_dic_column(df,col_name): | ||
df_new= pd.concat([df.drop([col_name], axis=1), df[col_name].apply(pd.Series)], axis=1) | ||
return df_new | ||
|
||
def append_df(df, df_to_append, ignore_index=True): | ||
new_df = df.append(df_to_append,ignore_index=ignore_index) | ||
return new_df | ||
|
||
def write_df_to_csv(df,outputFilePath): | ||
df.to_csv(outputFilePath, sep=str('\t'),quotechar=str('"'), index=False, header=True) | ||
|
||
def write_df_to_json(df,outputFilePath): | ||
df.to_json(path_or_buf=outputFilePath,orient='records',lines=True) | ||
|
||
def save_df_pickle(df,output_file): | ||
df.to_pickle(output_file) | ||
|
||
def get_unique_column_values(df,col_name): | ||
""" Returns unique values """ | ||
return df[col_name].unique() | ||
|
||
def count_unique(df, col_name): | ||
""" Count unique values in a df column """ | ||
count = df[col_name].nunique() | ||
return count | ||
|
||
def nlp_stats(nlg): | ||
nlg = nlg.encode('ascii',errors='ignore').lower().strip().strip('.?!') #Remove last '.?!' | ||
stats = {} | ||
stats['num_chars'] = len(nlg) | ||
sentences = nlg.replace('?','.').replace('!',' ').split('.') | ||
# replace '? and !' with '.'' ## [:-1] to discard last?? | ||
stats['num_sent'] = len(sentences) | ||
words = nlg.replace('.','').replace('!','').replace('?','').split() | ||
stats['num_words'] = len(words) | ||
return stats | ||
|
||
|
||
def main(args): | ||
all_files = glob.glob(args.file_dir + "/*.json") | ||
start = time.time() | ||
stats_df = pd.DataFrame() | ||
global_df = pd.DataFrame() | ||
global_user_df = pd.DataFrame() | ||
global_system_df = pd.DataFrame() | ||
print("Reading files") | ||
index = 0 | ||
for dialogue_json in all_files: | ||
index+=1 | ||
df = read_json_to_df(dialogue_json) | ||
df_flatten = flatten_json_column(df) | ||
df_flatten = df_flatten[['speaker','type', 'question-type','question-subtype','nlg','images']] | ||
# Assign filename | ||
df_flatten = df_flatten.assign(filename=dialogue_json) | ||
# Analysis | ||
df_flatten['num_images']=df_flatten['images'].apply(lambda x: len(x) if (type(x) is list) else None) | ||
# replace nans; create new df | ||
df = df_flatten.replace(np.nan, '', regex=True) | ||
# create state column | ||
df_flatten['state'] = df[['type', 'question-type','question-subtype']].apply(lambda x: ','.join(x), axis=1) | ||
df_flatten['nlp_stats'] = df_flatten['nlg'].apply(lambda x: nlp_stats(x) if (type(x) is unicode) else None) | ||
df_flatten = flatten_dic_column(df_flatten,'nlp_stats') | ||
# Flags | ||
df_flatten['is_image']=df_flatten['images'].apply(lambda x: 1 if (type(x) is list) else 0) | ||
df_flatten['is_nlg'] = df_flatten['nlg'].apply(lambda x: 1 if (type(x) is unicode) else 0) | ||
df_flatten['is_multimodal'] = df_flatten['is_nlg'] + df_flatten['is_image'] -1 # text + image -1 | ||
# Subset | ||
user_df = df_flatten.loc[df_flatten['speaker'] == 'user'] | ||
system_df = df_flatten.loc[df_flatten['speaker'] == 'system'] | ||
# Analytics | ||
image_turns = df_flatten['is_image'].sum() | ||
nlg_turns = df_flatten['is_nlg'].sum() | ||
multimodal_turns = df_flatten['is_multimodal'].sum() | ||
total_turns = df_flatten.shape[0] | ||
user_turns = user_df.shape[0] | ||
sys_turns = system_df.shape[0] | ||
user_nlg_turns = user_df['is_nlg'].sum() | ||
sys_nlg_turns = system_df['is_nlg'].sum() | ||
# summarized utterance df | ||
local_data = {'filename':dialogue_json, 'total_turns':total_turns, 'image_turns':image_turns, | ||
'nlg_turns':nlg_turns, 'multimodal_turns':multimodal_turns, 'user_turns':user_turns, | ||
'sys_turns':sys_turns, 'user_nlg_turns':user_nlg_turns, 'sys_nlg_turns':sys_nlg_turns} | ||
local_df = pd.DataFrame(data=local_data, index=[index]) | ||
# Append DF | ||
stats_df = append_df(stats_df,local_df,ignore_index=False) | ||
global_df = append_df(global_df, df_flatten) | ||
global_user_df = append_df(global_user_df, user_df) | ||
global_system_df = append_df(global_system_df, system_df) | ||
print("Writing files") | ||
write_df_to_json(global_df, args.output_file_json) | ||
save_df_pickle(global_df, args.output_file_pkl) | ||
write_df_to_json(global_user_df, args.output_user_file_json) | ||
save_df_pickle(global_user_df, args.output_user_file_pkl) | ||
write_df_to_json(global_system_df, args.output_sys_file_json) | ||
save_df_pickle(global_system_df, args.output_sys_file_pkl) | ||
write_df_to_json(stats_df, args.stats_file_json) | ||
save_df_pickle(stats_df, args.stats_file_pkl) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('-file_dir', help='Input file directory path') | ||
parser.add_argument('-output_file_json', help='Output file path') | ||
parser.add_argument('-output_file_pkl', help='Output file path') | ||
parser.add_argument('-output_user_file_json', help='Output file path') | ||
parser.add_argument('-output_user_file_pkl', help='Output file path') | ||
parser.add_argument('-output_sys_file_json', help='Output file path') | ||
parser.add_argument('-output_sys_file_pkl', help='Output file path') | ||
parser.add_argument('-stats_file_json', help='Output file path') | ||
parser.add_argument('-stats_file_pkl', help='Output file path') | ||
args = parser.parse_args() | ||
main(args) | ||
|
Oops, something went wrong.