Skip to content

Commit

Permalink
Refactor chunker classes to use superclass tokenizer initialization a…
Browse files Browse the repository at this point in the history
…nd update encoding methods
  • Loading branch information
bhavnicksm committed Nov 6, 2024
1 parent fe9f46a commit f13fcbb
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 9 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,14 @@ pip install chonkie[all]
Here's a basic example to get you started:

```python
# First import the chunker you want from Chonkie
from chonkie import TokenChunker

# Import your favorite tokenizer library
# Also supports AutoTokenizers, TikToken and AutoTikTokenizer
from tokenizers import Tokenizer
tokenizer = Tokenizer.from_pretrained("gpt2)

# Initialize the chunker
chunker = TokenChunker()

Expand Down
5 changes: 3 additions & 2 deletions src/chonkie/chunker/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def __init__(
ValueError: If parameters are invalid
ImportError: If required dependencies aren't installed
"""
super().__init__(tokenizer)

if max_chunk_size <= 0:
raise ValueError("max_chunk_size must be positive")
if similarity_threshold is not None and (similarity_threshold < 0 or similarity_threshold > 1):
Expand All @@ -65,7 +67,6 @@ def __init__(
if sentence_mode not in ["heuristic", "spacy"]:
raise ValueError("sentence_mode must be 'heuristic' or 'spacy'")

self.tokenizer = tokenizer
self.max_chunk_size = max_chunk_size
self.similarity_threshold = similarity_threshold
self.similarity_percentile = similarity_percentile
Expand Down Expand Up @@ -159,7 +160,7 @@ def _prepare_sentences(self, text: str) -> List[Sentence]:
embeddings = self.sentence_transformer.encode(raw_sentences, convert_to_numpy=True)

# Batch compute token counts
token_counts = [len(encoding.ids) for encoding in self.tokenizer.encode_batch(raw_sentences)]
token_counts = [len(encoding) for encoding in self._encode_batch(raw_sentences)]

# Create Sentence objects with all precomputed information
sentences = [
Expand Down
7 changes: 4 additions & 3 deletions src/chonkie/chunker/sentence.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(
ValueError: If parameters are invalid
Warning: If spacy mode is requested but spacy is not available
"""
super().__init__(tokenizer)

if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
if chunk_overlap >= chunk_size:
Expand All @@ -68,7 +70,6 @@ def __init__(
if min_sentences_per_chunk < 1:
raise ValueError("min_sentences_per_chunk must be at least 1")

self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.min_sentences_per_chunk = min_sentences_per_chunk
Expand Down Expand Up @@ -191,8 +192,8 @@ def _get_token_counts(self, sentences: List[str]) -> List[int]:
List of token counts for each sentence
"""
# Batch encode all sentences at once
encoded_sentences = self.tokenizer.encode_batch(sentences)
return [len(encoded.ids) for encoded in encoded_sentences]
encoded_sentences = self._encode_batch(sentences)
return [len(encoded) for encoded in encoded_sentences]

def _create_chunk(
self,
Expand Down
9 changes: 5 additions & 4 deletions src/chonkie/chunker/word.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@ def __init__(self, tokenizer: Tokenizer, chunk_size: int = 512, chunk_overlap: i
Raises:
ValueError: If chunk_size <= 0 or chunk_overlap >= chunk_size or invalid mode
"""
super().__init__(tokenizer)

if chunk_size <= 0:
raise ValueError("chunk_size must be positive")
if chunk_overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
if mode not in ["simple", "advanced"]:
raise ValueError("mode must be either 'heuristic' or 'advanced'")

self.tokenizer = tokenizer
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
self.mode = mode
Expand Down Expand Up @@ -139,7 +140,7 @@ def _get_token_count(self, text: str) -> int:
Returns:
Number of tokens
"""
return len(self.tokenizer.encode(text).ids)
return len(self._encode(text))

def _create_chunk(self, words: List[str], start_idx: int, end_idx: int) -> Tuple[Chunk, int]:
"""Create a chunk from a list of words.
Expand Down Expand Up @@ -170,8 +171,8 @@ def _get_word_list_token_counts(self, words: List[str]) -> List[int]:
Returns:
List of token counts for each word
"""
encodings = self.tokenizer.encode_batch(words)
return [len(encoding.ids) for encoding in encodings]
encodings = self._encode_batch(words)
return [len(encoding) for encoding in encodings]

def chunk(self, text: str) -> List[Chunk]:
"""Split text into overlapping chunks based on words while respecting token limits.
Expand Down

0 comments on commit f13fcbb

Please sign in to comment.