Skip to content

Commit

Permalink
Merge pull request #249 from oarriaga/omniglot
Browse files Browse the repository at this point in the history
Prototypical networks and Omniglot integration
  • Loading branch information
oarriaga authored Jan 23, 2023
2 parents e9b6fb7 + e8fac91 commit 4edcf13
Show file tree
Hide file tree
Showing 15 changed files with 602 additions and 242 deletions.
27 changes: 25 additions & 2 deletions docs/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from paz.optimization import callbacks
from paz import datasets
from paz import pipelines
from paz.utils import logger
from paz.utils import documentation

EXCLUDE = {}

Expand Down Expand Up @@ -248,7 +250,9 @@
{
'page': 'models/classification.md',
'functions': [
models.classification.MiniXception
models.classification.MiniXception,
models.ProtoEmbedding,
models.ProtoNet
],
},

Expand Down Expand Up @@ -315,7 +319,8 @@
datasets.FERPlus,
datasets.OpenImages,
datasets.CityScapes,
datasets.Shapes
datasets.Shapes,
datasets.Omniglot
],
},

Expand Down Expand Up @@ -725,5 +730,23 @@
},


{
'page': 'utils/logger.md',
'functions': [
logger.build_directory,
logger.make_directory,
logger.write_dictionary,
logger.write_weights
]

},

{
'page': 'utils/documentation.md',
'functions': [
documentation.docstring
]

},

]
189 changes: 0 additions & 189 deletions examples/prototypical_networks/dataset.py

This file was deleted.

11 changes: 0 additions & 11 deletions examples/prototypical_networks/download_omniglot.sh

This file was deleted.

32 changes: 21 additions & 11 deletions examples/prototypical_networks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,21 @@
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import callbacks as cb

from protonet import PROTONET, Embedding, schedule
from logger import build_directory, write_dictionary, write_weights
from dataset import (load_omniglot, remove_classes, split_data,
sample_between_alphabet, sample_within_alphabet,
Generator)
from paz.models import ProtoEmbedding, ProtoNet
from paz.utils import build_directory, write_dictionary, write_weights
from paz.datasets.omniglot import (load, remove_classes, split_data,
sample_between_alphabet,
sample_within_alphabet, Generator)


# TODO move to optimization and add tests
def schedule(period=20, rate=0.5):
def apply(epoch, learning_rate):
if ((epoch % period) == 0) and (epoch != 0):
learning_rate = rate * learning_rate
return learning_rate
return apply


description = 'Train and evaluation of prototypical networks'
parser = argparse.ArgumentParser(description=description)
Expand Down Expand Up @@ -52,8 +62,8 @@

image_shape = (args.image_H, args.image_W, 1)
train_args = (args.train_ways, args.train_shots, args.train_queries)
embed = Embedding(image_shape, args.num_blocks)
model = PROTONET(embed, *train_args, image_shape)
embed = ProtoEmbedding(image_shape, args.num_blocks)
model = ProtoNet(embed, *train_args, image_shape)
optimizer = Adam(args.learning_rate)
metrics = [args.metric]
model.compile(Adam(args.learning_rate), loss=args.loss, metrics=metrics)
Expand All @@ -64,7 +74,7 @@
cb.EarlyStopping('val_loss', args.stop_delta, args.stop_patience, 1)
]

train_data = load_omniglot(args.train_path, image_shape[:2])
train_data = load('train', image_shape[:2], True)
train_data = remove_classes(RNG, train_data, args.train_classes)
train_data, validation_data = split_data(train_data, args.validation_split)

Expand All @@ -81,19 +91,19 @@
results = {}
for way in args.test_ways:
for shot in args.test_shots:
test_model = PROTONET(embed, way, shot, args.test_queries, image_shape)
test_model = ProtoNet(embed, way, shot, args.test_queries, image_shape)
test_model.compile(optimizer, loss=args.loss, metrics=metrics)
test_args = (way, shot, args.test_queries)

data = load_omniglot(args.tests_path, image_shape[:2], flat=False)
data = load('test', image_shape[:2], flat=False)
sampler = partial(sample_within_alphabet, RNG, data, *test_args)
sequence = Generator(sampler, *test_args, image_shape, args.test_steps)
losses, accuracy = test_model.evaluate(sequence)
accuracy = round(100 * accuracy, 2)
results[f'{way}-way_{shot}-shot_within_alphabet'] = accuracy
print(f'Within alphabet {way}-way {shot}-shot accuracy {accuracy} %')

data = load_omniglot(args.tests_path, image_shape[:2], flat=True)
data = load('test', image_shape[:2], flat=True)
sampler = partial(sample_between_alphabet, RNG, data, *test_args)
sequence = Generator(sampler, *test_args, image_shape, args.test_steps)
losses, accuracy = test_model.evaluate(sequence)
Expand Down
29 changes: 16 additions & 13 deletions paz/backend/image/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,29 +289,32 @@ def make_mosaic(images, shape, border=0):
# Arguments
images: Numpy array of shape (num_images, height, width, num_channels)
shape: List of two integers indicating the mosaic shape.
Shape must satisfy: shape[0] * shape[1] == len(images).
border: Integer indicating the border per image.
# Returns
A numpy array containing all images.
# Exceptions
Shape must satisfy `len(images) > shape[0] * shape[1]`
"""
num_images = len(images)
num_images, H, W, num_channels = images.shape
num_rows, num_cols = shape
H, W, num_channels = images.shape[1:]
mosaic = np.ma.masked_all(
(num_rows * H + (num_rows - 1) * border,
num_cols * W + (num_cols - 1) * border, num_channels),
dtype=np.float32)
if num_images > (num_rows * num_cols):
raise ValueError('Number of images is bigger than shape')

total_rows = (num_rows * H) + ((num_rows - 1) * border)
total_cols = (num_cols * W) + ((num_cols - 1) * border)
mosaic = np.ones((total_rows, total_cols, num_channels))

padded_H = H + border
padded_W = W + border
for image_arg in range(num_images):

for image_arg, image in enumerate(images):
row = int(np.floor(image_arg / num_cols))
col = image_arg % num_cols
image = images[image_arg]
image_shape = image.shape
mosaic[row * padded_H:row * padded_H + image_shape[0],
col * padded_W:col * padded_W + image_shape[1], :] = image
return mosaic.astype('uint8')
mosaic[row * padded_H:row * padded_H + H,
col * padded_W:col * padded_W + W, :] = image
return mosaic


def draw_points2D(image, points2D, colors):
Expand Down
1 change: 1 addition & 0 deletions paz/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
from .CMU_poanoptic import MANOHandJoints
from .CMU_poanoptic import MPIIHandJoints
from .shapes import Shapes
from .omniglot import Omniglot
Loading

0 comments on commit 4edcf13

Please sign in to comment.