Skip to content

Commit

Permalink
add time counter (#665)
Browse files Browse the repository at this point in the history
  • Loading branch information
MXueguang authored Jun 24, 2021
1 parent 40e2c5d commit d31e2e6
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions pyserini/vsearch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import argparse
import json
import time
from tqdm import tqdm

from pyserini.vsearch import SimpleVectorSearcher
Expand Down Expand Up @@ -60,20 +61,25 @@
# support trec and msmarco format only for now
output_writer = get_output_writer(output_path, OutputFormat(args.output_format), max_hits=args.hits, tag=tag)

search_time = 0
with output_writer:
batch_topic_vectors = list()
batch_topic_ids = list()
for index, (topic_id, vec) in enumerate(tqdm(zip(topic_ids, topic_vectors))):
if args.batch_size <= 1 and args.threads <= 1:
start = time.time()
hits = searcher.search(vec, args.hits)
search_time += time.time() - start
results = [(topic_id, hits)]
else:
batch_topic_ids.append(str(topic_id))
batch_topic_vectors.append(vec)
if (index + 1) % args.batch_size == 0 or \
index == len(topic_ids) - 1:
start = time.time()
results = searcher.batch_search(
batch_topic_vectors, batch_topic_ids, args.hits, args.threads)
search_time += time.time() - start
results = [(id_, results[id_]) for id_ in batch_topic_ids]
batch_topic_ids.clear()
batch_topic_vectors.clear()
Expand All @@ -84,3 +90,5 @@
output_writer.write(topic, tie_breaker(hits))

results.clear()

print(f'Search {len(topic_ids)} topics in {search_time} seconds')

0 comments on commit d31e2e6

Please sign in to comment.