Skip to content

Commit

Permalink
Convert DAGTensorGraph to DAGModel
Browse files Browse the repository at this point in the history
  • Loading branch information
mlgill committed Mar 25, 2018
1 parent ccd2f0f commit f20016c
Show file tree
Hide file tree
Showing 7 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion deepchem/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from deepchem.models.tensorgraph.robust_multitask import RobustMultitaskClassifier
from deepchem.models.tensorgraph.robust_multitask import RobustMultitaskRegressor
from deepchem.models.tensorgraph.progressive_multitask import ProgressiveMultitaskRegressor, ProgressiveMultitaskClassifier
from deepchem.models.tensorgraph.models.graph_models import WeaveModel, DTNNModel, DAGTensorGraph, GraphConvModel, MPNNModel
from deepchem.models.tensorgraph.models.graph_models import WeaveModel, DTNNModel, DAGModel, GraphConvModel, MPNNModel
from deepchem.models.tensorgraph.models.symmetry_function_regression import BPSymmetryFunctionRegression, ANIRegression

from deepchem.models.tensorgraph.models.seqtoseq import SeqToSeq
Expand Down
6 changes: 3 additions & 3 deletions deepchem/models/tensorgraph/models/graph_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def predict(self, dataset, transformers=[], outputs=None):
return undo_transforms(retval, transformers)


class DAGTensorGraph(TensorGraph):
class DAGModel(TensorGraph):

def __init__(self,
n_tasks,
Expand Down Expand Up @@ -390,7 +390,7 @@ def __init__(self,
self.n_graph_feat = n_graph_feat
self.n_outputs = n_outputs
self.mode = mode
super(DAGTensorGraph, self).__init__(**kwargs)
super(DAGModel, self).__init__(**kwargs)
self.build_graph()

def build_graph(self):
Expand Down Expand Up @@ -508,7 +508,7 @@ def default_generator(self,
yield feed_dict

def predict_on_generator(self, generator, transformers=[], outputs=None):
out = super(DAGTensorGraph, self).predict_on_generator(
out = super(DAGModel, self).predict_on_generator(
generator, transformers=[], outputs=outputs)
if outputs is None:
outputs = self.outputs
Expand Down
2 changes: 1 addition & 1 deletion deepchem/models/tests/test_overfit.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def test_tensorgraph_DAG_singletask_regression_overfit(self):
transformer = dc.trans.DAGTransformer(max_atoms=50)
dataset = transformer.transform(dataset)

model = dc.models.DAGTensorGraph(
model = dc.models.DAGModel(
n_tasks,
max_atoms=50,
n_atom_feat=n_feat,
Expand Down
4 changes: 2 additions & 2 deletions deepchem/molnet/run_benchmark_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def model_builder(model_dir_logreg):
test_dataset.reshard(reshard_size)
test_dataset = transformer.transform(test_dataset)

model = deepchem.models.DAGTensorGraph(
model = deepchem.models.DAGModel(
len(tasks),
max_atoms=max_atoms,
n_atom_feat=n_features,
Expand Down Expand Up @@ -558,7 +558,7 @@ def benchmark_regression(train_dataset,
test_dataset.reshard(reshard_size)
test_dataset = transformer.transform(test_dataset)

model = deepchem.models.DAGTensorGraph(
model = deepchem.models.DAGModel(
len(tasks),
max_atoms=max_atoms,
n_atom_feat=n_features,
Expand Down
2 changes: 1 addition & 1 deletion examples/delaney/delaney_tensorgraph_DAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
valid_dataset.reshard(reshard_size)
valid_dataset = transformer.transform(valid_dataset)

model = dc.models.DAGTensorGraph(
model = dc.models.DAGModel(
len(delaney_tasks),
max_atoms=max_atoms,
n_atom_feat=n_atom_feat,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
valid_dataset.reshard(reshard_size)
valid_dataset = transformer.transform(valid_dataset)

model = dc.models.DAGTensorGraph(
model = dc.models.DAGModel(
len(permeability_tasks),
max_atoms=max_atoms,
n_atom_feat=n_atom_feat,
Expand Down
2 changes: 1 addition & 1 deletion examples/tox21/tox21_tensorgraph_DAG.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
n_atom_feat = 75
batch_size = 64

model = dc.models.DAGTensorGraph(
model = dc.models.DAGModel(
len(tox21_tasks),
max_atoms=max_atoms,
n_atom_feat=n_atom_feat,
Expand Down

0 comments on commit f20016c

Please sign in to comment.