diff --git a/main.py b/main.py index 0399de0..414a406 100644 --- a/main.py +++ b/main.py @@ -17,6 +17,7 @@ from transformers import HfArgumentParser from scripts.model.transformer_single_layer import my_transformer +from scripts.model.linear_model import LinearModel from scripts.data.dataset_mnist_8_8 import DatasetMnist from scripts.config.arguments import Arguments @@ -89,7 +90,8 @@ def main(): data_dim = args.data_split_dim*args.data_split_dim seq_length = int(args.data_dimension**2 / data_dim) # through the data_split_dim can split the mnist picture to sub blocks, the number of sub blocks stands for the transformers' sequence length - tModel = my_transformer(data_dim, data_dim, seq_length, args.n_heads, data_dim, args.num_classes).to(device) + # tModel = my_transformer(data_dim, data_dim, seq_length, args.n_heads, data_dim, args.num_classes).to(device) + tModel = LinearModel(data_dim, data_dim, n_seq=seq_length, out_dim=args.num_classes).to(device) # optimizer = optim.SGD(tModel.parameters(),lr=lr,momentum=mom) optimizer = optim.Adam(tModel.parameters(), lr=args.lr) loss_fn = nn.CrossEntropyLoss() diff --git a/scripts/model/linear_model.py b/scripts/model/linear_model.py new file mode 100644 index 0000000..28d0fc3 --- /dev/null +++ b/scripts/model/linear_model.py @@ -0,0 +1,52 @@ +import torch +import torch.nn as nn + + +class PoswiseFeedForward(nn.Module): + def __init__(self, d_dim, mid_dim, n_seq, bias=True): + super(PoswiseFeedForward,self).__init__() + # self.L1 = torch.nn.utils.weight_norm(nn.Linear(d_dim, mid_dim, bias=bias)) + # self.L2 = torch.nn.utils.weight_norm(nn.Linear(mid_dim, d_dim, bias=bias)) + self.L1 = nn.Linear(d_dim, mid_dim, bias=bias) + self.L2 = nn.Linear(mid_dim, d_dim, bias=bias) + #self.LN = nn.LayerNorm(d_dim, elementwise_affine=False) + self.LN = nn.LayerNorm([n_seq, d_dim], elementwise_affine=False) + self.relu = nn.ReLU() + + def forward(self, inputs): + residual = inputs + output = self.L1(inputs) + output = self.relu(output) + output = self.L2(output) + return self.LN(output + residual) + +class Mean(nn.Module): + def __init__(self, *args): + super(Mean, self).__init__() + self.index = args + def forward(self, input): + return torch.mean(input, dim=-2) + +class LinearModel(nn.Module): + def __init__(self, in_dim, hidden_dim, n_seq, out_dim, layer_nums=3): + super().__init__() + self.layers = nn.ModuleList() + for i in range(layer_nums): + self.layers.append(nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.ReLU(), + nn.Dropout(0.5), + )) + self.layers.append(PoswiseFeedForward(hidden_dim, hidden_dim, n_seq)) + self.layers.append(Mean()) + self.layers.append(nn.Linear(hidden_dim, out_dim)) + def forward(self, input): + for layer in self.layers: + input = layer(input) + return input + +if __name__ == '__main__': + model = LinearModel(768, 768, 32, 10) + input = torch.zeros(size=[10, 32, 768]) + res = model(input) + print(res)