Java Streams - Standard Deviation

后端 未结 2 593
余生分开走
余生分开走 2020-12-30 10:18

I wish to clarify upfront I am looking for a way to calculate Standard deviation using Streams (I have a working method at present which calculates & returns SD but with

2条回答
  •  旧巷少年郎
    2020-12-30 10:43

    You can use a custom collector for this task that calculates a sum of square. The buit-in DoubleSummaryStatistics collector does not keep track of it. This was discussed by the expert group in this thread but finally not implemented. The difficulty when calculating the sum of squares is the potential overflow when squaring the intermediate results.

    static class DoubleStatistics extends DoubleSummaryStatistics {
    
        private double sumOfSquare = 0.0d;
        private double sumOfSquareCompensation; // Low order bits of sum
        private double simpleSumOfSquare; // Used to compute right sum for non-finite inputs
    
        @Override
        public void accept(double value) {
            super.accept(value);
            double squareValue = value * value;
            simpleSumOfSquare += squareValue;
            sumOfSquareWithCompensation(squareValue);
        }
    
        public DoubleStatistics combine(DoubleStatistics other) {
            super.combine(other);
            simpleSumOfSquare += other.simpleSumOfSquare;
            sumOfSquareWithCompensation(other.sumOfSquare);
            sumOfSquareWithCompensation(other.sumOfSquareCompensation);
            return this;
        }
    
        private void sumOfSquareWithCompensation(double value) {
            double tmp = value - sumOfSquareCompensation;
            double velvel = sumOfSquare + tmp; // Little wolf of rounding error
            sumOfSquareCompensation = (velvel - sumOfSquare) - tmp;
            sumOfSquare = velvel;
        }
    
        public double getSumOfSquare() {
            double tmp =  sumOfSquare + sumOfSquareCompensation;
            if (Double.isNaN(tmp) && Double.isInfinite(simpleSumOfSquare)) {
                return simpleSumOfSquare;
            }
            return tmp;
        }
    
        public final double getStandardDeviation() {
            return getCount() > 0 ? Math.sqrt((getSumOfSquare() / getCount()) - Math.pow(getAverage(), 2)) : 0.0d;
        }
    
    }
    

    Then, you can use this class with

    Map standardDeviationMap =
        list.stream()
            .collect(Collectors.groupingBy(
                e -> e.getCar(),
                Collectors.mapping(
                    e -> e.getHigh() - e.getLow(),
                    Collector.of(
                        DoubleStatistics::new,
                        DoubleStatistics::accept,
                        DoubleStatistics::combine,
                        d -> d.getStandardDeviation()
                    )
                )
            ));
    

    This will collect the input list into a map where the values corresponds to the standard deviation of high - low for the same key.

提交回复
热议问题