Skip to content

Commit

Permalink
master
Browse files Browse the repository at this point in the history
  • Loading branch information
iser97 committed Oct 21, 2021
1 parent 2d1c091 commit c4f84d1
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
4 changes: 3 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from transformers import HfArgumentParser

from scripts.model.transformer_single_layer import my_transformer
from scripts.model.linear_model import LinearModel
from scripts.data.dataset_mnist_8_8 import DatasetMnist
from scripts.config.arguments import Arguments

Expand Down Expand Up @@ -89,7 +90,8 @@ def main():
data_dim = args.data_split_dim*args.data_split_dim
seq_length = int(args.data_dimension**2 / data_dim) # through the data_split_dim can split the mnist picture to sub blocks, the number of sub blocks stands for the transformers' sequence length

tModel = my_transformer(data_dim, data_dim, seq_length, args.n_heads, data_dim, args.num_classes).to(device)
# tModel = my_transformer(data_dim, data_dim, seq_length, args.n_heads, data_dim, args.num_classes).to(device)
tModel = LinearModel(data_dim, data_dim, n_seq=seq_length, out_dim=args.num_classes).to(device)
# optimizer = optim.SGD(tModel.parameters(),lr=lr,momentum=mom)
optimizer = optim.Adam(tModel.parameters(), lr=args.lr)
loss_fn = nn.CrossEntropyLoss()
Expand Down
52 changes: 52 additions & 0 deletions scripts/model/linear_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch
import torch.nn as nn


class PoswiseFeedForward(nn.Module):
def __init__(self, d_dim, mid_dim, n_seq, bias=True):
super(PoswiseFeedForward,self).__init__()
# self.L1 = torch.nn.utils.weight_norm(nn.Linear(d_dim, mid_dim, bias=bias))
# self.L2 = torch.nn.utils.weight_norm(nn.Linear(mid_dim, d_dim, bias=bias))
self.L1 = nn.Linear(d_dim, mid_dim, bias=bias)
self.L2 = nn.Linear(mid_dim, d_dim, bias=bias)
#self.LN = nn.LayerNorm(d_dim, elementwise_affine=False)
self.LN = nn.LayerNorm([n_seq, d_dim], elementwise_affine=False)
self.relu = nn.ReLU()

def forward(self, inputs):
residual = inputs
output = self.L1(inputs)
output = self.relu(output)
output = self.L2(output)
return self.LN(output + residual)

class Mean(nn.Module):
def __init__(self, *args):
super(Mean, self).__init__()
self.index = args
def forward(self, input):
return torch.mean(input, dim=-2)

class LinearModel(nn.Module):
def __init__(self, in_dim, hidden_dim, n_seq, out_dim, layer_nums=3):
super().__init__()
self.layers = nn.ModuleList()
for i in range(layer_nums):
self.layers.append(nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
))
self.layers.append(PoswiseFeedForward(hidden_dim, hidden_dim, n_seq))
self.layers.append(Mean())
self.layers.append(nn.Linear(hidden_dim, out_dim))
def forward(self, input):
for layer in self.layers:
input = layer(input)
return input

if __name__ == '__main__':
model = LinearModel(768, 768, 32, 10)
input = torch.zeros(size=[10, 32, 768])
res = model(input)
print(res)

0 comments on commit c4f84d1

Please sign in to comment.