-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add MiniLM cross-encoder reranker #200
Add MiniLM cross-encoder reranker #200
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for this! Some comments, yes it will be great if you can add AMP support in separate PR after we get this merged.
pygaggle/rerank/transformer.py
Outdated
@@ -247,3 +249,29 @@ def rescore(self, query: Query, texts: List[Text]) -> List[Text]: | |||
text.score = max(smax_val.item(), emax_val.item()) | |||
|
|||
return texts | |||
|
|||
|
|||
class CrossEncoderReranker(Reranker): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think calling this a CrossEncoderReranker while the others not is not a good naming convention since all of them are, how about change it to be SentenceTransformersReranker
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that'd be better!
pygaggle/rerank/transformer.py
Outdated
device=None, | ||
use_amp=None): | ||
device = device or ('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.use_amp = use_amp or (device == 'cuda') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why would we want to default to amp if there is a gpu it could bring performance drop, more explicit a flag would be better.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, I'm changing it to disabled by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks a lot for doing this
@ronakice, I will let you merge when you think it is ready.
Thanks a lot @hugoabonizio, I'm merging! |
This PR adds a new reranker based on a MiniLM cross-encoder pretrained on MS-MARCO provided by the SentenceTransformers package. MiniLM-based models are much faster than MonoT5/MonoBERT while maintaining a similar result when compared to MonoT5.
Although it adds a new dependency, SentenceTransformers handles smart batching and other optimizations that otherwise would need to be implemented in PyGaggle. I also added AMP integration for FP16 inference, which should make inferences even faster (I can add this to other models if that's desirable).
Besides MiniLM-based models, there is a list of supported models.