Skip to content

Commit

Permalink
Simplify metrics calculation (#9338)
Browse files Browse the repository at this point in the history
Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>
  • Loading branch information
Kayzwer and glenn-jocher authored Mar 27, 2024
1 parent 1325889 commit 978a3ca
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions ultralytics/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def kpt_iou(kpt1, kpt2, area, sigma, eps=1e-7):
d = (kpt1[:, None, :, 0] - kpt2[..., 0]).pow(2) + (kpt1[:, None, :, 1] - kpt2[..., 1]).pow(2) # (N, M, 17)
sigma = torch.tensor(sigma, device=kpt1.device, dtype=kpt1.dtype) # (17, )
kpt_mask = kpt1[..., 2] != 0 # (N, 17)
e = d / (2 * sigma).pow(2) / (area[:, None, None] + eps) / 2 # from cocoeval
e = d / ((2 * sigma).pow(2) * (area[:, None, None] + eps) * 2) # from cocoeval
# e = d / ((area[None, :, None] + eps) * sigma) ** 2 / 2 # from formula
return ((-e).exp() * kpt_mask[:, None]).sum(-1) / (kpt_mask.sum(-1)[:, None] + eps)

Expand Down Expand Up @@ -402,7 +402,7 @@ def plot(self, normalize=True, save_dir="", names=(), on_plot=None):

fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
nc, nn = self.nc, len(names) # number of classes, names
sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
sn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
ticklabels = (list(names) + ["background"]) if labels else "auto"
with warnings.catch_warnings():
Expand Down

0 comments on commit 978a3ca

Please sign in to comment.