diff --git a/.gitignore b/.gitignore index 7fec20a..d1dc9dd 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,10 @@ dmypy.json # Pyre type checker .pyre/ -/.idea/* \ No newline at end of file +/.idea/* + +# IDEs +*.iml + +# misc +*~ diff --git a/Makefile b/Makefile index 82129f5..7312987 100644 --- a/Makefile +++ b/Makefile @@ -4,9 +4,11 @@ format: python -m docformatter . lint: - python -m flake8 . - python -m pylint docdeid/ - python -m mypy docdeid/ + { python -m flake8 .; fret=$$?; }; \ + { python -m pylint docdeid/; pret=$$?; }; \ + { python -m mypy docdeid/; mret=$$?; }; \ + echo "flake8: $$fret, pylint: $$pret, mypy: $$mret"; \ + [ $$fret,$$pret,$$mret = "0,0,0" ] build-docs: sphinx-apidoc --module-first --force --templatedir=docs/templates -o docs/source/api docdeid diff --git a/docdeid/annotation.py b/docdeid/annotation.py index 60fd533..a52fa0c 100644 --- a/docdeid/annotation.py +++ b/docdeid/annotation.py @@ -46,7 +46,7 @@ class Annotation: # pylint: disable=R0902 Should only be used when the annotation ends on a token boundary. """ - length: int = field(init=False) + length: int = field(init=False, compare=False) """The number of characters of the annotation text.""" _key_cache: dict = field(default_factory=dict, repr=False, compare=False) @@ -100,7 +100,7 @@ def get_sort_key( val = getattr(self, attr, UNKNOWN_ATTR_DEFAULT) - if callbacks is not None and (attr in callbacks): + if callbacks is not None and attr in callbacks: val = callbacks[attr](val) sort_key.append(val) @@ -126,6 +126,9 @@ class AnnotationSet(set[Annotation]): It extends the builtin ``set``. """ + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + def sorted( self, by: tuple, # pylint: disable=C0103 @@ -150,14 +153,14 @@ def sorted( A RunTimeError, if the callbacks are not provided as a frozen dict. """ - if callbacks is not None and not isinstance(callbacks, frozendict): + if not isinstance(callbacks, (type(None), frozendict)): raise RuntimeError( "Please provide the callbacks as a frozen dict, e.g. " "frozendict.frozendict(end_char=lambda x: -x)" ) return sorted( - list(self), + self, key=lambda x: x.get_sort_key( by=by, callbacks=callbacks, deterministic=deterministic ), diff --git a/docdeid/direction.py b/docdeid/direction.py new file mode 100644 index 0000000..7c53507 --- /dev/null +++ b/docdeid/direction.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from enum import IntEnum +from typing import Iterable, Sequence, TypeVar + +T = TypeVar("T") + + +class Direction(IntEnum): + """Direction in text -- either left or right.""" + + LEFT = -1 + RIGHT = 1 + + @property + def opposite(self) -> Direction: + """The opposite direction to this.""" + return Direction(-self) + + @staticmethod + def from_string(val: str) -> Direction: + """Parses a Direction from a string (case insensitive).""" + try: + return Direction[val.upper()] + except KeyError as key_error: + raise ValueError(f"Invalid direction: '{val}'") from key_error + + def iter(self, seq: Sequence[T]) -> Iterable[T]: + """ + Returns an iterator over the given sequence that traverses it in this direction. + + Args: + seq: sequence to iterate over + """ + if self is Direction.RIGHT: + return seq + return reversed(seq) diff --git a/docdeid/document.py b/docdeid/document.py index dd515ce..b7a8444 100644 --- a/docdeid/document.py +++ b/docdeid/document.py @@ -1,7 +1,12 @@ +from collections import defaultdict +from collections.abc import Mapping +from dataclasses import dataclass from typing import Any, Optional -from docdeid.annotation import AnnotationSet -from docdeid.tokenizer import Tokenizer, TokenList +from frozendict import frozendict + +from docdeid.annotation import Annotation, AnnotationSet +from docdeid.tokenizer import Token, Tokenizer, TokenList class MetaData: @@ -66,6 +71,12 @@ class Document: Will be stored in a :class:`.MetaData` object. """ + @dataclass + class AnnosByToken: + """A cache entry associating an `AnnotationSet` with a token->annos map.""" + anno_set: AnnotationSet + value: defaultdict[Token, set[Annotation]] + def __init__( self, text: str, @@ -74,7 +85,9 @@ def __init__( ) -> None: self._text = text - self._tokenizers = tokenizers + self._tokenizers = None if tokenizers is None else frozendict(tokenizers) + self._default_annos_by_token = Document.AnnosByToken(None, None) + self._tmp_annos_by_token = Document.AnnosByToken(None, None) self.metadata = MetaData(metadata) """The :class:`.MetaData` of this :class:`.Document`, that can be interacted @@ -94,6 +107,13 @@ def text(self) -> str: """ return self._text + @property + def tokenizers(self) -> Mapping[str, Tokenizer]: + """Available tokenizers indexed by their name.""" + if self._tokenizers is None: + raise RuntimeError("No tokenizers initialized.") + return self._tokenizers + def get_tokens(self, tokenizer_name: str = "default") -> TokenList: """ Get the tokens corresponding to the input text, for a specific tokenizer. @@ -146,6 +166,62 @@ def annotations(self, annotations: AnnotationSet) -> None: """ self._annotations = annotations + def annos_by_token( + self, + annos: AnnotationSet = None, + ) -> defaultdict[Token, set[Annotation]]: + """ + Returns a mapping from document tokens to annotations. + + Args: + annos: annotations for this document to index by token (default: current + annotations of this `Document`) + """ + + # Fill the default arg value. + if annos is None: + eff_annos = self._annotations + cache = self._default_annos_by_token + else: + eff_annos = annos + cache = self._tmp_annos_by_token + + # Try to use a cached response. + if eff_annos == cache.anno_set: + return cache.value + + # Compute the return value. + annos_by_token = defaultdict(set) + for tokenizer in self.tokenizers: + token_list = self.get_tokens(tokenizer) + if not token_list: + continue + cur_tok_idx = 0 + tok = token_list[cur_tok_idx] + for anno in eff_annos.sorted(by=("start_char",)): + try: + # Iterate over tokens till we reach the annotation. + while tok.end_char < anno.start_char: + cur_tok_idx += 1 + tok = token_list[cur_tok_idx] + except IndexError: + break + # Iterate over tokens in the annotation till we reach the end + # of it or the end of the tokens. + anno_tok_idx = cur_tok_idx + anno_tok = tok + while anno_tok.start_char < anno.end_char: + annos_by_token[anno_tok].add(anno) + if anno_tok_idx == len(token_list) - 1: + break + anno_tok_idx += 1 + anno_tok = token_list[anno_tok_idx] + + # Cache the value before returning. + cache.anno_set = eff_annos + cache.value = annos_by_token + return annos_by_token + @property def deidentified_text(self) -> Optional[str]: """ diff --git a/docdeid/ds/lookup.py b/docdeid/ds/lookup.py index 4df8bdc..f0daa49 100644 --- a/docdeid/ds/lookup.py +++ b/docdeid/ds/lookup.py @@ -2,6 +2,7 @@ import codecs import itertools +from collections.abc import Sequence from typing import Iterable, Iterator, Optional, Union from docdeid.ds.ds import Datastructure @@ -140,7 +141,7 @@ def add_items_from_self( ) -> None: """ Add items from self (this items of this :class:`.LookupSet`). This can be used - to do a transformation or replacment of the items. + to do a transformation or replacement of the items. Args: cleaning_pipeline: A cleaning pipeline applied to the items of this set. @@ -265,7 +266,7 @@ def __init__(self, *args, **kwargs) -> None: self.children: dict[str, LookupTrie] = {} self.is_terminal = False - def add_item(self, item: list[str]) -> None: + def add_item(self, item: Sequence[str]) -> None: """ Add an item, i.e. a list of strings, to this Trie. @@ -285,7 +286,7 @@ def add_item(self, item: list[str]) -> None: self.children[head].add_item(tail) - def __contains__(self, item: list[str]) -> bool: + def __contains__(self, item: Sequence[str]) -> bool: """ Whether the trie contains the item. Respects the matching pipeline. @@ -304,7 +305,7 @@ def __contains__(self, item: list[str]) -> bool: return (head in self.children) and tail in self.children[head] def longest_matching_prefix( - self, item: list[str], start_i: int = 0 + self, item: Sequence[str], start_i: int = 0 ) -> Union[list[str], None]: """ Finds the longest matching prefix of a list of strings. This is used to find the diff --git a/docdeid/process/__init__.py b/docdeid/process/__init__.py index 79387f1..6a40be3 100644 --- a/docdeid/process/__init__.py +++ b/docdeid/process/__init__.py @@ -6,9 +6,10 @@ from .annotator import ( Annotator, MultiTokenLookupAnnotator, + MultiTokenTrieAnnotator, RegexpAnnotator, + SequenceAnnotator, SingleTokenLookupAnnotator, - TokenPatternAnnotator, ) from .doc_processor import DocProcessor, DocProcessorGroup from .redactor import RedactAllText, Redactor, SimpleRedactor diff --git a/docdeid/process/annotation_processor.py b/docdeid/process/annotation_processor.py index 0b3c277..41e9507 100644 --- a/docdeid/process/annotation_processor.py +++ b/docdeid/process/annotation_processor.py @@ -60,7 +60,7 @@ def __init__( @staticmethod def _zero_runs(arr: npt.NDArray) -> npt.NDArray: """ - Finds al zero runs in a numpy array. + Finds all zero runs in a numpy array. Source: https://stackoverflow.com/questions/24885092/ finding-the-consecutive-zeros-in-a-numpy-array @@ -68,7 +68,7 @@ def _zero_runs(arr: npt.NDArray) -> npt.NDArray: arr: The input array. Returns: - A (num_zero_runs, 2)-dim array, containing the start and end indeces + A (num_zero_runs, 2)-dim array, containing the start and end indices of the zero runs. Examples: diff --git a/docdeid/process/annotator.py b/docdeid/process/annotator.py index 60689df..8b7f5c2 100644 --- a/docdeid/process/annotator.py +++ b/docdeid/process/annotator.py @@ -1,21 +1,68 @@ +from __future__ import annotations + import re +import warnings + from abc import ABC, abstractmethod -from typing import Iterable, Optional, Union +from collections import defaultdict +from dataclasses import dataclass +from typing import Iterable, Literal, Optional, Union import docdeid.str from docdeid.annotation import Annotation +from docdeid.direction import Direction from docdeid.document import Document +from docdeid.ds import DsCollection from docdeid.ds.lookup import LookupSet, LookupTrie from docdeid.pattern import TokenPattern from docdeid.process.doc_processor import DocProcessor from docdeid.str.processor import StringModifier from docdeid.tokenizer import Token, Tokenizer +from docdeid.utils import leaf_items + + +@dataclass +class SimpleTokenPattern: + """A pattern for a token (and possibly its annotation, too).""" + + func: Literal[ + "equal", + "re_match", + "is_initial", + "is_initials", + "like_name", + "lookup", + "neg_lookup", + "tag", + ] + pattern: str + + +@dataclass +class NestedTokenPattern: + """Coordination of token patterns.""" + + func: Literal["and", "or"] + pattern: list[TokenPatternFromCfg] + + +TokenPatternFromCfg = Union[SimpleTokenPattern, NestedTokenPattern] + + +@dataclass +class SequencePattern: + """Pattern for matching a sequence of tokens.""" + + direction: Direction + skip: set[str] + pattern: list[TokenPatternFromCfg] + class Annotator(DocProcessor, ABC): """ Abstract class for annotators, which are responsible for generating annotations from - a given document. Instatiations should implement the annotate method. + a given document. Instantiations should implement the annotate method. Args: tag: The tag to use in the annotations. @@ -46,6 +93,69 @@ def annotate(self, doc: Document) -> list[Annotation]: A list of annotations. """ + def _match_sequence( + self, + doc: Document, + seq_pattern: SequencePattern, + start_token: Token, + annos_by_token: defaultdict[Token, Iterable[Annotation]], + dicts: Optional[DsCollection], + ) -> Optional[Annotation]: + """ + Matches a token sequence pattern at `start_token`. + + Args: + doc: The document. + seq_pattern: The pattern to match. + start_token: The start token to match. + annos_by_token: Map from tokens to annotations covering it. + dicts: Lookup dictionaries available. + + Returns: + An Annotation if matching is possible, None otherwise. + """ + + dir_ = seq_pattern.direction + + tokens = ( + token + for token in start_token.iter_to(dir_) + if token.text not in seq_pattern.skip + ) + # Iterate the token patterns in the direction corresponding to the surface + # order it's supposed to match (i.e. "left" means "iterate patterns from the + # end"). + tok_patterns = dir_.iter(seq_pattern.pattern) + + num_matched = 0 + end_token = start_token + for tok_pattern, end_token in zip(tok_patterns, tokens): + if _PatternPositionMatcher.match( + token_pattern=tok_pattern, + token=end_token, + annos=annos_by_token[end_token], + ds=dicts, + metadata=doc.metadata, + ): + num_matched += 1 + else: + break + + if num_matched != len(seq_pattern.pattern): + return None + + left_token, right_token = dir_.iter((start_token, end_token)) + + return Annotation( + text=doc.text[left_token.start_char : right_token.end_char], + start_char=left_token.start_char, + end_char=right_token.end_char, + tag=self.tag, + priority=self.priority, + start_token=left_token, + end_token=right_token, + ) + class SingleTokenLookupAnnotator(Annotator): """ @@ -100,79 +210,39 @@ def annotate(self, doc: Document) -> list[Annotation]: return self._tokens_to_annotations(annotate_tokens) -class MultiTokenLookupAnnotator(Annotator): +class MultiTokenTrieAnnotator(Annotator): """ - Matches lookup values against tokens, where the ``lookup_values`` may themselves be - sequences. + Annotates entity mentions by looking them up in a `LookupTrie`. Args: - lookup_values: An iterable of strings, that should be matched. These are - tokenized internally. - matching_pipeline: An optional pipeline that can be used for matching - (e.g. lowercasing). This has no specific impact on matching performance, - other than overhead for applying the pipeline to each string. - tokenizer: A tokenizer that is used to create the sequence patterns from - ``lookup_values``. - trie: A trie that is used for matching, rather than a combination of - `lookup_values` and a `matching_pipeline` (cannot be used simultaneously). - overlapping: Whether the annotator should match overlapping sequences, - or should process from left to right. - - Raises: - RunTimeError, when an incorrect combination of `lookup_values`, - `matching_pipeline` and `trie` is supplied. + trie: The `LookupTrie` containing all entity mentions that should be annotated. + overlapping: Whether overlapping phrases are to be returned. + *args, **kwargs: Passed through to the `Annotator` constructor (which accepts + the arguments `tag` and `priority`). """ def __init__( - self, - *args, - lookup_values: Optional[Iterable[str]] = None, - matching_pipeline: Optional[list[StringModifier]] = None, - tokenizer: Optional[Tokenizer] = None, - trie: Optional[LookupTrie] = None, - overlapping: bool = False, - **kwargs, + self, + *args, + trie: LookupTrie, + overlapping: bool = False, + **kwargs, ) -> None: - self._start_words: set[str] = set() - - if (trie is not None) and (lookup_values is None) and (tokenizer is None): - - self._trie = trie - self._matching_pipeline = trie.matching_pipeline or [] - self._start_words = set(trie.children.keys()) - - elif (trie is None) and (lookup_values is not None) and (tokenizer is not None): - self._matching_pipeline = matching_pipeline or [] - self._trie = LookupTrie(matching_pipeline=matching_pipeline) - self._init_lookup_structures(lookup_values, tokenizer) - - else: - raise RuntimeError( - "Please provide either looup_values and a tokenizer, or a trie." - ) - - self.overlapping = overlapping + self._trie = trie + self._overlapping = overlapping + self._start_words = set(trie.children) super().__init__(*args, **kwargs) - def _init_lookup_structures( - self, lookup_values: Iterable[str], tokenizer: Tokenizer - ) -> None: - - for val in lookup_values: - - texts = [token.text for token in tokenizer.tokenize(val)] - - if len(texts) > 0: - self._trie.add_item(texts) - - start_token = texts[0] - - for string_modifier in self._matching_pipeline: - start_token = string_modifier.process(start_token) - - self._start_words.add(start_token) + @property + def start_words(self) -> set[str]: + """First words of phrases detected by this annotator.""" + # If the trie has been modified (added to) since we computed _start_words, + if len(self._start_words) != len(self._trie.children): + # Recompute _start_words. + self._start_words = set(self._trie.children) + return self._start_words def annotate(self, doc: Document) -> list[Annotation]: @@ -180,7 +250,7 @@ def annotate(self, doc: Document) -> list[Annotation]: start_tokens = sorted( tokens.token_lookup( - self._start_words, matching_pipeline=self._matching_pipeline + self.start_words, matching_pipeline=self._trie.matching_pipeline ), key=lambda token: token.start_char, ) @@ -218,12 +288,66 @@ def annotate(self, doc: Document) -> list[Annotation]: ) ) - if not self.overlapping: + if not self._overlapping: min_i = i + len(longest_matching_prefix) # skip ahead return annotations +class MultiTokenLookupAnnotator(MultiTokenTrieAnnotator): + """ + Annotates entity mentions by looking them up in a `LookupTrie` or + a collection of phrases. This is a thin wrapper for + class:`MultiTokenTrieAnnotator` that additionally handles non-trie lookup + structures by building tries out of them and delegating to the parent class. + + Args: + lookup_values: An iterable of phrases that should be matched. These are + tokenized using ``tokenizer``. + matching_pipeline: An optional pipeline that can be used for matching + (e.g. lowercasing). This has no specific impact on matching performance, + other than overhead for applying the pipeline to each string. + tokenizer: A tokenizer that is used to create the sequence patterns from + ``lookup_values``. + trie: A `LookupTrie` containing all entity mentions that should be + annotated. Specifying this is mutually exclusive with specifying + ``lookup_values`` and ``tokenizer``. + overlapping: Whether overlapping phrases are to be returned. + *args, **kwargs: Passed through to the `Annotator` constructor (which accepts + the arguments `tag` and `priority`). + + Raises: + RunTimeError, when an incorrect combination of `lookup_values`, + `matching_pipeline` and `trie` is supplied. + """ + + def __init__( + self, + *args, + lookup_values: Optional[Iterable[str]] = None, + matching_pipeline: Optional[list[StringModifier]] = None, + tokenizer: Optional[Tokenizer] = None, + trie: Optional[LookupTrie] = None, + overlapping: bool = False, + **kwargs, + ) -> None: + + if (trie is not None) and (lookup_values is None) and (tokenizer is None): + eff_trie = trie + + elif (trie is None) and (lookup_values is not None) and (tokenizer is not None): + eff_trie = LookupTrie(matching_pipeline=matching_pipeline) + for phrase in filter(None, map(tokenizer.tokenize, lookup_values)): + eff_trie.add_item([token.text for token in phrase]) + + else: + raise RuntimeError( + "Please provide either looup_values and a tokenizer, or a trie." + ) + + super().__init__(*args, trie=eff_trie, overlapping=overlapping, **kwargs) + + class RegexpAnnotator(Annotator): """ Create annotations based on regular expression patterns. Note that these patterns do @@ -348,3 +472,205 @@ def annotate(self, doc: Document) -> list[Annotation]: ) return annotations + + +class _PatternPositionMatcher: # pylint: disable=R0903 + """Checks if a token matches against a single pattern.""" + + @classmethod + def match(cls, token_pattern: Union[dict, TokenPatternFromCfg], **kwargs) -> bool: + # pylint: disable=R0911 + """ + Matches a pattern position (a dict with one key). Other information should be + presented as kwargs. + + Args: + token_pattern: A dictionary with a single key, e.g. {'is_initial': True} + kwargs: Any other information, like the token or ds + + Returns: + True if the pattern position matches, false otherwise. + """ + + if isinstance(token_pattern, dict): + return cls.match(as_token_pattern(token_pattern), **kwargs) + + func = token_pattern.func + value = token_pattern.pattern + + if func == "equal": + return kwargs["token"].text == value + if func == "re_match": + return re.match(value, kwargs["token"].text) is not None + if func == "is_initial": + + warnings.warn( + "is_initial matcher pattern is deprecated and will be removed " + "in a future version", + DeprecationWarning, + ) + + return ( + (len(kwargs["token"].text) == 1 and kwargs["token"].text[0].isupper()) + or kwargs["token"].text in {"Ch", "Chr", "Ph", "Th"} + ) == value + if func == "is_initials": + return ( + len(kwargs["token"].text) <= 4 and kwargs["token"].text.isupper() + ) == value + if func == "like_name": + return ( + len(kwargs["token"].text) >= 3 + and kwargs["token"].text.istitle() + and not any(ch.isdigit() for ch in kwargs["token"].text) + ) == value + if func == "lookup": + return cls._lookup(value, **kwargs) + if func == "neg_lookup": + return not cls._lookup(value, **kwargs) + if func == "tag": + annos = kwargs.get("annos", ()) + return any(anno.tag == value for anno in annos) + if func == "and": + return all(cls.match(x, **kwargs) for x in value) + if func == "or": + return any(cls.match(x, **kwargs) for x in value) + + raise NotImplementedError(f"No known logic for pattern {func}") + + @classmethod + def _lookup(cls, ent_type: str, **kwargs) -> bool: + token = kwargs["token"].text + if "." in ent_type: + meta_key, meta_attr = ent_type.split(".", 1) + try: + meta_val = getattr(kwargs["metadata"][meta_key], meta_attr) + except (TypeError, KeyError, AttributeError): + return False + return token == meta_val if isinstance(meta_val, str) else token in meta_val + else: # pylint: disable=R1705 + return token in kwargs["ds"][ent_type] + + +def as_token_pattern(pat_dict: dict) -> TokenPatternFromCfg: + """ + Converts the JSON dictionary representation of token patterns into a + `TokenPatternFromCfg` instance. + + Args: + pat_dict: the JSON representation of the pattern + """ + if len(pat_dict) != 1: + raise ValueError( + f"Cannot parse a token pattern which doesn't have exactly 1 key: " + f"{pat_dict}." + ) + func, value = next(iter(pat_dict.items())) + if func in ("and", "or"): + return NestedTokenPattern(func, [as_token_pattern(it) for it in value]) + return SimpleTokenPattern(func, value) + + +class SequenceAnnotator(Annotator): + """ + Annotates based on token patterns, which should be provided as a list of dicts. Each + position in the list corresponds to a token. For example: + ``[{'is_initial': True}, {'like_name': True}]`` matches sequences of two tokens + where the first one is an initial and the second one looks like a name. + + Arguments: + pattern: The pattern + ds: Lookup dictionaries. Those referenced by the pattern should be LookupSets. + (Don't ask why.) + skip: Any string values that should be skipped in matching (e.g. periods) + """ + + def __init__( + self, + pattern: list[dict], + *args, + ds: Optional[DsCollection] = None, + skip: Optional[list[str]] = None, + **kwargs, + ) -> None: + self.pattern = pattern + self.ds = ds + + self._start_words = None + self._start_matching_pipeline = None + + SequenceAnnotator.validate_pattern(pattern, ds) + + # If the first token pattern is lookup, determine the possible starting words. + if start_ent_type := pattern[0].get("lookup"): + # XXX We assume the items of the lookup list are all single words. This + # is not always the case but just splitting the phrases wouldn't help + # because the "lookup" token matcher assumes matching against a single + # token. + self._start_words = ds[start_ent_type].items() + self._start_matching_pipeline = ds[start_ent_type].matching_pipeline + + self._seq_pattern = SequencePattern( + Direction.RIGHT, set(skip or ()), [as_token_pattern(it) for it in pattern] + ) + + super().__init__(*args, **kwargs) + + @classmethod + def validate_pattern(cls, pattern, ds): + if not pattern: + raise ValueError(f"Sequence pattern is missing or empty: {pattern}.") + + referenced_ents = {match_val + for tok_pattern in pattern + for func, match_val in leaf_items(tok_pattern) + if func.endswith("lookup") and "." not in match_val} + if referenced_ents and ds is None: + raise ValueError("Pattern relies on entity lookups but no lookup " + "structures were provided.") + + if missing := referenced_ents - set(ds or ()): + raise ValueError("Unknown lookup entity types: {}." + .format(", ".join(sorted(missing)))) + + if start_ent_type := pattern[0].get("lookup"): + if not isinstance(ds[start_ent_type], LookupSet): + raise ValueError('If the first token pattern is lookup, it must be ' + f'backed by a LookupSet, but "{start_ent_type}" is ' + f'backed by a {type(ds[start_ent_type]).__name__}.') + + def annotate(self, doc: Document) -> list[Annotation]: + """ + Annotate the document, by matching the pattern against all tokens. + + Args: + doc: The document being processed. + + Returns: + A list of Annotation. + """ + + annotations = [] + + token_list = doc.get_tokens() + + if self._start_words is not None: + tokens: Iterable[Token] = token_list.token_lookup( + lookup_values=self._start_words, + matching_pipeline=self._start_matching_pipeline, + ) + else: + tokens = token_list # ...to make Mypy happy. + + annos_by_token = doc.annos_by_token() + + for token in tokens: + + annotation = self._match_sequence( + doc, self._seq_pattern, token, annos_by_token, self.ds + ) + + if annotation is not None: + annotations.append(annotation) + + return annotations diff --git a/docdeid/process/doc_processor.py b/docdeid/process/doc_processor.py index 6db05f7..1e12115 100644 --- a/docdeid/process/doc_processor.py +++ b/docdeid/process/doc_processor.py @@ -1,8 +1,12 @@ +import logging from abc import ABC, abstractmethod from collections import OrderedDict from typing import Iterator, Optional, Union from docdeid.document import Document +from docdeid.utils import annotate_doc + +_ROOT_LOGGER = logging.getLogger() class DocProcessor(ABC): # pylint: disable=R0903 @@ -28,7 +32,7 @@ class DocProcessorGroup: def __init__(self) -> None: self._processors: OrderedDict[ - str, Union[DocProcessor | DocProcessorGroup] + str, Union[DocProcessor, DocProcessorGroup] ] = OrderedDict() def get_names(self, recursive: bool = True) -> list[str]: @@ -143,6 +147,9 @@ def process( elif isinstance(proc, DocProcessorGroup): proc.process(doc, enabled=enabled, disabled=disabled) + if _ROOT_LOGGER.isEnabledFor(logging.DEBUG): + logging.debug("after %s: %s", name, annotate_doc(doc)) + def __iter__(self) -> Iterator: return iter(self._processors.items()) diff --git a/docdeid/str/__init__.py b/docdeid/str/__init__.py index 7e6db61..2eba56a 100644 --- a/docdeid/str/__init__.py +++ b/docdeid/str/__init__.py @@ -1,6 +1,7 @@ from .processor import ( FilterByLength, LowercaseString, + LowercaseTail, RemoveNonAsciiCharacters, ReplaceNonAsciiCharacters, ReplaceValue, diff --git a/docdeid/str/processor.py b/docdeid/str/processor.py index b1023a5..44ee468 100644 --- a/docdeid/str/processor.py +++ b/docdeid/str/processor.py @@ -74,6 +74,27 @@ def process(self, item: str) -> str: return item.casefold() +_WORD_RX = re.compile("\\w+", re.U) + + +class LowercaseTail(StringModifier): + """Lowercases the tail of words.""" + + def __init__(self, lang: str = "nl") -> None: + self._lang = lang + + def _process_word_match(self, match: re.Match) -> str: + word = match.group(0) + if word.isupper(): + if self._lang == "nl" and word.startswith("IJ"): + return word[0:2] + word[2:].lower() + return word[0] + word[1:].lower() + return word + + def process(self, item: str) -> str: + return _WORD_RX.sub(self._process_word_match, item) + + class StripString(StringModifier): """ Strip string (whitespaces, tabs, newlines, etc. diff --git a/docdeid/tokenizer.py b/docdeid/tokenizer.py index 8813caf..9e69208 100644 --- a/docdeid/tokenizer.py +++ b/docdeid/tokenizer.py @@ -1,11 +1,14 @@ from __future__ import annotations import re +import sys from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Generator, Iterator, Sequence from dataclasses import dataclass, field -from typing import Iterator, Literal, Optional +from typing import Literal, Optional, SupportsIndex, overload +from docdeid.direction import Direction from docdeid.str import StringModifier @@ -120,6 +123,21 @@ def next(self, num: int = 1) -> Optional[Token]: """ return self._get_linked_token(num=num, attr="_next_token") + def iter_to( + self, + dir_: Direction = Direction.RIGHT, + ) -> Generator[Token, None, None]: + """ + Iterates linked tokens in the specified direction. + + Args: + dir_: direction to go + """ + token: Optional[Token] = self + while token is not None: + yield token + token = token.next() if dir_ is Direction.RIGHT else token.previous() + def __len__(self) -> int: """ The length of the text. @@ -130,7 +148,7 @@ def __len__(self) -> int: return len(self.text) -class TokenList: +class TokenList(Sequence[Token]): """ Contains a sequence of tokens, along with some lookup logic. @@ -248,9 +266,29 @@ def __len__(self) -> int: return len(self._tokens) + @overload def __getitem__(self, index: int) -> Token: + ... + + @overload + def __getitem__(self, indexes: slice) -> Sequence[Token]: + ... - return self._tokens[index] + def __getitem__(self, item): + return self._tokens[item] + + def index( + self, + __token: Token, + __start: SupportsIndex = 0, + __stop: SupportsIndex = sys.maxsize, + ) -> int: + try: + return self._token_index[__token] + except KeyError: + # Raise a plain ValueError, just like list.index. + # pylint: disable=W0707 + raise ValueError(f"'{__token}' is not in TokenList") def __eq__(self, other: object) -> bool: """ @@ -317,6 +355,15 @@ def tokenize(self, text: str) -> TokenList: return TokenList(tokens, link_tokens=self.link_tokens) +class DummyTokenizer(Tokenizer): # pylint: disable=R0903 + """ + Treats any given string as a single token. + """ + + def _split_text(self, text: str) -> list[Token]: + return [Token(text=text, start_char=0, end_char=len(text))] + + class SpaceSplitTokenizer(Tokenizer): # pylint: disable=R0903 """ Tokenizes based on splitting on whitespaces. @@ -333,11 +380,17 @@ def _split_text(self, text: str) -> list[Token]: class WordBoundaryTokenizer(Tokenizer): # pylint: disable=R0903 """ - Tokenizes based on word boundary. + Tokenizes based on word boundary. Sequences of non-alphanumeric characters are also + represented as tokens. - Whitespaces and similar characters are included as tokens. + Args: + keep_blanks: Keep whitespace in tokens, and whitespace-only tokens? """ + def __init__(self, keep_blanks: bool = True) -> None: + super().__init__() + self._trim = not keep_blanks + def _split_text(self, text: str) -> list[Token]: tokens = [] matches = [*re.finditer(r"\b", text)] @@ -347,9 +400,20 @@ def _split_text(self, text: str) -> list[Token]: start_char = start_match.span(0)[0] end_char = end_match.span(0)[0] + if self._trim: + word = text[start_char:end_char] + word = word.rstrip() + end_char = start_char + len(word) + word = word.lstrip() + start_char = end_char - len(word) + if not word: + continue + else: + word = text[start_char:end_char] + tokens.append( Token( - text=text[start_char:end_char], + text=word, start_char=start_char, end_char=end_char, ) diff --git a/docdeid/utils.py b/docdeid/utils.py index 1d3cf7c..d5bfe5a 100644 --- a/docdeid/utils.py +++ b/docdeid/utils.py @@ -1,8 +1,38 @@ +from collections import defaultdict +from collections.abc import Generator, Iterable, Iterator, Mapping +from typing import Any, Optional + from frozendict import frozendict from docdeid.document import Document +def leaf_items(json_struct: Mapping) -> Iterator[tuple]: + """ + Generates all `(key, value)` items that appear as leaves of the potentially deeply + nested JSON-like structure `json_struct`, where being a leaf item means that + `key` is associated with a `value` in a dict and `value` is of an atomic type + (such as a `str` but not list-like or map-like). + + :param json_struct: nested structure to iterate + :return: generator of leaf `(key, value)` items + """ + return __leaf_items(json_struct, None) + + +def __leaf_items(obj: Any, par_key: Optional[str]) -> Generator[tuple, None, None]: + if isinstance(obj, Mapping): + for key, val in obj.items(): + for item in __leaf_items(val, key): + yield item + elif isinstance(obj, Iterable) and not isinstance(obj, (bytes, str)): + for member in obj: + for item in __leaf_items(member, None): + yield item + elif par_key is not None: + yield par_key, obj + + def annotate_intext(doc: Document) -> str: """ Annotate intext, which can be useful to compare the annotations of two different @@ -32,3 +62,30 @@ def annotate_intext(doc: Document) -> str: ) return text + + +def annotate_doc(doc: Document) -> str: + """ + Adds XML-like markup for annotations into the text of a document. + + Handles also nested mentions and in a way also overlapping mentions, even though + this kind of markup cannot really represent them. + """ + annos_from_shortest = doc.annotations.sorted(by=("length", )) + idx_to_anno_starts = defaultdict(list) + idx_to_anno_ends = defaultdict(list) + for anno in annos_from_shortest: + idx_to_anno_starts[anno.start_char].append(anno) + idx_to_anno_ends[anno.end_char].append(anno) + markup_indices = sorted(set(idx_to_anno_starts).union(idx_to_anno_ends)) + chunks = [] + last_idx = 0 + for idx in markup_indices: + chunks.append(doc.text[last_idx:idx]) + for ending_anno in idx_to_anno_ends[idx]: + chunks.append(f"") + for starting_anno in reversed(idx_to_anno_starts[idx]): + chunks.append(f"<{starting_anno.tag.upper()}>") + last_idx = idx + chunks.append(doc.text[last_idx:]) + return "".join(chunks) diff --git a/tests/pipeline/test_pipeline.py b/tests/pipeline/test_pipeline.py index 62422e7..d4e0ba7 100644 --- a/tests/pipeline/test_pipeline.py +++ b/tests/pipeline/test_pipeline.py @@ -2,6 +2,7 @@ from docdeid.annotation import Annotation, AnnotationSet from docdeid.deidentifier import DocDeid +from docdeid.ds import LookupTrie from docdeid.process.annotator import ( MultiTokenLookupAnnotator, SingleTokenLookupAnnotator, @@ -49,11 +50,12 @@ def test_multipe_annotators(self, long_text): "name_annotator", SingleTokenLookupAnnotator(lookup_values=["Bob"], tag="name"), ) + loc_trie = LookupTrie() + loc_trie.add_item("the United States of America".split()) deidentifier.processors.add_processor( "location_annotator", MultiTokenLookupAnnotator( - lookup_values=["the United States of America"], - tokenizer=tokenizer, + trie=loc_trie, tag="location", ), ) @@ -86,11 +88,12 @@ def test_enabled(self, long_text): "name_annotator", SingleTokenLookupAnnotator(lookup_values=["Bob"], tag="name"), ) + loc_trie = LookupTrie() + loc_trie.add_item("the United States of America".split()) deidentifier.processors.add_processor( "location_annotator", MultiTokenLookupAnnotator( - lookup_values=["the United States of America"], - tokenizer=tokenizer, + trie=loc_trie, tag="location", ), ) @@ -124,11 +127,12 @@ def test_disabled(self, long_text): "name_annotator", SingleTokenLookupAnnotator(lookup_values=["Bob"], tag="name"), ) + loc_trie = LookupTrie() + loc_trie.add_item("the United States of America".split()) deidentifier.processors.add_processor( "location_annotator", MultiTokenLookupAnnotator( - lookup_values=["the United States of America"], - tokenizer=tokenizer, + trie=loc_trie, tag="location", ), ) diff --git a/tests/unit/process/test_annotator.py b/tests/unit/process/test_annotator.py index a71dc54..380e76e 100644 --- a/tests/unit/process/test_annotator.py +++ b/tests/unit/process/test_annotator.py @@ -1,18 +1,26 @@ import re +from collections import defaultdict from unittest.mock import patch +import pytest + import docdeid.ds from docdeid.annotation import Annotation +from docdeid.direction import Direction from docdeid.document import Document +from docdeid.ds import DsCollection, LookupSet, LookupTrie from docdeid.pattern import TokenPattern from docdeid.process.annotator import ( MultiTokenLookupAnnotator, RegexpAnnotator, + SequenceAnnotator, + SequencePattern, SingleTokenLookupAnnotator, TokenPatternAnnotator, + as_token_pattern, ) from docdeid.str.processor import LowercaseString -from docdeid.tokenizer import WordBoundaryTokenizer +from docdeid.tokenizer import SpaceSplitTokenizer, WordBoundaryTokenizer class TestSingleTokenLookupAnnotator: @@ -55,11 +63,10 @@ def test_single_token_with_matching_pipeline(self, long_text, long_tokenlist): class TestMultiTokenLookupAnnotator: def test_multi_token(self, long_text, long_tokenlist): doc = Document(long_text) - annotator = MultiTokenLookupAnnotator( - lookup_values=["my name", "my wife"], - tokenizer=WordBoundaryTokenizer(), - tag="prefix", - ) + my_trie = LookupTrie() + my_trie.add_item(("my", " ", "name")) + my_trie.add_item(("my", " ", "wife")) + annotator = MultiTokenLookupAnnotator(trie=my_trie, tag="prefix") expected_annotations = [ Annotation(text="my wife", start_char=39, end_char=46, tag="prefix"), ] @@ -73,12 +80,10 @@ def test_multi_token(self, long_text, long_tokenlist): def test_multi_token_with_matching_pipeline(self, long_text, long_tokenlist): doc = Document(long_text) - annotator = MultiTokenLookupAnnotator( - lookup_values=["my name", "my wife"], - tokenizer=WordBoundaryTokenizer(), - matching_pipeline=[LowercaseString()], - tag="prefix", - ) + my_trie = LookupTrie(matching_pipeline=[LowercaseString()]) + my_trie.add_item(("my", " ", "name")) + my_trie.add_item(("my", " ", "wife")) + annotator = MultiTokenLookupAnnotator(trie=my_trie, tag="prefix") expected_annotations = [ Annotation(text="My name", start_char=0, end_char=7, tag="prefix"), Annotation(text="my wife", start_char=39, end_char=46, tag="prefix"), @@ -93,9 +98,11 @@ def test_multi_token_lookup_with_overlap(self, long_text, long_tokenlist): doc = Document(long_text) + dr_trie = LookupTrie() + dr_trie.add_item(("dr", ". ", "John")) + dr_trie.add_item(("John", " ", "Smith")) annotator = MultiTokenLookupAnnotator( - lookup_values=["dr. John", "John Smith"], - tokenizer=WordBoundaryTokenizer(), + trie=dr_trie, tag="prefix", overlapping=True, ) @@ -114,9 +121,11 @@ def test_multi_token_lookup_no_overlap(self, long_text, long_tokenlist): doc = Document(long_text) + dr_trie = LookupTrie() + dr_trie.add_item(("dr", ". ", "John")) + dr_trie.add_item(("John", " ", "Smith")) annotator = MultiTokenLookupAnnotator( - lookup_values=["dr. John", "John Smith"], - tokenizer=WordBoundaryTokenizer(), + trie=dr_trie, tag="prefix", overlapping=False, ) @@ -137,7 +146,6 @@ def test_multi_token_lookup_with_trie(self, long_text, long_tokenlist): trie = docdeid.ds.LookupTrie(matching_pipeline=[LowercaseString()]) trie.add_item(["my", " ", "name"]) trie.add_item(["my", " ", "wife"]) - annotator = MultiTokenLookupAnnotator( trie=trie, tag="prefix", @@ -153,6 +161,34 @@ def test_multi_token_lookup_with_trie(self, long_text, long_tokenlist): assert annotations == expected_annotations + def test_trie_modified(self, long_text): + # The user of Deduce may want to amend the resources shipped with Deduce. + # Loading those happens in the Deduce initializer, which also constructs + # annotators according to the configuration. + + # Run the interesting portions of Deduce initialization. + doc = Document(long_text, tokenizers={"default": SpaceSplitTokenizer()}) + trie = docdeid.ds.LookupTrie() + # Yeah, the comma in "Smith," seems off... but then again, WordBoundaryTokenizer + # considers whitespace to be tokens. There is no good choice. + trie.add_item(("John", "Smith,")) + annotator = MultiTokenLookupAnnotator(trie=trie, tag="name") + + # Let's add our own resources. + trie.add_item(("jane", "Keith-Lucas")) + # ...including phrases with a potential to confuse the algorithm. + trie.add_item(("jane", "joplane")) + trie.add_item(("dr.", "John", "Hopkin")) + trie.add_item(("Smith,", "please")) + + # Expect also our phrases to be detected. + want = [ + Annotation(text="John Smith,", start_char=15, end_char=26, tag="name"), + Annotation(text="jane Keith-Lucas", start_char=47, end_char=63, tag="name"), + ] + got = annotator.annotate(doc) + assert got == want + class TestRegexpAnnotator: def test_regexp_annotator(self, long_text): @@ -263,3 +299,221 @@ def test_multi_pattern(self, long_text, long_tokens_linked, multi_pattern): annotations = annotator.annotate(doc) assert annotations == expected_annotations + + +class TestSequenceAnnotator: + @pytest.fixture + def ds(self): + ds = DsCollection() + + first_names = ["Andries", "pieter", "Aziz", "Bernard", "Won Jung"] + surnames = ["Meijer", "Smit", "Bakker", "Heerma"] + interfixes = ["v/d"] + interfixed_surnames = ["Heck"] + + ds["first_names"] = LookupSet() + ds["first_names"].add_items_from_iterable(items=first_names) + + ds["surnames"] = LookupSet() + ds["surnames"].add_items_from_iterable(items=surnames) + + ds["interfixes"] = LookupSet() + ds["interfixes"].add_items_from_iterable(items=interfixes) + + trie = LookupTrie() + for phrase in interfixed_surnames: + trie.add_item(phrase.split()) + ds["interfixed_surnames"] = trie + + return ds + + @pytest.fixture + def pattern_doc(self): + return Document( + text="De man heet Andries Meijer-Heerma, voornaam Andries.", + tokenizers={"default": WordBoundaryTokenizer(False)}, + ) + + @pytest.fixture + def interfixed_doc(self): + return Document( + text="De man heet v/d Heck.", + tokenizers={"default": WordBoundaryTokenizer(False)}, + ) + + @pytest.fixture + def korean_doc(self): + return Document( + text="De mevrouw heet Won Jung Meijer-Heerma.", + tokenizers={"default": WordBoundaryTokenizer(False)}, + ) + + def test_validation(self, ds): + with pytest.raises(ValueError) as exc_info: + SequenceAnnotator(pattern=[], ds=ds, tag="_") + assert "missing or empty" in str(exc_info) + + # Lookup structures are not required if there are no lookup token patterns. + SequenceAnnotator(pattern=[{"like_name": True}], tag="_") + assert True + + with pytest.raises(ValueError) as exc_info: + SequenceAnnotator(pattern=[{"lookup": "undefined_entity"}], tag="_") + assert "no lookup structures were provided" in str(exc_info) + + with pytest.raises(ValueError) as exc_info: + SequenceAnnotator(pattern=[{"lookup": "undefined_entity"}], + ds=ds, + tag="_") + assert "Unknown lookup entity types: undefined_entity." in str(exc_info) + + with pytest.raises(ValueError) as exc_info: + SequenceAnnotator( + pattern=[{"or": [{"lookup": "undefined_entity"}, + {"lookup": "another_entity"}]}], + ds=ds, + tag="_") + assert ("Unknown lookup entity types: another_entity, undefined_entity." + in str(exc_info)) + + # References to entities from metadata must not cause validation errors. + SequenceAnnotator(pattern=[{"or": [{"lookup": "patient.name"}, + {"lookup": "doctor.surname"}]}], + tag="_") + SequenceAnnotator(pattern=[{"or": [{"lookup": "interfixes"}, + {"lookup": "doctor.surname"}]}], + ds=ds, + tag="_") + assert True + + with pytest.raises(ValueError) as exc_info: + SequenceAnnotator( + pattern=[{"or": [{"lookup": "interfixes"}, + {"and": [{"lookup": "first_names"}, + {"lookup": "alien_entity"}]}]}, + {"lookup": "another_entity"}], + ds=ds, + tag="_") + assert ("Unknown lookup entity types: alien_entity, another_entity." + in str(exc_info)) + + with pytest.raises(ValueError) as exc_info: + SequenceAnnotator( + pattern=[{"lookup": "interfixed_surnames"}], + ds=ds, + tag="_") + assert ("is backed by a LookupTrie" in str(exc_info)) + + def test_match_sequence(self, pattern_doc, ds): + pattern = [{"lookup": "first_names"}, {"like_name": True}] + + tpa = SequenceAnnotator(pattern=pattern, ds=ds, tag="_") + + assert tpa._match_sequence( + pattern_doc, + tpa._seq_pattern, + start_token=pattern_doc.get_tokens()[3], + annos_by_token=defaultdict(list), + dicts=ds, + ) == Annotation(text="Andries Meijer", start_char=12, end_char=26, tag="_") + assert ( + tpa._match_sequence( + pattern_doc, + tpa._seq_pattern, + start_token=pattern_doc.get_tokens()[7], + annos_by_token=defaultdict(list), + dicts=ds, + ) + is None + ) + + def test_match_sequence_left(self, pattern_doc, ds): + """ + Matching is always performed in the direction left-to-right by + SequenceAnnotator proper but the same method is also called by + ContextAnnotator in Deduce, where matching may proceed also right-to-left. + """ + pattern = [{"lookup": "first_names"}, {"like_name": True}] + + tpa = SequenceAnnotator(pattern=pattern, ds=ds, tag="_") + + assert tpa._match_sequence( + pattern_doc, + SequencePattern( + Direction.LEFT, set(), [as_token_pattern(it) for it in pattern] + ), + start_token=pattern_doc.get_tokens()[4], + annos_by_token=defaultdict(list), + dicts=ds, + ) == Annotation(text="Andries Meijer", start_char=12, end_char=26, tag="_") + + assert ( + tpa._match_sequence( + pattern_doc, + SequencePattern( + Direction.LEFT, set(), [as_token_pattern(it) for it in pattern] + ), + start_token=pattern_doc.get_tokens()[8], + annos_by_token=defaultdict(list), + dicts=ds, + ) + is None + ) + + def test_match_sequence_skip(self, pattern_doc, ds): + pattern = [{"lookup": "surnames"}, {"like_name": True}] + + tpa = SequenceAnnotator(pattern=pattern, ds=ds, tag="_") + tpa_skipping = SequenceAnnotator(pattern=pattern, ds=ds, skip=["-"], tag="_") + + assert tpa_skipping._match_sequence( + pattern_doc, + tpa_skipping._seq_pattern, + start_token=pattern_doc.get_tokens()[4], + annos_by_token=defaultdict(list), + dicts=ds, + ) == Annotation(text="Meijer-Heerma", start_char=20, end_char=33, tag="_") + assert ( + tpa._match_sequence( + pattern_doc, + SequencePattern( + Direction.RIGHT, set(), [as_token_pattern(it) for it in pattern] + ), + start_token=pattern_doc.get_tokens()[4], + annos_by_token=defaultdict(list), + dicts=ds, + ) + is None + ) + + def test_annotate(self, pattern_doc, ds): + pattern = [{"lookup": "first_names"}, {"like_name": True}] + + tpa = SequenceAnnotator(pattern=pattern, ds=ds, tag="_") + + assert tpa.annotate(pattern_doc) == [ + Annotation(text="Andries Meijer", start_char=12, end_char=26, tag="_") + ] + + @pytest.mark.xfail(reason="The lookup token pattern only ever matches a single " + "token and the SequenceAnnotator docstring accordingly " + "rules the case of multiple tokens per pattern out of " + "scope. Yet, the packaged base_config.json seems to " + "rely on such multi-word matches, most notably in the " + "case of the interfix_with_name annotator.") + def test_annotate_multiword(self, interfixed_doc, korean_doc, ds): + inter_pattern = [{"lookup": "interfixes"}, {"lookup": "interfixed_surnames"}] + ipa = SequenceAnnotator(pattern=inter_pattern, ds=ds, + # tokenizer=WordBoundaryTokenizer(False), + tag="_") + assert ipa.annotate(interfixed_doc) == [ + Annotation(text="v/d Heck", start_char=12, end_char=20, tag="_") + ] + + pattern = [{"lookup": "first_names"}, {"like_name": True}] + kpa = SequenceAnnotator(pattern=pattern, ds=ds, + # tokenizer=WordBoundaryTokenizer(False), + tag="_") + assert kpa.annotate(korean_doc) == [ + Annotation(text="Won Jung Meijer", start_char=16, end_char=31, tag="_") + ] diff --git a/tests/unit/test_annotation.py b/tests/unit/test_annotation.py index fe4f785..072979f 100644 --- a/tests/unit/test_annotation.py +++ b/tests/unit/test_annotation.py @@ -1,8 +1,11 @@ +import re + import pytest from frozendict import frozendict +from docdeid import Document from docdeid.annotation import Annotation, AnnotationSet -from docdeid.tokenizer import Token +from docdeid.tokenizer import Token, Tokenizer, WordBoundaryTokenizer class TestAnnotation: @@ -157,3 +160,74 @@ def test_get_annotations_sorted_no_frozendict(self, annotations): _ = annotation_set.sorted( by=("priority", "length"), callbacks=dict(length=lambda x: -x) ) + + def test_annos_by_token(self, annotations): + doc = Document( + "1 2 3 1 2 3 hum Hello hum I'm Bob - said Cindy", + tokenizers={"default": WordBoundaryTokenizer(False)}, + ) + aset = AnnotationSet( + [ + a1 := Annotation("Hello", 16, 21, "word"), + a2 := Annotation("I", 26, 27, "ltr"), + a3 := Annotation("I'm", 26, 29, "words"), + a4 := Annotation("Bob", 30, 33, "name"), + a5 := Annotation("I'm Bob", 26, 33, "stmt"), + ] + ) + + # import pydevd_pycharm + # pydevd_pycharm.settrace() + + got = doc.annos_by_token(aset) + + want = { + Token("Hello", 16, 21): {a1}, + Token("I", 26, 27): {a2, a3, a5}, + Token("'", 27, 28): {a3, a5}, + Token("m", 28, 29): {a3, a5}, + Token("Bob", 30, 33): {a4, a5}, + } + + assert got == want + + def test_annos_by_token_2(self, annotations): + class HumTokenizer(Tokenizer): + """Extracts each "hum" word and the following word as a token.""" + + def _split_text(self, text: str) -> list[Token]: + return [ + Token(match.group(0), match.start(), match.end()) + for match in re.finditer("\\bhum\\s+\\w+", text) + ] + + doc = Document( + "1 2 3 1 2 3 hum Hello hum I'm Bob - said Cindy", + tokenizers={ + "default": WordBoundaryTokenizer(False), + "for_fun": HumTokenizer(), + }, + ) + aset = AnnotationSet( + [ + a1 := Annotation("Hello", 16, 21, "word"), + a2 := Annotation("I", 26, 27, "ltr"), + a3 := Annotation("I'm", 26, 29, "words"), + a4 := Annotation("Bob", 30, 33, "name"), + a5 := Annotation("I'm Bob", 26, 33, "stmt"), + ] + ) + + got = doc.annos_by_token(aset) + + want = { + Token("Hello", 16, 21): {a1}, + Token("I", 26, 27): {a2, a3, a5}, + Token("'", 27, 28): {a3, a5}, + Token("m", 28, 29): {a3, a5}, + Token("Bob", 30, 33): {a4, a5}, + Token("hum Hello", 12, 21): {a1}, + Token("hum I", 22, 27): {a2, a3, a5}, + } + + assert got == want diff --git a/tests/unit/test_direction.py b/tests/unit/test_direction.py new file mode 100644 index 0000000..07537b1 --- /dev/null +++ b/tests/unit/test_direction.py @@ -0,0 +1,30 @@ +import pytest + +from docdeid.direction import Direction + + +class TestDirection: + def test_basics(self): + assert Direction.LEFT != Direction.RIGHT + assert Direction.LEFT.opposite == Direction.RIGHT + assert Direction.RIGHT.opposite == Direction.LEFT + + def test_parsing(self): + assert Direction.from_string("left") == Direction.LEFT + assert Direction.from_string("Left") == Direction.LEFT + assert Direction.from_string("LEFT") == Direction.LEFT + assert Direction.from_string("right") == Direction.RIGHT + assert Direction.from_string("Right") == Direction.RIGHT + assert Direction.from_string("RIGHT") == Direction.RIGHT + + def test_parsing_failure(self): + with pytest.raises(ValueError, match="Invalid direction: 'down'"): + Direction.from_string("down") + with pytest.raises(ValueError, match="Invalid direction: ' left'"): + Direction.from_string(" left") + + def test_iteration(self): + assert list(Direction.RIGHT.iter([])) == [] + assert list(Direction.LEFT.iter([])) == [] + assert list(Direction.RIGHT.iter([1, 2, "three"])) == [1, 2, "three"] + assert list(Direction.LEFT.iter([1, 2, "three"])) == ["three", 2, 1] diff --git a/tests/unit/test_document.py b/tests/unit/test_document.py index 05e3e3c..e63bd94 100644 --- a/tests/unit/test_document.py +++ b/tests/unit/test_document.py @@ -84,6 +84,7 @@ def test_get_tokens_multiple_tokenizers(self, short_tokens): tokenizer1, "tokenize", return_value=short_tokens ), patch.object(tokenizer2, "_split_text", return_value=[]): + assert set(doc.tokenizers.keys()) == {"tokenizer_1", "tokenizer_2"} assert doc.get_tokens(tokenizer_name="tokenizer_1") == short_tokens assert doc.get_tokens(tokenizer_name="tokenizer_2") == TokenList([]) diff --git a/tests/unit/test_tokenizer.py b/tests/unit/test_tokenizer.py index ce118ce..463b1cd 100644 --- a/tests/unit/test_tokenizer.py +++ b/tests/unit/test_tokenizer.py @@ -215,3 +215,17 @@ def test_word_boundary_tokenizer(self): tokens = tokenizer._split_text(text) assert tokens == expected_tokens + + def test_trimming(self): + text = "Jane Keith-Lucas" + tokenizer = WordBoundaryTokenizer(keep_blanks=False) + expected_tokens = [ + Token(text="Jane", start_char=0, end_char=4), + Token(text="Keith", start_char=5, end_char=10), + Token(text="-", start_char=10, end_char=11), + Token(text="Lucas", start_char=11, end_char=16), + ] + + tokens = tokenizer._split_text(text) + + assert tokens == expected_tokens