Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Faster R-CNN to Keras3 #2458

Merged
merged 46 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c37ae23
Base structure for faster rcnn till rpn head
sineeli Jun 10, 2024
973dd6a
Add export for Faster RNN
sineeli Jun 10, 2024
70c7f24
add init file
sineeli Jun 10, 2024
de67b89
initalize faster rcnn at model level
sineeli Jun 10, 2024
aaebe30
code fix fo roi align
sineeli Jun 12, 2024
0707858
Forward Pass code for Faster R-CNN
sineeli Jun 12, 2024
cff3b8e
Faster RCNN Base code for Keras3(Draft-1)
sineeli Jun 25, 2024
4f511e9
Add local batch size
sineeli Jun 25, 2024
0eef933
Add parameters to RPN Head
sineeli Jul 2, 2024
75c64ca
Make FPN more customizable with parameters and remove redudant code
sineeli Jul 2, 2024
6267a4b
Compute output shape for ROI Generator
sineeli Jul 2, 2024
1931f02
Faster RCNN functional model with required import corrections
sineeli Jul 2, 2024
58dc7f9
add clip boxes to forward pass
sineeli Jul 8, 2024
7c65348
add prediction decoder and use "yxyx" as default internal bounding bo…
sineeli Jul 11, 2024
676fcf1
feature pryamid correction
sineeli Jul 16, 2024
dcea19f
change ops.divide to ops.divide_no_nan
sineeli Jul 29, 2024
2179157
use from logits=True for Non Max supression
sineeli Jul 29, 2024
a002c49
include box convertions for both rois and ground truth boxes
sineeli Jul 29, 2024
5953f0a
Change number of detections in decoder
sineeli Jul 29, 2024
91f21fa
Use categoricalcrossentropy to avoid -1 class error + added get_confi…
sineeli Jul 30, 2024
abf0b44
add basic test cases + linting
sineeli Jul 30, 2024
d2b78e0
Add seed generator for sampling in RPN label encoding and ROI samplin…
sineeli Jul 30, 2024
a397a6c
Use only spatial dimension for ops.nn.avg_pool + use ops.convert_to_t…
sineeli Jul 30, 2024
e336d69
Convert list to tensor using keras ops
sineeli Jul 30, 2024
ecd0dad
Remove seed number from seed generator
sineeli Jul 31, 2024
c91ac27
Remove print and add proper comments
sineeli Aug 5, 2024
ba86502
- Use stddev(0.01) as per paper across RPN and R-CNN Heads
sineeli Aug 8, 2024
4979a99
- Fixes slice for multi backend
sineeli Aug 8, 2024
357a14a
- Add compute metrics method
sineeli Aug 9, 2024
ef27533
Correct test cases and add missing args
sineeli Aug 12, 2024
f37d799
Fix lint issues
sineeli Aug 13, 2024
36d4e10
- Fix lint and remove hard coded params to make it user friendly.
sineeli Aug 13, 2024
5060382
- Generate ROI's while decoding for predictions
sineeli Aug 14, 2024
02d24b0
- Add faster rcnn to build method
sineeli Aug 14, 2024
c0556d8
- Test only for Keras3
sineeli Aug 14, 2024
879028f
- Correct test case
sineeli Aug 15, 2024
c77d03c
- Correct the test cases decorator to skip for Keras2
sineeli Aug 16, 2024
10b9e76
- Skip Legacy test cases
sineeli Aug 16, 2024
e1d89e7
- Remove unecessary import in legacy code to fix lint
sineeli Aug 16, 2024
58178c6
- Correct pytest complexity
sineeli Aug 16, 2024
1c6125b
- FIx Image Shape to 512, 512 default which will not break other test…
sineeli Aug 16, 2024
df56fa6
- Lower image sizes for test cases
sineeli Aug 19, 2024
6b03271
- fix keras to 3.3.3 version
sineeli Aug 20, 2024
8608516
- Generate api
sineeli Aug 20, 2024
d1f05af
- Lint fix
sineeli Aug 20, 2024
8360e5b
- Increase the atol, rtol for YOLOv8 Detector forward pass
sineeli Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
include box convertions for both rois and ground truth boxes
  • Loading branch information
sineeli committed Jul 29, 2024
commit a002c49dd74ba03a9c5f22f86c52e411ccfb7bb2
21 changes: 12 additions & 9 deletions keras_cv/src/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class ROISampler(keras.layers.Layer):
if its range is [0, num_classes).

Args:
bounding_box_format: The format of bounding boxes to generate. Refer
roi_bounding_box_format: The format of roi bounding boxes. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
gt_bounding_box_format: The format of gt bounding boxes.
roi_matcher: a `BoxMatcher` object that matches proposals with ground
truth boxes. The positive match must be 1 and negative match must be -1.
Such assumption is not being validated here.
Expand All @@ -59,7 +60,8 @@ class ROISampler(keras.layers.Layer):

