Skip to content

Commit

Permalink
parallel batch nearest neighbor search (koide3#68)
Browse files Browse the repository at this point in the history
  • Loading branch information
koide3 authored Jun 20, 2024
1 parent ac6c79a commit 7e42a90
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 4 deletions.
31 changes: 27 additions & 4 deletions src/python/kdtree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void define_kdtree(py::module& m) {
k_sq_dist : float
The squared distance to the nearest neighbor.
)""")

.def(
"knn_search",
[](const KdTree<PointCloud>& kdtree, const Eigen::Vector3d& pt, int k) {
Expand Down Expand Up @@ -85,28 +86,39 @@ void define_kdtree(py::module& m) {
k_sq_dists : NDArray, shape (k,)
The squared distances to the k nearest neighbors.
)""")

.def(
"batch_nearest_neighbor_search",
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts) {
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int num_threads) {
if (pts.cols() != 3 && pts.cols() != 4) {
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
}

std::vector<size_t> k_indices(pts.rows(), -1);
std::vector<double> k_sq_dists(pts.rows(), std::numeric_limits<double>::max());

#pragma omp parallel for num_threads(num_threads)
for (int i = 0; i < pts.rows(); ++i) {
const size_t found = traits::nearest_neighbor_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), &k_indices[i], &k_sq_dists[i]);
if (!found) {
k_indices[i] = -1;
k_sq_dists[i] = std::numeric_limits<double>::max();
}
}

return std::make_pair(k_indices, k_sq_dists);
},
py::arg("pts"),
py::arg("num_threads") = 1,
R"""(
Find the nearest neighbors for a batch of points.
Parameters
----------
pts : NDArray, shape (n, 3)
pts : NDArray, shape (n, 3) or (n, 4)
The input points.
num_threads : int, optional
The number of threads to use for the search. Default is 1.
Returns
-------
Expand All @@ -115,11 +127,18 @@ void define_kdtree(py::module& m) {
k_sq_dists : NDArray, shape (n,)
The squared distances to the nearest neighbors for each input point.
)""")

.def(
"batch_knn_search",
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int k) {
[](const KdTree<PointCloud>& kdtree, const Eigen::MatrixXd& pts, int k, int num_threads) {
if (pts.cols() != 3 && pts.cols() != 4) {
throw std::invalid_argument("pts must have shape (n, 3) or (n, 4)");
}

std::vector<std::vector<size_t>> k_indices(pts.rows(), std::vector<size_t>(k, -1));
std::vector<std::vector<double>> k_sq_dists(pts.rows(), std::vector<double>(k, std::numeric_limits<double>::max()));

#pragma omp parallel for num_threads(num_threads)
for (int i = 0; i < pts.rows(); ++i) {
const size_t found = traits::knn_search(kdtree, Eigen::Vector4d(pts(i, 0), pts(i, 1), pts(i, 2), 1.0), k, k_indices[i].data(), k_sq_dists[i].data());
if (found < k) {
Expand All @@ -129,19 +148,23 @@ void define_kdtree(py::module& m) {
}
}
}

return std::make_pair(k_indices, k_sq_dists);
},
py::arg("pts"),
py::arg("k"),
py::arg("num_threads") = 1,
R"""(
Find the k nearest neighbors for a batch of points.
Parameters
----------
pts : NDArray, shape (n, 3)
pts : NDArray, shape (n, 3) or (n, 4)
The input points.
k : int
The number of nearest neighbors to search for.
num_threads : int, optional
The number of threads to use for the search. Default is 1.
Returns
-------
Expand Down
67 changes: 67 additions & 0 deletions src/test/python_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright 2024 Kenji Koide
# SPDX-License-Identifier: MIT
import numpy
from scipy.spatial import KDTree
from scipy.spatial.transform import Rotation

import small_gicp
Expand Down Expand Up @@ -188,3 +189,69 @@ def test_registration(load_points):

result = small_gicp.align(target_voxelmap, source)
verify_result(result.T_target_source, gt_T_target_source)

# KdTree test
def test_kdtree(load_points):
_, target_raw_numpy, source_raw_numpy = load_points

target, target_tree = small_gicp.preprocess_points(target_raw_numpy, downsampling_resolution=0.5)
source, source_tree = small_gicp.preprocess_points(source_raw_numpy, downsampling_resolution=0.5)

target_tree_ref = KDTree(target.points())
source_tree_ref = KDTree(source.points())

def batch_test(points, queries, tree, tree_ref, num_threads):
# test for batch interface
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1)
k_indices, k_sq_dists = tree.batch_nearest_neighbor_search(queries)
assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists) < 1e-6)
assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices] - queries, axis=1) ** 2 - k_sq_dists) < 1e-6)

for k in [2, 10]:
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k)
k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref)

k_indices, k_sq_dists = tree.batch_knn_search(queries, k, num_threads=num_threads)
k_indices, k_sq_dists = numpy.array(k_indices), numpy.array(k_sq_dists)

assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists) < 1e-6))
for i in range(k):
diff = numpy.linalg.norm(points[k_indices[:, i]] - queries, axis=1) ** 2 - k_sq_dists[:, i]
assert(numpy.all(numpy.abs(diff) < 1e-6))

# test for single query interface
if num_threads != 1:
return

k_dists_ref, k_indices_ref = tree_ref.query(queries, k=1)
k_indices2, k_sq_dists2 = [], []
for query in queries:
found, index, sq_dist = tree.nearest_neighbor_search(query[:3])
assert found
k_indices2.append(index)
k_sq_dists2.append(sq_dist)

assert numpy.all(numpy.abs(numpy.square(k_dists_ref) - k_sq_dists2) < 1e-6)
assert numpy.all(numpy.abs(numpy.linalg.norm(points[k_indices2] - queries, axis=1) ** 2 - k_sq_dists2) < 1e-6)

for k in [2, 10]:
k_dists_ref, k_indices_ref = tree_ref.query(queries, k=k)
k_sq_dists_ref, k_indices_ref = numpy.array(k_dists_ref) ** 2, numpy.array(k_indices_ref)

k_indices2, k_sq_dists2 = [], []
for query in queries:
indices, sq_dists = tree.knn_search(query[:3], k)
k_indices2.append(indices)
k_sq_dists2.append(sq_dists)
k_indices2, k_sq_dists2 = numpy.array(k_indices2), numpy.array(k_sq_dists2)

assert(numpy.all(numpy.abs(k_sq_dists_ref - k_sq_dists2) < 1e-6))
for i in range(k):
diff = numpy.linalg.norm(points[k_indices2[:, i]] - queries, axis=1) ** 2 - k_sq_dists2[:, i]
assert(numpy.all(numpy.abs(diff) < 1e-6))


for num_threads in [1, 2]:
batch_test(target.points(), target.points(), target_tree, target_tree_ref, num_threads=num_threads)
batch_test(target.points(), source.points(), target_tree, target_tree_ref, num_threads=num_threads)
batch_test(source.points(), target.points(), source_tree, source_tree_ref, num_threads=num_threads)

0 comments on commit 7e42a90

Please sign in to comment.