On the wikipedia entry for k-d trees, an algorithm is presented for doing a nearest neighbor search on a k-d tree. What I don\'t understand is the explanation of step 3.2.
Yes, the description of NN (Nearest Neighbour) search in a KD Tree on Wikipedia is a little hard to follow. It doesn't help that a lot of the top Google search results on NN KD Tree searches are just plain wrong!
Here's some C++ code to show you how to get it right:
template
void KDTree::nearest (
const const KDNode &node,
const std::array &point, // looking for closest node to this point
const KDPoint &closest, // closest node (so far)
double &minDist,
const uint depth) const
{
if (node->isLeaf()) {
const double dist = distance(point, node->leaf->point);
if (dist < minDist) {
minDist = dist;
closest = node->leaf;
}
} else {
const T dim = depth % N;
if (point[dim] < node->splitVal) {
// search left first
nearest(node->left, point, closest, minDist, depth + 1);
if (point[dim] + minDist >= node->splitVal)
nearest(node->right, point, closest, minDist, depth + 1);
} else {
// search right first
nearest(node->right, point, closest, minDist, depth + 1);
if (point[dim] - minDist <= node->splitVal)
nearest(node->left, point, closest, minDist, depth + 1);
}
}
}
API for NN searching on a KD Tree:
// Nearest neighbour
template
const KDPoint KDTree::nearest (const std::array &point) const {
const KDPoint closest;
double minDist = std::numeric_limits::max();
nearest(root, point, closest, minDist);
return closest;
}
Default distance function:
template
double distance (const std::array &p1, const std::array &p2) {
double d = 0.0;
for (uint i = 0; i < N; ++i) {
d += pow(p1[i] - p2[i], 2.0);
}
return sqrt(d);
}
Edit: some people are asking for help with the data structures too (not just the NN algorithm), so here is what I have used. Depending on your purpose, you might wish to modify the data structures slightly. (Note: but you almost certainly do not want to modify the NN algorithm.)
KDPoint class:
template
class KDPoint {
public:
KDPoint (std::array &&t) : point(std::move(t)) { };
virtual ~KDPoint () = default;
std::array point;
};
KDNode class:
template
class KDNode
{
public:
KDNode () = delete;
KDNode (const KDNode &) = delete;
KDNode & operator = (const KDNode &) = delete;
~KDNode () = default;
// branch node
KDNode (const T split,
std::unique_ptr &lhs,
std::unique_ptr &rhs) : splitVal(split), left(std::move(lhs)), right(std::move(rhs)) { };
// leaf node
KDNode (std::shared_ptr> p) : splitVal(0), leaf(p) { };
bool isLeaf (void) const { return static_cast(leaf); }
// data members
const T splitVal;
const std::unique_ptr> left, right;
const std::shared_ptr> leaf;
};
KDTree class: (Note: you'll need to add a member function to build/fill your tree.)
template
class KDTree {
public:
KDTree () = delete;
KDTree (const KDTree &) = delete;
KDTree (KDTree &&t) : root(std::move(const_cast>&>(t.root))) { };
KDTree & operator = (const KDTree &) = delete;
~KDTree () = default;
const KDPoint nearest (const std::array &point) const;
// Nearest neighbour search - runs in O(log n)
void nearest (const std::unique_ptr> &node,
const std::array &point,
std::shared_ptr> &closest,
double &minDist,
const uint depth = 0) const;
// data members
const std::unique_ptr> root;
};