def __init__(
self,
bounding_box_format: str,
roi_bounding_box_format: str,
gt_bounding_box_format: str,
roi_matcher: box_matcher.BoxMatcher,
positive_fraction: float = 0.25,
background_class: int = 0,
Expand All @@ -68,7 +70,8 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.roi_bounding_box_format = roi_bounding_box_format
self.gt_bounding_box_format = gt_bounding_box_format
self.roi_matcher = roi_matcher
self.positive_fraction = positive_fraction
self.background_class = background_class
Expand Down Expand Up @@ -97,6 +100,12 @@ def call(
sampled_gt_classes: [batch_size, num_sampled_rois, 1]
sampled_class_weights: [batch_size, num_sampled_rois, 1]
"""
rois = bounding_box.convert_format(
rois, source=self.roi_bounding_box_format, target="yxyx"
)
gt_boxes = bounding_box.convert_format(
gt_boxes, source=self.gt_bounding_box_format, target="yxyx"
)
if self.append_gt_boxes:
# num_rois += num_gt
rois = ops.concatenate([rois, gt_boxes], axis=1)
Expand All @@ -110,12 +119,6 @@ def call(
"num_rois must be less than `num_sampled_rois` "
f"({self.num_sampled_rois}), got {num_rois}"
)
rois = bounding_box.convert_format(
rois, source=self.bounding_box_format, target="yxyx"
)
gt_boxes = bounding_box.convert_format(
gt_boxes, source=self.bounding_box_format, target="yxyx"
)
# [batch_size, num_rois, num_gt]
similarity_mat = iou.compute_iou(
rois, gt_boxes, bounding_box_format="yxyx", use_masking=True
Expand Down
166 changes: 62 additions & 104 deletions keras_cv/src/models/object_detection/faster_rcnn/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import tree

from keras_cv.src import bounding_box
from keras_cv.src import layers as cv_layers
from keras_cv.src import losses
from keras_cv.src.api_export import keras_cv_export
from keras_cv.src.backend import keras
from keras_cv.src.backend import ops
from keras_cv.src.bounding_box.converters import _decode_deltas_to_boxes
from keras_cv.src.bounding_box.utils import _clip_boxes
from keras_cv.src.layers.object_detection.anchor_generator import (
AnchorGenerator,
)
from keras_cv.src.layers.object_detection.box_matcher import BoxMatcher
from keras_cv.src.layers.object_detection.multi_class_non_max_suppression import ( # noqa: E501
MultiClassNonMaxSuppression,
)
from keras_cv.src.layers.object_detection.roi_align import ROIAligner
from keras_cv.src.layers.object_detection.roi_generator import ROIGenerator
from keras_cv.src.layers.object_detection.roi_sampler import ROISampler
Expand Down Expand Up @@ -89,6 +93,7 @@ def __init__(
num_filters=rpn_filters,
kernel_size=rpn_kernel_size,
)

# 5. ROI Generator
roi_generator = ROIGenerator(
bounding_box_format="yxyx",
Expand Down Expand Up @@ -171,10 +176,10 @@ def __init__(

inputs = {"images": images}
outputs = {
"box": box_pred,
"classification": cls_pred,
"rpn_box": rpn_box_pred,
"rpn_classification": rpn_cls_pred,
"box": box_pred,
"classification": cls_pred,
}

super().__init__(
Expand Down Expand Up @@ -204,29 +209,29 @@ def __init__(
thresholds=[0.0, 0.5], match_values=[-2, -1, 1]
)
self.roi_sampler = ROISampler(
bounding_box_format="yxyx",
roi_bounding_box_format="yxyx",
gt_bounding_box_format=bounding_box_format,
roi_matcher=self.box_matcher,
background_class=num_classes,
num_sampled_rois=512,
)
self.roi_pooler = roi_pooler
self.rcnn_head = rcnn_head
self._prediction_decoder = (
prediction_decoder
or cv_layers.MultiClassNonMaxSuppression(
or MultiClassNonMaxSuppression(
bounding_box_format=bounding_box_format,
from_logits=True,
max_detections_per_class=10,
max_detections=10,
from_logits=False,
max_detections_per_class=200,
max_detections=200,
confidence_threshold=0.3,
)
)

def compile(
self,
box_loss=None,
classification_loss=None,
rpn_box_loss=None,
rpn_classification_loss=None,
box_loss=None,
classification_loss=None,
weight_decay=0.0001,
loss=None,
metrics=None,
Expand All @@ -238,21 +243,22 @@ def compile(
"Instead, please pass `box_loss` and `classification_loss`. "
"`loss` will be ignored during training."
)
box_loss = _parse_box_loss(box_loss)
classification_loss = _parse_classification_loss(classification_loss)

rpn_box_loss = _parse_box_loss(rpn_box_loss)
rpn_classification_loss = _parse_rpn_classification_loss(
rpn_classification_loss
)

if hasattr(rpn_classification_loss, "from_logits"):
if not rpn_classification_loss.from_logits:
raise ValueError(
"FasterRCNN.compile() expects `from_logits` to be True for "
"`rpn_classification_loss`. Got "
"`rpn_classification_loss.from_logits="
f"{classification_loss.from_logits}`"
f"{rpn_classification_loss.from_logits}`"
)
box_loss = _parse_box_loss(box_loss)
classification_loss = _parse_classification_loss(classification_loss)

if hasattr(classification_loss, "from_logits"):
if not classification_loss.from_logits:
raise ValueError(
Expand All @@ -271,38 +277,41 @@ def compile(
"`box_loss.bounding_box_format="
f"{self.bounding_box_format}`"
)

self.rpn_box_loss = rpn_box_loss
self.rpn_cls_loss = rpn_classification_loss
self.box_loss = box_loss
self.cls_loss = classification_loss
self.weight_decay = weight_decay
losses = {
"box": self.box_loss,
"classification": self.cls_loss,
"rpn_box": self.rpn_box_loss,
"rpn_classification": self.rpn_cls_loss,
"box": self.box_loss,
"classification": self.cls_loss,
}
self._has_user_metrics = metrics is not None and len(metrics) != 0
self._user_metrics = metrics
super().compile(loss=losses, **kwargs)

def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
def compute_loss(
self, x, y, y_pred, sample_weight, training=True, **kwargs
):
# 1. Unpack the inputs
images = x
gt_boxes = y["boxes"]
if keras.ops.ndim(y["classes"]) != 2:
if ops.ndim(y["classes"]) != 2:
raise ValueError(
"Expected 'classes' to be a Tensor of rank 2. "
f"Got y['classes'].shape={keras.ops.shape(y['classes'])}."
f"Got y['classes'].shape={ops.shape(y['classes'])}."
)

gt_classes = y["classes"]
gt_classes = keras.ops.expand_dims(y["classes"], axis=-1)
gt_classes = ops.expand_dims(gt_classes, axis=-1)

# Generate anchors
# image shape must not contain the batch size
local_batch = keras.ops.shape(images)[0]
image_shape = keras.ops.shape(images)[1:]
local_batch = ops.shape(images)[0]
image_shape = ops.shape(images)[1:]
anchors = self.anchor_generator(image_shape=image_shape)

# 2. Label with the anchors -- exclusive to compute_loss
Expand All @@ -312,7 +321,7 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
rpn_cls_targets,
rpn_cls_weights,
) = self.rpn_labeler(
anchors_dict=keras.ops.concatenate(
anchors_dict=ops.concatenate(
tree.flatten(anchors),
axis=0,
),
Expand Down Expand Up @@ -359,53 +368,62 @@ def compute_loss(self, x, y, y_pred, sample_weight, **kwargs):
variance=BOX_VARIANCE,
)

rois, _ = self.roi_generator(decoded_rpn_boxes, rpn_scores)
rois, _ = self.roi_generator(
decoded_rpn_boxes, rpn_scores, training=training
)
rois = _clip_boxes(rois, "yxyx", image_shape)

# print(f"ROI's Generated from RPN Network: {rois}")

# 4. Stop gradient from flowing into the ROI
# -- exclusive to compute_loss
rois = keras.ops.stop_gradient(rois)

rois = ops.stop_gradient(rois)
# 5. Sample the ROIS -- exclusive to compute_loss
# -- exclusive to compute loss
(
rois,
box_targets,
box_weights,
cls_targets,
cls_weights,
) = self.roi_sampler(rois, gt_boxes, gt_classes)
cls_targets = ops.squeeze(
cls_targets, axis=-1
) # to apply one hot encoding

# to apply one hot encoding
cls_targets = ops.squeeze(cls_targets, axis=-1)
cls_weights = ops.squeeze(cls_weights, axis=-1)

# 6. Box and class weights -- exclusive to compute loss
box_weights /= self.roi_sampler.num_sampled_rois * local_batch * 0.25
cls_weights /= self.roi_sampler.num_sampled_rois * local_batch

# print(f"Box Targets Shape: {box_targets.shape}")
# print(f"Box Weights Shape: {box_weights.shape}")
# print(f"Cls Targets Shape: {cls_targets.shape}")
# print(f"Cls Weights Shape: {cls_weights.shape}")
# print(f"RPN Box Targets Shape: {rpn_box_targets.shape}")
# print(f"RPN Box Weights Shape: {rpn_box_weights.shape}")
# print(f"RPN Cls Targets Shape: {rpn_cls_targets.shape}")
# print(f"RPN Cls Weights Shape: {rpn_cls_weights.shape}")
# print(f"Cls Weights: {cls_weights}")
# print(f"Box Weights: {box_weights}")

# print(f"Cls Targets: {cls_targets}")
# print(f"Box Targets: {box_targets}")

#######################################################################
# Call RCNN
#######################################################################

feature_map = self.roi_pooler(features=feature_map, boxes=rois)

# [BS, H*W*K]
feature_map = keras.ops.reshape(
feature_map = ops.reshape(
feature_map,
newshape=keras.ops.shape(rois)[:2] + (-1,),
newshape=ops.shape(rois)[:2] + (-1,),
)

# [BS, H*W*K, 4], [BS, H*W*K, num_classes + 1]
box_pred, cls_pred = self.rcnn_head(feature_map=feature_map)

# Class targets will be in categorical so change it to one hot encoding
cls_targets = keras.ops.one_hot(
cls_targets,
self.num_classes + 1, # +1 for background class
dtype=cls_pred.dtype,
)

y_true = {
"rpn_box": rpn_box_targets,
"rpn_classification": rpn_cls_targets,
Expand Down Expand Up @@ -441,66 +459,6 @@ def test_step(self, *args):
x, y = unpack_input(data)
return super().test_step(*args, (x, y))

def predict_step(self, *args):
outputs = super().predict_step(*args)
if type(outputs) is tuple:
return self.decode_predictions(outputs[0], args[-1]), outputs[1]
else:
return self.decode_predictions(outputs, args[-1])

@property
def prediction_decoder(self):
return self._prediction_decoder

@prediction_decoder.setter
def prediction_decoder(self, prediction_decoder):
if prediction_decoder.bounding_box_format != self.bounding_box_format:
raise ValueError(
"Expected `prediction_decoder` and RetinaNet to "
"use the same `bounding_box_format`, but got "
"`prediction_decoder.bounding_box_format="
f"{prediction_decoder.bounding_box_format}`, and "
"`self.bounding_box_format="
f"{self.bounding_box_format}`."
)
self._prediction_decoder = prediction_decoder
self.make_predict_function(force=True)
self.make_train_function(force=True)
self.make_test_function(force=True)

def decode_predictions(self, predictions, images):
box_pred, cls_pred = predictions["box"], predictions["classification"]
# box_pred is on "center_yxhw" format, convert to target format.
image_shape = tuple(images[0].shape)
anchors = self.anchor_generator(image_shape=image_shape)
anchors = ops.concatenate([a for a in anchors.values()], axis=0)

box_pred = _decode_deltas_to_boxes(
anchors=anchors,
boxes_delta=box_pred,
anchor_format=self.anchor_generator.bounding_box_format,
box_format=self.bounding_box_format,
variance=BOX_VARIANCE,
image_shape=image_shape,
)
# box_pred is now in "self.bounding_box_format" format
box_pred = bounding_box.convert_format(
box_pred,
source=self.bounding_box_format,
target=self.prediction_decoder.bounding_box_format,
image_shape=image_shape,
)
y_pred = self.prediction_decoder(
box_pred, cls_pred, image_shape=image_shape
)
y_pred["boxes"] = bounding_box.convert_format(
y_pred["boxes"],
source=self.prediction_decoder.bounding_box_format,
target=self.bounding_box_format,
image_shape=image_shape,
)
return y_pred

@staticmethod
def default_anchor_generator(scales, aspect_ratios, bounding_box_format):
strides = {f"P{i}": 2**i for i in range(2, 7)}
Expand All @@ -511,7 +469,7 @@ def default_anchor_generator(scales, aspect_ratios, bounding_box_format):
"P5": 256.0,
"P6": 512.0,
}
return cv_layers.AnchorGenerator(
return AnchorGenerator(
bounding_box_format=bounding_box_format,
sizes=sizes,
aspect_ratios=aspect_ratios,
Expand Down Expand Up @@ -564,7 +522,7 @@ def _parse_classification_loss(loss):
if loss.lower() == "focal":
return losses.FocalLoss(reduction="sum", from_logits=True)
if loss.lower() == "categoricalcrossentropy":
return keras.losses.CategoricalCrossentropy(
return keras.losses.SparseCategoricalCrossentropy(
reduction="sum", from_logits=True
)

Expand Down
Loading