From f553d43e5bd0b5617a002f1ab7861a158d6e2e71 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gon=C3=A7alves?= <11768325+gsgoncalves@users.noreply.github.com> Date: Sun, 10 Jul 2022 14:54:33 +0100 Subject: [PATCH] Updated script to run with the SpacyNER tagger and REL linker (#1226) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Updated script to run with the SpacyNER tagger and REL linker. It is RAM intensive * Shortened left/right context creation into a one-liner. Corrected indentation from tabs to spaces * Added safeguard to keep documents without entities in the final collection file. Fixed positions that refer to the original position of the mentions in the original text. Co-authored-by: Gustavo Gonçalves --- docs/working-with-entity-linking.md | 17 ++++ scripts/entity_linking.py | 140 +++++++++++++++++++--------- 2 files changed, 115 insertions(+), 42 deletions(-) diff --git a/docs/working-with-entity-linking.md b/docs/working-with-entity-linking.md index 4084df7c2..83edf0248 100644 --- a/docs/working-with-entity-linking.md +++ b/docs/working-with-entity-linking.md @@ -64,5 +64,22 @@ python entity_linking.py --input_path [input_jsonl_file] --rel_base_url [base_ur --spacy_model [en_core_web_sm, en_core_web_lg, etc.] --output_path [output_jsonl_file] ``` +An extended example assuming you're running the script from the scripts dir: +```bash +REL_DATA_PATH=/home/$USER/REL/data +INPUT_JSONL_FILE=../collections/msmarco-passage/collection_jsonl/docs00.json +mkdir ../collections/msmarco-passage/collection_jsonl_with_entities/ +OUTPUT_JSONL_FILE=../collections/msmarco-passage/msmarco-passage/collection_jsonl_with_entities/docs00.json +BASE_URL=$REL_DATA_PATH +ED_MODEL=$REL_DATA_PATH/ed-wiki-2019/model +WIKI_VERSION=wiki_2019 +WIKIMAPPER_INDEX=$REL_DATA_PATH/index_enwiki-20190420.db + +python entity_linking.py --input_path $INPUT_JSONL_FILE \ +--rel_base_url $BASE_URL --rel_ed_model_path $ED_MODEL \ +--rel_wiki_version $WIKI_VERSION --wikimapper_index $WIKIMAPPER_INDEX \ +--spacy_model en_core_web_sm --output_path $OUTPUT_JSONL_FILE +``` + It should take about 5 to 10 minutes to run entity linking on 5,000 MS MARCO passages on Compute Canada. See [this](https://github.com/castorini/onboarding/blob/master/docs/cc-guide.md#compute-canada) for instructions about running scripts on Compute Canada. diff --git a/scripts/entity_linking.py b/scripts/entity_linking.py index 7bbb49805..38e56b3b3 100644 --- a/scripts/entity_linking.py +++ b/scripts/entity_linking.py @@ -17,34 +17,87 @@ import argparse import jsonlines import spacy -from REL.REL.mention_detection import MentionDetection -from REL.REL.utils import process_results -from REL.REL.entity_disambiguation import EntityDisambiguation -from REL.REL.ner import NERBase, Span +import sys +from REL.mention_detection import MentionDetectionBase +from REL.utils import process_results, split_in_words +from REL.entity_disambiguation import EntityDisambiguation +from REL.ner import Span from wikimapper import WikiMapper - +from typing import Dict, List, Tuple +from tqdm import tqdm # Spacy Mention Detection class which overrides the NERBase class in the REL entity linking process -class NERSpacy(NERBase): - def __init__(self): +class NERSpacyMD(MentionDetectionBase): + def __init__(self, base_url:str, wiki_version:str, spacy_model:str): + super().__init__(base_url, wiki_version) # we only want to link entities of specific types self.ner_labels = ['PERSON', 'NORP', 'FAC', 'ORG', 'GPE', 'LOC', 'PRODUCT', 'EVENT', 'WORK_OF_ART', 'LAW', 'LANGUAGE', 'DATE', 'TIME', 'MONEY', 'QUANTITY'] + self.spacy_model = spacy_model + spacy.prefer_gpu() + self.tagger = spacy.load(spacy_model) # mandatory function which overrides NERBase.predict() - def predict(self, doc): - mentions = [] + def predict(self, doc: spacy.tokens.Doc) -> List[Span]: + spans = [] for ent in doc.ents: if ent.label_ in self.ner_labels: - mentions.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_)) - return mentions + spans.append(Span(ent.text, ent.start_char, ent.end_char, 0, ent.label_)) + return spans + + """ + Responsible for finding mentions given a set of documents in a batch-wise manner. More specifically, + it returns the mention, its left/right context and a set of candidates. + :return: Dictionary with mentions per document. + """ + + def find_mentions(self, dataset: Dict[str, str]) -> Tuple[Dict[str, List[Dict]], int]: + results = {} + total_ment = 0 + for i, doc in tqdm(enumerate(dataset), desc='Finding mentions', total=len(dataset)): + result_doc = [] + doc_text = dataset[doc] + spacy_doc = self.tagger(doc_text) + spans = self.predict(spacy_doc) + for entity in spans: + text, start_pos, end_pos, conf, tag = ( + entity.text, + entity.start_pos, + entity.end_pos, + entity.score, + entity.tag, + ) + m = self.preprocess_mention(text) + cands = self.get_candidates(m) + if len(cands) == 0: + continue + total_ment += 1 + # Re-create ngram as 'text' is at times changed by Flair (e.g. double spaces are removed). + ngram = doc_text[start_pos:end_pos] + left_ctxt = " ".join(split_in_words(doc_text[:start_pos])[-100:]) + right_ctxt = " ".join(split_in_words(doc_text[end_pos:])[:100]) + res = { + "mention": m, + "context": (left_ctxt, right_ctxt), + "candidates": cands, + "gold": ["NONE"], + "pos": start_pos, + "sent_idx": 0, + "ngram": ngram, + "end_pos": end_pos, + "sentence": doc_text, + "conf_md": conf, + "tag": tag, + } + result_doc.append(res) + results[doc] = result_doc + return results, total_ment # run REL entity linking on processed doc -def rel_entity_linking(spacy_docs, rel_base_url, rel_wiki_version, rel_ed_model_path): - mention_detection = MentionDetection(rel_base_url, rel_wiki_version) - tagger_spacy = NERSpacy() - mentions_dataset, _ = mention_detection.find_mentions(spacy_docs, tagger_spacy) +def rel_entity_linking(docs: Dict[str,str], spacy_model:str, rel_base_url:str, rel_wiki_version:str, rel_ed_model_path:str) -> Dict[str, List[Tuple]]: + mention_detection = NERSpacyMD(rel_base_url, rel_wiki_version, spacy_model) + mentions_dataset, _ = mention_detection.find_mentions(docs) config = { 'mode': 'eval', 'model_path': rel_ed_model_path, @@ -52,40 +105,42 @@ def rel_entity_linking(spacy_docs, rel_base_url, rel_wiki_version, rel_ed_model_ ed_model = EntityDisambiguation(rel_base_url, rel_wiki_version, config) predictions, _ = ed_model.predict(mentions_dataset) - linked_entities = process_results(mentions_dataset, predictions, spacy_docs) + linked_entities = process_results(mentions_dataset, predictions, docs) return linked_entities -# apply spaCy nlp processing pipeline on each doc -def apply_spacy_pipeline(input_path, spacy_model): - nlp = spacy.load(spacy_model) - spacy_docs = {} +# read input pyserini json docs into a dictionary +def read_docs(input_path: str) -> Dict[str, str]: + docs = {} with jsonlines.open(input_path) as reader: - for obj in reader: - spacy_docs[obj['id']] = nlp(obj['contents']) - return spacy_docs + for obj in tqdm(reader, desc='Reading docs'): + docs[obj['id']] = obj['contents'] + return docs # enrich REL entity linking results with entities' wikidata ids, and write final results as json objects -def enrich_el_results(rel_linked_entities, spacy_docs, wikimapper_index): +# rel_linked_entities: Tuples of entities are composed by start_pos:int, mention_length:int, ent_text:str, ent_wikipedia_id:str, conf_score:float, ner_score:int, ent_type:str +def enrich_el_results(rel_linked_entities: Dict[str, List[Tuple]], docs: Dict[str, str], wikimapper_index:str) -> List[Dict]: wikimapper = WikiMapper(wikimapper_index) linked_entities_json = [] - for docid, ents in rel_linked_entities.items(): - linked_entities_info = [] - for start_pos, end_pos, ent_text, ent_wikipedia_id, ent_type in ents: - # find entities' wikidata ids using their REL results (i.e. linked wikipedia ids) - ent_wikipedia_id = ent_wikipedia_id.replace('&', '&') - ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id) - - # write results as json objects - linked_entities_info.append({'start_pos': start_pos, 'end_pos': end_pos, 'ent_text': ent_text, - 'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id, - 'ent_type': ent_type}) - linked_entities_json.append({'id': docid, 'contents': spacy_docs[docid].text, - 'entities': linked_entities_info}) + for docid, doc_text in tqdm(docs.items(), desc='Enriching EL results', total=len(rel_linked_entities)): + if docid not in rel_linked_entities: + linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': []}) + else: + linked_entities_info = [] + ents = rel_linked_entities[docid] + for start_pos, mention_length, ent_text, ent_wikipedia_id, conf_score, ner_score, ent_type in ents: + # find entities' wikidata ids using their REL results (i.e. linked wikipedia ids) + ent_wikipedia_id = ent_wikipedia_id.replace('&', '&') + ent_wikidata_id = wikimapper.title_to_id(ent_wikipedia_id) + + # write results as json objects + linked_entities_info.append({'start_pos': start_pos, 'end_pos': start_pos + mention_length, 'ent_text': ent_text, + 'wikipedia_id': ent_wikipedia_id, 'wikidata_id': ent_wikidata_id, + 'ent_type': ent_type}) + linked_entities_json.append({'id': docid, 'contents': doc_text, 'entities': linked_entities_info}) return linked_entities_json - def main(): parser = argparse.ArgumentParser() parser.add_argument('-p', '--input_path', type=str, help='path to input texts') @@ -97,13 +152,14 @@ def main(): parser.add_argument('-o', '--output_path', type=str, help='path to output json file') args = parser.parse_args() - spacy_docs = apply_spacy_pipeline(args.input_path, args.spacy_model) - rel_linked_entities = rel_entity_linking(spacy_docs, args.rel_base_url, args.rel_wiki_version, + docs = read_docs(args.input_path) + rel_linked_entities = rel_entity_linking(docs, args.spacy_model, args.rel_base_url, args.rel_wiki_version, args.rel_ed_model_path) - linked_entities_json = enrich_el_results(rel_linked_entities, spacy_docs, args.wikimapper_index) + linked_entities_json = enrich_el_results(rel_linked_entities, docs, args.wikimapper_index) with jsonlines.open(args.output_path, mode='w') as writer: writer.write_all(linked_entities_json) if __name__ == '__main__': - main() \ No newline at end of file + main() + sys.exit(0)