Skip to content

Commit

Permalink
Bpr replication (#668)
Browse files Browse the repository at this point in the history
* add bpr replication results

* update bpr readme
  • Loading branch information
alexlimh authored Jun 24, 2021
1 parent 9a64568 commit 40e2c5d
Show file tree
Hide file tree
Showing 5 changed files with 325 additions and 13 deletions.
49 changes: 49 additions & 0 deletions docs/experiments-bpr.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Pyserini: Reproducing BPR Results

[Binary passage retriever](https://arxiv.org/abs/2106.00882) (BPR) is a two-stage ranking approach that represents the passages in both binary codes and dense vectors for memory efficiency and effectiveness.

We have replicated BPR's results and incorporated the technique into Pyserini.
To be clear, we started with model checkpoint and index releases in the official [BPR repo](https://github.com/studio-ousia/bpr) and did _not_ train the query and passage encoders from scratch.

This guide provides instructions to reproduce the BPR's results.
We cover only retrieval here; for end-to-end answer extraction, please see [this guide](https://github.com/castorini/pygaggle/blob/master/docs/experiments-dpr-reader.md) in our PyGaggle neural text ranking library. For more instructions, please see our [dense retrieval replication guide](https://github.com/castorini/pyserini/blob/master/docs/experiments-dpr.md).

## Summary

Here's how our results stack up against results reported in the paper using the BPR model (index 2.3GB + model 0.4GB):

| Dataset | Method | Top-20 (orig) | Top-20 (us)| Top-100 (orig) | Top-100 (us)|
|:------------|:--------------|--------------:|-----------:|---------------:|------------:|
| NQ | BPR | 77.9 | 77.9 | 85.7 | 85.7 |
| NQ | BPR w/o reranking | 76.5 | 76.0 | 84.9 | 85.0 |

## Natural Questions (NQ) with BPR

**DPR retrieval** with brute-force index:

```bash
$ python -m pyserini.dsearch --topics dpr-nq-test \
--index wikipedia-bpr-nq-hash \
--encoded-queries bpr-nq-test \
--output runs/run.bpr.rerank.nq-test.nq.hash.trec \
--rerank \
--hits 100 --binary-hits 1000 \
--batch-size 36 --threads 12
```

The option `--encoded-queries` specifies the use of encoded queries (i.e., queries that have already been converted into dense vectors and cached).

To evaluate, first convert the TREC output format to DPR's `json` format:

```bash
$ python -m pyserini.eval.convert_trec_run_to_dpr_retrieval_run --topics dpr-nq-test \
--index wikipedia-dpr \
--input runs/run.bpr.rerank.nq-test.nq.hash.trec \
--output runs/run.bpr.rerank.nq-test.nq.hash.json

$ python -m pyserini.eval.evaluate_dpr_retrieval --retrieval runs/run.bpr.rerank.nq-test.nq.hash.json --topk 20 100
Top20 accuracy: 0.779
Top100 accuracy: 0.857
```

## Reproduction Log[*](reproducibility.md)
6 changes: 3 additions & 3 deletions pyserini/dsearch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
# limitations under the License.
#

from ._dsearcher import DenseSearchResult, SimpleDenseSearcher, QueryEncoder, \
DprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AutoQueryEncoder
from ._dsearcher import DenseSearchResult, SimpleDenseSearcher, BinaryDenseSearcher, QueryEncoder, \
DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, TctColBertQueryEncoder, AnceQueryEncoder, AutoQueryEncoder
from ._model import AnceEncoder

__all__ = ['DenseSearchResult', 'SimpleDenseSearcher', 'QueryEncoder', 'DprQueryEncoder', 'DkrrDprQueryEncoder',
__all__ = ['DenseSearchResult', 'SimpleDenseSearcher', 'BinaryDenseSearcher', 'QueryEncoder', 'DprQueryEncoder', 'BprQueryEncoder', 'DkrrDprQueryEncoder',
'TctColBertQueryEncoder', 'AnceEncoder', 'AnceQueryEncoder', 'AutoQueryEncoder']
34 changes: 25 additions & 9 deletions pyserini/dsearch/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@

from tqdm import tqdm

from pyserini.dsearch import SimpleDenseSearcher, TctColBertQueryEncoder, \
QueryEncoder, DprQueryEncoder, DkrrDprQueryEncoder, AnceQueryEncoder, AutoQueryEncoder
from pyserini.dsearch import SimpleDenseSearcher, BinaryDenseSearcher, TctColBertQueryEncoder, \
QueryEncoder, DprQueryEncoder, BprQueryEncoder, DkrrDprQueryEncoder, AnceQueryEncoder, AutoQueryEncoder
from pyserini.query_iterator import get_query_iterator, TopicsFormat
from pyserini.output_writer import get_output_writer, OutputFormat

Expand Down Expand Up @@ -63,6 +63,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
return DkrrDprQueryEncoder(encoder_dir=encoder, device=device, prefix=prefix)
elif 'dpr' in encoder:
return DprQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)
elif 'bpr' in encoder:
return BprQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)
elif 'tct_colbert' in encoder:
return TctColBertQueryEncoder(encoder_dir=encoder, tokenizer_name=tokenizer_name, device=device)
elif 'ance' in encoder:
Expand All @@ -75,8 +77,12 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de

if encoded_queries:
if os.path.exists(encoded_queries):
return QueryEncoder(encoded_queries)
if 'bpr' in encoded_queries:
return BprQueryEncoder(encoded_query_dir=encoded_queries)
else:
return QueryEncoder(encoded_queries)
return QueryEncoder.load_encoded_queries(encoded_queries)

if topics_name in encoded_queries_map:
return QueryEncoder.load_encoded_queries(encoded_queries_map[topics_name])
raise ValueError(f'No encoded queries for topic {topics_name}')
Expand All @@ -87,6 +93,8 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
parser.add_argument('--topics', type=str, metavar='topic_name', required=True,
help="Name of topics. Available: msmarco-passage-dev-subset.")
parser.add_argument('--hits', type=int, metavar='num', required=False, default=1000, help="Number of hits.")
parser.add_argument('--binary-hits', type=int, metavar='num', required=False, default=1000, help="Number of binary hits.")
parser.add_argument("--rerank", action="store_true", help='whethere rerank bpr sparse results.')
parser.add_argument('--topics-format', type=str, metavar='format', default=TopicsFormat.DEFAULT.value,
help=f"Format of topics. Available: {[x.value for x in list(TopicsFormat)]}")
parser.add_argument('--output-format', type=str, metavar='format', default=OutputFormat.TREC.value,
Expand All @@ -109,14 +117,22 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
topics = query_iterator.topics

query_encoder = init_query_encoder(args.encoder, args.tokenizer, args.topics, args.encoded_queries, args.device, args.query_prefix)

kwargs = {}
if os.path.exists(args.index):
# create searcher from index directory
searcher = SimpleDenseSearcher(args.index, query_encoder)
if 'bpr' in args.encoded_queries or 'bpr' in args.encoder:
kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank)
searcher = BinaryDenseSearcher(args.index, query_encoder)
else:
searcher = SimpleDenseSearcher(args.index, query_encoder)
else:
# create searcher from prebuilt index name
searcher = SimpleDenseSearcher.from_prebuilt_index(args.index, query_encoder)

if 'bpr' in args.encoded_queries or 'bpr' in args.encoder:
kwargs = dict(binary_k=args.binary_hits, rerank=args.rerank)
searcher = BinaryDenseSearcher.from_prebuilt_index(args.index, query_encoder)
else:
searcher = SimpleDenseSearcher.from_prebuilt_index(args.index, query_encoder)

if not searcher:
exit()

Expand All @@ -137,15 +153,15 @@ def init_query_encoder(encoder, tokenizer_name, topics_name, encoded_queries, de
batch_topic_ids = list()
for index, (topic_id, text) in enumerate(tqdm(query_iterator, total=len(topics.keys()))):
if args.batch_size <= 1 and args.threads <= 1:
hits = searcher.search(text, args.hits)
hits = searcher.search(text, args.hits, **kwargs)
results = [(topic_id, hits)]
else:
batch_topic_ids.append(str(topic_id))
batch_topics.append(text)
if (index + 1) % args.batch_size == 0 or \
index == len(topics.keys()) - 1:
results = searcher.batch_search(
batch_topics, batch_topic_ids, args.hits, args.threads)
batch_topics, batch_topic_ids, args.hits, threads=args.threads, **kwargs)
results = [(id_, results[id_]) for id_ in batch_topic_ids]
batch_topic_ids.clear()
batch_topics.clear()
Expand Down
173 changes: 172 additions & 1 deletion pyserini/dsearch/_dsearcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ def encode(self, query: str):
return super().encode(query)


class BprQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, tokenizer_name: str = None,
encoded_query_dir: str = None, device: str = 'cpu'):
self.has_model = False
self.has_encoded_query = False
if encoded_query_dir:
self.embedding = self._load_embeddings(encoded_query_dir)
self.has_encoded_query = True

if encoder_dir:
self.device = device
self.model = DPRQuestionEncoder.from_pretrained(encoder_dir)
self.model.to(self.device)
self.tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(tokenizer_name or encoder_dir)
self.has_model = True
if (not self.has_model) and (not self.has_encoded_query):
raise Exception('Neither query encoder model nor encoded queries provided. Please provide at least one')

def encode(self, query: str):
if self.has_model:
input_ids = self.tokenizer(query, return_tensors='pt')
input_ids.to(self.device)
embeddings = self.model(input_ids["input_ids"]).pooler_output.detach().cpu()
dense_embeddings = embeddings.numpy()
sparse_embeddings = self.convert_to_binary_code(embeddings).numpy()
return {'dense':dense_embeddings.flatten(), 'sparse':sparse_embeddings.flatten()}
else:
return super().encode(query)

def convert_to_binary_code(self, input_repr: torch.Tensor):
return input_repr.new_ones(input_repr.size()).masked_fill_(input_repr < 0, -1.0)

@staticmethod
def _load_embeddings(encoded_query_dir):
df = pd.read_pickle(os.path.join(encoded_query_dir, 'embedding.pkl'))
ret = {}
for text, dense, sparse in zip(df['text'].tolist(), df['dense_embedding'].tolist(), df['sparse_embedding'].tolist()):
ret[text] = {'dense': dense, 'sparse': sparse}
return ret


class DkrrDprQueryEncoder(QueryEncoder):

def __init__(self, encoder_dir: str = None, encoded_query_dir: str = None, device: str = 'cpu', prefix: str = "question:"):
Expand Down Expand Up @@ -271,7 +313,8 @@ def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], preb
self.index, self.docids = self.load_index(index_dir)
self.dimension = self.index.d
self.num_docs = self.index.ntotal
assert self.num_docs == len(self.docids)

assert self.docids is None or self.num_docs == len(self.docids)
if prebuilt_index_name:
sparse_index = get_sparse_index(prebuilt_index_name)
self.ssearcher = SimpleSearcher.from_prebuilt_index(sparse_index)
Expand Down Expand Up @@ -408,3 +451,131 @@ def load_docids(docid_path: str) -> List[str]:
docids = [line.rstrip() for line in id_f.readlines()]
id_f.close()
return docids


class BinaryDenseSearcher(SimpleDenseSearcher):
"""Simple Searcher for binary-dense representation
Parameters
----------
index_dir : str
Path to faiss index directory.
"""

def __init__(self, index_dir: str, query_encoder: Union[QueryEncoder, str], prebuilt_index_name: Optional[str] = None):
super().__init__(index_dir, query_encoder, prebuilt_index_name)

def search(self, query: str, k: int = 10, binary_k: int = 100, rerank: bool = True, threads: int = 1) -> List[DenseSearchResult]:
"""Search the collection.
Parameters
----------
query : str
query text
k : int
Number of hits to return at second stage.
binary_k : int
Number of hits to return at first stage.
rerank: bool
Whether to use dense repr to rerank the binary ranking results.
threads : int
Maximum number of threads to use for intra-query search.
Returns
-------
List[DenseSearchResult]
List of search results.
"""
ret = self.query_encoder.encode(query)
dense_emb_q = ret['dense']
sparse_emb_q = ret['sparse']
assert len(dense_emb_q) == self.dimension
assert len(sparse_emb_q) == self.dimension

dense_emb_q = dense_emb_q.reshape((1, len(dense_emb_q)))
sparse_emb_q = sparse_emb_q.reshape((1, len(sparse_emb_q)))
faiss.omp_set_num_threads(threads)
distances, indexes = self.binary_dense_search(k, binary_k, rerank, dense_emb_q, sparse_emb_q)
distances = distances.flat
indexes = indexes.flat
return [DenseSearchResult(str(idx), score)
for score, idx in zip(distances, indexes) if idx != -1]

def batch_search(self, queries: List[str], q_ids: List[str], k: int = 10, binary_k: int = 100, \
rerank: bool = True, threads: int = 1) -> Dict[str, List[DenseSearchResult]]:
"""
Parameters
----------
queries : List[str]
List of query texts
q_ids : List[str]
List of corresponding query ids.
k : int
Number of hits to return.
binary_k : int
Number of hits to return at first stage.
rerank: bool
Whether to use dense repr to rerank the binary ranking results.
threads : int
Maximum number of threads to use.
Returns
-------
Dict[str, List[DenseSearchResult]]
Dictionary holding the search results, with the query ids as keys and the corresponding lists of search
results as the values.
"""
dense_q_embs = []
sparse_q_embs = []
for q in queries:
ret = self.query_encoder.encode(q)
dense_q_embs.append(ret['dense'])
sparse_q_embs.append(ret['sparse'])
dense_q_embs = np.array(dense_q_embs)
sparse_q_embs = np.array(sparse_q_embs)
n, m = dense_q_embs.shape
assert m == self.dimension
faiss.omp_set_num_threads(threads)
D, I = self.binary_dense_search(k, binary_k, rerank, dense_q_embs, sparse_q_embs)
return {key: [DenseSearchResult(str(idx), score)
for score, idx in zip(distances, indexes) if idx != -1]
for key, distances, indexes in zip(q_ids, D, I)}

def binary_dense_search(self, k, binary_k, rerank, dense_emb_q, sparse_emb_q):
num_queries = dense_emb_q.shape[0]
sparse_emb_q = np.packbits(np.where(sparse_emb_q > 0, 1, 0)).reshape(num_queries, -1)

if not rerank:
distances, indexes = self.index.search(sparse_emb_q, k)
else:
raw_index = self.index.index
_, indexes = raw_index.search(sparse_emb_q, binary_k)
sparse_emb_p = np.vstack(
[np.unpackbits(raw_index.reconstruct(int(id_))) for id_ in indexes.reshape(-1)]
)
sparse_emb_p = sparse_emb_p.reshape(
dense_emb_q.shape[0], binary_k, dense_emb_q.shape[1]
)
sparse_emb_p = sparse_emb_p.astype(np.float32)
sparse_emb_p = sparse_emb_p * 2 - 1
distances = np.einsum("ijk,ik->ij", sparse_emb_p, dense_emb_q)
sorted_indices = np.argsort(-distances, axis=1)

indexes = indexes[np.arange(num_queries)[:, None], sorted_indices]
indexes = np.array([self.index.id_map.at(int(id_)) for id_ in indexes.reshape(-1)], dtype=np.int)
indexes = indexes.reshape(num_queries, -1)[:, :k]
distances = distances[np.arange(num_queries)[:, None], sorted_indices][:, :k]
return distances, indexes

def load_index(self, index_dir: str):
index_path = os.path.join(index_dir, 'index')
index = faiss.read_index_binary(index_path)
return index, None

@staticmethod
def _init_encoder_from_str(encoder):
encoder = encoder.lower()
if 'bpr' in encoder:
return BprQueryEncoder(encoder_dir=encoder)
else:
raise NotImplementedError
Loading

0 comments on commit 40e2c5d

Please sign in to comment.