Skip to content

Commit

Permalink
Add remove_dups argument for dense retrieval (#1663)
Browse files Browse the repository at this point in the history
  • Loading branch information
yogeswarl authored Sep 30, 2023
1 parent 902ac1f commit bdb9504
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion pyserini/search/faiss/_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def list_prebuilt_indexes():
"""Display information about available prebuilt indexes."""
get_dense_indexes_info()

def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, return_vector: bool = False) \
def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, remove_dups: bool = False, return_vector: bool = False) \
-> Union[List[DenseSearchResult], Tuple[np.ndarray, List[PRFDenseSearchResult]]]:
"""Search the collection.
Expand All @@ -451,6 +451,8 @@ def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, r
Number of hits to return.
threads : int
Maximum number of threads to use for intra-query search.
remove_dups : bool
Remove duplicate docids when writing final run output.
return_vector : bool
Return the results with vectors
Returns
Expand All @@ -477,6 +479,14 @@ def search(self, query: Union[str, np.ndarray], k: int = 10, threads: int = 1, r
distances, indexes = self.index.search(emb_q, k)
distances = distances.flat
indexes = indexes.flat
if remove_dups:
unique_docs = set()
results = list()
for score, idx in zip(distances, indexes):
if idx not in unique_docs:
unique_docs.add(idx)
results.append(DenseSearchResult(self.docids[idx],score))
return results
return [DenseSearchResult(self.docids[idx], score)
for score, idx in zip(distances, indexes) if idx != -1]

Expand Down

0 comments on commit bdb9504

Please sign in to comment.