Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random_color_degeneration processing layer #20679

Merged
merged 3 commits into from
Dec 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
RandomColorJitter,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
RandomColorJitter,
)
Expand Down
3 changes: 3 additions & 0 deletions keras/src/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@
from keras.src.layers.preprocessing.image_preprocessing.random_brightness import (
RandomBrightness,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import (
RandomColorDegeneration,
)
from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import (
RandomColorJitter,
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501
BaseImagePreprocessingLayer,
)
from keras.src.random import SeedGenerator


@keras_export("keras.layers.RandomColorDegeneration")
class RandomColorDegeneration(BaseImagePreprocessingLayer):
"""Randomly performs the color degeneration operation on given images.

The sharpness operation first converts an image to gray scale, then back to
color. It then takes a weighted average between original image and the
degenerated image. This makes colors appear more dull.

Args:
factor: A tuple of two floats or a single float.
`factor` controls the extent to which the
image sharpness is impacted. `factor=0.0` makes this layer perform a
no-op operation, while a value of 1.0 uses the degenerated result
entirely. Values between 0 and 1 result in linear interpolation
between the original image and the sharpened image.
Values should be between `0.0` and `1.0`. If a tuple is used, a
`factor` is sampled between the two values for every image
augmented. If a single float is used, a value between `0.0` and the
passed float is sampled. In order to ensure the value is always the
same, please pass a tuple with two identical floats: `(0.5, 0.5)`.
seed: Integer. Used to create a random seed.
"""

_VALUE_RANGE_VALIDATION_ERROR = (
"The `value_range` argument should be a list of two numbers. "
)

def __init__(
self,
factor,
value_range=(0, 255),
data_format=None,
seed=None,
**kwargs,
):
super().__init__(data_format=data_format, **kwargs)
self._set_factor(factor)
self._set_value_range(value_range)
self.seed = seed
self.generator = SeedGenerator(seed)

def _set_value_range(self, value_range):
if not isinstance(value_range, (tuple, list)):
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
if len(value_range) != 2:
raise ValueError(
self._VALUE_RANGE_VALIDATION_ERROR
+ f"Received: value_range={value_range}"
)
self.value_range = sorted(value_range)

def get_random_transformation(self, data, training=True, seed=None):
if isinstance(data, dict):
images = data["images"]
else:
images = data
images_shape = self.backend.shape(images)
rank = len(images_shape)
if rank == 3:
batch_size = 1
elif rank == 4:
batch_size = images_shape[0]
else:
raise ValueError(
"Expected the input image to be rank 3 or 4. Received: "
f"inputs.shape={images_shape}"
)

if seed is None:
seed = self._get_seed_generator(self.backend._backend)

factor = self.backend.random.uniform(
(batch_size, 1, 1, 1),
minval=self.factor[0],
maxval=self.factor[1],
seed=seed,
)
factor = factor
return {"factor": factor}

def transform_images(self, images, transformation=None, training=True):
if training:
images = self.backend.cast(images, self.compute_dtype)
factor = self.backend.cast(
transformation["factor"], self.compute_dtype
)
degenerates = self.backend.image.rgb_to_grayscale(
images, data_format=self.data_format
)
images = images + factor * (degenerates - images)
images = self.backend.numpy.clip(
images, self.value_range[0], self.value_range[1]
)
images = self.backend.cast(images, self.compute_dtype)
return images

def transform_labels(self, labels, transformation, training=True):
return labels

def transform_segmentation_masks(
self, segmentation_masks, transformation, training=True
):
return segmentation_masks

def transform_bounding_boxes(
self, bounding_boxes, transformation, training=True
):
return bounding_boxes

def get_config(self):
config = super().get_config()
config.update(
{
"factor": self.factor,
"value_range": self.value_range,
"seed": self.seed,
}
)
return config

def compute_output_shape(self, input_shape):
return input_shape
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import numpy as np
import pytest
from tensorflow import data as tf_data

import keras
from keras.src import backend
from keras.src import layers
from keras.src import testing


class RandomColorDegenerationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_layer(self):
self.run_layer_test(
layers.RandomColorDegeneration,
init_kwargs={
"factor": 0.75,
"value_range": (0, 1),
"seed": 1,
},
input_shape=(8, 3, 4, 3),
supports_masking=False,
expected_output_shape=(8, 3, 4, 3),
)

def test_random_color_degeneration_value_range(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)

layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1))
adjusted_image = layer(image)

self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0))
self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1))

def test_random_color_degeneration_no_op(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))

layer = layers.RandomColorDegeneration((0.5, 0.5))
output = layer(inputs, training=False)
self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5)

def test_random_color_degeneration_factor_zero(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
inputs = np.random.random((2, 8, 8, 3))
else:
inputs = np.random.random((2, 3, 8, 8))
layer = layers.RandomColorDegeneration(factor=(0.0, 0.0))
result = layer(inputs)

self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5)

def test_random_color_degeneration_randomness(self):
image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5]

layer = layers.RandomColorDegeneration(0.2)
adjusted_images = layer(image)

self.assertNotAllClose(adjusted_images, image)

def test_tf_data_compatibility(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
input_data = np.random.random((2, 8, 8, 3))
else:
input_data = np.random.random((2, 3, 8, 8))
layer = layers.RandomColorDegeneration(
factor=0.5, data_format=data_format, seed=1337
)

ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer)
for output in ds.take(1):
output.numpy()
Loading