Skip to content

Commit

Permalink
refactor: New res blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
ander-db committed Sep 21, 2024
1 parent 5ec77bc commit 65c7470
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 207 deletions.
73 changes: 0 additions & 73 deletions src/blocks/Res_Base.py

This file was deleted.

37 changes: 31 additions & 6 deletions src/blocks/Res_BatchNorm_ReLU_D.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from .Res_Base import ResBlock

import torch
import torch.nn as nn
from typing import Optional


class ResBlockBatchNorm(ResBlock):
class ResBlockBatchNorm(nn.Module):
"""
Residual block with Batch Normalization and ReLU activation.
Expand All @@ -20,7 +20,9 @@ class ResBlockBatchNorm(ResBlock):
|
Output
Attributes:
dropout_prob (float): Probability of an element to be zeroed in the dropout layers.
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
dropout_rate (float): Probability of an element to be zeroed in the dropout layers.
"""

def __init__(
Expand All @@ -29,11 +31,15 @@ def __init__(
out_channels: int,
dropout_rate: float = 0.1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.dropout_rate = dropout_rate
super(ResBlockBatchNorm, self).__init__(in_channels, out_channels)

def _build_main_path(self) -> nn.Sequential:
self.main_path = self._build_main_path()
self.residual_connection = self._build_residual_connection()

def _build_main_path(self) -> nn.Sequential:
return nn.Sequential(
nn.BatchNorm2d(self.in_channels),
nn.ReLU(),
Expand All @@ -54,3 +60,22 @@ def _build_main_path(self) -> nn.Sequential:
),
nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
)

def _build_residual_connection(self) -> Optional[nn.Module]:
if self.in_channels != self.out_channels:
return nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
return None

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ResBlockBatchNorm.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after applying the residual block.
"""
main_output = self.main_path(x)
residual = self.residual_connection(x) if self.residual_connection else x
return main_output + residual
54 changes: 49 additions & 5 deletions src/blocks/Res_GroupNorm_SiLU_D.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .Res_Base import ResBlock

from typing import Optional
import torch
import torch.nn as nn
import warnings


class ResBlockGroupNorm(ResBlock):
class ResBlockGroupNorm(nn.Module):
"""
Residual block with Group Normalization and SiLU activation.
Expand All @@ -22,8 +22,10 @@ class ResBlockGroupNorm(ResBlock):
Output
Attributes:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
n_groups (int): Number of groups for Group Normalization.
dropout_prob (float): Probability of an element to be zeroed in the dropout layers.
dropout_rate (float): Probability of an element to be zeroed in the dropout layers.
"""

def __init__(
Expand All @@ -33,9 +35,14 @@ def __init__(
n_groups: int = 32,
dropout_rate: float = 0.1,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.n_groups = n_groups
self.dropout_rate = dropout_rate
super().__init__(in_channels, out_channels)

self.main_path = self._build_main_path()
self.residual_connection = self._build_residual_connection()

def _adjust_groups(self, channels: int) -> int:
if self.n_groups > channels:
Expand All @@ -45,11 +52,29 @@ def _adjust_groups(self, channels: int) -> int:
UserWarning,
)
return channels

if channels % self.n_groups != 0:
# Encontrar el mayor divisor de channels que sea menor o igual a self.n_groups
new_n_groups = self.n_groups
while new_n_groups > 1:
if channels % new_n_groups == 0:
break
new_n_groups -= 1

warnings.warn(
f"Number of channels ({channels}) is not divisible by the number of groups ({self.n_groups}). "
f"Setting number of groups to {new_n_groups}.",
UserWarning,
)
return new_n_groups

return self.n_groups

def _build_main_path(self) -> nn.Sequential:
n_groups_first = self._adjust_groups(self.in_channels)
print(f"Number of groups first: {n_groups_first}")
n_groups_second = self._adjust_groups(self.out_channels)
print(f"Number of groups second: {n_groups_second}")

return nn.Sequential(
nn.GroupNorm(n_groups_first, self.in_channels),
Expand All @@ -61,3 +86,22 @@ def _build_main_path(self) -> nn.Sequential:
nn.Conv2d(self.out_channels, self.out_channels, kernel_size=3, padding=1),
nn.Dropout(self.dropout_rate) if self.dropout_rate > 0 else nn.Identity(),
)

def _build_residual_connection(self) -> Optional[nn.Module]:
if self.in_channels != self.out_channels:
return nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1)
return None

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the ResBlockGroupNorm.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Output after applying the residual block.
"""
main_output = self.main_path(x)
residual = self.residual_connection(x) if self.residual_connection else x
return main_output + residual
Loading

0 comments on commit 65c7470

Please sign in to comment.