Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add MiniLM cross-encoder reranker #200

Merged

Conversation

hugoabonizio
Copy link
Contributor

This PR adds a new reranker based on a MiniLM cross-encoder pretrained on MS-MARCO provided by the SentenceTransformers package. MiniLM-based models are much faster than MonoT5/MonoBERT while maintaining a similar result when compared to MonoT5.

Although it adds a new dependency, SentenceTransformers handles smart batching and other optimizations that otherwise would need to be implemented in PyGaggle. I also added AMP integration for FP16 inference, which should make inferences even faster (I can add this to other models if that's desirable).

Besides MiniLM-based models, there is a list of supported models.

Copy link
Member

@ronakice ronakice left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this! Some comments, yes it will be great if you can add AMP support in separate PR after we get this merged.

@@ -247,3 +249,29 @@ def rescore(self, query: Query, texts: List[Text]) -> List[Text]:
text.score = max(smax_val.item(), emax_val.item())

return texts


class CrossEncoderReranker(Reranker):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think calling this a CrossEncoderReranker while the others not is not a good naming convention since all of them are, how about change it to be SentenceTransformersReranker

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that'd be better!

device=None,
use_amp=None):
device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
self.use_amp = use_amp or (device == 'cuda')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would we want to default to amp if there is a gpu it could bring performance drop, more explicit a flag would be better.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I'm changing it to disabled by default.

Copy link
Member

@rodrigonogueira4 rodrigonogueira4 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a lot for doing this

@ronakice, I will let you merge when you think it is ready.

@ronakice
Copy link
Member

Thanks a lot @hugoabonizio, I'm merging!

@ronakice ronakice merged commit 0a05d43 into castorini:master Jul 12, 2021
@hugoabonizio hugoabonizio deleted the feature/add-minilm-cross-encoder branch July 12, 2021 20:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants