From 040c9f2c47b6317cc0910389ffd740fc91327430 Mon Sep 17 00:00:00 2001 From: king-menin Date: Sat, 19 Jan 2019 18:40:47 +0300 Subject: [PATCH] fix from_config --- modules/data/bert_data.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/data/bert_data.py b/modules/data/bert_data.py index b8b41ba..18af202 100644 --- a/modules/data/bert_data.py +++ b/modules/data/bert_data.py @@ -338,8 +338,9 @@ def __init__(self, train_path, valid_path, vocab_file, data_type, self.is_cls = True self.id2cls = sorted(cls2idx.keys(), key=lambda x: cls2idx[x]) + # TODO: write docs @classmethod - def from_config(cls, config): + def from_config(cls, config, for_train=True): if config["data_type"] == "bert_cased": do_lower_case = False fn = get_bert_data_loaders @@ -348,7 +349,7 @@ def from_config(cls, config): fn = get_bert_data_loaders else: raise NotImplementedError("No requested mode :(.") - if config["train_path"] and config["valid_path"]: + if config["train_path"] and config["valid_path"] and for_train: fn_res = fn(config["train_path"], config["valid_path"], config["vocab_file"], config["batch_size"], config["cuda"], config["is_cls"], do_lower_case, config["max_seq_len"], config["is_meta"], label2idx=config["label2idx"], cls2idx=config["cls2idx"])