-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_network.py
49 lines (43 loc) · 2 KB
/
train_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from nnkit.core.neuronet import DenseLayer, DenseNetwork
from nnkit.core.activations import Softmax, ReLU, Sigmoid, Tanh
from nnkit.core.losses import CrossEntropySoftmax
from nnkit.datasets import mnist
from nnkit.datasets.utils import DataLabelSet, one_hot
from nnkit.training.neurotrain import NetworkTrainer
from nnkit.training.update_rules import SGD, RPropPlus, IRPropPlus, RPropMinus, IRPropMinus
from nnkit.training.stopping import GLStoppingCriterion
from nnkit.training.metrics import Accuracy, MetricsEvaluator
if __name__ == '__main__':
# Build Network
net = DenseNetwork(
DenseLayer(num_inputs=784, num_neurons=256, activation_function=Tanh()),
DenseLayer(num_neurons=10, activation_function=Softmax())
)
# Load data / Data pre-processing
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28 * 28))
train_images = (train_images.astype('float32') / 255)
train_labels = one_hot(train_labels)
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype('float32') / 255
test_labels = one_hot(test_labels)
# Training data / Validation data
training_set = DataLabelSet(train_images, train_labels, batch_size=1, name='training')
training_set, validation_set = training_set.split(
split_factor=0.2,
split_set_batch_size=len(train_images),
split_set_name='validation'
)
# Train the network
trainer = NetworkTrainer(
net=net,
update_rule=SGD(learning_rate=0.1, momentum=0.9),
loss_function=CrossEntropySoftmax(),
metrics=[Accuracy()]
)
history = trainer.train_network(training_set, validation_set, epochs=30)
# Test the network
test_set = DataLabelSet(test_images, test_labels, batch_size=len(test_images), name='test')
evaluator = MetricsEvaluator(net, metrics=[Accuracy()], loss_function=CrossEntropySoftmax())
metrics = evaluator.compute_metrics(test_set)
print(metrics)