Skip to content

Commit

Permalink
Add reproduction script for "End-to-End Retrieval with Learned Dense …
Browse files Browse the repository at this point in the history
…and Sparse Representations Using Lucene" (#2317)
  • Loading branch information
ArthurChen189 authored Dec 29, 2023
1 parent 2ebc11c commit 1ebe6dd
Show file tree
Hide file tree
Showing 5 changed files with 211 additions and 41 deletions.
56 changes: 17 additions & 39 deletions src/main/java/io/anserini/index/IndexInfo.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,65 +20,43 @@ public enum IndexInfo {
MSMARCO_V1_PASSAGE("msmarco-v1-passage",
"Lucene index of the MS MARCO V1 passage corpus. (Lucene 9)",
"lucene-index.msmarco-v1-passage.20221004.252b5e.tar.gz",
"lucene-index.msmarco-v1-passage.20221004.252b5e.README.md",
new String[] {
"https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-index.msmarco-v1-passage.20221004.252b5e.tar.gz" },
"c697b18c9a0686ca760583e615dbe450", "2170758938", "352316036", "8841823",
"2660824", false),
"c697b18c9a0686ca760583e615dbe450"),

CACM("cacm",
"Lucene index of the CACM corpus. (Lucene 9)",
"lucene-index.cacm.tar.gz",
new String[] {
"https://github.com/castorini/anserini-data/raw/master/CACM/lucene-index.cacm.20221005.252b5e.tar.gz" },
"cfe14d543c6a27f4d742fb2d0099b8e0",
"2347197",
"320968",
"3204",
"14363");
"cfe14d543c6a27f4d742fb2d0099b8e0"),

MSMARCO_V1_PASSAGE_COS_DPR_DISTIL("msmarco-v1-passage-cos-dpr-distil",
"Lucene index of the MS MARCO V1 passage corpus encoded by cos-DPR Distil. (Lucene 9)",
"lucene-hnsw.msmarco-v1-passage-cos-dpr-distil.20231124.9d3427.tar.gz",
new String[] {
"https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-hnsw.msmarco-v1-passage-cos-dpr-distil.20231124.9d3427.tar.gz" },
"7aa825e292a411abbe1585fb4d9f20ee"),

MSMARCO_V1_PASSAGE_SPLADE_PP_ED("msmarco-v1-passage-splade-pp-ed",
"Lucene impact index of the MS MARCO passage corpus encoded by SPLADE++ CoCondenser-EnsembleDistil. (Lucene 9)",
"lucene-index.msmarco-v1-passage-splade-pp-ed.20230524.a59610.tar.gz",
new String[] {
"https://rgw.cs.uwaterloo.ca/pyserini/indexes/lucene-index.msmarco-v1-passage-splade-pp-ed.20230524.a59610.tar.gz" },
"4b3c969033cbd017306df42ce134c395");

public final String indexName;
public final String description;
public final String filename;
public final String readme;
public final String[] urls;
public final String md5;
public final String size;
public final String totalTerms;
public final String totalDocs;
public final String totalUniqueTerms;
public final boolean downloaded;

// constructor with all 11 fields
IndexInfo(String indexName, String description, String filename, String readme, String[] urls, String md5,
String size, String totalTerms, String totalDocs, String totalUniqueTerms, boolean downloaded) {
this.indexName = indexName;
this.description = description;
this.filename = filename;
this.readme = readme;
this.urls = urls;
this.md5 = md5;
this.size = size;
this.totalTerms = totalTerms;
this.totalDocs = totalDocs;
this.totalUniqueTerms = totalUniqueTerms;
this.downloaded = downloaded;
}

