Skip to content

Commit

Permalink
Add regression for HNSW in Lucene using cos DPR distil (castorini#2106)
Browse files Browse the repository at this point in the history
  • Loading branch information
lintool authored Apr 20, 2023
1 parent 84d8407 commit 910821a
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 26 deletions.
8 changes: 2 additions & 6 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,15 @@
</program>
<program>
<mainClass>io.anserini.index.IndexHnswDenseVectors</mainClass>
<id>IndexDenseVectors</id>
<id>IndexHnswDenseVectors</id>
</program>
<program>
<mainClass>io.anserini.search.SearchCollection</mainClass>
<id>SearchCollection</id>
</program>
<program>
<mainClass>io.anserini.search.SearchHnswDenseVectors</mainClass>
<id>SearchDenseVectors</id>
</program>
<program>
<mainClass>io.anserini.search.SearchMsmarco</mainClass>
<id>SearchMsmarco</id>
<id>SearchHnswDenseVectors</id>
</program>
<program>
<mainClass>io.anserini.search.SimpleSearcher</mainClass>
Expand Down
64 changes: 45 additions & 19 deletions src/main/python/run_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,12 @@
]

INDEX_COMMAND = 'target/appassembler/bin/IndexCollection'
INDEX_HNSW_COMMAND = 'target/appassembler/bin/IndexHnswDenseVectors'

INDEX_STATS_COMMAND = 'target/appassembler/bin/IndexReaderUtils'

SEARCH_COMMAND = 'target/appassembler/bin/SearchCollection'
SEARCH_HNSW_COMMAND = 'target/appassembler/bin/SearchHnswDenseVectors'


def is_close(a, b, rel_tol=1e-09, abs_tol=0.0):
Expand Down Expand Up @@ -111,8 +115,13 @@ def construct_indexing_command(yaml_data, args):
if not os.path.exists('indexes'):
os.makedirs('indexes')

if yaml_data['collection_class'] == 'JsonDenseVectorCollection':
root_cmd = INDEX_HNSW_COMMAND
else:
root_cmd = INDEX_COMMAND

