From 202b7f7f73fc00a4b52fe033b915b8cf0d51eb69 Mon Sep 17 00:00:00 2001 From: Katsuhiko Ishiguro Date: Wed, 29 Jul 2020 11:34:25 +0900 Subject: [PATCH] Converters fixed --- .../dataset/converters/__init__.py | 18 ----------- .../dataset/preprocessors/__init__.py | 18 ----------- examples/molnet_wle/predict_molnet_wle.py | 3 +- examples/molnet_wle/train_molnet_wle.py | 32 +++++++++++++++++++ 4 files changed, 34 insertions(+), 37 deletions(-) diff --git a/chainer_chemistry/dataset/converters/__init__.py b/chainer_chemistry/dataset/converters/__init__.py index cf0d5ebc..eba6e92c 100644 --- a/chainer_chemistry/dataset/converters/__init__.py +++ b/chainer_chemistry/dataset/converters/__init__.py @@ -6,34 +6,16 @@ 'ecfp': concat_mols, 'nfp': concat_mols, 'nfp_gwm': concat_mols, - 'nfp_wle': concat_mols, - 'nfp_cwle': concat_mols, - 'nfp_gwle': concat_mols, 'ggnn': concat_mols, 'ggnn_gwm': concat_mols, - 'ggnn_wle': concat_mols, - 'ggnn_cwle': concat_mols, - 'ggnn_gwle': concat_mols, 'gin': concat_mols, 'gin_gwm': concat_mols, - 'gin_wle': concat_mols, - 'gin_cwle': concat_mols, - 'gin_gwle': concat_mols, 'schnet': concat_mols, 'weavenet': concat_mols, 'relgcn': concat_mols, - 'relgcn_wle': concat_mols, - 'relgcn_cwle': concat_mols, - 'relgcn_gwle': concat_mols, 'rsgcn': concat_mols, 'rsgcn_gwm': concat_mols, - 'rsgcn_wle': concat_mols, - 'rsgcn_cwle': concat_mols, - 'rsgcn_gwle': concat_mols, 'relgat': concat_mols, - 'relgat_wle': concat_mols, - 'relgat_cwle': concat_mols, - 'relgat_gwle': concat_mols, 'gnnfilm': concat_mols, 'megnet': megnet_converter, 'cgcnn': cgcnn_converter diff --git a/chainer_chemistry/dataset/preprocessors/__init__.py b/chainer_chemistry/dataset/preprocessors/__init__.py index 0dab2522..b035e14d 100644 --- a/chainer_chemistry/dataset/preprocessors/__init__.py +++ b/chainer_chemistry/dataset/preprocessors/__init__.py @@ -28,35 +28,17 @@ 'ecfp': ECFPPreprocessor, 'nfp': NFPPreprocessor, 'nfp_gwm': NFPGWMPreprocessor, - 'nfp_wle': NFPPreprocessor, - 'nfp_cwle': NFPPreprocessor, - 'nfp_gwle': NFPPreprocessor, 'ggnn': GGNNPreprocessor, 'ggnn_gwm': GGNNGWMPreprocessor, - 'ggnn_wle': GGNNPreprocessor, - 'ggnn_cwle': GGNNPreprocessor, - 'ggnn_gwle': GGNNPreprocessor, 'gin': GINPreprocessor, 'gin_gwm': GINGWMPreprocessor, - 'gin_wle': GINPreprocessor, - 'gin_cwle': GINPreprocessor, - 'gin_gwle': GINPreprocessor, 'schnet': SchNetPreprocessor, 'weavenet': WeaveNetPreprocessor, 'relgcn': RelGCNPreprocessor, - 'relgcn_wle': RelGCNPreprocessor, - 'relgcn_cwle': RelGCNPreprocessor, - 'relgcn_gwle': RelGCNPreprocessor, 'rsgcn': RSGCNPreprocessor, 'rsgcn_gwm': RSGCNGWMPreprocessor, - 'rsgcn_wle': RSGCNPreprocessor, - 'rsgcn_cwle': RSGCNPreprocessor, - 'rsgcn_gwle': RSGCNPreprocessor, 'relgat': RelGATPreprocessor, 'relgcn_sparse': RelGCNSparsePreprocessor, - 'relgat_wle': RelGATPreprocessor, - 'relgat_cwle': RelGATPreprocessor, - 'relgat_gwle': RelGATPreprocessor, 'gin_sparse': GINSparsePreprocessor, 'gnnfilm': GNNFiLMPreprocessor, 'megnet': MEGNetPreprocessor, diff --git a/examples/molnet_wle/predict_molnet_wle.py b/examples/molnet_wle/predict_molnet_wle.py index 9ed2ddc0..94e50102 100644 --- a/examples/molnet_wle/predict_molnet_wle.py +++ b/examples/molnet_wle/predict_molnet_wle.py @@ -26,10 +26,11 @@ from chainer import functions as F from chainer_chemistry.links.scaler.standard_scaler import StandardScaler # NOQA from chainer_chemistry.models.prediction.graph_conv_predictor import GraphConvPredictor # NOQA +from train_molnet import dict_for_wles from train_molnet import dataset_part_filename from train_molnet import download_entire_dataset - +dict_for_wles() def parse_arguments(): # Lists of supported preprocessing methods/models. diff --git a/examples/molnet_wle/train_molnet_wle.py b/examples/molnet_wle/train_molnet_wle.py index 1d1e7db2..86752d0a 100644 --- a/examples/molnet_wle/train_molnet_wle.py +++ b/examples/molnet_wle/train_molnet_wle.py @@ -32,6 +32,38 @@ from chainer_chemistry.models.cwle.cwle_graph_conv_model import MAX_WLE_NUM +def dict_for_wles(): + wle_keys = ['nfp_wle', 'ggnn_wle', 'relgat_wle', 'relgcn_wle', 'rsgcn_wle', 'gin_wle', + 'nfp_cwle', 'ggnn_cwle', 'relgat_cwle', 'relgcn_cwle', 'rsgcn_cwle', 'gin_cwle', + 'nfp_gwle', 'ggnn_gwle', 'relgat_gwle', 'relgcn_gwle', 'rsgcn_gwle', 'gin_gwle'] + + from chainer_chemistry.dataset.converters.concat_mols import concat_mols + from chainer_chemistry.dataset.preprocessors.nfp_preprocessor import NFPPreprocessor + from chainer_chemistry.dataset.preprocessors.ggnn_preprocessor import GGNNPreprocessor + from chainer_chemistry.dataset.preprocessors.gin_preprocessor import GINPreprocessor + from chainer_chemistry.dataset.preprocessors.relgat_preprocessor import RelGATPreprocessor + from chainer_chemistry.dataset.preprocessors.relgcn_preprocessor import RelGCNPreprocessor + from chainer_chemistry.dataset.preprocessors.rsgcn_preprocessor import RSGCNPreprocessor + + for key in wle_keys: + converter_method_dict[key] = concat_mols + + if key.startswith('nfp'): + preprocess_method_dict[key] = NFPPreprocessor + elif key.startswith('ggnn'): + preprocess_method_dict[key] = GGNNPreprocessor + elif key.startswith('gin'): + preprocess_method_dict[key] = GINPreprocessor + elif key.startswith('relgcn'): + preprocess_method_dict[key] = RelGCNPreprocessor + elif key.startswith('rsgcn_wle'): + preprocess_method_dict[key] = RSGCNPreprocessor + elif key.startswith('relgat'): + preprocess_method_dict[key] = RelGATPreprocessor + else: + assert key in wle_keys # should be die +dict_for_wles() + def parse_arguments(): # Lists of supported preprocessing methods/models and datasets. method_list = ['nfp', 'ggnn', 'schnet', 'weavenet', 'rsgcn', 'relgcn',