Median of Medians in Java

前端 未结 5 1500
清歌不尽
清歌不尽 2020-12-03 09:11

I am trying to implement Median of Medians in Java for a method like this:

Select(Comparable[] list, int pos, int colSize, int colMed)
5条回答
  •  时光取名叫无心
    2020-12-03 10:14

    The question asked for Java, so here it is

    import java.util.*;
    
    public class MedianOfMedians {
        private MedianOfMedians() {
    
        }
    
        /**
         * Returns median of list in linear time.
         * 
         * @param list list to search, which may be reordered on return
         * @return median of array in linear time.
         */
        public static Comparable getMedian(ArrayList list) {
            int s = list.size();
            if (s < 1)
                throw new IllegalArgumentException();
            int pos = select(list, 0, s, s / 2);
            return list.get(pos);
        }
    
        /**
         * Returns position of k'th largest element of sub-list.
         * 
         * @param list list to search, whose sub-list may be shuffled before
         *            returning
         * @param lo first element of sub-list in list
         * @param hi just after last element of sub-list in list
         * @param k
         * @return position of k'th largest element of (possibly shuffled) sub-list.
         */
        public static int select(ArrayList list, int lo, int hi, int k) {
            if (lo >= hi || k < 0 || lo + k >= hi)
                throw new IllegalArgumentException();
            if (hi - lo < 10) {
                Collections.sort(list.subList(lo, hi));
                return lo + k;
            }
            int s = hi - lo;
            int np = s / 5; // Number of partitions
            for (int i = 0; i < np; i++) {
                // For each partition, move its median to front of our sublist
                int lo2 = lo + i * 5;
                int hi2 = (i + 1 == np) ? hi : (lo2 + 5);
                int pos = select(list, lo2, hi2, 2);
                Collections.swap(list, pos, lo + i);
            }
    
            // Partition medians were moved to front, so we can recurse without making another list.
            int pos = select(list, lo, lo + np, np / 2);
    
            // Re-partition list to [pivot]
            int m = triage(list, lo, hi, pos);
            int cmp = lo + k - m;
            if (cmp > 0)
                return select(list, m + 1, hi, k - (m - lo) - 1);
            else if (cmp < 0)
                return select(list, lo, m, k);
            return lo + k;
        }
    
        /**
         * Partition sub-list into 3 parts [pivot].
         * 
         * @param list
         * @param lo
         * @param hi
         * @param pos input position of pivot value
         * @return output position of pivot value
         */
        private static int triage(ArrayList list, int lo, int hi,
                int pos) {
            Comparable pivot = list.get(pos);
            int lo3 = lo;
            int hi3 = hi;
            while (lo3 < hi3) {
                Comparable e = list.get(lo3);
                int cmp = e.compareTo(pivot);
                if (cmp < 0)
                    lo3++;
                else if (cmp > 0)
                    Collections.swap(list, lo3, --hi3);
                else {
                    while (hi3 > lo3 + 1) {
                        assert (list.get(lo3).compareTo(pivot) == 0);
                        e = list.get(--hi3);
                        cmp = e.compareTo(pivot);
                        if (cmp <= 0) {
                            if (lo3 + 1 == hi3) {
                                Collections.swap(list, lo3, lo3 + 1);
                                lo3++;
                                break;
                            }
                            Collections.swap(list, lo3, lo3 + 1);
                            assert (list.get(lo3 + 1).compareTo(pivot) == 0);
                            Collections.swap(list, lo3, hi3);
                            lo3++;
                            hi3++;
                        }
                    }
                    break;
                }
            }
            assert (list.get(lo3).compareTo(pivot) == 0);
            return lo3;
        }
    
    }
    

    Here is a Unit test to check it works...

    import java.util.*;
    
    import junit.framework.TestCase;
    
    public class MedianOfMedianTest extends TestCase {
        public void testMedianOfMedianTest() {
            Random r = new Random(1);
            int n = 87;
            for (int trial = 0; trial < 1000; trial++) {
                ArrayList list = new ArrayList();
                int[] a = new int[n];
                for (int i = 0; i < n; i++) {
                    int v = r.nextInt(256);
                    a[i] = v;
                    list.add(v);
                }
                int m1 = (Integer)MedianOfMedians.getMedian(list);
                Arrays.sort(a);
                int m2 = a[n/2];
                assertEquals(m1, m2);
            }
        }
    }
    

    However, the above code is too slow for practical use.

    Here is a simpler way to get the k'th element that does not guarantee performance, but is much faster in practice:

    /**
     * Returns position of k'th largest element of sub-list.
     * 
     * @param list list to search, whose sub-list may be shuffled before
     *            returning
     * @param lo first element of sub-list in list
     * @param hi just after last element of sub-list in list
     * @param k
     * @return position of k'th largest element of (possibly shuffled) sub-list.
     */
    static int select(double[] list, int lo, int hi, int k) {
        int n = hi - lo;
        if (n < 2)
            return lo;
    
        double pivot = list[lo + (k * 7919) % n]; // Pick a random pivot
    
        // Triage list to [pivot]
        int nLess = 0, nSame = 0, nMore = 0;
        int lo3 = lo;
        int hi3 = hi;
        while (lo3 < hi3) {
            double e = list[lo3];
            int cmp = compare(e, pivot);
            if (cmp < 0) {
                nLess++;
                lo3++;
            } else if (cmp > 0) {
                swap(list, lo3, --hi3);
                if (nSame > 0)
                    swap(list, hi3, hi3 + nSame);
                nMore++;
            } else {
                nSame++;
                swap(list, lo3, --hi3);
            }
        }
        assert (nSame > 0);
        assert (nLess + nSame + nMore == n);
        assert (list[lo + nLess] == pivot);
        assert (list[hi - nMore - 1] == pivot);
        if (k >= n - nMore)
            return select(list, hi - nMore, hi, k - nLess - nSame);
        else if (k < nLess)
            return select(list, lo, lo + nLess, k);
        return lo + k;
    }
    

提交回复
热议问题