Skip to content


Improve and fix augment_sentence and expand_tokens (see amazon-science#4
Browse files Browse the repository at this point in the history
giove91 committed Jun 5, 2021
1 parent c6110a9 commit 2394bc9
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions
Original file line number Diff line number Diff line change
@@ -25,29 +25,29 @@ def get_episode_indices(episodes_string: str) -> List[int]:
return episode_indices

def expand_tokens(tokens: List[str], entity_tree: Dict[Tuple[int, int], List[Tuple[List[tuple], int, int]]],
root: tuple, begin_entity_token: str, sep_token: str, relation_sep_token: str,
end_entity_token: str) -> List[str]:
def expand_tokens(tokens: List[str], augmentations: List[Tuple[List[tuple], int, int]],
entity_tree: Dict[int, List[int]], root: int,
begin_entity_token: str, sep_token: str, relation_sep_token: str, end_entity_token: str) \
-> List[str]:
Recursively expand the tokens to obtain a sentence in augmented natural language.
Used in the augment_sentence function below (see the documentation there).

new_tokens = []
_, root_start, root_end = root
root_start, root_end = augmentations[root][1:] if root >= 0 else (0, len(tokens))
i = root_start # current index

for entity in entity_tree[root_start, root_end]:
tags, start, end = entity
for entity_index in entity_tree[root]:
tags, start, end = augmentations[entity_index]

# add tokens before this entity
new_tokens += tokens[i:start]

# expand this entity
new_tokens += expand_tokens(tokens, entity_tree, entity, begin_entity_token, sep_token, relation_sep_token,
new_tokens += expand_tokens(tokens, augmentations, entity_tree, entity_index,
begin_entity_token, sep_token, relation_sep_token, end_entity_token)

for tag in tags:
if tag[0]:
@@ -95,36 +95,37 @@ def augment_sentence(tokens: List[str], augmentations: List[Tuple[List[tuple], i
output augmented sentence:
[ Tolkien | person | born in = here ] was born [ here | location ]
# sort entities by start position
augmentations = list(sorted(augmentations, key=lambda z: z[1]))
# sort entities by start position, longer entities first
augmentations = list(sorted(augmentations, key=lambda z: (z[1], -z[2])))

# check that the entities have a tree structure (if two entities overlap, then one is contained in
# the other), and build the entity tree
root = (None, -1, len(tokens)) # this node represents the entire sentence
entity_tree = {root[1:]: []} # list of children of each node
root = -1 # each node is represented by its position in the list of augmentations, except that the root is -1
entity_tree = {root: []} # list of children of each node
current_stack = [root] # where we are in the tree

for x in sorted(augmentations, key=lambda z: (z[1], -z[2])):
for j, x in enumerate(augmentations):
tags, start, end = x
if any(other_start < start < other_end < end for _, other_start, other_end in current_stack):
if any(augmentations[k][1] < start < augmentations[k][2] < end for k in current_stack):
# tree structure is not satisfied!
logging.warning(f'Tree structure is not satisfied! {current_stack}')
logging.warning(f'Tree structure is not satisfied! Dropping annotation {x}')

while not (current_stack[-1][1] <= start <= end <= current_stack[-1][2]):
while current_stack[-1] >= 0 and \
not (augmentations[current_stack[-1]][1] <= start <= end <= augmentations[current_stack[-1]][2]):

# add as a child of its father

# update stack

# create empty list of children for this new node
entity_tree[start, end] = []
entity_tree[j] = []

return ' '.join(expand_tokens(
tokens, entity_tree, root, begin_entity_token, sep_token, relation_sep_token, end_entity_token
tokens, augmentations, entity_tree, root, begin_entity_token, sep_token, relation_sep_token, end_entity_token

0 comments on commit 2394bc9

Please sign in to comment.