diff --git a/keras/api/_tf_keras/keras/layers/__init__.py b/keras/api/_tf_keras/keras/layers/__init__.py index 7245456cc18..db8c361610c 100644 --- a/keras/api/_tf_keras/keras/layers/__init__.py +++ b/keras/api/_tf_keras/keras/layers/__init__.py @@ -155,6 +155,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( RandomColorJitter, ) diff --git a/keras/api/layers/__init__.py b/keras/api/layers/__init__.py index 6a3e3b55f14..e38f87bdef3 100644 --- a/keras/api/layers/__init__.py +++ b/keras/api/layers/__init__.py @@ -155,6 +155,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( RandomColorJitter, ) diff --git a/keras/src/layers/__init__.py b/keras/src/layers/__init__.py index 303b3104a56..881a6620f68 100644 --- a/keras/src/layers/__init__.py +++ b/keras/src/layers/__init__.py @@ -99,6 +99,9 @@ from keras.src.layers.preprocessing.image_preprocessing.random_brightness import ( RandomBrightness, ) +from keras.src.layers.preprocessing.image_preprocessing.random_color_degeneration import ( + RandomColorDegeneration, +) from keras.src.layers.preprocessing.image_preprocessing.random_color_jitter import ( RandomColorJitter, ) diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py new file mode 100644 index 00000000000..c3255c846eb --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration.py @@ -0,0 +1,132 @@ +from keras.src.api_export import keras_export +from keras.src.layers.preprocessing.image_preprocessing.base_image_preprocessing_layer import ( # noqa: E501 + BaseImagePreprocessingLayer, +) +from keras.src.random import SeedGenerator + + +@keras_export("keras.layers.RandomColorDegeneration") +class RandomColorDegeneration(BaseImagePreprocessingLayer): + """Randomly performs the color degeneration operation on given images. + + The sharpness operation first converts an image to gray scale, then back to + color. It then takes a weighted average between original image and the + degenerated image. This makes colors appear more dull. + + Args: + factor: A tuple of two floats or a single float. + `factor` controls the extent to which the + image sharpness is impacted. `factor=0.0` makes this layer perform a + no-op operation, while a value of 1.0 uses the degenerated result + entirely. Values between 0 and 1 result in linear interpolation + between the original image and the sharpened image. + Values should be between `0.0` and `1.0`. If a tuple is used, a + `factor` is sampled between the two values for every image + augmented. If a single float is used, a value between `0.0` and the + passed float is sampled. In order to ensure the value is always the + same, please pass a tuple with two identical floats: `(0.5, 0.5)`. + seed: Integer. Used to create a random seed. + """ + + _VALUE_RANGE_VALIDATION_ERROR = ( + "The `value_range` argument should be a list of two numbers. " + ) + + def __init__( + self, + factor, + value_range=(0, 255), + data_format=None, + seed=None, + **kwargs, + ): + super().__init__(data_format=data_format, **kwargs) + self._set_factor(factor) + self._set_value_range(value_range) + self.seed = seed + self.generator = SeedGenerator(seed) + + def _set_value_range(self, value_range): + if not isinstance(value_range, (tuple, list)): + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + if len(value_range) != 2: + raise ValueError( + self._VALUE_RANGE_VALIDATION_ERROR + + f"Received: value_range={value_range}" + ) + self.value_range = sorted(value_range) + + def get_random_transformation(self, data, training=True, seed=None): + if isinstance(data, dict): + images = data["images"] + else: + images = data + images_shape = self.backend.shape(images) + rank = len(images_shape) + if rank == 3: + batch_size = 1 + elif rank == 4: + batch_size = images_shape[0] + else: + raise ValueError( + "Expected the input image to be rank 3 or 4. Received: " + f"inputs.shape={images_shape}" + ) + + if seed is None: + seed = self._get_seed_generator(self.backend._backend) + + factor = self.backend.random.uniform( + (batch_size, 1, 1, 1), + minval=self.factor[0], + maxval=self.factor[1], + seed=seed, + ) + factor = factor + return {"factor": factor} + + def transform_images(self, images, transformation=None, training=True): + if training: + images = self.backend.cast(images, self.compute_dtype) + factor = self.backend.cast( + transformation["factor"], self.compute_dtype + ) + degenerates = self.backend.image.rgb_to_grayscale( + images, data_format=self.data_format + ) + images = images + factor * (degenerates - images) + images = self.backend.numpy.clip( + images, self.value_range[0], self.value_range[1] + ) + images = self.backend.cast(images, self.compute_dtype) + return images + + def transform_labels(self, labels, transformation, training=True): + return labels + + def transform_segmentation_masks( + self, segmentation_masks, transformation, training=True + ): + return segmentation_masks + + def transform_bounding_boxes( + self, bounding_boxes, transformation, training=True + ): + return bounding_boxes + + def get_config(self): + config = super().get_config() + config.update( + { + "factor": self.factor, + "value_range": self.value_range, + "seed": self.seed, + } + ) + return config + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py new file mode 100644 index 00000000000..18a0adc7c1f --- /dev/null +++ b/keras/src/layers/preprocessing/image_preprocessing/random_color_degeneration_test.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest +from tensorflow import data as tf_data + +import keras +from keras.src import backend +from keras.src import layers +from keras.src import testing + + +class RandomColorDegenerationTest(testing.TestCase): + @pytest.mark.requires_trainable_backend + def test_layer(self): + self.run_layer_test( + layers.RandomColorDegeneration, + init_kwargs={ + "factor": 0.75, + "value_range": (0, 1), + "seed": 1, + }, + input_shape=(8, 3, 4, 3), + supports_masking=False, + expected_output_shape=(8, 3, 4, 3), + ) + + def test_random_color_degeneration_value_range(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1) + + layer = layers.RandomColorDegeneration(0.2, value_range=(0, 1)) + adjusted_image = layer(image) + + self.assertTrue(keras.ops.numpy.all(adjusted_image >= 0)) + self.assertTrue(keras.ops.numpy.all(adjusted_image <= 1)) + + def test_random_color_degeneration_no_op(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + + layer = layers.RandomColorDegeneration((0.5, 0.5)) + output = layer(inputs, training=False) + self.assertAllClose(inputs, output, atol=1e-3, rtol=1e-5) + + def test_random_color_degeneration_factor_zero(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + inputs = np.random.random((2, 8, 8, 3)) + else: + inputs = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorDegeneration(factor=(0.0, 0.0)) + result = layer(inputs) + + self.assertAllClose(inputs, result, atol=1e-3, rtol=1e-5) + + def test_random_color_degeneration_randomness(self): + image = keras.random.uniform(shape=(3, 3, 3), minval=0, maxval=1)[:5] + + layer = layers.RandomColorDegeneration(0.2) + adjusted_images = layer(image) + + self.assertNotAllClose(adjusted_images, image) + + def test_tf_data_compatibility(self): + data_format = backend.config.image_data_format() + if data_format == "channels_last": + input_data = np.random.random((2, 8, 8, 3)) + else: + input_data = np.random.random((2, 3, 8, 8)) + layer = layers.RandomColorDegeneration( + factor=0.5, data_format=data_format, seed=1337 + ) + + ds = tf_data.Dataset.from_tensor_slices(input_data).batch(2).map(layer) + for output in ds.take(1): + output.numpy()