Implementing a Map where keys are sets of non-overlapping ranges

后端 未结 2 1351
广开言路
广开言路 2020-12-21 14:25

I am facing a performance issue with my current implementation using List and loops. I was thinking to make some custom Map but is it possible to o

相关标签:
2条回答
  • 2020-12-21 15:06

    There is a structure called an Interval Tree that may fit your needs. Here's an implementation of it.

    It allows you to attach objects to intervals rather than the usual object.

    Note that this implementation does not implement the sorted indexes suggested by the original algorithm as the use case I needed it for did not require that level of speed.

    /**
     * @author OldCurmudgeon
     * @param <T> - The type stored in the tree. Must implement IntervalTree.Interval but beyond that you can do what you like. Probably store that value in there too.
     */
    public class IntervalTree<T extends IntervalTree.Interval> {
    
        // My intervals.
        private final List<T> intervals;
        // My center value. All my intervals contain this center.
        private final long center;
        // My interval range.
        private final long lBound;
        private final long uBound;
        // My left tree. All intervals that end below my center.
        private final IntervalTree<T> left;
        // My right tree. All intervals that start above my center.
        private final IntervalTree<T> right;
    
        public IntervalTree(List<T> intervals) {
            if (intervals == null) {
                throw new NullPointerException();
            }
    
            // Initially, my root contains all intervals.
            this.intervals = intervals;
    
            // Find my center.
            center = findCenter();
    
            /*
             * Builds lefts out of all intervals that end below my center.
             * Builds rights out of all intervals that start above my center.
             * What remains contains all the intervals that contain my center.
             */
            // Lefts contains all intervals that end below my center point.
            final List<T> lefts = new ArrayList<>();
            // Rights contains all intervals that start above my center point.
            final List<T> rights = new ArrayList<>();
    
            // Track my bounds while distributing.
            long uB = Long.MIN_VALUE;
            long lB = Long.MAX_VALUE;
            for (T i : intervals) {
                long start = i.getStart();
                long end = i.getEnd();
                if (end < center) {
                    // It ends below me - move it to my left.
                    lefts.add(i);
                } else if (start > center) {
                    // It starts above me - move it to my right.
                    rights.add(i);
                } else {
                    // One of mine.
                    lB = Math.min(lB, start);
                    uB = Math.max(uB, end);
                }
            }
    
            // Remove all those not mine.
            intervals.removeAll(lefts);
            intervals.removeAll(rights);
            // Record my bounds.
            uBound = uB;
            lBound = lB;
    
            // Build the subtrees.
            left = lefts.size() > 0 ? new IntervalTree<>(lefts) : null;
            right = rights.size() > 0 ? new IntervalTree<>(rights) : null;
    
            // Build my ascending and descending arrays.
            /**
             * @todo Build my ascending and descending arrays.
             */
        }
    
        /*
         * Returns a list of all intervals containing the point.
         */
        List<T> query(long point) {
            // Check my range.
            if (point >= lBound) {
                if (point <= uBound) {
                    // In my range but remember, there may also be contributors from left or right.
                    List<T> found = new ArrayList<>();
                    // Gather all intersecting ones.
                    // Could be made faster (perhaps) by holding two sorted lists by start and end.
                    for (T i : intervals) {
                        if (i.getStart() <= point && point <= i.getEnd()) {
                            found.add(i);
                        }
                    }
    
                    // Gather others.
                    if (point < center && left != null) {
                        found.addAll(left.query(point));
                    }
                    if (point > center && right != null) {
                        found.addAll(right.query(point));
                    }
    
                    return found;
                } else {
                    // To right.
                    return right != null ? right.query(point) : Collections.<T>emptyList();
                }
            } else {
                // To left.
                return left != null ? left.query(point) : Collections.<T>emptyList();
            }
    
        }
    
        private long findCenter() {
            //return average();
            return median();
        }
    
        protected long median() {
            // Choose the median of all centers. Could choose just ends etc or anything.
            long[] points = new long[intervals.size()];
            int x = 0;
            for (T i : intervals) {
                // Take the mid point.
                points[x++] = (i.getStart() + i.getEnd()) / 2;
            }
            Arrays.sort(points);
            return points[points.length / 2];
        }
    
        /*
         * What an interval looks like.
         */
        public interface Interval {
    
            public long getStart();
    
            public long getEnd();
    
        }
    
        /*
         * A simple implemementation of an interval.
         */
        public static class SimpleInterval implements Interval {
    
            private final long start;
            private final long end;
    
            public SimpleInterval(long start, long end) {
                this.start = start;
                this.end = end;
            }
    
            @Override
            public long getStart() {
                return start;
            }
    
            @Override
            public long getEnd() {
                return end;
            }
    
            @Override
            public String toString() {
                return "{" + start + "," + end + "}";
            }
    
        }
    
        public static void main(String[] args) {
            // Make some test data.
            final int testEntries = 1 * 100;
            ArrayList<SimpleInterval> intervals = new ArrayList<>();
            Random random = new Random();
            for (int i = 0; i < testEntries; i++) {
                // Make a random interval.
                long start = random.nextLong();
                intervals.add(new SimpleInterval(start, start + 1000));
            }
            ProcessTimer timer = new ProcessTimer();
            IntervalTree<SimpleInterval> tree = new IntervalTree<>(intervals);
            System.out.println("Took " + timer);
        }
    
    }
    
    0 讨论(0)
  • 2020-12-21 15:24

    UPDATE: Added full implementation

    UPDATE 2: If you want you can use RangeMap for internal theMap as suggested in the comments.

    If you key ranges don't overlap, you can create a custom container which internally stores data in TreeMap with a custom key which implements Comparable:

    class MyStorage<T> {
        private static final class Range implements Comparable<Range> {
            private int first;
            private int last;
    
            public Range(int first_, int last_) {
                first = first_;
                last = last_;
            }
    
            // This heavily relies on that the ranges don't overlap
            @Override public int compareTo(Range other) {
                if (last < other.first)
                    return -1;
                if (first > other.last)
                    return 1;
                return 0;
            }
        }
    
        private Map<Range, T> theMap = new TreeMap<Range, T>();
    
        public void put(String key, T obj) {
            String[] ranges = key.split(";");
            for (String range : ranges) {
                //System.out.println("Adding " + range);
                String[] bounds = range.split("-");
                //System.out.println("Bounds " + bounds.length);
                int first = Integer.parseInt(bounds[0]);
                if (bounds.length == 1)
                    theMap.put(new Range(first, first), obj);
                else 
                    theMap.put(new Range(first, Integer.parseInt(bounds[1])), obj);
            }
        }
    
        public T get(String key) {
            return get(Integer.parseInt(key));
        }
    
        public T get(int key) {
            return theMap.get(new Range(key, key));
        }
    }
    
    class Main
    {
        public static void main (String[] args) throws java.lang.Exception
        {
            MyStorage<Integer> storage = new MyStorage<Integer>();
            storage.put("10;20-30", 123);
            storage.put("15;31-50", 456);
    
            System.out.println(storage.get("42"));
        }
    }
    
    0 讨论(0)
提交回复
热议问题