diff --git a/modules/data/bert_data.py b/modules/data/bert_data.py index 347b2ec..1b0ed0d 100644 --- a/modules/data/bert_data.py +++ b/modules/data/bert_data.py @@ -6,6 +6,7 @@ from tqdm import tqdm from modules.utils import read_json, save_json import logging +import os class InputFeatures(object): @@ -448,12 +449,12 @@ def create(cls, meta2idx = None cls2idx = None - if idx2label is None and idx2label_path is not None: + if idx2label is None and os.path.exists(str(idx2label_path)): idx2label = read_json(idx2label_path) - if is_meta and idx2meta is None and idx2meta_path is not None: + if is_meta and idx2meta is None and os.path.exists(str(idx2meta_path)): idx2meta = read_json(idx2meta_path) meta2idx = {label: idx for idx, label in enumerate(idx2meta)} - if is_cls and idx2cls is None and idx2cls_path is not None: + if is_cls and idx2cls is None and os.path.exists(str(idx2cls_path)): idx2cls = read_json(idx2cls_path) cls2idx = {label: idx for idx, label in enumerate(idx2cls)} label2idx = {label: idx for idx, label in enumerate(idx2label)}