Skip to content

Commit

Permalink
more changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ander-db committed Oct 17, 2024
1 parent 8224656 commit 3852c54
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 104 deletions.
28 changes: 20 additions & 8 deletions src/blocks/Res_GroupNorm_SiLU_D.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,28 @@ def _build_main_path(self) -> nn.Sequential:
n_groups_second = self._adjust_groups(self.out_channels)

return nn.Sequential(
nn.GroupNorm(n_groups_first, self.in_channels, eps=1e-3),
nn.SiLU(),
nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
nn.GroupNorm(n_groups_second, self.out_channels, eps=1e-3),
nn.SiLU(),
nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
nn.GroupNorm(n_groups_first, self.in_channels, eps=1e-3),
nn.SiLU(),
nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
nn.GroupNorm(n_groups_second, self.out_channels, eps=1e-3),
nn.SiLU(),
nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
)

# Instead of group, silu, conv, dropout, group, silu, conv, dropout
# We'll use conv, group, silu, dropout, conv, group

#return nn.Sequential(
# nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, padding=1),
# nn.GroupNorm(n_groups_second, self.out_channels, eps=1e-3),
# nn.SiLU(),
# nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
# nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
# nn.GroupNorm(n_groups_second, self.out_channels, eps=1e-3),
#)

def _build_residual_connection(self) -> Optional[nn.Module]:
if self.in_channels != self.out_channels:
return nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
Expand Down
109 changes: 80 additions & 29 deletions src/callbacks/vision.py → src/callbacks/metrics.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import torch
import torch.nn as nn
import lightning as L
from typing import Any
from lightning.pytorch.callbacks import Callback
Expand All @@ -9,7 +10,7 @@

