Skip to content

Commit

Permalink
add class_weight handle
Browse files Browse the repository at this point in the history
  • Loading branch information
puke3615 committed Nov 28, 2017
1 parent 37ff54d commit cc8ba66
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 2 deletions.
3 changes: 2 additions & 1 deletion classifier_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def compile_mode(self, force=False):
self.optimizer = Nadam(self.lr)
self.model.compile(loss='categorical_crossentropy', optimizer=self.optimizer, metrics=['accuracy'])

def train(self):
def train(self, class_weight=None):
# calculate files number
file_num = utils.calculate_file_num(PATH_TRAIN_IMAGES)
steps_train = file_num // self.batch_size
Expand Down Expand Up @@ -127,6 +127,7 @@ def train(self):
validation_data=val_generator,
validation_steps=steps_val,
verbose=1,
class_weight=class_weight,
)
except KeyboardInterrupt:
print('\nStop by keyboardInterrupt, try saving weights.')
Expand Down
3 changes: 2 additions & 1 deletion classifier_xception.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import utils
from classifier_base import BaseClassifier
from keras.applications import *
from keras.optimizers import *
Expand Down Expand Up @@ -38,4 +39,4 @@ def data_generator(self, path_image, train=True):
# classifier = XceptionClassifier(lr=2e-4)
# classifier = XceptionClassifier(lr=2e-5)
classifier = XceptionClassifier('xception_resize', optimizer=Adam(1e-4))
classifier.train()
classifier.train(class_weight=utils.calculate_class_weight())
53 changes: 53 additions & 0 deletions test_augment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from imgaug import augmenters as iaa
from PIL import Image
import imgaug as ia
import numpy as np
import os

ia.seed(1)

# Example batch of images.
# The array has shape (32, 64, 64, 3) and dtype uint8.
images = np.array(
[ia.quokka(size=(64, 64)) for _ in range(32)],
dtype=np.uint8
)

seq = iaa.Sequential([
iaa.Fliplr(0.5), # horizontal flips
iaa.Crop(percent=(0, 0.1)), # random crops
# Small gaussian blur with random sigma between 0 and 0.5.
# But we only blur about 50% of all images.
iaa.Sometimes(0.5,
iaa.GaussianBlur(sigma=(0, 0.5))
),
# Strengthen or weaken the contrast in each image.
iaa.ContrastNormalization((0.75, 1.5)),
# Add gaussian noise.
# For 50% of all images, we sample the noise once per pixel.
# For the other 50% of all images, we sample the noise per pixel AND
# channel. This can change the color (not only brightness) of the
# pixels.
iaa.AdditiveGaussianNoise(loc=0, scale=(0.0, 0.05 * 255), per_channel=0.5),
# Make some images brighter and some darker.
# In 20% of all cases, we sample the multiplier once per channel,
# which can end up changing the color of the images.
iaa.Multiply((0.8, 1.2), per_channel=0.2),
# Apply affine transformations to each image.
# Scale/zoom them, translate/move them, rotate them and shear them.
iaa.Affine(
scale={"x": (0.8, 1.2), "y": (0.8, 1.2)},
translate_percent={"x": (-0.2, 0.2), "y": (-0.2, 0.2)},
rotate=(-25, 25),
shear=(-8, 8)
)
], random_order=True) # apply augmenters in random order

path = '/Users/zijiao/Desktop/1'
images_aug = seq.augment_images(images)
for i, im in enumerate(images_aug):
im = Image.fromarray(im)
im.show()
# with open(os.path.join(path, '%d.jpg' % i), 'wb') as f:
# im.save(f)
print('Done.')
10 changes: 10 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from keras.applications.xception import preprocess_input
import keras.backend as K
import tensorflow as tf
import config
import os


Expand Down Expand Up @@ -36,6 +37,15 @@ def ensure_dir(dir):
os.makedirs(dir)


def calculate_class_weight(train_path=config.PATH_TRAIN_IMAGES):
if not os.path.isdir(train_path):
raise Exception('Dir "%s" not exists.' % train_path)
n_classes = [len(os.listdir(os.path.join(train_path, subdir))) for subdir in os.listdir(train_path)]
print n_classes
n_all = sum(n_classes)
return [num / float(n_all) for num in n_classes]


def get_best_weights(path_weights, mode='acc', postfix='.h5'):
if not os.path.isdir(path_weights):
return None
Expand Down

0 comments on commit cc8ba66

Please sign in to comment.