nearest neighbor - k-d tree - wikipedia proof

后端 未结 2 1047
灰色年华
灰色年华 2020-12-24 08:59

On the wikipedia entry for k-d trees, an algorithm is presented for doing a nearest neighbor search on a k-d tree. What I don\'t understand is the explanation of step 3.2.

2条回答
  •  野趣味
    野趣味 (楼主)
    2020-12-24 09:20

    Yes, the description of NN (Nearest Neighbour) search in a KD Tree on Wikipedia is a little hard to follow. It doesn't help that a lot of the top Google search results on NN KD Tree searches are just plain wrong!

    Here's some C++ code to show you how to get it right:

    template 
    void KDTree::nearest (
        const const KDNode &node,
        const std::array &point, // looking for closest node to this point
        const KDPoint &closest,   // closest node (so far)
        double &minDist,
        const uint depth) const
    {
        if (node->isLeaf()) {
            const double dist = distance(point, node->leaf->point);
            if (dist < minDist) {
                minDist = dist;
                closest = node->leaf;
            }
        } else {
            const T dim = depth % N;
            if (point[dim] < node->splitVal) {
                // search left first
                nearest(node->left, point, closest, minDist, depth + 1);
                if (point[dim] + minDist >= node->splitVal)
                    nearest(node->right, point, closest, minDist, depth + 1);
            } else {
                // search right first
                nearest(node->right, point, closest, minDist, depth + 1);
                if (point[dim] - minDist <= node->splitVal)
                    nearest(node->left, point, closest, minDist, depth + 1);
            }
        }
    }
    

    API for NN searching on a KD Tree:

    // Nearest neighbour
    template 
    const KDPoint KDTree::nearest (const std::array &point) const {
        const KDPoint closest;
        double minDist = std::numeric_limits::max();
        nearest(root, point, closest, minDist);
        return closest;
    }
    

    Default distance function:

    template 
    double distance (const std::array &p1, const std::array &p2) {
        double d = 0.0;
        for (uint i = 0; i < N; ++i) {
            d += pow(p1[i] - p2[i], 2.0);
        }
        return sqrt(d);
    }
    

    Edit: some people are asking for help with the data structures too (not just the NN algorithm), so here is what I have used. Depending on your purpose, you might wish to modify the data structures slightly. (Note: but you almost certainly do not want to modify the NN algorithm.)

    KDPoint class:

    template 
    class KDPoint {
        public:
            KDPoint (std::array &&t) : point(std::move(t)) { };
            virtual ~KDPoint () = default;
            std::array point;
    };
    

    KDNode class:

    template 
    class KDNode
    {
        public:
            KDNode () = delete;
            KDNode (const KDNode &) = delete;
            KDNode & operator = (const KDNode &) = delete;
            ~KDNode () = default;
    
            // branch node
            KDNode (const T                       split,
                    std::unique_ptr &lhs,
                    std::unique_ptr &rhs) : splitVal(split), left(std::move(lhs)), right(std::move(rhs)) { };
            // leaf node
            KDNode (std::shared_ptr> p) : splitVal(0), leaf(p) { };
    
            bool isLeaf (void) const { return static_cast(leaf); }
    
            // data members
            const T                                   splitVal;
            const std::unique_ptr>  left, right;
            const std::shared_ptr> leaf;
    };
    

    KDTree class: (Note: you'll need to add a member function to build/fill your tree.)

    template 
    class KDTree {
        public:
            KDTree () = delete;
            KDTree (const KDTree &) = delete;
            KDTree (KDTree &&t) : root(std::move(const_cast>&>(t.root))) { };
            KDTree & operator = (const KDTree &) = delete;
            ~KDTree () = default;
    
            const KDPoint nearest (const std::array &point) const;
    
            // Nearest neighbour search - runs in O(log n)
            void nearest (const std::unique_ptr> &node,
                          const std::array &point,
                          std::shared_ptr> &closest,
                          double &minDist,
                          const uint depth = 0) const;
    
            // data members
            const std::unique_ptr> root;
    };
    

提交回复
热议问题