Skip to content

Commit

Permalink
master
Browse files Browse the repository at this point in the history
  • Loading branch information
iser97 committed Oct 21, 2021
1 parent c4f84d1 commit ff42488
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
10 changes: 7 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
import logging
import matplotlib as mpl
from torch._C import Value
if os.environ.get('DISPLAY','')=='':
print('no display found. Using non-interactive Agg backend')
mpl.use("agg")
Expand Down Expand Up @@ -89,9 +90,12 @@ def train_step(model, optimizer, loss_fn, train_loader, test_loader):
def main():
data_dim = args.data_split_dim*args.data_split_dim
seq_length = int(args.data_dimension**2 / data_dim) # through the data_split_dim can split the mnist picture to sub blocks, the number of sub blocks stands for the transformers' sequence length

# tModel = my_transformer(data_dim, data_dim, seq_length, args.n_heads, data_dim, args.num_classes).to(device)
tModel = LinearModel(data_dim, data_dim, n_seq=seq_length, out_dim=args.num_classes).to(device)
if args.model_type == 'transformer':
tModel = my_transformer(data_dim, data_dim, seq_length, args.n_heads, data_dim, args.num_classes).to(device)
elif args.model_type == 'linear':
tModel = LinearModel(data_dim, data_dim, n_seq=seq_length, out_dim=args.num_classes).to(device)
else:
raise ValueError("{} model type is not implemented".format(args.model_type))
# optimizer = optim.SGD(tModel.parameters(),lr=lr,momentum=mom)
optimizer = optim.Adam(tModel.parameters(), lr=args.lr)
loss_fn = nn.CrossEntropyLoss()
Expand Down
11 changes: 8 additions & 3 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,16 @@

## file dir

cache: root for saving data, the 'test_data.txt', 'train_data.txt', 'test_label.txt', 'train_label.txt' is save in this directory.
cache: root for saving data, the 'test_data.txt', 'train_data.txt', 'test_label.txt', 'train_label.txt' is saved in this directory.

scripts: total scripts

更改模型类别可以通过参数model_type进行更改, model_type=transformer or model_type=linear

training and test run the following command
```bash
python main.py --lr 0.0001 --batch_size 256 --n_epochs 200 --data_split_dim 4 --data_dimension 8 --n_heads 4
```
python main.py --lr 0.0001 --batch_size 256 --n_epochs 200 --data_split_dim 4 --data_dimension 8 --n_heads 4 --model_type linear
```



3 changes: 3 additions & 0 deletions scripts/config/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ class Arguments:
)
num_classes: Optional[int] = field(
default=10, metadata={"help": "classes for mnist is 10"}
)
model_type: Optional[str] = field(
default='transformer', metadata={"help": "model type choices: transformer or linear"}
)

0 comments on commit ff42488

Please sign in to comment.