Skip to content

Commit

Permalink
Merge pull request #16 from smsharma/experiments
Browse files Browse the repository at this point in the history
Update default hyperparams and rescaled losses for galaxies benchmark
  • Loading branch information
smsharma authored Dec 5, 2024
2 parents 883dd8d + 5ae5b9e commit c373b3a
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 20 deletions.
48 changes: 29 additions & 19 deletions benchmarks/galaxies/train_cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@
from models.segnn import SEGNN
from models.nequip import NequIP
from models.diffpool import DiffPool
from models.pointnet import PointNet

from benchmarks.galaxies.dataset import get_halo_dataset, STD_PARAMS_DICT

from benchmarks.galaxies.dataset import get_halo_dataset

MLP_PARAMS = {
"feature_sizes": [128, 128, 128, 2],
Expand All @@ -50,7 +52,7 @@
GNN_PARAMS = {
"d_hidden": 128,
"n_layers": 3,
"message_passing_steps": 3,
"message_passing_steps": 8,
"message_passing_agg": "mean",
"activation": "gelu",
"norm": "none",
Expand All @@ -63,7 +65,7 @@
}

EGNN_PARAMS = {
"message_passing_steps": 3,
"message_passing_steps": 4,
"d_hidden": 128,
"n_layers": 3,
"activation": "gelu",
Expand All @@ -86,7 +88,7 @@
"combine_hierarchies_method": "mean",
"use_edge_features": True,
"task": "graph",
"mlp_readout_widths": [8, 2, 2],
"mlp_readout_widths": [8, 2],
}

POINTNET_PARAMS = {
Expand All @@ -97,19 +99,20 @@
"combine_hierarchies_method": "mean",
"use_edge_features": True,
"task": "graph",
"mlp_readout_widths": [8, 2, 2],
"mlp_readout_widths": [4, 2, 2],
"n_outputs": 2,
"k": 10,
}

SEGNN_PARAMS = {
"d_hidden": 128,
"n_layers": 3,
"message_passing_steps": 3,
"message_passing_steps": 8,
"message_passing_agg": "mean",
"scalar_activation": "gelu",
"gate_activation": "sigmoid",
"task": "graph",
"n_outputs": 2,
"output_irreps": e3nn.Irreps("1x0e"),
"readout_agg": "mean",
"mlp_readout_widths": (4, 2, 2),
"l_max_hidden": 2,
Expand All @@ -121,8 +124,7 @@
"d_hidden": 128,
"l_max":1,
"sphharm_norm": 'integral',
"irreps_out": e3nn.Irreps("1x0e"),
"message_passing_steps": 3,
"message_passing_steps": 4,
"n_layers": 3,
"message_passing_agg": "mean",
"readout_agg": "mean",
Expand Down Expand Up @@ -211,7 +213,6 @@ def __call__(self, x):


def loss_mse(pred_batch, cosmo_batch):
# return np.mean((pred_batch - cosmo_batch) ** 2)
return np.mean((pred_batch - cosmo_batch) ** 2, axis=0)


Expand Down Expand Up @@ -312,7 +313,7 @@ def run_expt(
use_tpcf="none",
n_steps=1000,
batch_size=32,
n_train=1248,
n_train=2048,
n_val=512,
n_test=512,
learning_rate=3e-4,
Expand Down Expand Up @@ -531,10 +532,14 @@ def run_expt(
# if early_stop.should_stop:
# print(f'Met early stopping criteria, breaking at epoch {step}')
# break

# Rescale losses by multiplying by variances, since we normalize during training
loss_rescaling = np.array([STD_PARAMS_DICT[param]**2 for param in target])

print(
"Training done.\n"
f"Final checkpoint test loss {test_loss_ckp}.\n"
f"Final rescaled test loss: {test_loss_ckp * loss_rescaling}.\n"
)

if plotting:
Expand Down Expand Up @@ -586,12 +591,16 @@ def main(model, feats, lr, decay, steps, batch_size, n_train, use_tpcf, k, data_
params = MLP_PARAMS
elif model == "GNN":
params = GNN_PARAMS
n_radial_basis = 64
elif model == "EGNN":
params = EGNN_PARAMS
n_radial_basis = 32
elif model == "SEGNN":
params = SEGNN_PARAMS
n_radial_basis = 32
elif model == "NequIP":
params = NEQUIP_PARAMS
n_radial_basis = 64
elif model == "DiffPool":
DIFFPOOL_PARAMS["gnn_kwargs"] = {
"d_hidden": 64,
Expand All @@ -602,14 +611,14 @@ def main(model, feats, lr, decay, steps, batch_size, n_train, use_tpcf, k, data_
DIFFPOOL_PARAMS["d_hidden"] = DIFFPOOL_PARAMS["gnn_kwargs"]["d_hidden"]
params = DIFFPOOL_PARAMS
elif model == "PointNet":
POINTNET_PARAMS["gnn_kwargs"] = {
"d_hidden": 64,
"n_layers": 4,
"message_passing_steps": 3,
"task": "node",
}
GNN_PARAMS["n_outputs"] = GNN_PARAMS["d_hidden"]
GNN_PARAMS["message_passing_steps"] = 3
GNN_PARAMS["n_layers"] = 4
GNN_PARAMS["task"] = "node"
POINTNET_PARAMS["gnn_kwargs"] = GNN_PARAMS
POINTNET_PARAMS["d_hidden"] = POINTNET_PARAMS["gnn_kwargs"]["d_hidden"]
params = POINTNET_PARAMS
n_radial_basis = 32
else:
raise NotImplementedError

Expand All @@ -623,7 +632,8 @@ def main(model, feats, lr, decay, steps, batch_size, n_train, use_tpcf, k, data_
n_steps=steps,
batch_size=batch_size,
n_train=n_train,
use_tpcf=use_tpcf
use_tpcf=use_tpcf,
n_radial_basis=n_radial_basis
)

if __name__ == "__main__":
Expand All @@ -636,7 +646,7 @@ def main(model, feats, lr, decay, steps, batch_size, n_train, use_tpcf, k, data_
parser.add_argument("--decay", type=float, help="Weight decay", default=1e-5)
parser.add_argument("--steps", type=int, help="Number of steps", default=5000)
parser.add_argument("--batch_size", type=int, help="Batch size", default=32)
parser.add_argument("--n_train", type=int, help="Number of training samples", default=1248)
parser.add_argument("--n_train", type=int, help="Number of training samples", default=2048)
parser.add_argument(
"--use_tpcf",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion models/pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class PointNet(nn.Module):
def __call__(self, x, return_assignments=True):
# If graph prediction task, collect pooled embeddings at each hierarchy level
if self.task == "graph":
x_pool = jnp.zeros((self.n_downsamples, self.gnn_kwargs['d_output']))
x_pool = jnp.zeros((self.n_downsamples, self.gnn_kwargs['n_outputs']))

assignments = []
node_pos = x.nodes[..., :3]
Expand Down

0 comments on commit c373b3a

Please sign in to comment.