diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 0000000..4added4 --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,36 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions + +name: Python test + +on: + pull_request: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9, 2.7] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python setup.py install + python -m pip install -r test-requirements.txt + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + # Code style checks disabled until the style of the project is set. + ./run_tests.sh lint + - name: Test with pytest + run: | + ./run_tests.sh test diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml new file mode 100644 index 0000000..95dfc04 --- /dev/null +++ b/.github/workflows/push.yml @@ -0,0 +1,100 @@ +# This workflow will install Python dependencies, run tests and lint with a variety of Python versions +# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions +# In addition it will tag a release if setup.py is updated with a new version +# and publish a release to pypi from the tag + +name: Python package + +on: + push: + branches: [ master ] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9, 2.7] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python setup.py install + python -m pip install -r test-requirements.txt + - name: Lint with flake8 + run: | + # stop the build if there are Python syntax errors or undefined names + # Code style checks disabled until the style of the project is set. + ./run_tests.sh lint + - name: Test with pytest + run: | + ./run_tests.sh test + + tag-release-if-needed: + runs-on: ubuntu-latest + outputs: + version: ${{ steps.tag.outputs.version }} + steps: + - uses: actions/checkout@v2 + - id: tag + name: Tag release + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + run: | + git remote add tag_target "https://$GITHUB_TOKEN@github.com/MycroftAI/adapt.git" + VERSION=$(python setup.py --version) + git tag -f release/v$VERSION || exit 0 + if git push tag_target --tags; then + echo "New tag published on github, push to PyPI as well." + pip install twine wheel + python setup.py sdist bdist_wheel + twine check dist/* + twine upload dist/* + echo "Package pushed to PyPI. Prepare for mycroft-core PR." + echo "::set-output name=version::$VERSION" + fi + + update-mycroft-core: + runs-on: ubuntu-latest + needs: tag-release-if-needed + steps: + - uses: actions/checkout@v2 + with: + repository: MycroftAI/mycroft-core + + - name: Update mycroft-core + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + run: | + VERSION=${{needs.tag-release-if-needed.outputs.version}} + if [[ $VERSION != *"."* ]]; then + echo "Not a valid version number." + exit 1 + elif [[ $VERSION == *"-"* ]]; then + echo "Pre-release suffix detected. Not pushing to mycroft-core." + else + sed -E "s/adapt-parser==[0-9]+\.[0-9]+\.[0-9]+/adapt-parser==$VERSION/" requirements/requirements.txt > tmp-requirements.txt + mv tmp-requirements.txt requirements/requirements.txt + echo "ADAPT_VERSION=$VERSION" >> $GITHUB_ENV + fi + + - name: Create Pull Request + if: ${{ env.ADAPT_VERSION }} + uses: peter-evans/create-pull-request@v3 + with: + token: ${{ secrets.BOT_TOKEN }} + push-to-fork: mycroft-adapt-bot/mycroft-core + commit-message: Update Adapt to v${{ env.ADAPT_VERSION }} + branch: feature/update-adapt + delete-branch: true + title: Update Adapt to v${{ env.ADAPT_VERSION }} + body: Automated update from mycroftai/adapt. diff --git a/.gitignore b/.gitignore index ab433dc..5e87f07 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ dist/ *.pyc TEST-*.xml .virtualenv +.envrc diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 754e423..0000000 --- a/.travis.yml +++ /dev/null @@ -1,10 +0,0 @@ -language: python -python: - - "3.5" - - "3.6" - - "3.7" - - "3.8" -install: - - pip install -r requirements.txt - - pip install -r test-requirements.txt -script: python run_tests.py --fail-on-error diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..58989f9 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,3 @@ +README.md +LICENSE.md +requirements.txt diff --git a/README.md b/README.md index b8836d6..b349894 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,43 @@ def register_pandora_vocab(emitter): for match in m.groups(): register_vocab('Pandora Station', match) ``` + +Development +=========== + +Glad you'd like to help! + +To install test and development requirements run + +``` +pip install -r test-requirements.txt +``` + +This will install the test-requirements as well as the runtime requirements for adapt. + +To test any changes before submitting them run + +``` +./run_tests.sh +``` + +This will run the same checks as the Github actions and verify that your code should pass with flying colours. + +Reporting Issues +================ +It's often difficult to debug issues with adapt without a complete context. To facilitate simpler debugging, +please include a serialized copy of the intent determination engine using the debug dump +utilities. + +```python +from adapt.engine import IntentDeterminationEngine +engine = IntentDeterminationEngine() +# Load engine with vocabulary and parsers + +import adapt.tools.debug as atd +atd.dump(engine, 'debug.adapt') +``` + Learn More ======== diff --git a/adapt/context.py b/adapt/context.py index 8f1a8f2..e7d7b06 100644 --- a/adapt/context.py +++ b/adapt/context.py @@ -14,11 +14,7 @@ # """ -This is to Manage Context of a Conversation - -Notes: - Comments are subject to evaluation and may not reflect intent. - Comments should be updated as code is clearly understood. +Context Management code for Adapt (where context ~= persistent session state). """ from six.moves import xrange @@ -34,7 +30,7 @@ class ContextManagerFrame(object): entities(list): Entities that belong to ContextManagerFrame metadata(object): metadata to describe context belonging to ContextManagerFrame """ - def __init__(self, entities=[], metadata={}): + def __init__(self, entities=None, metadata=None): """ Initialize ContextManagerFrame @@ -42,17 +38,15 @@ def __init__(self, entities=[], metadata={}): entities(list): List of Entities... metadata(object): metadata to describe context? """ - self.entities = entities - self.metadata = metadata + self.entities = entities or [] + self.metadata = metadata or {} - def metadata_matches(self, query={}): + def metadata_matches(self, query=None): """ Returns key matches to metadata - This will check every key in query for a matching key in metadata - returning true if every key is in metadata. query without keys - return false. - + Asserts that the contents of query exist within (logical subset of) + metadata in this frame. Args: query(object): metadata for matching @@ -64,6 +58,7 @@ def metadata_matches(self, query={}): found in self.metadata """ + query = query or {} result = len(query.keys()) > 0 for key in query.keys(): result = result and query[key] == self.metadata.get(key) @@ -82,9 +77,9 @@ def merge_context(self, tag, metadata): metadata(object): metadata containes keys to be added to self.metadata """ self.entities.append(tag) - for k in metadata.keys(): + for k, v in metadata.items(): if k not in self.metadata: - self.metadata[k] = k + self.metadata[k] = v class ContextManager(object): @@ -96,8 +91,11 @@ class ContextManager(object): def __init__(self): self.frame_stack = [] - def inject_context(self, entity, metadata={}): + def inject_context(self, entity, metadata=None): """ + Add an entity to the current context. + If metadata matches the top of the context frame stack, merge. + Else, create a new frame and push it on top of the stack. Args: entity(object): format {'data': 'Entity tag as ', @@ -106,6 +104,7 @@ def inject_context(self, entity, metadata={}): } metadata(object): dict, arbitrary metadata about the entity being added """ + metadata = metadata or {} top_frame = self.frame_stack[0] if len(self.frame_stack) > 0 else None if top_frame and top_frame.metadata_matches(metadata): top_frame.merge_context(entity, metadata) @@ -113,9 +112,9 @@ def inject_context(self, entity, metadata={}): frame = ContextManagerFrame(entities=[entity], metadata=metadata.copy()) self.frame_stack.insert(0, frame) - def get_context(self, max_frames=None, missing_entities=[]): + def get_context(self, max_frames=None, missing_entities=None): """ - Constructs a list of entities from the context. + Returns context, including decaying weights based on depth in stack. Args: max_frames(int): maximum number of frames to look back @@ -124,6 +123,7 @@ def get_context(self, max_frames=None, missing_entities=[]): Returns: list: a list of entities """ + missing_entities = missing_entities or [] if not max_frames or max_frames > len(self.frame_stack): max_frames = len(self.frame_stack) diff --git a/adapt/engine.py b/adapt/engine.py index b6b669b..b74e3a5 100644 --- a/adapt/engine.py +++ b/adapt/engine.py @@ -15,7 +15,6 @@ import re import heapq -import pyee from adapt.entity_tagger import EntityTagger from adapt.parser import Parser from adapt.tools.text.tokenizer import EnglishTokenizer @@ -24,7 +23,7 @@ __author__ = 'seanfitz' -class IntentDeterminationEngine(pyee.BaseEventEmitter): +class IntentDeterminationEngine(object): """ IntentDeterminationEngine @@ -34,7 +33,7 @@ class IntentDeterminationEngine(pyee.BaseEventEmitter): weighted based on the percentage of the utterance (per character) that the entity match represents. This system makes heavy use of generators to enable greedy algorithms to short circuit large portions of - computation. + computation, however making use of context or regular expressions prevents these optimizations. """ def __init__(self, tokenizer=None, trie=None): """ @@ -45,17 +44,16 @@ def __init__(self, tokenizer=None, trie=None): example EnglishTokenizer() trie(Trie): tree of matches to Entites """ - pyee.BaseEventEmitter.__init__(self) self.tokenizer = tokenizer or EnglishTokenizer() self.trie = trie or Trie() self.regular_expressions_entities = [] self._regex_strings = set() - self.tagger = EntityTagger(self.trie, self.tokenizer, self.regular_expressions_entities) self.intent_parsers = [] def __best_intent(self, parse_result, context=[]): """ - Decide the best intent + For the specified parse_result, find the intent parser with the + highest confidence match. Args: parse_result(list): results used to match the best intent. @@ -102,13 +100,18 @@ def __get_unused_context(self, parse_result, context): result_context = [c for c in context if c['key'] not in tags_keys] return result_context + @property + def tagger(self): + return EntityTagger(self.trie, self.tokenizer, + self.regular_expressions_entities) + def determine_intent(self, utterance, num_results=1, include_tags=False, context_manager=None): """ Given an utterance, provide a valid intent. Args: utterance(str): an ascii or unicode string representing natural language speech - include_tags(list): includes the parsed tags (including position and confidence) + include_tags(bool): includes the parsed tags (including position and confidence) as part of result context_manager(list): a context manager to provide context to the utterance num_results(int): a maximum number of results to be returned. @@ -116,23 +119,36 @@ def determine_intent(self, utterance, num_results=1, include_tags=False, context Returns: A generator that yields dictionaries. """ parser = Parser(self.tokenizer, self.tagger) - parser.on('tagged_entities', - (lambda result: - self.emit("tagged_entities", result))) context = [] if context_manager: context = context_manager.get_context() - for result in parser.parse(utterance, N=num_results, context=context): - self.emit("parse_result", result) - # create a context without entities used in result - remaining_context = self.__get_unused_context(result, context) - best_intent, tags = self.__best_intent(result, remaining_context) - if best_intent and best_intent.get('confidence', 0.0) > 0: - if include_tags: - best_intent['__tags__'] = tags - yield best_intent + # Adapt consumers assume that results are sorted by confidence. parser + # will yield results sorted by utterance coverage, but regex + # and context entities will have different weights, and + # can influence final sorting. + requires_final_sort = self.regular_expressions_entities or context + + def generate_intents(): + for result in parser.parse(utterance, N=num_results, context=context): + # create a context without entities used in result + remaining_context = self.__get_unused_context(result, context) + best_intent, tags = self.__best_intent(result, remaining_context) + if best_intent and best_intent.get('confidence', 0.0) > 0: + if include_tags: + best_intent['__tags__'] = tags + yield best_intent + + if requires_final_sort: + sorted_iterable = sorted([ + i for i in generate_intents() + ], key=lambda x: -x.get('confidence', 0.0)) + else: + sorted_iterable = generate_intents() + + for intent in sorted_iterable: + yield intent def register_entity(self, entity_value, entity_type, alias_of=None): """ @@ -175,6 +191,75 @@ def register_intent_parser(self, intent_parser): else: raise ValueError("%s is not an intent parser" % str(intent_parser)) + def drop_intent_parser(self, parser_names): + """Drop a registered intent parser. + + Arguments: + parser_names (str or iterable): parser name to drop or list of + names + + Returns: + (bool) True if a parser was dropped else False + """ + if isinstance(parser_names, str): + parser_names = [parser_names] + + new_parsers = [p for p in self.intent_parsers + if p.name not in parser_names] + num_original_parsers = len(self.intent_parsers) + self.intent_parsers = new_parsers + + return len(self.intent_parsers) != num_original_parsers + + def drop_entity(self, entity_type=None, match_func=None): + """Drop all entities mathching the given entity type or match function + + Arguments: + entity_type (str): entity name to match against + match_func (callable): match function to find entities + + Returns: + (bool) True if vocab was found and removed otherwise False. + """ + def default_match_func(data): + return data and data[1] == entity_type + + ent_tuples = self.trie.scan(match_func or default_match_func) + for entity in ent_tuples: + self.trie.remove(*entity) + + return len(ent_tuples) != 0 + + def drop_regex_entity(self, entity_type=None, match_func=None): + """Remove registered regex entity. + + Arguments: + entity_type (str): entity name to match against + match_func (callable): match function to find entities + + Returns: + (bool) True if vocab was found and removed otherwise False. + """ + def default_match_func(regexp): + return entity_type in regexp.groupindex.keys() + + match_func = match_func or default_match_func + matches = [r for r in self.regular_expressions_entities + if match_func(r)] + matching_patterns = [r.pattern for r in matches] + + matches = [ + r for r in self.regular_expressions_entities if r in matches + ] + for match in matches: + self.regular_expressions_entities.remove(match) + + self._regex_strings = { + r for r in self._regex_strings if r not in matching_patterns + } + + return len(matches) != 0 + class DomainIntentDeterminationEngine(object): """ @@ -194,11 +279,6 @@ class DomainIntentDeterminationEngine(object): def __init__(self): """ Initialize DomainIntentDeterminationEngine. - - Args: - tokenizer(tokenizer): The tokenizer you wish to use. - trie(Trie): the Trie() you wish to use. - domain(str): a string representing the domain you wish to add """ self.domains = {} @@ -368,3 +448,43 @@ def register_intent_parser(self, intent_parser, domain=0): self.register_domain(domain=domain) self.domains[domain].register_intent_parser( intent_parser=intent_parser) + + def drop_intent_parser(self, parser_names, domain): + """Drop a registered intent parser. + + Arguments: + parser_names (list, str): parser names to drop. + domain (str): domain to drop from + + Returns: + (bool) True if an intent parser was dropped else false. + """ + return self.domains[domain].drop_intent_parser(parser_names) + + def drop_entity(self, domain, entity_type=None, match_func=None): + """Drop all entities mathching the given entity type or match function. + + Arguments: + domain (str): intent domain + entity_type (str): entity name to match against + match_func (callable): match function to find entities + + Returns: + (bool) True if vocab was found and removed otherwise False. + """ + return self.domains[domain].drop_entity(entity_type=entity_type, + match_func=match_func) + + def drop_regex_entity(self, domain, entity_type=None, match_func=None): + """Remove registered regex entity. + + Arguments: + domain (str): intent domain + entity_type (str): entity name to match against + match_func (callable): match function to find entities + + Returns: + (bool) True if vocab was found and removed otherwise False. + """ + return self.domains[domain].drop_regex_entity(entity_type=entity_type, + match_func=match_func) diff --git a/adapt/intent.py b/adapt/intent.py index ae1a1d0..8abb662 100644 --- a/adapt/intent.py +++ b/adapt/intent.py @@ -15,6 +15,8 @@ __author__ = 'seanfitz' +import itertools + CLIENT_ENTITY_NAME = 'Client' @@ -30,21 +32,24 @@ def find_first_tag(tags, entity_type, after_index=-1): """Searches tags for entity type after given index Args: - tags(list): a list of tags with entity types to be compaired too entity_type + tags(list): a list of tags with entity types to be compared to + entity_type entity_type(str): This is he entity type to be looking for in tags - after_index(int): the start token must be greaterthan this. + after_index(int): the start token must be greater than this. Returns: ( tag, v, confidence ): tag(str): is the tag that matched v(str): ? the word that matched? - confidence(float): is a mesure of accuacy. 1 is full confidence and 0 is none. + confidence(float): is a measure of accuracy. 1 is full confidence + and 0 is none. """ for tag in tags: for entity in tag.get('entities'): for v, t in entity.get('data'): if t.lower() == entity_type.lower() and \ - (tag.get('start_token', 0) > after_index or tag.get('from_context', False)): + (tag.get('start_token', 0) > after_index or \ + tag.get('from_context', False)): return tag, v, entity.get('confidence') return None, None, None @@ -58,38 +63,37 @@ def find_next_tag(tags, end_index=0): def choose_1_from_each(lists): - """Takes a list of lists and returns a list of lists with one item - from each list. This new list should be the length of each list multiplied - by the others. 18 for an list with lists of 3, 2 and 3. Also the lenght - of each sub list should be same as the length of lists passed in. + """ + The original implementation here was functionally equivalent to + :func:`~itertools.product`, except that the former returns a generator + of lists, and itertools returns a generator of tuples. This is going to do + a light transform for now, until callers can be verified to work with + tuples. Args: - lists(list of Lists): A list of lists + A list of lists or tuples, expected as input to + :func:`~itertools.product` Returns: - list of lists: returns a list of lists constructions of one item from each - list in lists. + a generator of lists, see docs on :func:`~itertools.product` """ - if len(lists) == 0: - yield [] - else: - for el in lists[0]: - for next_list in choose_1_from_each(lists[1:]): - yield [el] + next_list + for result in itertools.product(*lists): + yield list(result) def resolve_one_of(tags, at_least_one): - """This searches tags for Entities in at_least_one and returns any match + """Search through all combinations of at_least_one rules to find a + combination that is covered by tags Args: tags(list): List of tags with Entities to search for Entities at_least_one(list): List of Entities to find in tags Returns: - object: returns None if no match is found but returns any match as an object + object: + returns None if no match is found but returns any match as an object """ - if len(tags) < len(at_least_one): - return None + for possible_resolution in choose_1_from_each(at_least_one): resolution = {} pr = possible_resolution[:] @@ -97,13 +101,15 @@ def resolve_one_of(tags, at_least_one): last_end_index = -1 if entity_type in resolution: last_end_index = resolution[entity_type][-1].get('end_token') - tag, value, c = find_first_tag(tags, entity_type, after_index=last_end_index) + tag, value, c = find_first_tag(tags, entity_type, + after_index=last_end_index) if not tag: break else: if entity_type not in resolution: resolution[entity_type] = [] resolution[entity_type].append(tag) + # Check if this is a valid resolution (all one_of rules matched) if len(resolution) == len(possible_resolution): return resolution @@ -129,21 +135,24 @@ def validate(self, tags, confidence): """Using this method removes tags from the result of validate_with_tags Returns: - intent(intent): Resuts from validate_with_tags + intent(intent): Results from validate_with_tags """ intent, tags = self.validate_with_tags(tags, confidence) return intent def validate_with_tags(self, tags, confidence): - """Validate weather tags has required entites for this intent to fire + """Validate whether tags has required entites for this intent to fire Args: tags(list): Tags and Entities used for validation - confidence(float): ? + confidence(float): The weight associate to the parse result, + as indicated by the parser. This is influenced by a parser + that uses edit distance or context. Returns: intent, tags: Returns intent and tags used by the intent on - falure to meat required entities then returns intent with confidence + failure to meat required entities then returns intent with + confidence of 0.0 and an empty list for tags. """ result = {'intent_type': self.name} @@ -152,7 +161,8 @@ def validate_with_tags(self, tags, confidence): used_tags = [] for require_type, attribute_name in self.requires: - required_tag, canonical_form, confidence = find_first_tag(local_tags, require_type) + required_tag, canonical_form, tag_confidence = \ + find_first_tag(local_tags, require_type) if not required_tag: result['confidence'] = 0.0 return result, [] @@ -161,35 +171,40 @@ def validate_with_tags(self, tags, confidence): if required_tag in local_tags: local_tags.remove(required_tag) used_tags.append(required_tag) - # TODO: use confidence based on edit distance and context - intent_confidence += confidence + intent_confidence += tag_confidence if len(self.at_least_one) > 0: - best_resolution = resolve_one_of(tags, self.at_least_one) + best_resolution = resolve_one_of(local_tags, self.at_least_one) if not best_resolution: result['confidence'] = 0.0 return result, [] else: for key in best_resolution: - result[key] = best_resolution[key][0].get('key') # TODO: at least one must support aliases - intent_confidence += 1.0 * best_resolution[key][0]['entities'][0].get('confidence', 1.0) - used_tags.append(best_resolution) + # TODO: at least one should support aliases + result[key] = best_resolution[key][0].get('key') + intent_confidence += \ + 1.0 * best_resolution[key][0]['entities'][0]\ + .get('confidence', 1.0) + used_tags.append(best_resolution[key][0]) if best_resolution in local_tags: - local_tags.remove(best_resolution) + local_tags.remove(best_resolution[key][0]) for optional_type, attribute_name in self.optional: - optional_tag, canonical_form, conf = find_first_tag(local_tags, optional_type) + optional_tag, canonical_form, tag_confidence = \ + find_first_tag(local_tags, optional_type) if not optional_tag or attribute_name in result: continue result[attribute_name] = canonical_form if optional_tag in local_tags: local_tags.remove(optional_tag) used_tags.append(optional_tag) - intent_confidence += 1.0 + intent_confidence += tag_confidence - total_confidence = intent_confidence / len(tags) * confidence + total_confidence = (intent_confidence / len(tags) * confidence) \ + if tags else 0.0 - target_client, canonical_form, confidence = find_first_tag(local_tags, CLIENT_ENTITY_NAME) + target_client, canonical_form, confidence = \ + find_first_tag(local_tags, CLIENT_ENTITY_NAME) result['target'] = target_client.get('key') if target_client else None result['confidence'] = total_confidence @@ -203,7 +218,7 @@ class IntentBuilder(object): Attributes: at_least_one(list): A list of Entities where one is required. - These are seperated into lists so you can have one of (A or B) and + These are separated into lists so you can have one of (A or B) and then require one of (D or F). requires(list): A list of Required Entities optional(list): A list of optional Entities @@ -213,14 +228,18 @@ class IntentBuilder(object): This is designed to allow construction of intents in one line. Example: - IntentBuilder("Intent").requires("A").one_of("C","D").optional("G").build() + IntentBuilder("Intent")\ + .requires("A")\ + .one_of("C","D")\ + .optional("G").build() """ def __init__(self, intent_name): """ Constructor Args: - intent_name(str): the name of the intents that this parser parses/validates + intent_name(str): the name of the intents that this parser + parses/validates """ self.at_least_one = [] self.requires = [] @@ -229,7 +248,8 @@ def __init__(self, intent_name): def one_of(self, *args): """ - The intent parser should require one of the provided entity types to validate this clause. + The intent parser should require one of the provided entity types to + validate this clause. Args: args(args): *args notation list of entity names @@ -246,7 +266,8 @@ def require(self, entity_type, attribute_name=None): Args: entity_type(str): an entity type - attribute_name(str): the name of the attribute on the parsed intent. Defaults to match entity_type. + attribute_name(str): the name of the attribute on the parsed intent. + Defaults to match entity_type. Returns: self: to continue modifications. @@ -258,11 +279,13 @@ def require(self, entity_type, attribute_name=None): def optionally(self, entity_type, attribute_name=None): """ - Parsed intents from this parser can optionally include an entity of the provided type. + Parsed intents from this parser can optionally include an entity of the + provided type. Args: entity_type(str): an entity type - attribute_name(str): the name of the attribute on the parsed intent. Defaults to match entity_type. + attribute_name(str): the name of the attribute on the parsed intent. + Defaults to match entity_type. Returns: self: to continue modifications. @@ -278,4 +301,5 @@ def build(self): :return: an Intent instance. """ - return Intent(self.name, self.requires, self.at_least_one, self.optional) + return Intent(self.name, self.requires, + self.at_least_one, self.optional) diff --git a/adapt/parser.py b/adapt/parser.py index 7ae2b30..1442caf 100644 --- a/adapt/parser.py +++ b/adapt/parser.py @@ -13,7 +13,6 @@ # limitations under the License. # -import pyee import time from adapt.expander import BronKerboschExpander from adapt.tools.text.trie import Trie @@ -21,12 +20,11 @@ __author__ = 'seanfitz' -class Parser(pyee.BaseEventEmitter): +class Parser(object): """ Coordinate a tagger and expander to yield valid parse results. """ def __init__(self, tokenizer, tagger): - pyee.BaseEventEmitter.__init__(self) self._tokenizer = tokenizer self._tagger = tagger @@ -45,7 +43,6 @@ def parse(self, utterance, context=None, N=1): utterance. This might be used to determan the most likely intent. """ - start = time.time() context_trie = None if context and isinstance(context, list): # sort by confidence in ascending order, so @@ -61,12 +58,6 @@ def parse(self, utterance, context=None, N=1): weight=entity.get('confidence')) tagged = self._tagger.tag(utterance.lower(), context_trie=context_trie) - self.emit("tagged_entities", - { - 'utterance': utterance, - 'tags': list(tagged), - 'time': time.time() - start - }) start = time.time() bke = BronKerboschExpander(self._tokenizer) diff --git a/adapt/tools/debug/__init__.py b/adapt/tools/debug/__init__.py new file mode 100644 index 0000000..9aee82d --- /dev/null +++ b/adapt/tools/debug/__init__.py @@ -0,0 +1,70 @@ +import pickle + +from adapt.engine import DomainIntentDeterminationEngine, \ + IntentDeterminationEngine + +""" +Pickling is not inherently secure, and the documentation for pickle +recommends never unpickling data from an untrusted source. This makes +using it for debug reports a little dicey, but fortunately there are +some things we can leverage to make things safer. + +First, we expect the deserialized object to be an instance of +IntentDeterminationEngine or DomainIntentDeterminationEngine as a basic +sanity check. We also leverage a custom `pickle.Unpickler` implementation +that only allows specific imports from the adapt namespace, and displays +a helpful error message if this assumption is validated. + +This is a bit of security theater; folks investigating issues have to know to +specifically use this library to hydrate any submissions, otherwise it's all +for naught. +""" + + +EXPECTED_ENGINES = set([ + IntentDeterminationEngine, + DomainIntentDeterminationEngine, +]) + +SAFE_CLASSES = [ + ("adapt.engine", "IntentDeterminationEngine"), + ("adapt.engine", "DomainIntentDeterminationEngine"), + ("adapt.tools.text.tokenizer", "EnglishTokenizer"), + ("adapt.tools.text.trie", "Trie"), + ("adapt.tools.text.trie", "TrieNode"), + ("adapt.intent", "Intent") +] + + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + if (module, name) not in SAFE_CLASSES: + raise pickle.UnpicklingError("Attempted illegal import: " + "{}.{}".format(module, name)) + return pickle.Unpickler.find_class(self, module, name) + + +def load(filename): + """ + Load a file that contains a serialized intent determination engine. + :param filename (str): source path + :return: An instance of IntentDeterminationEngine or + DomainIntentDeterminationEngine + """ + with open(filename, 'rb') as f: + engine = RestrictedUnpickler(f).load() + if engine.__class__ not in EXPECTED_ENGINES: + raise ValueError("Was expecting to instantiate an " + "IntentDeterminationEngine, but instead found " + "{}".format(engine.__class__)) + return engine + + +def dump(engine, filename): + """ + Serialize an adapt Intent engine and write it to the target file. + :param engine (IntentDeterminationEngine or DomainIntentDeterminationEngine): + :param filename (str): destination path + """ + with open(filename, 'wb') as f: + pickle.dump(engine, f) diff --git a/adapt/tools/text/tokenizer.py b/adapt/tools/text/tokenizer.py index 02c2f5a..64bf5c6 100644 --- a/adapt/tools/text/tokenizer.py +++ b/adapt/tools/text/tokenizer.py @@ -36,7 +36,7 @@ def __init__(self): pass def tokenize(self, string): - """Used to parce a string into tokens + """Used to parse a string into tokens This function is to take in a string and return a list of tokens @@ -54,18 +54,18 @@ def tokenize(self, string): done with a contraction dictionary and some preprocessing. """ s = string - s = re.sub('\t', " ", s) - s = re.sub("(" + regex_separator + ")", " \g<1> ", s) - s = re.sub("([^0-9]),", "\g<1> , ", s) - s = re.sub(",([^0-9])", " , \g<1>", s) - s = re.sub("^(')", "\g<1> ", s) - s = re.sub("(" + regex_not_letter_number + ")'", "\g<1> '", s) - s = re.sub("(" + regex_clitics + ")$", " \g<1>", s) - s = re.sub("(" + regex_clitics + ")(" + regex_not_letter_number + ")", " \g<1> \g<2>", s) + s = re.sub(r'\t', " ", s) + s = re.sub(r"(" + regex_separator + ")", r" \g<1> ", s) + s = re.sub(r"([^0-9]),", r"\g<1> , ", s) + s = re.sub(r",([^0-9])", r" , \g<1>", s) + s = re.sub(r"^(')", r"\g<1> ", s) + s = re.sub(r"(" + regex_not_letter_number + r")'", r"\g<1> '", s) + s = re.sub(r"(" + regex_clitics + r")$", r" \g<1>", s) + s = re.sub(r"(" + regex_clitics + r")(" + regex_not_letter_number + r")", r" \g<1> \g<2>", s) words = s.strip().split() - p1 = re.compile(".*" + regex_letter_number + "\\.") - p2 = re.compile("^([A-Za-z]\\.([A-Za-z]\\.)+|[A-Z][bcdfghj-nptvxz]+\\.)$") + p1 = re.compile(r".*" + regex_letter_number + r"\.") + p2 = re.compile(r"^([A-Za-z]\.([A-Za-z]\.)+|[A-Z][bcdfghj-nptvxz]+\.)$") token_list = [] diff --git a/adapt/tools/text/trie.py b/adapt/tools/text/trie.py index c4cc087..147931c 100644 --- a/adapt/tools/text/trie.py +++ b/adapt/tools/text/trie.py @@ -30,18 +30,29 @@ def __init__(self, data=None, is_terminal=False): def lookup(self, iterable, index=0, gather=False, edit_distance=0, max_edit_distance=0, match_threshold=0.0, matched_length=0): """ - TODO: Implement trie lookup with edit distance - Args: - iterable(list?): key used to find what is requested this could - be a generator. - index(int): index of what is requested - gather(bool): of weather to gather or not - edit_distance(int): the distance -- currently not used - max_edit_distance(int): the max distance -- not currently used + iterable(hashable): a list of items used to traverse the Trie. + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] + + index(int): index of item for current position in traversal. + we pass the original iterable and an index to avoid + the cost of repeatedly copying the original iterable + gather(bool): whether to return intermediate results (gather + algorithm) + edit_distance(int): current edit distance in the traversal. + max_edit_distance(int): maximum edit distance + match_threshold(float): minimum confidence of match for discovery + matched_length(int): related to edit distance, for calculating + confidence of match where + confidence = (length - abs(matched_length - length)) / length yields: - object: yields the results of the search + generator[TrieNode]: a generator that vends the results of the + lookup, of type TrieNode """ if self.is_terminal: if index == len(iterable) or \ @@ -61,7 +72,8 @@ def lookup(self, iterable, index=0, gather=False, edit_distance=0, max_edit_dist edit_distance=edit_distance, max_edit_distance=max_edit_distance, matched_length=matched_length + 1): yield result - # if there's edit distance remaining and it's possible to match a word above the confidence threshold + # if there's edit distance remaining and it's possible to + # match a word above the confidence threshold, continue searching potential_confidence = float(index - edit_distance + (max_edit_distance - edit_distance)) / \ (float(index) + (max_edit_distance - edit_distance)) if index + max_edit_distance - edit_distance > 0 else 0.0 if edit_distance < max_edit_distance and potential_confidence > match_threshold: @@ -87,7 +99,12 @@ def insert(self, iterable, index=0, data=None, weight=1.0): """Insert new node into tree Args: - iterable(hashable): key used to find in the future. + iterable(hashable): a list of items used to traverse the Trie. + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] data(object): data associated with the key index(int): an index used for insertion. weight(float): the wait given for the item added. @@ -104,7 +121,9 @@ def insert(self, iterable, index=0, data=None, weight=1.0): self.children[iterable[index]].insert(iterable, index + 1, data) def is_prefix(self, iterable, index=0): - if iterable[index] in self.children: + if index == len(iterable): + return True + elif iterable[index] in self.children: return self.children[iterable[index]].is_prefix(iterable, index + 1) else: return False @@ -113,7 +132,12 @@ def remove(self, iterable, data=None, index=0): """Remove an element from the trie Args - iterable(hashable): key used to find what is to be removed + iterable(hashable): a list of items used to traverse the Trie. + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] data(object): data associated with the key index(int): index of what is to me removed @@ -141,12 +165,22 @@ def remove(self, iterable, data=None, index=0): class Trie(object): - """Interface for the tree + """Recursive implementation of a prefix trie (Trie) + https://en.wikipedia.org/wiki/Trie + Additionally supports #gather, a traversal whose results include + any terminal nodes visited. Attributes: root(TrieNode): parent node to start the tree - max_edit_distance(int): ? - match_threshold(int): ? + max_edit_distance(int): values > 0 allow for fuzzy matching + with a maximum levenshtein edit distance + https://en.wikipedia.org/wiki/Edit_distance + match_threshold(float): only return values with a higher confidence + than this value + + While most frequently used with strings, the Trie can be populated with any + iterable (arrays of ints, arrays of objects, arrays of strings) as long + as each value responds to `__hash__`. """ @@ -157,19 +191,29 @@ def __init__(self, max_edit_distance=0, match_threshold=0.0): max_edit_distance and match_threshold. Args: - max_edit_distance(int): ? - match_threshold(int): ? - - Notes: - This never seems to get called with max_edit_distance or match_threshold + max_edit_distance(int): values > 0 allow for fuzzy matching + with a maximum levenshtein edit distance + https://en.wikipedia.org/wiki/Edit_distance + match_threshold(float): only return values with a higher confidence + than this value """ self.root = TrieNode('root') self.max_edit_distance = max_edit_distance self.match_threshold = match_threshold def gather(self, iterable): - """Calls the lookup with gather True Passing iterable and yields - the result. + """Executes a "gather" traversal of the Trie + Result set will include any `is_terminal` nodes encountered during + the traversal + + Args: + iterable(hashable): a list of items used to traverse the Trie + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] + """ for result in self.lookup(iterable, gather=True): yield result @@ -177,12 +221,16 @@ def gather(self, iterable): def lookup(self, iterable, gather=False): """Call the lookup on the root node with the given parameters. - Args - iterable(index or key): Used to retrive nodes from tree - gather(bool): this is passed down to the root node lookup + Args: + iterable(hashable): a list of items used to traverse the Trie + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] + gather(bool): flag to indicate whether gather results + should be included - Notes: - max_edit_distance and match_threshold come from the init """ for result in self.root.lookup(iterable, gather=gather, @@ -192,20 +240,63 @@ def lookup(self, iterable, gather=False): yield result def insert(self, iterable, data=None, weight=1.0): - """Used to insert into he root node + """Used to insert into the trie - Args - iterable(hashable): index or key used to identify - data(object): data to be paired with the key + Args: + iterable(hashable): a list of items used to traverse the Trie + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] + data(object): data to stored or merged for this iterable """ - self.root.insert(iterable, index=0, data=data, weight=1.0) + self.root.insert(iterable, index=0, data=data, weight=weight) def remove(self, iterable, data=None): """Used to remove from the root node Args: - iterable(hashable): index or key used to identify - item to remove - data: data to be paired with the key + iterable(hashable): a list of items used to traverse the Trie + This represents the position of a node in the Trie, matching the + iterable used at insertion time. + For example: + trie.insert('foo', {'bar': 'baz'}) + list(trie.lookup('foo')) == [TrieNode(data={'bar': 'baz'}, is_terminal=True)] + data: data to removed. If None, or node is empty as a result, + remove the node. """ return self.root.remove(iterable, data=data) + + def scan(self, match_func): + """Traverse the trie scanning for end nodes with matching data. + + Args: + match_func (callable): function used to match data. + + Returns: + (list) list with matching (data, value) pairs. + """ + def _traverse(node, match_func, current=''): + """Traverse Trie searching for nodes with matching data + + Performs recursive depth first search of Trie and collects + value / data pairs matched by the match_func + + Arguments: + node (trie node): Node to parse + match_func (callable): Function performing match + current (str): string "position" in Trie + + Returns: + (list) list with matching (data, value) pairs. + """ + # Check if node matches + result = [(current, d) for d in node.data if match_func(d)] + + # Traverse further down into the tree + for c in node.children: + result += _traverse(node.children[c], match_func, current + c) + return result + + return _traverse(self.root, match_func) diff --git a/examples/multi_domain_intent_parser.py b/examples/multi_domain_intent_parser.py index 1a5bddd..8a442e7 100644 --- a/examples/multi_domain_intent_parser.py +++ b/examples/multi_domain_intent_parser.py @@ -9,19 +9,10 @@ import json import sys -from adapt.entity_tagger import EntityTagger -from adapt.tools.text.tokenizer import EnglishTokenizer -from adapt.tools.text.trie import Trie from adapt.intent import IntentBuilder -from adapt.parser import Parser from adapt.engine import DomainIntentDeterminationEngine -tokenizer = EnglishTokenizer() -trie = Trie() -tagger = EntityTagger(trie, tokenizer) -parser = Parser(tokenizer, tagger) - engine = DomainIntentDeterminationEngine() engine.register_domain('Domain1') diff --git a/examples/multi_intent_parser.py b/examples/multi_intent_parser.py index e5c8d4a..6416098 100644 --- a/examples/multi_intent_parser.py +++ b/examples/multi_intent_parser.py @@ -10,18 +10,9 @@ import json import sys -from adapt.entity_tagger import EntityTagger -from adapt.tools.text.tokenizer import EnglishTokenizer -from adapt.tools.text.trie import Trie from adapt.intent import IntentBuilder -from adapt.parser import Parser from adapt.engine import IntentDeterminationEngine -tokenizer = EnglishTokenizer() -trie = Trie() -tagger = EntityTagger(trie, tokenizer) -parser = Parser(tokenizer, tagger) - engine = IntentDeterminationEngine() # define vocabulary diff --git a/publish/publish.sh b/publish/publish.sh deleted file mode 100755 index 48e8242..0000000 --- a/publish/publish.sh +++ /dev/null @@ -1,98 +0,0 @@ -#!/usr/bin/env bash - -set -Ee # fail on error - -# pypi and testpypi .python.org credential variables must be set -#TESTPYPI_USERNAME= -#TESTPYPI_PASSWORD= -#PYPI_USERNAME= -#PYPI_PASSWORD= - -# set top of working directory -TOP=$(cd $(dirname $0)/.. && pwd -L) -echo "The working directory top is ${TOP}" - -# set virtualenv root -VIRTUALENV_ROOT=${VIRTUALENV_ROOT:-"${HOME}/.virtualenvs/adapt"} -echo "The virtualenv root location is ${VIRTUALENV_ROOT}" - -# get latest release version -VERSION="$(basename $(git for-each-ref --format="%(refname:short)" --sort=-authordate --count=1 refs/tags) | sed -e 's/v//g')" -echo "The latest adapt release version is ${VERSION}" - -# check out tagged version -git checkout release/v${VERSION} - -# get setup.py version -PYPI_VERSION=$(python ${TOP}/setup.py --version) -echo "The adapt version found in setup.py is ${PYPI_VERSION}" - -# verify release tag and setup.py version are equal -if [[ ${VERSION} != ${PYPI_VERSION} ]]; then - echo "setup.py and release tag version are inconsistent." - echo "please update setup.py and verify release" - exit 1 - -fi - -# clean virtualenv and remove previous test results -echo "Removing previous virtualenv and test results if they exist" -rm -rf ${VIRTUALENV_ROOT} TEST-*.xml - -# create virtualenv -echo "Creating virtualenv" -virtualenv ${VIRTUALENV_ROOT} -# activate virtualenv -. ${VIRTUALENV_ROOT}/bin/activate - -echo "Installing adapt requirements.txt" -# install adapt requirements -pip install -r requirements.txt - -echo "Installing adapt test-requiremtns.txt" -# install adapt test runner requirements -pip install -r test-requirements.txt - -# run unit tests -python run_tests.py - -function replace() { # @Seanfitz is the bomb.com - local FILE=$1 - local PATTERN=$2 - local VALUE=$3 - local TMP_FILE="/tmp/$$.replace" - cat ${FILE} | sed -e "s/${PATTERN}/${VALUE}/g" > ${TMP_FILE} - mv ${TMP_FILE} ${FILE} -} - -echo "Creating ~/.pypirc from template" -PYPIRC_FILE=~/.pypirc -cp -v ${TOP}/publish/pypirc.template ${PYPIRC_FILE} -replace ${PYPIRC_FILE} %%PYPI_USERNAME%% ${PYPI_USERNAME} -replace ${PYPIRC_FILE} %%PYPI_PASSWORD%% ${PYPI_PASSWORD} -replace ${PYPIRC_FILE} %%TESTPYPI_USERNAME%% ${TESTPYPI_USERNAME} -replace ${PYPIRC_FILE} %%TESTPYPI_PASSWORD%% ${TESTPYPI_PASSWORD} -# make .pyric private -chmod -v 600 ${PYPIRC_FILE} - -echo "Registering at pypitest.python.org" -python setup.py register -r pypitest -echo "Uploading to pypitest.python.org" -python setup.py sdist upload -r pypitest - -echo "testing installation from testpypi.python.org" -PYPI_TEST_VIRTUALENV='/tmp/.virtualenv' -rm -Rvf ${PYPI_TEST_VIRTUALENV} -virtualenv ${PYPI_TEST_VIRTUALENV} -deactivate -. ${PYPI_TEST_VIRTUALENV}/bin/activate -pip install -r requirements.txt -pip install -i https://testpypi.python.org/pypi adapt-parser==${VERSION} -deactivate -rm -Rvf ${PYPI_TEST_VIRTUALENV} - -. ${VIRTUALENV_ROOT}/bin/activate -echo "Registering at pypi.python.org" -python setup.py register -r pypi -echo "Uploading to pypi.python.org" -python setup.py sdist upload -r pypi diff --git a/publish/pypirc.template b/publish/pypirc.template deleted file mode 100644 index 532b45a..0000000 --- a/publish/pypirc.template +++ /dev/null @@ -1,14 +0,0 @@ -[distutils] -index-servers = - pypi - pypitest - -[pypi] -repository=https://pypi.python.org/pypi -username=%%PYPI_USERNAME%% -password=%%PYPI_PASSWORD%% - -[pypitest] -repository=https://testpypi.python.org/pypi -username=%%TESTPYPI_USERNAME%% -password=%%TESTPYPI_PASSWORD%% \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 719ceb5..dde3933 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1 @@ -argparse==1.2.1 -pyee==7.0.1 six>=1.10.0 diff --git a/run_tests.py b/run_tests.py deleted file mode 100644 index 8c460ea..0000000 --- a/run_tests.py +++ /dev/null @@ -1,13 +0,0 @@ -import unittest - -from xmlrunner import XMLTestRunner -import os -import sys - -loader = unittest.TestLoader() -tests = loader.discover(os.path.dirname(os.path.realpath(__file__)), pattern="*Test.py") -fail_on_error = "--fail-on-error" in sys.argv -runner = XMLTestRunner() -result = runner.run(tests) -if fail_on_error and len(result.failures + result.errors) > 0: - sys.exit(1) diff --git a/run_tests.sh b/run_tests.sh new file mode 100755 index 0000000..11a99b6 --- /dev/null +++ b/run_tests.sh @@ -0,0 +1,40 @@ +#! /bin/bash + +ADAPT_DIR=$(dirname $0) + +do_lint () { + flake8 "${ADAPT_DIR}/adapt" --select=E9,F63,F7,F82 --show-source && \ + flake8 "${ADAPT_DIR}/test" --select=E9,F63,F7,F82 --show-source +} + +do_test () { + pytest "${ADAPT_DIR}/test/"* +} + +show_help () { + echo "Tests for adapt." + echo "If no arguments are given, both test and linting is performed." + echo "Otherwise the argument will determine which part is performed." + echo "" + echo " Usage: $0 [test/lint]" + echo "" + echo "Arguments:" + echo " test: Only run the tests." + echo " lint: Only perform codestyle and static analysis." +} + +if [[ $# == 0 ]]; then + do_lint || exit $? # Exit on failure + do_test || exit $? # Exit on failure +elif [[ $# == 1 ]]; then + if [[ $1 == "lint" ]]; then + do_lint + elif [[ $1 == "test" ]]; then + do_test + else + show_help + fi +else + show_help +fi + diff --git a/setup.py b/setup.py index 71296ba..3966aa0 100644 --- a/setup.py +++ b/setup.py @@ -15,21 +15,46 @@ __author__ = 'seanfitz' +import os from setuptools import setup +with open("README.md", "r", encoding="utf_8") as fh: + long_description = fh.read() + + +def required(requirements_file): + """Read requirements file and remove comments and empty lines.""" + base_dir = os.path.abspath(os.path.dirname(__file__)) + with open(os.path.join(base_dir, requirements_file), 'r') as f: + requirements = f.read().splitlines() + return [pkg for pkg in requirements + if pkg.strip() and not pkg.startswith("#")] + setup( - name = "adapt-parser", - version = "0.3.6", - author = "Sean Fitzgerald", - author_email = "sean@fitzgeralds.me", - description = ("A text-to-intent parsing framework."), - license = ("Apache License 2.0"), - keywords = "natural language processing", - url = "https://github.com/MycroftAI/adapt", - packages = ["adapt", "adapt.tools", "adapt.tools.text"], - - install_requires = [ - "pyee==7.0.1", - "six>=1.10.0" - ] + name="adapt-parser", + version="1.0.0", + author="Sean Fitzgerald", + author_email="sean@fitzgeralds.me", + description=("A text-to-intent parsing framework."), + long_description=long_description, + long_description_content_type="text/markdown", + license=("Apache License 2.0"), + keywords="natural language processing", + url="https://github.com/MycroftAI/adapt", + packages=["adapt", "adapt.tools", "adapt.tools.text", "adapt.tools.debug"], + classifiers=[ + 'Development Status :: 5 - Production/Stable', + 'Intended Audience :: Developers', + 'Topic :: Text Processing :: Linguistic', + 'License :: OSI Approved :: Apache Software License', + + 'Programming Language :: Python :: 2.7', + 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', + ], + + install_requires=required('requirements.txt') ) diff --git a/test-requirements.txt b/test-requirements.txt index dfe4a1c..19f6206 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1 +1,3 @@ -xmlrunner==1.7.7 +flake8 +pytest +-r requirements.txt diff --git a/test/ContextManagerTest.py b/test/ContextManagerTest.py index c626337..076a4d9 100644 --- a/test/ContextManagerTest.py +++ b/test/ContextManagerTest.py @@ -15,7 +15,43 @@ import unittest -from adapt.context import ContextManager +from adapt.context import ContextManager, ContextManagerFrame + + +class ContextManagerFrameTest(unittest.TestCase): + def setUp(self): + pass + + def testMetadataMatches(self): + frame1 = ContextManagerFrame(entities=['foo'], + metadata={'domain': 'music', + 'foo': 'test'}) + + self.assertTrue(frame1.metadata_matches({'domain': 'music'}), + "Should match subset of metadata") + + self.assertFalse(frame1.metadata_matches({'domain': 'weather'}), + "Should not match metadata value mismatch") + self.assertTrue( + frame1.metadata_matches({'domain': 'music', 'foo': 'test'}), + "Should match exact metadata") + self.assertFalse(frame1.metadata_matches( + {'domain': 'music', 'foo': 'test', 'bar': 'test'}), + "Should not match superset of metadata") + + def testMergeContext(self): + frame1 = ContextManagerFrame(entities=['foo'], + metadata={'domain': 'music', + 'foo': 'test'}) + + self.assertFalse(frame1.metadata_matches({'bar': 'test'}), + "Should not match before merging context") + + frame1.merge_context('bar', {'domain': 'music', 'bar': 'test'}) + self.assertTrue(frame1.metadata_matches({'domain': 'music'}), + "Should continue to match subset of metadata") + self.assertTrue(frame1.metadata_matches({'bar': 'test'}), + "Should match after merging context") class ContextManagerTest(unittest.TestCase): @@ -54,7 +90,7 @@ def testNewContextNoMetadataResultsInNewFrame(self): assert len(context) == 2 assert context[0].get('confidence') == 0.5 assert context[0].get('data') == 'Film' - assert context[1].get('confidence') == 1.0/3.0 + assert context[1].get('confidence') == 1.0 / 3.0 assert context[1].get('data') == 'Book' def testNewContextWithMetadataSameFrame(self): @@ -74,4 +110,3 @@ def testNewContextWithMetadataSameFrame(self): assert context[0].get('data') == 'Book' assert context[1].get('confidence') == 0.5 assert context[1].get('data') == 'Film' - diff --git a/test/DomainIntentEngineTest.py b/test/DomainIntentEngineTest.py index 4bcfb7c..075b78d 100644 --- a/test/DomainIntentEngineTest.py +++ b/test/DomainIntentEngineTest.py @@ -219,3 +219,53 @@ def test_select_best_intent_enuse_enitities_dont_register_in_multiple_domains(se intents = self.engine.determine_intent(utterance, 1) for intent in intents: self.assertNotEqual(intent['intent_type'], 'Parser2') + + def test_drop_intent_from_domain(self): + """Test that intent is dropped from the correct domain.""" + self.engine.register_domain('Domain1') + self.engine.register_domain('Domain2') + + # Creating first intent domain + parser1 = IntentBuilder("Parser1").require("Entity1").build() + self.engine.register_intent_parser(parser1, domain='Domain1') + self.engine.register_entity("tree", "Entity1", domain='Domain1') + + # Creating second intent domain + parser2 = IntentBuilder("Parser2").require("Entity2").build() + self.engine.register_intent_parser(parser2, domain="Domain2") + self.engine.register_entity("house", "Entity2", domain="Domain2") + + self.engine.drop_intent_parser(domain="Domain2", + parser_names=['Parser2']) + self.assertEqual(len(self.engine.domains['Domain2'].intent_parsers), 0) + + def test_drop_entity_from_domain(self): + """Test that entity is dropped from domain.""" + self.engine.register_domain('Domain1') + self.engine.register_domain('Domain2') + + # Creating first intent domain + parser1 = IntentBuilder("Parser1").require("Entity1").build() + self.engine.register_intent_parser(parser1, domain='Domain1') + self.engine.register_entity("tree", "Entity1", domain='Domain1') + + # Creating second intent domain + parser2 = IntentBuilder("Parser2").require("Entity2").build() + self.engine.register_intent_parser(parser2, domain="Domain2") + self.engine.register_entity("house", "Entity2", domain="Domain2") + + self.assertTrue(self.engine.drop_entity(domain="Domain2", + entity_type='Entity2')) + + def testDropRegexEntity(self): + self.engine.register_domain("Domain1") + self.engine.register_domain("Domain2") + + self.engine.register_regex_entity(r"the dog (?P.*)", + "Domain1") + self.engine.register_regex_entity(r"the cat (?P.*)", + "Domain2") + self.assertTrue(self.engine.drop_regex_entity(domain='Domain2', + entity_type='Cat')) + self.assertFalse(self.engine.drop_regex_entity(domain='Domain1', + entity_type='Cat')) diff --git a/test/IntentEngineTest.py b/test/IntentEngineTest.py index 43b6d62..7cd4203 100644 --- a/test/IntentEngineTest.py +++ b/test/IntentEngineTest.py @@ -58,3 +58,170 @@ def testSelectBestIntent(self): intent = next(self.engine.determine_intent(utterance)) assert intent assert intent['intent_type'] == 'Parser2' + + def testDropIntent(self): + parser1 = IntentBuilder("Parser1").require("Entity1").build() + self.engine.register_intent_parser(parser1) + self.engine.register_entity("tree", "Entity1") + + parser2 = (IntentBuilder("Parser2").require("Entity1") + .require("Entity2").build()) + self.engine.register_intent_parser(parser2) + self.engine.register_entity("house", "Entity2") + + utterance = "go to the tree house" + + intent = next(self.engine.determine_intent(utterance)) + assert intent + assert intent['intent_type'] == 'Parser2' + + assert self.engine.drop_intent_parser('Parser2') is True + intent = next(self.engine.determine_intent(utterance)) + assert intent + assert intent['intent_type'] == 'Parser1' + + def testDropEntity(self): + parser1 = IntentBuilder("Parser1").require("Entity1").build() + self.engine.register_intent_parser(parser1) + self.engine.register_entity("laboratory", "Entity1") + self.engine.register_entity("lab", "Entity1") + + utterance = "get out of my lab" + utterance2 = "get out of my laboratory" + intent = next(self.engine.determine_intent(utterance)) + assert intent + assert intent['intent_type'] == 'Parser1' + + intent = next(self.engine.determine_intent(utterance2)) + assert intent + assert intent['intent_type'] == 'Parser1' + + # Remove Entity and re-register laboratory and make sure only that + # matches. + self.engine.drop_entity(entity_type='Entity1') + self.engine.register_entity("laboratory", "Entity1") + + # Sentence containing lab should not produce any results + with self.assertRaises(StopIteration): + intent = next(self.engine.determine_intent(utterance)) + + # But sentence with laboratory should + intent = next(self.engine.determine_intent(utterance2)) + assert intent + assert intent['intent_type'] == 'Parser1' + + def testCustomDropEntity(self): + parser1 = (IntentBuilder("Parser1").one_of("Entity1", "Entity2") + .build()) + self.engine.register_intent_parser(parser1) + self.engine.register_entity("laboratory", "Entity1") + self.engine.register_entity("lab", "Entity2") + + utterance = "get out of my lab" + utterance2 = "get out of my laboratory" + intent = next(self.engine.determine_intent(utterance)) + assert intent + assert intent['intent_type'] == 'Parser1' + + intent = next(self.engine.determine_intent(utterance2)) + assert intent + assert intent['intent_type'] == 'Parser1' + + def matcher(data): + return data[1].startswith('Entity') + + self.engine.drop_entity(match_func=matcher) + self.engine.register_entity("laboratory", "Entity1") + + # Sentence containing lab should not produce any results + with self.assertRaises(StopIteration): + intent = next(self.engine.determine_intent(utterance)) + + # But sentence with laboratory should + intent = next(self.engine.determine_intent(utterance2)) + assert intent + + def testDropRegexEntity(self): + self.engine.register_regex_entity(r"the dog (?P.*)") + self.engine.register_regex_entity(r"the cat (?P.*)") + assert len(self.engine._regex_strings) == 2 + assert len(self.engine.regular_expressions_entities) == 2 + self.engine.drop_regex_entity(entity_type='Cat') + assert len(self.engine._regex_strings) == 1 + assert len(self.engine.regular_expressions_entities) == 1 + + def testCustomDropRegexEntity(self): + self.engine.register_regex_entity(r"the dog (?P.*)") + self.engine.register_regex_entity(r"the cat (?P.*)") + self.engine.register_regex_entity(r"the mangy dog (?P.*)") + assert len(self.engine._regex_strings) == 3 + assert len(self.engine.regular_expressions_entities) == 3 + + def matcher(regexp): + """Matcher for all match groups defined for SkillB""" + match_groups = regexp.groupindex.keys() + return any([k.startswith('SkillB') for k in match_groups]) + + self.engine.drop_regex_entity(match_func=matcher) + assert len(self.engine._regex_strings) == 2 + assert len(self.engine.regular_expressions_entities) == 2 + + def testAddingOfRemovedRegexp(self): + self.engine.register_regex_entity(r"the cool (?P.*)") + + def matcher(regexp): + """Matcher for all match groups defined for SkillB""" + match_groups = regexp.groupindex.keys() + return any([k.startswith('thing') for k in match_groups]) + + self.engine.drop_regex_entity(match_func=matcher) + assert len(self.engine.regular_expressions_entities) == 0 + self.engine.register_regex_entity(r"the cool (?P.*)") + assert len(self.engine.regular_expressions_entities) == 1 + + def testUsingOfRemovedRegexp(self): + self.engine.register_regex_entity(r"the cool (?P.*)") + parser = IntentBuilder("Intent").require("thing").build() + self.engine.register_intent_parser(parser) + + def matcher(regexp): + """Matcher for all match groups defined for SkillB""" + match_groups = regexp.groupindex.keys() + return any([k.startswith('thing') for k in match_groups]) + + self.engine.drop_regex_entity(match_func=matcher) + assert len(self.engine.regular_expressions_entities) == 0 + + utterance = "the cool cat" + intents = [match for match in self.engine.determine_intent(utterance)] + assert len(intents) == 0 + + def testEmptyTags(self): + # Validates https://github.com/MycroftAI/adapt/issues/114 + engine = IntentDeterminationEngine() + engine.register_entity("Kevin", + "who") # same problem if several entities + builder = IntentBuilder("Buddies") + builder.optionally("who") # same problem if several entity types + engine.register_intent_parser(builder.build()) + + intents = [i for i in engine.determine_intent("Julien is a friend")] + assert len(intents) == 0 + + def testResultsAreSortedByConfidence(self): + self.engine.register_entity('what is', 'Query', None) + self.engine.register_entity('weather', 'Weather', None) + self.engine.register_regex_entity('(at|in) (?P.+)') + self.engine.register_regex_entity('(?P.*)') + + i = IntentBuilder("CurrentWeatherIntent").require( + "Weather").optionally("Location").build() + self.engine.register_intent_parser(i) + utterance = "what is the weather like in stockholm" + intents = [ + i for i in self.engine.determine_intent(utterance, num_results=100) + ] + confidences = [intent.get('confidence', 0.0) for intent in intents] + assert len(confidences) > 1 + assert all(confidences[i] >= confidences[i+1] for i in range(len(confidences)-1)) + diff --git a/test/IntentTest.py b/test/IntentTest.py index 37431ea..0476b02 100644 --- a/test/IntentTest.py +++ b/test/IntentTest.py @@ -17,7 +17,7 @@ import unittest from adapt.parser import Parser from adapt.entity_tagger import EntityTagger -from adapt.intent import IntentBuilder, resolve_one_of +from adapt.intent import IntentBuilder, resolve_one_of, choose_1_from_each from adapt.tools.text.tokenizer import EnglishTokenizer from adapt.tools.text.trie import Trie @@ -30,12 +30,15 @@ def setUp(self): self.trie = Trie() self.tokenizer = EnglishTokenizer() self.regex_entities = [] - self.tagger = EntityTagger(self.trie, self.tokenizer, regex_entities=self.regex_entities) + self.tagger = EntityTagger(self.trie, self.tokenizer, + regex_entities=self.regex_entities) self.trie.insert("play", ("play", "PlayVerb")) self.trie.insert("stop", ("stop", "StopVerb")) - self.trie.insert("the big bang theory", ("the big bang theory", "Television Show")) + self.trie.insert("the big bang theory", + ("the big bang theory", "Television Show")) self.trie.insert("the big", ("the big", "Not a Thing")) - self.trie.insert("barenaked ladies", ("barenaked ladies", "Radio Station")) + self.trie.insert("barenaked ladies", + ("barenaked ladies", "Radio Station")) self.trie.insert("show", ("show", "Command")) self.trie.insert("what", ("what", "Question")) self.parser = Parser(self.tokenizer, self.tagger) @@ -44,29 +47,32 @@ def tearDown(self): pass def test_basic_intent(self): - intent = IntentBuilder("play television intent")\ - .require("PlayVerb")\ - .require("Television Show")\ + intent = IntentBuilder("play television intent") \ + .require("PlayVerb") \ + .require("Television Show") \ .build() for result in self.parser.parse("play the big bang theory"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('PlayVerb') == 'play' assert result_intent.get('Television Show') == "the big bang theory" def test_at_least_one(self): - intent = IntentBuilder("play intent")\ - .require("PlayVerb")\ - .one_of("Television Show", "Radio Station")\ + intent = IntentBuilder("play intent") \ + .require("PlayVerb") \ + .one_of("Television Show", "Radio Station") \ .build() for result in self.parser.parse("play the big bang theory"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('PlayVerb') == 'play' assert result_intent.get('Television Show') == "the big bang theory" for result in self.parser.parse("play the barenaked ladies"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('PlayVerb') == 'play' assert result_intent.get('Radio Station') == "barenaked ladies" @@ -76,14 +82,16 @@ def test_at_least_one_with_tag_in_multiple_slots(self): self.trie.insert("living room", ("living room", "living room")) self.trie.insert("what is", ("what is", "what is")) - intent = IntentBuilder("test intent")\ - .one_of("what is")\ - .one_of("temperature", "living room")\ - .one_of("temperature")\ + intent = IntentBuilder("test intent") \ + .one_of("what is") \ + .one_of("temperature", "living room") \ + .one_of("temperature") \ .build() - for result in self.parser.parse("what is the temperature in the living room"): - result_intent = intent.validate(result.get("tags"), result.get("confidence")) + for result in self.parser.parse( + "what is the temperature in the living room"): + result_intent = intent.validate(result.get("tags"), + result.get("confidence")) assert result_intent.get("confidence") > 0.0 assert result_intent.get("temperature") == "temperature" assert result_intent.get("living room") == "living room" @@ -94,12 +102,14 @@ def test_at_least_on_no_required(self): .one_of("Television Show", "Radio Station") \ .build() for result in self.parser.parse("play the big bang theory"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('Television Show') == "the big bang theory" for result in self.parser.parse("play the barenaked ladies"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('Radio Station') == "barenaked ladies" @@ -109,46 +119,51 @@ def test_at_least_one_alone(self): .build() for result in self.parser.parse("show"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('Command') == "show" def test_basic_intent_with_alternate_names(self): - intent = IntentBuilder("play television intent")\ - .require("PlayVerb", "Play Verb")\ - .require("Television Show", "series")\ + intent = IntentBuilder("play television intent") \ + .require("PlayVerb", "Play Verb") \ + .require("Television Show", "series") \ .build() for result in self.parser.parse("play the big bang theory"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('Play Verb') == 'play' assert result_intent.get('series') == "the big bang theory" def test_intent_with_regex_entity(self): self.trie = Trie() - self.tagger = EntityTagger(self.trie, self.tokenizer, self.regex_entities) + self.tagger = EntityTagger(self.trie, self.tokenizer, + self.regex_entities) self.parser = Parser(self.tokenizer, self.tagger) self.trie.insert("theory", ("theory", "Concept")) regex = re.compile(r"the (?P.*)") self.regex_entities.append(regex) - intent = IntentBuilder("mock intent")\ - .require("Event")\ + intent = IntentBuilder("mock intent") \ + .require("Event") \ .require("Concept").build() for result in self.parser.parse("the big bang theory"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('Event') == 'big bang' assert result_intent.get('Concept') == "theory" def test_intent_using_alias(self): self.trie.insert("big bang", ("the big bang theory", "Television Show")) - intent = IntentBuilder("play television intent")\ - .require("PlayVerb", "Play Verb")\ - .require("Television Show", "series")\ + intent = IntentBuilder("play television intent") \ + .require("PlayVerb", "Play Verb") \ + .require("Television Show", "series") \ .build() for result in self.parser.parse("play the big bang theory"): - result_intent = intent.validate(result.get('tags'), result.get('confidence')) + result_intent = intent.validate(result.get('tags'), + result.get('confidence')) assert result_intent.get('confidence') > 0.0 assert result_intent.get('Play Verb') == 'play' assert result_intent.get('series') == "the big bang theory" @@ -312,3 +327,175 @@ def test_resolve_one_of(self): } assert resolve_one_of(tags, at_least_one) == result + + +# noinspection PyPep8Naming +def TestTag(tag_name, + tag_value, + tag_confidence=1.0, + entity_confidence=1.0, + match=None): + """ + Create a dict in the shape of a tag as yielded from parser. + :param tag_name: tag name (equivalent to a label) + :param tag_value: tag value (value being labeled) + :param tag_confidence: confidence of parse of the tag, influenced by + fuzzy matching or context + :param entity_confidence: weight of the entity, influenced by + context + :param match: the text matched by the parser, which may not match tag_value + in the case of an alias or fuzzy matching. Defaults to tag_value. + + Uses "from_context" attribute to force token positioning to be ignored. + + :return: a dict that matches the shape of a parser tag + """ + return { + "confidence": tag_confidence, + "entities": [ + { + "confidence": entity_confidence, + "data": [ + [ + tag_value, + tag_name + ] + ], + "key": tag_value, + "match": match or tag_value + } + ], + "from_context": False, + "key": tag_value, + "match": match or tag_value, + "start_token": -1, + "end_token": -1, + "from_context": True + } + + +class IntentUtilityFunctionsTest(unittest.TestCase): + def test_choose_1_from_each_empty(self): + expected = [] + actual = list(choose_1_from_each([[]])) + self.assertListEqual(expected, actual) + + def test_choose_1_from_each_basic(self): + inputs = [ + ['A', 'B'], + ['C', 'D'] + ] + expected = [ + ['A', 'C'], + ['A', 'D'], + ['B', 'C'], + ['B', 'D'] + ] + actual = list(choose_1_from_each(inputs)) + self.assertListEqual(expected, actual) + + def test_choose_1_from_each_varying_sizes(self): + inputs = [ + ['A'], + ['B', 'C'], + ['D', 'E', 'F'] + ] + + expected = [ + ['A', 'B', 'D'], + ['A', 'B', 'E'], + ['A', 'B', 'F'], + ['A', 'C', 'D'], + ['A', 'C', 'E'], + ['A', 'C', 'F'], + ] + + actual = list(choose_1_from_each(inputs)) + self.assertListEqual(expected, actual) + + +class IntentScoringTest(unittest.TestCase): + def setUp(self): + self.require_intent = IntentBuilder('require_intent'). \ + require('required'). \ + build() + self.one_of_intent = IntentBuilder('one_of_intent'). \ + one_of('one_of_1', 'one_of_2'). \ + build() + self.optional_intent = IntentBuilder('optional_intent'). \ + optionally('optional'). \ + build() + self.all_features_intent = IntentBuilder('test_intent'). \ + require('required'). \ + one_of('one_of_1', 'one_of_2'). \ + optionally('optional'). \ + build() + + def test_basic_scoring_default_weights(self): + required = TestTag('required', 'foo') + one_of_1 = TestTag('one_of_1', 'bar') + optional = TestTag('optional', 'bing') + + intent, tags = \ + self.require_intent.validate_with_tags([required], + confidence=1.0) + self.assertEqual(1.0, intent.get('confidence')) + self.assertListEqual([required], tags) + + intent, tags = \ + self.one_of_intent.validate_with_tags([one_of_1], + confidence=1.0) + self.assertEqual(1.0, intent.get('confidence')) + self.assertListEqual([one_of_1], tags) + + intent, tags = \ + self.optional_intent.validate_with_tags([optional], + confidence=1.0) + self.assertEqual(1.0, intent.get('confidence')) + self.assertListEqual([optional], tags) + + def test_weighted_scoring_from_regex_entities(self): + required = TestTag('required', 'foo', entity_confidence=0.5) + one_of_1 = TestTag('one_of_1', 'bar', entity_confidence=0.5) + optional = TestTag('optional', 'bing', entity_confidence=0.5) + + intent, tags = \ + self.require_intent.validate_with_tags([required], + confidence=1.0) + self.assertEqual(0.5, intent.get('confidence')) + self.assertListEqual([required], tags) + + intent, tags = \ + self.one_of_intent.validate_with_tags([one_of_1], + confidence=1.0) + self.assertEqual(0.5, intent.get('confidence')) + self.assertListEqual([one_of_1], tags) + + intent, tags = \ + self.optional_intent.validate_with_tags([optional], + confidence=1.0) + self.assertEqual(0.5, intent.get('confidence')) + self.assertListEqual([optional], tags) + + def test_weighted_scoring_from_fuzzy_matching(self): + required = TestTag('required', 'foo') + one_of_1 = TestTag('one_of_1', 'bar') + optional = TestTag('optional', 'bing') + + intent, tags = \ + self.require_intent.validate_with_tags([required], + confidence=0.5) + self.assertEqual(0.5, intent.get('confidence')) + self.assertListEqual([required], tags) + + intent, tags = \ + self.one_of_intent.validate_with_tags([one_of_1], + confidence=0.5) + self.assertEqual(0.5, intent.get('confidence')) + self.assertListEqual([one_of_1], tags) + + intent, tags = \ + self.optional_intent.validate_with_tags([optional], + confidence=0.5) + self.assertEqual(0.5, intent.get('confidence')) + self.assertListEqual([optional], tags) diff --git a/test/TrieTest.py b/test/TrieTest.py index 61fa4c8..406a035 100644 --- a/test/TrieTest.py +++ b/test/TrieTest.py @@ -125,6 +125,70 @@ def test_edit_distance_no_confidence(self): results = list(trie.gather("of the big bang theory")) assert len(results) == 0 + def test_remove(self): + trie = Trie(max_edit_distance=2) + trie.insert("1", "Number") + trie.insert("2", "Number") + trie.remove("2") + + one_lookup = list(trie.gather("1")) + two_lookup = list(trie.gather("2")) + assert len(one_lookup) == 1 # One match found + assert len(two_lookup) == 0 # Zero matches since removed + + def test_remove_multi_last(self): + trie = Trie(max_edit_distance=2) + trie.insert("Kermit", "Muppets") + trie.insert("Kermit", "Frogs") + kermit_lookup = list(trie.lookup("Kermit"))[0] + assert 'Frogs' in kermit_lookup['data'] + assert 'Muppets' in kermit_lookup['data'] + + trie.remove("Kermit", "Frogs") + + kermit_lookup = list(trie.gather("Kermit"))[0] + assert kermit_lookup['data'] == {"Muppets"} # Right data remains + + def test_remove_multi_first(self): + trie = Trie(max_edit_distance=2) + trie.insert("Kermit", "Muppets") + trie.insert("Kermit", "Frogs") + kermit_lookup = list(trie.lookup("Kermit"))[0] + assert 'Frogs' in kermit_lookup['data'] + assert 'Muppets' in kermit_lookup['data'] + + trie.remove("Kermit", "Muppets") + + kermit_lookup = list(trie.lookup("Kermit"))[0] + assert kermit_lookup['data'] == {"Frogs"} # Right data remains + + def test_scan(self): + trie = Trie(max_edit_distance=2) + trie.insert("Kermit", "Muppets") + trie.insert("Gonzo", "Muppets") + trie.insert("Rowlf", "Muppets") + trie.insert("Gobo", "Fraggles") + + def match_func(data): + return data == "Muppets" + + results = trie.scan(match_func) + assert len(results) == 3 + muppet_names = [r[0] for r in results] + assert "Kermit" in muppet_names + assert "Gonzo" in muppet_names + assert "Rowlf" in muppet_names + + def test_is_prefix(self): + trie = Trie() + trie.insert("play", "PlayVerb") + trie.insert("the big bang theory", "Television Show") + trie.insert("the big", "Not a Thing") + trie.insert("barenaked ladies", "Radio Station") + + assert trie.root.is_prefix("the") + assert trie.root.is_prefix("play") + assert not trie.root.is_prefix("Kermit") def tearDown(self): pass