Skip to content

Commit

Permalink
Huggingface Tokenizer Analyzer Pyserini Binding (#1377)
Browse files Browse the repository at this point in the history
- Huggingface Tokenizer Analyzer Pyserini Binding
  • Loading branch information
ToluClassics authored Dec 8, 2022
1 parent 2287be0 commit f5a73f0
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
7 changes: 6 additions & 1 deletion pyserini/analysis/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@
JAnalyzerUtils = autoclass('io.anserini.analysis.AnalyzerUtils')
JDefaultEnglishAnalyzer = autoclass('io.anserini.analysis.DefaultEnglishAnalyzer')
JTweetAnalyzer = autoclass('io.anserini.analysis.TweetAnalyzer')
JHuggingFaceTokenizerAnalyzer = autoclass('io.anserini.analysis.HuggingFaceTokenizerAnalyzer')


def get_lucene_analyzer(language='en', stemming=True, stemmer='porter', stopwords=True) -> JAnalyzer:
def get_lucene_analyzer(language: str='en', stemming: bool=True, stemmer: str='porter', stopwords: bool=True, huggingFaceTokenizer: str=None) -> JAnalyzer:
"""Create a Lucene ``Analyzer`` with specific settings.
Parameters
Expand All @@ -64,6 +65,8 @@ def get_lucene_analyzer(language='en', stemming=True, stemmer='porter', stopword
Stemmer to use.
stopwords : bool
Set to filter stopwords.
huggingFaceTokenizer: str
a huggingface model id or path to a tokenizer.json file
Returns
-------
Expand Down Expand Up @@ -112,6 +115,8 @@ def get_lucene_analyzer(language='en', stemming=True, stemmer='porter', stopword
return JTurkishAnalyzer()
elif language.lower() == 'tweet':
return JTweetAnalyzer()
elif language.lower() == 'hgf_tokenizer':
return JHuggingFaceTokenizerAnalyzer(huggingFaceTokenizer)
elif language.lower() == 'en':
if stemming:
if stopwords:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ def test_analysis(self):
tokens = analyzer.analyze('City buses are running on time.')
self.assertEqual(tokens, ['citi', 'buse', 'ar', 'run', 'on', 'time'])

# HuggingFace analyzer, with bert wordpiece tokenizer
analyzer = Analyzer(get_lucene_analyzer(language="hgf_tokenizer", huggingFaceTokenizer="bert-base-uncased"))
self.assertTrue(isinstance(analyzer, Analyzer))
tokens = analyzer.analyze('This tokenizer generates wordpiece tokens')
self.assertEqual(tokens, ['this', 'token', '##izer', 'generates', 'word', '##piece', 'token', '##s'])

def test_invalid_analyzer_wrapper(self):
# Invalid JAnalyzer, make sure we get an exception.
with self.assertRaises(TypeError):
Expand Down

0 comments on commit f5a73f0

Please sign in to comment.