Skip to content

Commit

Permalink
add detection_threshold parameter for fixed keypoint number detector
Browse files Browse the repository at this point in the history
  • Loading branch information
felipecadar committed Jun 22, 2024
1 parent 5580d38 commit 65ae30b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
4 changes: 2 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from modules.xfeat import XFeat as _XFeat
import torch

def XFeat(pretrained=True, top_k=4096):
def XFeat(pretrained=True, top_k=4096, detection_threshold=0.05):
"""
XFeat model
pretrained (bool): kwargs, load pretrained weights into the model
Expand All @@ -11,5 +11,5 @@ def XFeat(pretrained=True, top_k=4096):
if pretrained:
weights = torch.hub.load_state_dict_from_url("https://github.com/verlab/accelerated_features/raw/main/weights/xfeat.pt")

model = _XFeat(weights, top_k=top_k)
model = _XFeat(weights, top_k=top_k, detection_threshold=detection_threshold)
return model
8 changes: 5 additions & 3 deletions modules/xfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@ class XFeat(nn.Module):
It supports inference for both sparse and semi-dense feature extraction & matching.
"""

def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096):
def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../weights/xfeat.pt', top_k = 4096, detection_threshold=0.05):
super().__init__()
self.dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.net = XFeatModel().to(self.dev).eval()
self.top_k = top_k
self.detection_threshold = detection_threshold

if weights is not None:
if isinstance(weights, str):
Expand All @@ -36,7 +37,7 @@ def __init__(self, weights = os.path.abspath(os.path.dirname(__file__)) + '/../w
self.interpolator = InterpolateSparse2d('bicubic')

@torch.inference_mode()
def detectAndCompute(self, x, top_k = None):
def detectAndCompute(self, x, top_k = None, detection_threshold = None):
"""
Compute sparse keypoints & descriptors. Supports batched mode.
Expand All @@ -50,6 +51,7 @@ def detectAndCompute(self, x, top_k = None):
'descriptors' -> torch.Tensor(N, 64): local features
"""
if top_k is None: top_k = self.top_k
if detection_threshold is None: detection_threshold = self.detection_threshold
x, rh1, rw1 = self.preprocess_tensor(x)

B, _, _H1, _W1 = x.shape
Expand All @@ -59,7 +61,7 @@ def detectAndCompute(self, x, top_k = None):

#Convert logits to heatmap and extract kpts
K1h = self.get_kpts_heatmap(K1)
mkpts = self.NMS(K1h, threshold=0.05, kernel_size=5)
mkpts = self.NMS(K1h, threshold=detection_threshold, kernel_size=5)

#Compute reliability scores
_nearest = InterpolateSparse2d('nearest')
Expand Down

0 comments on commit 65ae30b

Please sign in to comment.