class LogVisionMetricsBase(Callback):
"""
Callback to log vision metrics (SSIM, PSNR) for the denoising model.
Callback to log metrics (PSNR, SSIM, MAE, MSE) for the vision models.
Args:
- log_every_n_epochs (int): Log the metrics every n epochs
Expand All @@ -29,7 +30,7 @@ def __init__(self, batch_idx: int = 0, log_every_n_epochs: int = -1):

@torch.no_grad()
def calc_preds(self, pl_module: "L.LightningModule", ref: torch.Tensor):
pass
raise NotImplementedError

@torch.no_grad()
def on_train_batch_end(
Expand All @@ -40,7 +41,15 @@ def on_train_batch_end(
batch: Any,
batch_idx: int,
) -> None:
self.calc_batch_metrics(pl_module, outputs, batch, batch_idx, prefix="train")
if self.batch_idx != -1 and batch_idx != self.batch_idx:
return

preds = torch.cat(pl_module.train_preds)
target = torch.cat(pl_module.train_targets)

self._log_metrics(
pl_module=pl_module, target=target, preds=preds, prefix="train"
)

@torch.no_grad()
def on_validation_batch_end(
Expand All @@ -52,7 +61,14 @@ def on_validation_batch_end(
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self.calc_batch_metrics(pl_module, outputs, batch, batch_idx, prefix="val")

if self.batch_idx != -1 and batch_idx != self.batch_idx:
return

preds = torch.cat(pl_module.val_preds)
target = torch.cat(pl_module.val_targets)

self._log_metrics(pl_module=pl_module, target=target, preds=preds, prefix="val")

@torch.no_grad()
def on_test_batch_end(
Expand All @@ -64,57 +80,92 @@ def on_test_batch_end(
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
self.calc_batch_metrics(pl_module, outputs, batch, batch_idx, prefix="test")

@torch.no_grad()
def calc_batch_metrics(self, pl_module, outputs, batch, batch_idx, prefix):
if self.batch_idx != -1 and batch_idx != self.batch_idx:
return

if (
self.log_every_n_epochs != -1
and pl_module.current_epoch % self.log_every_n_epochs != 0
):
return

ref, x = batch
ref = ref.to(pl_module.device)
x = x.to(pl_module.device)

# Log metrics
preds = self.calc_preds(pl_module=pl_module, ref=ref)
preds = torch.cat(pl_module.test_preds)
target = torch.cat(pl_module.test_targets)

self._log_metrics(pl_module=pl_module, target=x, preds=preds, prefix=prefix)
self._log_metrics(
pl_module=pl_module, target=target, preds=preds, prefix="test"
)

@torch.no_grad()
def _calc_psnr(self, preds, targets):
return peak_signal_noise_ratio(preds, targets)
psnr = peak_signal_noise_ratio(
preds,
targets,
reduction="none",
data_range=(-1, 1),
dim=(1, 2, 3),
)

mean_psnr = psnr.mean()
std_psnr = psnr.std()

return mean_psnr, std_psnr

@torch.no_grad()
def _calc_ssim(self, preds, target):
ssim = structural_similarity_index_measure(preds, target)
ssim = structural_similarity_index_measure(
preds, target, reduction="none", data_range=(-1, 1)
)
if isinstance(ssim, tuple):
ssim = ssim[0]
return ssim

mean_ssim = ssim.mean()
std_ssim = ssim.std()
return mean_ssim, std_ssim

@torch.no_grad()
def _calc_mae(self, preds, target):
return torch.mean(torch.abs(preds - target))
_mae = nn.L1Loss(reduction="none")
mae = _mae(preds, target).mean(dim=(1, 2, 3))

mean_mae = mae.mean()
std_mae = mae.std()

return mean_mae, std_mae

@torch.no_grad()
def _calc_mse(self, preds, target):
_mse = nn.MSELoss(reduction="none")
mse = _mse(preds, target).mean(dim=(1, 2, 3))

mean_mse = mse.mean()
std_mse = mse.std()

return mean_mse, std_mse

@torch.no_grad()
def _log_metrics(self, *, pl_module, target, preds, prefix):
psnr = self._calc_psnr(preds, target)
ssim = self._calc_ssim(preds, target)
mae = self._calc_mae(preds, target)
# Calculate metrics
psnr, std_psnr = self._calc_psnr(preds, target)
ssim, std_ssim = self._calc_ssim(preds, target)
mae, std_mae = self._calc_mae(preds, target)
mse, std_mse = self._calc_mse(preds, target)

# PSNR
pl_module.log(f"{prefix}/psnr", psnr)
pl_module.log(f"{prefix}/psnr_std", std_psnr)

# SSIM
pl_module.log(f"{prefix}/ssim", ssim)
pl_module.log(f"{prefix}/ssim_std", std_ssim)

# MAE
pl_module.log(f"{prefix}/mae", mae)
pl_module.log(f"{prefix}/mae_std", std_mae)

# MSE
pl_module.log(f"{prefix}/mse", mse)
pl_module.log(f"{prefix}/mse_std", std_mse)


class LogVisionMetricsDDPM(LogVisionMetricsBase):
"""
Callback to log vision metrics (SSIM, PSNR, LPIPS) for the DDPM denoising model.
Callback to log metrics (PSNR, SSIM, MAE, MSE) for the DDPM denoising model.
Args:
- log_every_n_epochs (int): Log the metrics every n epochs
Expand All @@ -131,7 +182,7 @@ def calc_preds(self, pl_module: "L.LightningModule", ref: torch.Tensor):

class LogVisionMetricsUNet(LogVisionMetricsBase):
"""
Callback to log vision metrics (SSIM, PSNR, LPIPS) for the DDPM denoising model.
Callback to log metrics (PSNR, SSIM, MAE, MSE) for the UNet denoising model.
Args:
- log_every_n_epochs (int): Log the metrics every n epochs
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions src/ddpm/ddpm_A.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def _build_network(self):
def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
pe_emb = self.time_embedding(t.detach())
print(f'[DEBUG] pe_emb.shape: {pe_emb.shape}')

return self.model(x, pe_emb)

Expand Down Expand Up @@ -192,6 +193,7 @@ def sample(self, ref: torch.Tensor) -> torch.Tensor:
# Repite for each timestep
for i in range(self.diffusion_steps - 1, -1, -1):
model_input = torch.cat([x, ref], dim=1)

t = torch.full((ref.shape[0],), i, device=device, dtype=torch.long)
predicted_noise = self.forward(model_input, t)

Expand Down
Loading

0 comments on commit 3852c54

Please sign in to comment.