// constructor with 9 fields
IndexInfo(String indexName, String description, String filename, String[] urls, String md5, String size,
String totalTerms, String totalDocs, String totalUniqueTerms) {
IndexInfo(String indexName, String description, String filename, String[] urls, String md5) {
this.indexName = indexName;
this.description = description;
this.filename = filename;
this.readme = "";
this.urls = urls;
this.md5 = md5;
this.size = size;
this.totalTerms = totalTerms;
this.totalDocs = totalDocs;
this.totalUniqueTerms = totalUniqueTerms;
this.downloaded = false;
}

public static boolean contains(String indexName) {
Expand Down
24 changes: 23 additions & 1 deletion src/main/java/io/anserini/search/HnswDenseSearcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import io.anserini.encoder.dense.DenseEncoder;
import io.anserini.index.Constants;
import io.anserini.search.query.VectorQueryGenerator;
import io.anserini.util.PrebuiltIndexHandler;

import org.apache.commons.lang3.time.DurationFormatUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
Expand All @@ -36,6 +38,8 @@
import javax.annotation.Nullable;
import java.io.Closeable;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.SortedMap;
Expand Down Expand Up @@ -79,8 +83,26 @@ public HnswDenseSearcher(Args args) {
// We might not be able to successfully create a reader for a variety of reasons, anything from path doesn't exist
// to corrupt index. Gather all possible exceptions together as an unchecked exception to make initialization and
// error reporting clearer.
Path indexPath = Path.of(args.index);
PrebuiltIndexHandler indexHandler = new PrebuiltIndexHandler(args.index);
if (!Files.exists(indexPath)) {
// it doesn't exist locally, we try to download it from remote
try {
indexHandler.initialize();
indexHandler.download();
indexPath = Path.of(indexHandler.decompressIndex());
} catch (IOException e) {
throw new RuntimeException("MD5 checksum does not match!");
} catch (Exception e) {
throw new IllegalArgumentException(String.format("\"%s\" does not appear to be a valid index.", args.index));
}
} else {
// if it exists locally, we use it
indexPath = Paths.get(args.index);
}

try {
this.reader = DirectoryReader.open(FSDirectory.open(Paths.get(args.index)));
this.reader = DirectoryReader.open(FSDirectory.open(indexPath));
} catch (IOException e) {
throw new IllegalArgumentException(String.format("\"%s\" does not appear to be a valid index.", args.index));
}
Expand Down
92 changes: 92 additions & 0 deletions src/main/python/e2e_sparse_dense_lucene/reproduction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Anserini: A toolkit for reproducible information retrieval research built on Lucene
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import yaml
from typing import Union, Dict, List, Optional, Any
import os
import subprocess
TOPIC_NAMES = ['msmarco-passage-dev-subset', 'dl19-passage', 'dl20-passage']
EVAL_CMD_MAP = {
'map': '-m map -c -l 2', # AP
'ndcg_cut_10': '-m ndcg_cut.10 -c', # nDCG@10
'recall_1000_msmarco': '-c -m recall.1000', # R@1000 for MS MARCO
'recall_1000': '-m recall.1000 -c -l 2', # R@1000
'recip_rank': '-c -M 10 -m recip_rank' # RR@10
}
TOPIC_EVAL_MAP = {
'msmarco-passage-dev-subset': ['recip_rank', 'recall_1000_msmarco'],
'dl19-passage': ['map', 'ndcg_cut_10', 'recall_1000'],
'dl20-passage': ['map', 'ndcg_cut_10', 'recall_1000']
}


def get_output_run_file_name(topic: str, name: str):
return f'runs/{topic}_{name}.txt'


def get_search_command(model_name: str, cmd_template: str, topics: List[str]):
outputs = [get_output_run_file_name(
topic_name, model_name) for topic_name in TOPIC_NAMES]

for topic, output in zip(topics, outputs):
cmd = cmd_template.format(topic=topic, output=output)
yield cmd


def get_eval_command(param: str, qrel: str, run_file: str, cmd_template: str):
cmd = cmd_template.format(
param=param, qrel=qrel, output=run_file)
yield cmd


def main(config):
# print all search commands
for model_name, model_config in config['collections'].items():
print("running model: ", model_name)
# # search
# for cmd in get_search_command(model_name, model_config['search_command'], model_config['topics']):
# p = subprocess.Popen(
# cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# stdout, stderr = p.communicate()
# if stderr:
# print(stderr.decode('utf-8'))

# eval
expected_results = model_config['results']
run_files = [get_output_run_file_name(
topic_name, model_name) for topic_name in TOPIC_NAMES]
eval_cmd = model_config['eval_command']
metric_precision = model_config['metric_precision']

for run_file, topic_name, qrel in zip(run_files, TOPIC_NAMES, model_config['qrels']):
for metric in TOPIC_EVAL_MAP[topic_name]:
for cmd in get_eval_command(EVAL_CMD_MAP[metric], qrel, run_file, eval_cmd):
p = subprocess.Popen(
cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = p.communicate()
stdout = [out.strip()
for out in stdout.decode('utf-8').split('\t')]
actual_result = round(float(stdout[-1]), metric_precision)
expected_result = expected_results[topic_name][metric]
assert actual_result == expected_result, f'{model_name} {topic_name} {metric} {actual_result} != {expected_result}, expected: {expected_results[topic_name]}'
print(
f"{topic_name} {metric} {actual_result} == {expected_result}")
print(f"{model_name} passed!")
print("="*50)


if __name__ == '__main__':
with open('src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
main(config)
78 changes: 78 additions & 0 deletions src/main/resources/e2e_sparse_dense_lucene/pre-encoded.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
---
collections:
bm25:
name: bm25
search_command: target/appassembler/bin/SearchCollection -index msmarco-v1-passage -topicReader TsvInt -topics {topic} -output {output} -bm25 -parallelism 12
topics:
- tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt
- tools/topics-and-qrels/topics.dl19-passage.txt
- tools/topics-and-qrels/topics.dl20.txt
qrels:
- tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt
- tools/topics-and-qrels/qrels.dl19-passage.txt
- tools/topics-and-qrels/qrels.dl20-passage.txt

eval_command: tools/eval/trec_eval.9.0.4/trec_eval {param} {qrel} {output}
results:
msmarco-passage-dev-subset:
recip_rank: 0.184
recall_1000_msmarco: 0.853
dl19-passage:
map: 0.301
ndcg_cut_10: 0.506
recall_1000: 0.750
dl20-passage:
map: 0.286
ndcg_cut_10: 0.480
recall_1000: 0.786
metric_precision: 3
cosdpr-distil:
name: cosdpr-distil
search_command: target/appassembler/bin/SearchHnswDenseVectors -index msmarco-v1-passage-cos-dpr-distil -topicReader TsvInt -topics {topic} -output {output} -generator VectorQueryGenerator -topicField title -threads 12 -hits 1000 -efSearch 1000 -encoder CosDprDistil
topics:
- tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt
- tools/topics-and-qrels/topics.dl19-passage.txt
- tools/topics-and-qrels/topics.dl20.txt
qrels:
- tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt
- tools/topics-and-qrels/qrels.dl19-passage.txt
- tools/topics-and-qrels/qrels.dl20-passage.txt
eval_command: tools/eval/trec_eval.9.0.4/trec_eval {param} {qrel} {output}
results:
msmarco-passage-dev-subset:
recip_rank: 0.389
recall_1000_msmarco: 0.975
dl19-passage:
map: 0.466
ndcg_cut_10: 0.725
recall_1000: 0.822
dl20-passage:
map: 0.487
ndcg_cut_10: 0.703
recall_1000: 0.852
metric_precision: 3
splade-pp-ed:
name: splade-pp-ed
search_command: target/appassembler/bin/SearchCollection -index msmarco-v1-passage-splade-pp-ed -topicReader TsvInt -topics {topic} -output {output} -impact -pretokenized -parallelism 12 -encoder SpladePlusPlusEnsembleDistil
topics:
- tools/topics-and-qrels/topics.msmarco-passage.dev-subset.txt
- tools/topics-and-qrels/topics.dl19-passage.txt
- tools/topics-and-qrels/topics.dl20.txt
qrels:
- tools/topics-and-qrels/qrels.msmarco-passage.dev-subset.txt
- tools/topics-and-qrels/qrels.dl19-passage.txt
- tools/topics-and-qrels/qrels.dl20-passage.txt
eval_command: tools/eval/trec_eval.9.0.4/trec_eval {param} {qrel} {output}
results:
msmarco-passage-dev-subset:
recip_rank: 0.383
recall_1000_msmarco: 0.983
dl19-passage:
map: 0.505
ndcg_cut_10: 0.731
recall_1000: 0.873
dl20-passage:
map: 0.500
ndcg_cut_10: 0.720
recall_1000: 0.900
metric_precision: 3
2 changes: 1 addition & 1 deletion src/test/java/io/anserini/index/PrebuiltIndexTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,6 @@ public void testUrls() {
// test number of prebuilt-indexes
@Test
public void testNumPrebuiltIndexes() {
assert IndexInfo.values().length == 2;
assert IndexInfo.values().length == 4;
}
}

0 comments on commit 1ebe6dd

Please sign in to comment.