From 27931daad08a249a67eb4db9611d3fe36021c304 Mon Sep 17 00:00:00 2001 From: AileenLin Date: Sat, 9 Sep 2023 21:14:14 -0400 Subject: [PATCH] Fix bug in SPLADE on-the-fly encoding with PyTorch re: #1625 and add test case (#1626) --- pyserini/encode/_splade.py | 2 +- tests/test_encoder.py | 11 +++++++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/pyserini/encode/_splade.py b/pyserini/encode/_splade.py index 6f866fe37..23312c6d5 100644 --- a/pyserini/encode/_splade.py +++ b/pyserini/encode/_splade.py @@ -25,7 +25,7 @@ def encode(self, text, max_length=256, **kwargs): batch_aggregated_logits, _ = torch.max(torch.log(1 + torch.relu(batch_logits)) * input_attention.unsqueeze(-1), dim=1) batch_aggregated_logits = batch_aggregated_logits.cpu().detach().numpy() - raw_weights = self._output_to_weight_dicts(batch_token_ids, batch_weights) + raw_weights = self._output_to_weight_dicts(batch_aggregated_logits) return self._get_encoded_query_token_wight_dicts(raw_weights)[0] def _output_to_weight_dicts(self, batch_aggregated_logits): diff --git a/tests/test_encoder.py b/tests/test_encoder.py index ef85453a8..073d72569 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -233,6 +233,17 @@ def test_onnx_encode_unicoil(self): temp_object.close() del temp_object + + temp_object1 = LuceneImpactSearcher(f'{self.index_dir}lucene9-index.cacm', 'naver/splade-cocondenser-ensembledistil') + + # this function will never be called in _impact_searcher, here to check quantization correctness + results = temp_object1.encode("here is a test") + self.assertEqual(results.get("here"), 156) + self.assertEqual(results.get("a"), 31) + self.assertEqual(results.get("test"), 149) + + temp_object1.close() + del temp_object1 def tearDown(self): os.remove(self.tarball_name)