Skip to content

Commit

Permalink
updated searcher and cv scripts (castorini#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
emmileaf authored and zeynepakkalyoncu committed Jun 27, 2019
1 parent dca1cd3 commit 1c8041b
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 78 deletions.
11 changes: 10 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,16 @@ wget https://zenodo.org/record/3241945/files/birch_data.tar.gz
tar -xzvf birch_data.tar.gz
```

# Training
## Dataset

```
python src/robust04_cv.py --anserini_path <path/to/anserini> --index_path <path/to/index> --cv_fold <2, 5>
```

This step retrieves documents to depth 1000 for each query, and splits them into sentences to generate folds data. You may skip to the next step and and use the downloaded data under `data/datasets`.

## Training

```
python src/main.py --mode training --collection mb --qrels_file qrels.microblog.txt --batch_size <batch_size> --eval_steps <eval_steps> --learning_rate <learning_rate> --num_train_epochs <num_train_epochs> --device cuda
```
Expand Down
5 changes: 3 additions & 2 deletions src/core_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
qid2docid = get_relevant_docids(fqrel)
qid2text = get_query(ftopic, collection=collection)

searcher = build_searcher(k1=0.9, b=0.4, index_path=index_path, rm3=True)
search_document(searcher, qid2docid, qid2text, output_fn, collection, K=1000)
docsearch = Searcher(anserini_path)
searcher = docsearch.build_searcher(k1=0.9, b=0.4, index_path=index_path, rm3=True)
docsearch.search_document(searcher, qid2docid, qid2text, output_fn, collection, K=1000)
42 changes: 24 additions & 18 deletions src/robust04_cv.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,52 @@
from utils import *
from searcher import *
from shutil import copyfileobj
from args import get_args

if __name__ == '__main__':
args, _ = get_args()
collection = args.collection
anserini_path = args.anserini_path
index_path = args.index_path
output_fn = args.output_path
folds_path = os.path.join(anserini_path, 'src', 'main', 'resources', 'fine_tuning', args.folds_file)
cv_fold = args.cv_fold

fqrel = os.path.join(anserini_path, 'src', 'main', 'resources', 'topics-and-qrels', 'qrels.' + collection + '.txt')
ftopic = os.path.join(anserini_path, 'src', 'main', 'resources', 'topics-and-qrels', 'topics.' + collection + '.301-450.601-700.txt')

qid2docid = get_relevant_docids(fqrel)
qid2text = get_query(ftopic, collection='robust04')

with open(os.path.join(folds_path)) as f:
folds = json.load(f)
output_fn = os.path.join(args.data_path, 'datasets', "robust04_" + str(cv_fold) + "cv")

# TODO: dynamic params
if cv_fold == '5':
folds_file = "robust04-paper2-folds.json"
params = ["0.9 0.5 47 9 0.30",
"0.9 0.5 47 9 0.30",
"0.9 0.5 47 9 0.30",
"0.9 0.5 47 9 0.30",
"0.9 0.5 26 8 0.30"]
else:
folds_file = "robust04-paper1-folds.json"
params = ["0.9 0.5 50 17 0.20",
"0.9 0.5 26 8 0.30"]

folds_path = os.path.join(anserini_path, 'src', 'main', 'resources', 'fine_tuning', folds_file)

fqrel = os.path.join(anserini_path, 'src', 'main', 'resources', 'topics-and-qrels', 'qrels.robust2004.txt')
ftopic = os.path.join(anserini_path, 'src', 'main', 'resources', 'topics-and-qrels', 'topics.robust04.301-450.601-700.txt')

qid2docid = get_relevant_docids(fqrel)
qid2text = get_query(ftopic, collection='robust04')

with open(os.path.join(folds_path)) as f:
folds = json.load(f)

folder_idx = 1
docsearch = Searcher(anserini_path)
for topics, param in zip(folds, params):
print(folder_idx)
# Extract each parameter
k1, b, fb_terms, fb_docs, original_query_weight = map(float, param.strip().split())
searcher = build_searcher(k1=k1, b=b, fb_terms=fb_terms, fb_docs=fb_docs,
original_query_weight=original_query_weight,
index_path=index_path, rm3=True)
search_document(searcher, qid2docid, qid2text,
output_fn + str(folder_idx),
'robust04', 1000, topics)
searcher = docsearch.build_searcher(k1=k1, b=b, fb_terms=fb_terms, fb_docs=fb_docs,
original_query_weight=original_query_weight,index_path=index_path, rm3=True)
docsearch.search_document(searcher, qid2docid, qid2text, output_fn + str(folder_idx),
'robust04', 1000, topics)

folder_idx += 1

with open(output_fn + ".csv", 'w') as outfile:
for infile in [output_fn + str(n) for n in range(1, folder_idx)]:
copyfileobj(open(infile), outfile)
124 changes: 67 additions & 57 deletions src/searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,69 +3,79 @@
from utils import parse_doc_from_index, clean_html, tokenizer, MAX_INPUT_LENGTH, chunk_sent

import jnius_config
# TODO: make path dynamic
jnius_config.set_classpath("../Anserini/target/anserini-0.4.1-SNAPSHOT-fatjar.jar")
import glob

try:
from jnius import autoclass
except KeyError:
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-8-oracle'
from jnius import autoclass
class Searcher:
def __init__(self, anserini_path):
paths = glob.glob(os.path.join(anserini_path, 'target', 'anserini-*-fatjar.jar'))
if not paths:
raise Exception("No matching jar file for Anserini found in target")

JString = autoclass('java.lang.String')
JSearcher = autoclass('io.anserini.search.SimpleSearcher')
latest = max(paths, key=os.path.getctime)
jnius_config.set_classpath(latest)

from jnius import autoclass
self.JString = autoclass('java.lang.String')
self.JSearcher = autoclass('io.anserini.search.SimpleSearcher')
self.qidx = 1
self.didx = 1

def build_searcher(k1=0.9, b=0.4, fb_terms=10, fb_docs=10, original_query_weight=0.5,
index_path="index/lucene-index.robust04.pos+docvectors+rawdocs", rm3=False):
searcher = JSearcher(JString(index_path))
searcher.setBM25Similarity(k1, b)
if not rm3:
searcher.setDefaultReranker()
else:
searcher.setRM3Reranker(fb_terms, fb_docs, original_query_weight, False)
return searcher
def reset_idx(self):
self.qidx = 1
self.didx = 1

def build_searcher(self, k1=0.9, b=0.4, fb_terms=10, fb_docs=10, original_query_weight=0.5,
index_path="index/lucene-index.robust04.pos+docvectors+rawdocs", rm3=False):
searcher = self.JSearcher(self.JString(index_path))
searcher.setBM25Similarity(k1, b)
if not rm3:
searcher.setDefaultReranker()
else:
searcher.setRM3Reranker(fb_terms, fb_docs, original_query_weight, False)
return searcher

def search_document(searcher, qid2docid, qid2text, output_fn, collection='robust04', K=1000, topics=None):
qidx, didx = 1, 1
with open(output_fn, 'w') as out:
if 'core' in collection:
# Robust04 provides CV topics
topics = qid2text
for qid in topics:
text = qid2text[qid]
hits = searcher.search(JString(text), K)
for i in range(len(hits)):
sim = hits[i].score
docno = hits[i].docid
label = 1 if qid in qid2docid and docno in qid2docid[qid] else 0
content = hits[i].content
if collection == 'core18':
content_json = json.loads(content)
content = ''
for each in content_json['contents']:
if each is not None and 'content' in each.keys():
content += '{}\n'.format(each['content'])
if collection == 'robust04':
content = parse_doc_from_index(content)
clean_content = clean_html(content, collection=collection)
tokenized_content = tokenizer.tokenize(clean_content)
sentid = 0
for sent in tokenized_content:
# Split sentence if it's longer than BERT's maximum input length
if len(sent.strip().split()) > MAX_INPUT_LENGTH:
seq_list = chunk_sent(sent, MAX_INPUT_LENGTH)
for seq in seq_list:

def search_document(self, searcher, qid2docid, qid2text, output_fn, collection='robust04', K=1000, topics=None):
output_dir = os.path.dirname(output_fn)
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(output_fn, 'w', encoding="utf-8") as out:
if 'core' in collection:
# Robust04 provides CV topics
topics = qid2text
for qid in topics:
text = qid2text[qid]
hits = searcher.search(self.JString(text), K)
for i in range(len(hits)):
sim = hits[i].score
docno = hits[i].docid
label = 1 if qid in qid2docid and docno in qid2docid[qid] else 0
content = hits[i].content
if collection == 'core18':
content_json = json.loads(content)
content = ''
for each in content_json['contents']:
if each is not None and 'content' in each.keys():
content += '{}\n'.format(each['content'])
if collection == 'robust04':
content = parse_doc_from_index(content)
clean_content = clean_html(content, collection=collection)
tokenized_content = tokenizer.tokenize(clean_content)
sentid = 0
for sent in tokenized_content:
# Split sentence if it's longer than BERT's maximum input length
if len(sent.strip().split()) > MAX_INPUT_LENGTH:
seq_list = chunk_sent(sent, MAX_INPUT_LENGTH)
for seq in seq_list:
sentno = docno + '_' + str(sentid)
out.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(label, round(float(sim), 16), text, seq, qid, sentno, self.qidx, self.didx))
out.flush()
sentid += 1
self.didx += 1
else:
sentno = docno + '_' + str(sentid)
out.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(label, round(float(sim), 11), text, seq, qid, sentno, qidx, didx))
out.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(label, round(float(sim), 16), text, sent, qid, sentno, self.qidx, self.didx))
out.flush()
sentid += 1
didx += 1
else:
sentno = docno + '_' + str(sentid)
out.write('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n'.format(label, round(float(sim), 11), text, sent, qid, sentno, qidx, didx))
out.flush()
sentid += 1
didx += 1
qidx += 1
self.didx += 1
self.qidx += 1

0 comments on commit 1c8041b

Please sign in to comment.