Random number with Probabilities

前端 未结 12 2213
醉梦人生
醉梦人生 2020-11-27 02:56

I am wondering what would be the best way (e.g. in Java) to generate random numbers within a particular range where each number has a certain probability to occur or not?

12条回答
  •  情歌与酒
    2020-11-27 03:15

    Written this class for interview after referencing the paper pointed by pjs in another post , the population of base64 table can be further optimized. The result is amazingly fast, initialization is slightly expensive, but if the probabilities are not changing often, this is a good approach.

    *For duplicate key, the last probability is taken instead of being combined (slightly different from EnumeratedIntegerDistribution behaviour)

    public class RandomGen5 extends BaseRandomGen {
    
        private int[] t_array = new int[4];
        private int sumOfNumerator;
        private final static int DENOM = (int) Math.pow(2, 24);
        private static final int[] bitCount = new int[] {18, 12, 6, 0};
        private static final int[] cumPow64 = new int[] {
                (int) ( Math.pow( 64, 3 ) + Math.pow( 64, 2 ) + Math.pow( 64, 1 ) + Math.pow( 64, 0 ) ),
                (int) ( Math.pow( 64, 2 ) + Math.pow( 64, 1 ) + Math.pow( 64, 0 ) ),
                (int) ( Math.pow( 64, 1 ) + Math.pow( 64, 0 ) ),
                (int) ( Math.pow( 64, 0 ) )
        };
    
    
        ArrayList[] base64Table = {new ArrayList()
                , new ArrayList()
                , new ArrayList()
                , new ArrayList()};
    
        public int nextNum() {
            int rand = (int) (randGen.nextFloat() * sumOfNumerator);
    
            for ( int x = 0 ; x < 4 ; x ++ ) {
                    if (rand < t_array[x])
                        return x == 0 ? (int) base64Table[x].get(rand >> bitCount[x])
                                : (int) base64Table[x].get( ( rand - t_array[x-1] ) >> bitCount[x]) ;
            }
            return 0;
        }
    
        public void setIntProbList( int[] intList, float[] probList ) {
            Map map = normalizeMap( intList, probList );
            populateBase64Table( map );
        }
    
        private void clearBase64Table() {
            for ( int x = 0 ; x < 4 ; x++ ) {
                base64Table[x].clear();
            }
        }
    
        private void populateBase64Table( Map intProbMap ) {
            int startPow, decodedFreq, table_index;
            float rem;
    
            clearBase64Table();
    
            for ( Map.Entry numObj : intProbMap.entrySet() ) {
                rem = numObj.getValue();
                table_index = 3;
                for ( int x = 0 ; x < 4 ; x++ ) {
                    decodedFreq = (int) (rem % 64);
                    rem /= 64;
                    for ( int y = 0 ; y < decodedFreq ; y ++ ) {
                        base64Table[table_index].add( numObj.getKey() );
                    }
                    table_index--;
                }
            }
    
            startPow = 3;
            for ( int x = 0 ; x < 4 ; x++ ) {
                t_array[x] = x == 0 ? (int) ( Math.pow( 64, startPow-- ) * base64Table[x].size() )
                        : ( (int) ( ( Math.pow( 64, startPow-- ) * base64Table[x].size() ) + t_array[x-1] ) );
            }
    
        }
    
        private Map normalizeMap( int[] intList, float[] probList ) {
            Map tmpMap = new HashMap<>();
            Float mappedFloat;
            int numerator;
            float normalizedProb, distSum = 0;
    
            //Remove duplicates, and calculate the sum of non-repeated keys
            for ( int x = 0 ; x < probList.length ; x++ ) {
                mappedFloat = tmpMap.get( intList[x] );
                if ( mappedFloat != null ) {
                    distSum -= mappedFloat;
                } else {
                    distSum += probList[x];
                }
                tmpMap.put( intList[x], probList[x] );
            }
    
            //Normalise the map to key -> corresponding numerator by multiplying with 2^24
            sumOfNumerator = 0;
            for ( Map.Entry intProb : tmpMap.entrySet() ) {
                normalizedProb = intProb.getValue() / distSum;
                numerator = (int) ( normalizedProb * DENOM );
                intProb.setValue( (float) numerator );
                sumOfNumerator += numerator;
            }
    
            return tmpMap;
        }
    }
    

提交回复
热议问题