Skip to content

Commit

Permalink
add splade query encode (#815)
Browse files Browse the repository at this point in the history
  • Loading branch information
MXueguang authored Oct 8, 2021
1 parent 12792dc commit 509bb5a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 0 deletions.
41 changes: 41 additions & 0 deletions docs/experiments-spladev2.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,47 @@ QueriesRanked: 6980

The final evaluation metric is very close to the one reported in the paper (0.368).

Alternatively, we can use one-the-fly query encoding.

First, download the model checkpoint from NAVER's github [repo](https://github.com/naver/splade/tree/main/weights/splade_max):
```bash
mkdir splade-distil-max
cd splade-distil-max
wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/pytorch_model.bin
wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/config.json
wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/special_tokens_map.json
wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/tokenizer_config.json
wget https://github.com/naver/splade/raw/main/weights/distilsplade_max/vocab.txt
cd ..
```

Then run retrieval with `--encoder splade-distil-max`

```bash
python -m pyserini.search --topics msmarco-passage-dev-subset \
--index indexes/lucene-index.msmarco-passage-distill-splade-max \
--encoder splade-distil-max \
--output runs/run.msmarco-passage-distill-splade-max.tsv \
--impact \
--hits 1000 --batch 36 --threads 12 \
--output-format msmarco
```

And then evaluate:

```bash
python -m pyserini.eval.msmarco_passage_eval msmarco-passage-dev-subset runs/run.msmarco-passage-distill-splade-max.tsv
```

The results should be as follows:

```
#####################
MRR @10: 0.3684321417201083
QueriesRanked: 6980
#####################
```

## Reproduction Log[*](reproducibility.md)

+ Results reproduced by [@lintool](https://github.com/lintool) on 2021-10-05 (commit [`58d286c`](https://github.com/castorini/pyserini/commit/58d286c3f9fe845e261c271f2a0f514462844d97))
Expand Down
36 changes: 36 additions & 0 deletions pyserini/encode/_splade.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np

from pyserini.encode import QueryEncoder


class SpladeQueryEncoder(QueryEncoder):
def __init__(self, model_name_or_path, tokenizer_name=None, device='cpu'):
self.device = device
self.model = AutoModelForMaskedLM.from_pretrained(model_name_or_path)
self.model.to(self.device)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name or model_name_or_path)
self.reverse_voc = {v: k for k, v in self.tokenizer.vocab.items()}

def encode(self, text, **kwargs):
max_length = 256 # hardcode for now
inputs = self.tokenizer([text], max_length=max_length, padding='longest',
truncation=True, add_special_tokens=True,
return_tensors='pt').to(self.device)
input_ids = inputs['input_ids']
input_attention = inputs['attention_mask']
batch_logits = self.model(input_ids)['logits']
batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits))
* input_attention.unsqueeze(-1), dim=1)
batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy()
return self._output_to_weight_dicts(batch_aggregated_logits)[0]

def _output_to_weight_dicts(self, batch_aggregated_logits):
to_return = []
for aggregated_logits in batch_aggregated_logits:
col = np.nonzero(aggregated_logits)[0]
weights = aggregated_logits[col]
d = {self.reverse_voc[k]: float(v) for k, v in zip(list(col), list(weights))}
to_return.append(d)
return to_return
3 changes: 3 additions & 0 deletions pyserini/search/_impact_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pyserini.pyclass import autoclass, JFloat, JArrayList, JHashMap, JString
from pyserini.util import download_prebuilt_index
from pyserini.encode import QueryEncoder, TokFreqQueryEncoder, UniCoilQueryEncoder, CachedDataQueryEncoder
from ..encode._splade import SpladeQueryEncoder

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -226,6 +227,8 @@ def _init_query_encoder_from_str(query_encoder):
return CachedDataQueryEncoder(query_encoder)
elif 'unicoil' in query_encoder.lower():
return UniCoilQueryEncoder(query_encoder)
elif 'splade' in query_encoder.lower():
return SpladeQueryEncoder(query_encoder)

@staticmethod
def _compute_idf(index_path):
Expand Down

0 comments on commit 509bb5a

Please sign in to comment.