Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactored set_cell_anchors() in AnchorGenerator #3755

Merged
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Refactored set_cell_anchors() in AnchorGenerator
  • Loading branch information
prabhat00155 committed Apr 29, 2021
commit b662d97d9a5819914efdbd7d4afee30c5ec337c7
27 changes: 6 additions & 21 deletions torchvision/models/detection/anchor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def __init__(

self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self.cell_anchors = [self.generate_anchors(size, aspect_ratio)
for size, aspect_ratio in zip(sizes, aspect_ratios)]

# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered anchor with those values.
Expand All @@ -67,24 +68,8 @@ def generate_anchors(self, scales: List[int], aspect_ratios: List[float], dtype:
return base_anchors.round()

def set_cell_anchors(self, dtype: torch.dtype, device: torch.device):
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return

cell_anchors = [
self.generate_anchors(
sizes,
aspect_ratios,
dtype,
device
)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
self.cell_anchors = [torch.as_tensor(cell_anchor, dtype=dtype, device=device)
prabhat00155 marked this conversation as resolved.
Show resolved Hide resolved
for cell_anchor in self.cell_anchors]

def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
Expand Down Expand Up @@ -130,15 +115,15 @@ def grid_anchors(self, grid_sizes: List[List[int]], strides: List[List[Tensor]])
return anchors

def forward(self, image_list: ImageList, feature_maps: List[Tensor]) -> List[Tensor]:
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
grid_sizes = [feature_map.shape[-2:] for feature_map in feature_maps]
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [[torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device)] for g in grid_sizes]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.grid_anchors(grid_sizes, strides)
anchors: List[List[torch.Tensor]] = []
for i in range(len(image_list.image_sizes)):
for _ in range(len(image_list.image_sizes)):
anchors_in_image = [anchors_per_feature_map for anchors_per_feature_map in anchors_over_all_feature_maps]
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
Expand Down