index_command = [
INDEX_COMMAND,
root_cmd,
'-collection', yaml_data['collection_class'],
'-generator', yaml_data['generator_class'],
'-threads', str(threads),
Expand All @@ -131,7 +140,7 @@ def construct_runfile_path(corpus, id, model_name):
def construct_search_commands(yaml_data):
ranking_commands = [
[
SEARCH_COMMAND,
SEARCH_HNSW_COMMAND if 'VectorQueryGenerator' in model['params'] else SEARCH_COMMAND,
'-index', construct_index_path(yaml_data),
'-topics', os.path.join('tools/topics-and-qrels', topic_set['path']),
'-topicreader', topic_set['topic_reader'] if 'topic_reader' in topic_set and topic_set['topic_reader'] else yaml_data['topic_reader'],
Expand All @@ -142,6 +151,7 @@ def construct_search_commands(yaml_data):
]
return ranking_commands


def construct_convert_commands(yaml_data):
converting_commands = [
[
Expand All @@ -157,6 +167,7 @@ def construct_convert_commands(yaml_data):
]
return converting_commands


def evaluate_and_verify(yaml_data, dry_run):
fail_str = '\033[91m[FAIL]\033[0m '
ok_str = ' [OK] '
Expand All @@ -183,9 +194,19 @@ def evaluate_and_verify(yaml_data, dry_run):
eval_out = out.strip().split(metric['separator'])[metric['parse_index']]
expected = round(model['results'][metric['metric']][i], metric['metric_precision'])
actual = round(float(eval_out), metric['metric_precision'])
result_str = 'expected: {0:.4f} actual: {1:.4f} - metric: {2:<8} model: {3} topics: {4}'.format(
expected, actual, metric['metric'], model['name'], topic_set['id'])
if is_close(expected, actual):

# For HNSW, we only print to third digit
if 'VectorQueryGenerator' in model['params']:
result_str = 'expected: {0:.3f} actual: {1:.3f} - metric: {2:<8} model: {3} topics: {4}'.format(
expected, actual, metric['metric'], model['name'], topic_set['id'])
else:
result_str = 'expected: {0:.4f} actual: {1:.4f} - metric: {2:<8} model: {3} topics: {4}'.format(
expected, actual, metric['metric'], model['name'], topic_set['id'])

# For inverted indexes, we expect scores to match precisely.
# For HNSW, be more tolerant.
if is_close(expected, actual) or \
('VectorQueryGenerator' in model['params'] and is_close(expected, actual, abs_tol=0.006)):
logger.info(ok_str + result_str)
else:
if args.lucene8 and is_close_lucene8(expected, actual):
Expand All @@ -207,10 +228,12 @@ def run_search(cmd):
logger.info(' '.join(cmd))
call(' '.join(cmd), shell=True)


def run_convert(cmd):
logger.info(' '.join(cmd))
call(' '.join(cmd), shell=True)


# https://gist.github.com/leimao/37ff6e990b3226c2c9670a2cd1e4a6f5
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
Expand Down Expand Up @@ -336,20 +359,23 @@ def download_url(url, save_dir, local_filename=None, md5=None, force=False, verb
# Verify index statistics.
if args.verify:
logger.info('='*10 + ' Verifying Index ' + '='*10)
index_utils_command = [INDEX_STATS_COMMAND, '-index', construct_index_path(yaml_data), '-stats']
verification_command = ' '.join(index_utils_command)
logger.info(verification_command)
if not args.dry_run:
out = check_output(' '.join(index_utils_command)).decode('utf-8').split('\n')
for line in out:
stat = line.split(':')[0]
if stat in yaml_data['index_stats']:
value = int(line.split(':')[1])
if value != yaml_data['index_stats'][stat]:
print('{}: expected={}, actual={}'.format(stat, yaml_data['index_stats'][stat], value))
assert value == yaml_data['index_stats'][stat]
logger.info(line)
logger.info('Index statistics successfully verified!')
if yaml_data['collection_class'] == 'JsonDenseVectorCollection':
logger.info('Skipping verification step for HNSW dense indexes.')
else:
index_utils_command = [INDEX_STATS_COMMAND, '-index', construct_index_path(yaml_data), '-stats']
verification_command = ' '.join(index_utils_command)
logger.info(verification_command)
if not args.dry_run:
out = check_output(' '.join(index_utils_command)).decode('utf-8').split('\n')
for line in out:
stat = line.split(':')[0]
if stat in yaml_data['index_stats']:
value = int(line.split(':')[1])
if value != yaml_data['index_stats'][stat]:
print('{}: expected={}, actual={}'.format(stat, yaml_data['index_stats'][stat], value))
assert value == yaml_data['index_stats'][stat]
logger.info(line)
logger.info('Index statistics successfully verified!')

# Search and verify results.
if args.search:
Expand Down
60 changes: 60 additions & 0 deletions src/main/resources/regression/msmarco-passage-cos-dpr-distil.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
---
corpus: msmarco-passage
corpus_path: collections/msmarco/msmarco-passage-cos-dpr-distil/

index_path: indexes/lucene-hnsw.msmarco-passage-cos-dpr-distil/
collection_class: JsonDenseVectorCollection
generator_class: LuceneDenseVectorDocumentGenerator
index_threads: 16
index_options: -M 16 -efC 100

metrics:
- metric: AP@1000
command: tools/eval/trec_eval.9.0.4/trec_eval
params: -c -m map
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false
- metric: RR@10
command: tools/eval/trec_eval.9.0.4/trec_eval
params: -c -M 10 -m recip_rank
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false
- metric: R@100
command: tools/eval/trec_eval.9.0.4/trec_eval
params: -c -m recall.100
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false
- metric: R@1000
command: tools/eval/trec_eval.9.0.4/trec_eval
params: -c -m recall.1000
separator: "\t"
parse_index: 2
metric_precision: 4
can_combine: false

topic_reader: JsonIntVector
topics:
- name: "[MS MARCO Passage: Dev](https://github.com/microsoft/MSMARCO-Passage-Ranking)"
id: dev
path: topics.msmarco-passage.dev-subset.cos-dpr-distil.jsonl.gz
qrel: qrels.msmarco-passage.dev-subset.txt

models:
- name: cos-dpr-distil
display: cosDPR-distil
params: -querygenerator VectorQueryGenerator -topicfield vector -threads 16 -hits 1000 -efSearch 1000
results:
AP@1000:
- 0.392
RR@10:
- 0.387
R@100:
- 0.900
R@1000:
- 0.970
2 changes: 1 addition & 1 deletion tools

0 comments on commit 910821a

Please sign in to comment.