from argparse import ArgumentParser import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision import torchvision.transforms as transforms parser = ArgumentParser() parser.add_argument("--no-cuda", action="store_true", default=False) args = parser.parse_args() device = torch.device("cpu" if args.no_cuda else "cuda") dataloader_kwargs = {"pin_memory": True} transform = transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] ) trainset = torchvision.datasets.CIFAR10( root="/data", train=True, download=True, transform=transform ) trainloader = torch.utils.data.DataLoader( trainset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True, ) testset = torchvision.datasets.CIFAR10( root="/data", train=False, download=True, transform=transform ) testloader = torch.utils.data.DataLoader( testset, batch_size=128, shuffle=False, num_workers=2, pin_memory=True, ) classes = ( "plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck", ) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x net = nn.DataParallel(Net().to(device)) criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), lr=0.001) for epoch in range(2): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data[0].to(device), data[1].to(device) optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 10 == 0: print("[%d, %5d] loss: %.3f" % (epoch + 1, i + 1, running_loss / 10)) running_loss = 0.0 print("Finished Training") correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data[0].to(device), data[1].to(device) outputs = net(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: {:.2%}'.format(correct / total))