Skip to content

Commit

Permalink
Merge branch 'main' into rdkit_2d_features
Browse files Browse the repository at this point in the history
  • Loading branch information
hwpang authored Jun 5, 2024
2 parents 41820ba + 3352aee commit e57d674
Show file tree
Hide file tree
Showing 73 changed files with 271 additions and 269 deletions.
10 changes: 4 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,10 @@ jobs:
steps:
# clone the repo, run black and flake8 on it
- uses: actions/checkout@v4
- run: python -m pip install black==23.* flake8
- run: python -m pip install black==23.* flake8 isort
- run: black --check .
- run: flake8 .
- run: isort --check .

test:
name: Execute Tests
Expand Down Expand Up @@ -94,11 +95,8 @@ jobs:
- name: Test notebooks
shell: bash -l {0}
run: |
pytest --nbmake examples/training.ipynb
pytest --nbmake examples/predicting.ipynb
pytest --nbmake examples/convert_v1_to_v2.ipynb
pytest --nbmake examples/training_regression_multicomponent.ipynb
pytest --nbmake examples/predicting_regression_multicomponent.ipynb
python -m pip install matplotlib
pytest --no-cov -v --nbmake examples/*.ipynb
conda-test:
name: Execute Tests with Conda-based Install
Expand Down
4 changes: 2 additions & 2 deletions chemprop/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import data, featurizers, models, nn, utils, conf, exceptions, schedulers
from . import data, exceptions, featurizers, models, nn, schedulers, utils

__all__ = ["data", "featurizers", "models", "nn", "utils", "conf", "exceptions", "schedulers"]
__all__ = ["data", "featurizers", "models", "nn", "utils", "exceptions", "schedulers"]

__version__ = "2.0.0"
4 changes: 2 additions & 2 deletions chemprop/cli/common.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from argparse import ArgumentError, ArgumentParser, Namespace
import logging
from argparse import ArgumentParser, Namespace, ArgumentError
from pathlib import Path

from chemprop.cli.utils import LookupAction
from chemprop.cli.utils.args import uppercase
from chemprop.featurizers import MoleculeFeaturizerRegistry, RxnMode, AtomFeatureMode
from chemprop.featurizers import AtomFeatureMode, MoleculeFeaturizerRegistry, RxnMode

logger = logging.getLogger(__name__)

Expand Down
2 changes: 1 addition & 1 deletion chemprop/cli/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from argparse import ArgumentError, ArgumentParser, Namespace
import sys
import logging
from pathlib import Path
import sys

from chemprop.cli.utils import Subcommand
from chemprop.utils.v1_to_v2 import convert_model_file_v1_to_v2
Expand Down
4 changes: 2 additions & 2 deletions chemprop/cli/fingerprint.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import sys
from argparse import ArgumentError, ArgumentParser, Namespace
import logging
from pathlib import Path
import sys

import numpy as np
import pandas as pd
Expand Down
82 changes: 48 additions & 34 deletions chemprop/cli/hpopt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import json
import logging
import sys
from argparse import ArgumentParser, Namespace
from copy import deepcopy
import json
import logging
from pathlib import Path
import sys

import torch
from lightning import pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping
import numpy as np
import torch

from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
from chemprop.cli.train import (
Expand Down Expand Up @@ -40,21 +41,21 @@
prepare_trainer,
)
from ray.train.torch import TorchTrainer
from ray.tune.schedulers import ASHAScheduler
from ray.tune.schedulers import ASHAScheduler, FIFOScheduler

DEFAULT_SEARCH_SPACE = {
"activation": tune.choice(categories=list(Activation.keys())),
"aggregation": tune.choice(categories=list(AggregationRegistry.keys())),
"aggregation_norm": tune.quniform(lower=1, upper=200, q=1),
"batch_size": tune.loguniform(lower=16, upper=256, base=2),
"depth": tune.quniform(lower=2, upper=6, q=1),
"dropout": tune.choice([tune.choice([0.0]), tune.quniform(lower=0.05, upper=0.4, q=0.05)]),
"ffn_hidden_dim": tune.quniform(lower=300, upper=2400, q=100),
"ffn_num_layers": tune.quniform(lower=1, upper=3, q=1),
"final_lr_ratio": tune.loguniform(lower=1e-4, upper=1),
"message_hidden_dim": tune.quniform(lower=300, upper=2400, q=100),
"init_lr_ratio": tune.loguniform(lower=1e-4, upper=1),
"max_lr": tune.loguniform(lower=1e-6, upper=1e-2),
"batch_size": tune.choice([16, 32, 64, 128, 256]),
"depth": tune.qrandint(lower=2, upper=6, q=1),
"dropout": tune.choice([0.0] * 8 + list(np.arange(0.05, 0.45, 0.05))),
"ffn_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
"ffn_num_layers": tune.qrandint(lower=1, upper=3, q=1),
"final_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
"message_hidden_dim": tune.qrandint(lower=300, upper=2400, q=100),
"init_lr_ratio": tune.loguniform(lower=1e-2, upper=1),
"max_lr": tune.loguniform(lower=1e-4, upper=1e-2),
"warmup_epochs": None,
}
except ImportError:
Expand All @@ -66,11 +67,11 @@
except ImportError:
NO_HYPEROPT = True

# NO_OPTUNA = False
# try:
# from ray.tune.search.optuna import OptunaSearch
# except ImportError:
# NO_OPTUNA = True
NO_OPTUNA = False
try:
from ray.tune.search.optuna import OptunaSearch
except ImportError:
NO_OPTUNA = True


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -143,11 +144,18 @@ def add_hpopt_args(parser: ArgumentParser) -> ArgumentParser:

raytune_args.add_argument(
"--raytune-search-algorithm",
choices=["random", "hyperopt"], # , "optuna"],
choices=["random", "hyperopt", "optuna"],
default="hyperopt",
help="Passed to Ray Tune TuneConfig to control search algorithm",
)

raytune_args.add_argument(
"--raytune-trial-scheduler",
choices=["FIFO", "AsyncHyperBand"],
default="FIFO",
help="Passed to Ray Tune TuneConfig to control trial scheduler",
)

raytune_args.add_argument(
"--raytune-num-workers",
type=int,
Expand Down Expand Up @@ -227,8 +235,8 @@ def process_hpopt_args(args: Namespace) -> Namespace:


def build_search_space(search_parameters: list[str], train_epochs: int) -> dict:
if "warmup_epochs" not in SEARCH_SPACE and "warmup_epochs" in search_parameters:
SEARCH_SPACE["warmup_epochs"] = tune.quniform(lower=1, upper=train_epochs // 2, q=1)
if "warmup_epochs" in search_parameters and SEARCH_SPACE.get("warmup_epochs", None) is None:
SEARCH_SPACE["warmup_epochs"] = tune.qrandint(lower=1, upper=train_epochs // 2, q=1)

return {param: SEARCH_SPACE[param] for param in search_parameters}

Expand Down Expand Up @@ -277,7 +285,7 @@ def train_model(config, args, train_dset, val_dset, logger, output_transform, in
devices=args.devices,
max_epochs=args.epochs,
gradient_clip_val=args.grad_clip,
strategy=RayDDPStrategy(find_unused_parameters=True),
strategy=RayDDPStrategy(),
callbacks=[RayTrainReportCallback(), early_stopping],
plugins=[RayLightningEnvironment()],
deterministic=args.pytorch_seed is not None,
Expand All @@ -289,11 +297,17 @@ def train_model(config, args, train_dset, val_dset, logger, output_transform, in
def tune_model(
args, train_dset, val_dset, logger, monitor_mode, output_transform, input_transforms
):
scheduler = ASHAScheduler(
max_t=args.epochs,
grace_period=min(args.raytune_grace_period, args.epochs),
reduction_factor=args.raytune_reduction_factor,
)
match args.raytune_trial_scheduler:
case "FIFO":
scheduler = FIFOScheduler()
case "AsyncHyperBand":
scheduler = ASHAScheduler(
max_t=args.epochs,
grace_period=min(args.raytune_grace_period, args.epochs),
reduction_factor=args.raytune_reduction_factor,
)
case _:
raise ValueError(f"Invalid trial scheduler! got: {args.raytune_trial_scheduler}.")

scaling_config = ScalingConfig(
num_workers=args.raytune_num_workers, use_gpu=args.raytune_use_gpu
Expand Down Expand Up @@ -331,13 +345,13 @@ def tune_model(
n_initial_points=args.hyperopt_n_initial_points,
random_state_seed=args.hyperopt_random_state_seed,
)
# case "optuna":
# if NO_OPTUNA:
# raise ImportError(
# "OptunaSearch requires optuna to be installed. Use 'pip -U install optuna' to install."
# )
case "optuna":
if NO_OPTUNA:
raise ImportError(
"OptunaSearch requires optuna to be installed. Use 'pip -U install optuna' to install."
)

# search_alg = OptunaSearch()
search_alg = OptunaSearch()

tune_config = tune.TuneConfig(
metric="val_loss",
Expand Down
12 changes: 6 additions & 6 deletions chemprop/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from configargparse import ArgumentParser
import logging
import sys
from pathlib import Path
import sys

from chemprop.cli.train import TrainSubcommand
from chemprop.cli.predict import PredictSubcommand
from configargparse import ArgumentParser

from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW
from chemprop.cli.convert import ConvertSubcommand
from chemprop.cli.fingerprint import FingerprintSubcommand
from chemprop.cli.hpopt import HpoptSubcommand

from chemprop.cli.predict import PredictSubcommand
from chemprop.cli.train import TrainSubcommand
from chemprop.cli.utils import pop_attr
from chemprop.cli.conf import LOG_DIR, LOG_LEVELS, NOW

logger = logging.getLogger(__name__)

Expand Down
12 changes: 5 additions & 7 deletions chemprop/cli/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,17 @@
import logging
from pathlib import Path
import sys
import pandas as pd

from lightning import pytorch as pl
import pandas as pd
import torch

from chemprop import data
from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset
from chemprop.models import load_model
from chemprop.nn.loss import LossFunctionRegistry
from chemprop.nn.predictors import MulticlassClassificationFFN
from chemprop.models import load_model

from chemprop.cli.utils import Subcommand, build_data_from_files, make_dataset
from chemprop.cli.common import add_common_args, process_common_args, validate_common_args


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -190,7 +188,7 @@ def make_prediction_for_model(
no_header_row=args.no_header_row,
smiles_cols=args.smiles_columns,
rxn_cols=args.reaction_columns,
target_cols=None,
target_cols=[],
ignore_cols=None,
splits_col=None,
weight_col=None,
Expand Down
12 changes: 6 additions & 6 deletions chemprop/cli/train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from copy import deepcopy
import json
import logging
import sys
from copy import deepcopy
from pathlib import Path
import sys

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from configargparse import ArgumentError, ArgumentParser, Namespace
from lightning import pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

from chemprop.cli.common import add_common_args, process_common_args, validate_common_args
from chemprop.cli.conf import NOW
Expand Down
6 changes: 3 additions & 3 deletions chemprop/cli/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from .args import bounded
from .actions import LookupAction
from .args import bounded
from .command import Subcommand
from .parsing import (
build_data_from_files,
get_column_names,
make_datapoints,
make_dataset,
get_column_names,
parse_indices,
)
from .utils import pop_attr, _pop_attr, _pop_attr_d, validate_loss_function
from .utils import _pop_attr, _pop_attr_d, pop_attr, validate_loss_function

__all__ = [
"bounded",
Expand Down
15 changes: 3 additions & 12 deletions chemprop/cli/utils/actions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from argparse import Action, ArgumentParser, Namespace
from typing import Any, Mapping, Sequence
from argparse import _StoreAction
from typing import Any, Mapping


def LookupAction(obj: Mapping[str, Any]):
class LookupAction_(Action):
class LookupAction_(_StoreAction):
def __init__(self, option_strings, dest, default=None, choices=None, **kwargs):
if default not in obj.keys() and default is not None:
raise ValueError(
Expand All @@ -16,13 +16,4 @@ def __init__(self, option_strings, dest, default=None, choices=None, **kwargs):

super().__init__(option_strings, dest, **kwargs)

def __call__(
self,
parser: ArgumentParser,
namespace: Namespace,
values: str | Sequence[Any] | None,
option_string: str | None = None,
):
setattr(namespace, self.dest, values)

return LookupAction_
2 changes: 1 addition & 1 deletion chemprop/cli/utils/command.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from argparse import ArgumentParser, _SubParsersAction, Namespace
from argparse import ArgumentParser, Namespace, _SubParsersAction


class Subcommand(ABC):
Expand Down
20 changes: 9 additions & 11 deletions chemprop/cli/utils/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

from chemprop.data.datapoints import MoleculeDatapoint, ReactionDatapoint
from chemprop.data.datasets import MoleculeDataset, ReactionDataset
from chemprop.featurizers.atom import get_multi_hot_atom_featurizer
from chemprop.featurizers.base import VectorFeaturizer
from chemprop.featurizers.molgraph import (
CondensedGraphOfReactionFeaturizer,
SimpleMoleculeMolGraphFeaturizer,
)
from chemprop.featurizers.atom import get_multi_hot_atom_featurizer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -50,11 +50,10 @@ def parse_csv(

if target_cols is None:
target_cols = list(
set(df.columns)
- set(input_cols)
- set(ignore_cols or [])
- set(splits_col or [])
- set(weight_col or [])
column
for column in df.columns
if column
not in set(input_cols + (ignore_cols or []) + (splits_col or []) + (weight_col or []))
)

Y = df[target_cols]
Expand Down Expand Up @@ -94,11 +93,10 @@ def get_column_names(

if target_cols is None:
target_cols = list(
set(df.columns)
- set(input_cols)
- set(ignore_cols or [])
- set(splits_col or [])
- set(weight_col or [])
column
for column in df.columns
if column
not in set(input_cols + (ignore_cols or []) + (splits_col or []) + (weight_col or []))
)

return input_cols + target_cols
Expand Down
Loading

0 comments on commit e57d674

Please sign in to comment.