Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/castorini/birch
Browse files Browse the repository at this point in the history
  • Loading branch information
zeynepakkalyoncu committed Jun 13, 2019
2 parents 9f0c6ea + 82dfb06 commit ac54b63
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 14 deletions.
18 changes: 16 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,22 @@ tar -xzvf birch_data.tar.gz
## Inference

```
python src/main.py --experiment <qa_2cv, mb_2cv, qa_5cv, mb_5cv> --collection <robust04_2cv.csv, robust04_5cv.csv> --inference --model_path <models/saved.mb_3, models/saved.qa_2> --load_trained
python src/main.py --experiment <qa_2cv, mb_2cv, qa_5cv, mb_5cv> --data_path data --collection <robust04_2cv, robust04_5cv> --inference --model_path <models/saved.mb_3, models/saved.qa_2> --load_trained --batch_size <batch_size>
```

Note that this step takes a long time.
If you don't want to evaluate the pretrained models, you may skip to the next step and evaluate with our predictions under `data/predictions`.

## Evaluation

### BM25+RM3:

```
./eval_scripts/baseline.sh <path/to/anserini> <path/to/index> <2, 5>
```

### Sentence Evidence:

- Compute document score

Set the last argument to True if you want to tune the hyperparameters first.
Expand All @@ -52,7 +60,9 @@ To use the default hyperparameters, set to False.

- Evaluate with trec_eval

```./eval_scripts/eval.sh <qa_2cv, mb_2cv, qa_5cv, mb_5cv> <path/to/anserini> qrels.robust2004.txt```
```
./eval_scripts/eval.sh <bm25+rm3_2cv, qa_2cv, mb_2cv, bm25+rm3_5cv, qa_5cv, mb_5cv> <path/to/anserini> qrels.robust2004.txt
```


---
Expand Down Expand Up @@ -87,6 +97,10 @@ To use the default hyperparameters, set to False.

See this [paper](https://dl.acm.org/citation.cfm?id=3308781) for the exact fold settings.

### Replication Log

+ Results replicated by [@emmileaf](https://github.com/emmileaf) on 2019-06-10 (commit [`cc42b60`](https://github.com/castorini/birch/commit/cc42b60093090969c1d9b24cddd1257c1cad66df))

---

**How do I cite this work?**
Expand Down
20 changes: 20 additions & 0 deletions eval_scripts/baseline.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env bash

anserini_path=$1
index_path=$2
num_folds=$3

birch_path=$(pwd)
cd ${anserini_path}

if [ ${num_folds} == '5' ] ; then
folds_path="src/main/resources/fine_tuning/robust04-paper2-folds.json"
params_path="src/main/resources/fine_tuning/robust04-paper2-folds-map-params.json"
else
folds_path="src/main/resources/fine_tuning/robust04-paper1-folds.json"
params_path="src/main/resources/fine_tuning/robust04-paper1-folds-map-params.json"
fi

python3 src/main/python/fine_tuning/reconstruct_robus04_tuned_run.py --index ${index_path} --folds ${folds_path} --params ${params_path}
rm run.robust04.bm25+rm3.fold*
mv run.robust04.bm25+rm3.txt ${birch_path}/runs/run.bm25+rm3_${num_folds}cv.txt
17 changes: 11 additions & 6 deletions eval_scripts/eval.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,16 @@ qrels_file=$3

echo "Experiment: ${experiment}"

echo "1S:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.cv.a"
if [[ ${experiment} == *"bm25+rm3"* ]] ; then
echo "BM25+RM3:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.txt"
else
echo "1S:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.cv.a"

echo "2S:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.cv.ab"
echo "2S:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.cv.ab"

echo "SS:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.cv.abc"
echo "SS:"
${anserini_path}/eval/trec_eval.9.0.4/trec_eval -M1000 -m map -m P.20 "${anserini_path}/src/main/resources/topics-and-qrels/${qrels_file}" "runs/run.${experiment}.cv.abc"
fi
3 changes: 1 addition & 2 deletions src/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ def get_args():
parser.add_argument('--inference', action='store_true', default=False,
help='Evaluate model if True, use prediction files otherwise')
parser.add_argument('--model_path', default='models/saved.tmp', help='Path to pretrained model')
parser.add_argument('--batch_size', default=8, type=int,
help='[1, 8, 16, 32]')
parser.add_argument('--batch_size', default=64, type=int)
parser.add_argument('--local_model', default=None,
help='[None, path to local model file]')
parser.add_argument('--local_tokenizer', default=None,
Expand Down
5 changes: 2 additions & 3 deletions src/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from data import load_data


def test(args, predictions_path, model=None, tokenizer=None):
def test(args, datasets_path, predictions_path, model=None, tokenizer=None):
if model is None:
if args.load_trained:
epoch, arch, model, tokenizer, scores = load_checkpoint(
Expand All @@ -16,8 +16,7 @@ def test(args, predictions_path, model=None, tokenizer=None):
model, tokenizer = load_pretrained_model_tokenizer(base_model=args.local_model,
base_tokenizer=args.local_tokenizer)

test_dataset = load_data(args.data_path, args.collection,
args.batch_size, tokenizer)
test_dataset = load_data(datasets_path, args.collection, args.batch_size, tokenizer)

model.eval()
prediction_score_list, prediction_index_list, labels = [], [], []
Expand Down
2 changes: 1 addition & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def main():
datasets_path = os.path.join(args.data_path, 'datasets')

if inference:
scores = test(args, predictions_path)
scores = test(args, datasets_path, predictions_path)
print_scores(scores)
else:
folds_path = os.path.join(anserini_path, 'src', 'main', 'resources', 'fine_tuning', args.folds_file)
Expand Down

0 comments on commit ac54b63

Please sign in to comment.