Java Streams - Standard Deviation

霸气de小男生 提交于 2019-11-30 09:22:09

You can use a custom collector for this task that calculates a sum of square. The buit-in DoubleSummaryStatistics collector does not keep track of it. This was discussed by the expert group in this thread but finally not implemented. The difficulty when calculating the sum of squares is the potential overflow when squaring the intermediate results.

static class DoubleStatistics extends DoubleSummaryStatistics {

    private double sumOfSquare = 0.0d;
    private double sumOfSquareCompensation; // Low order bits of sum
    private double simpleSumOfSquare; // Used to compute right sum for non-finite inputs

    @Override
    public void accept(double value) {
        super.accept(value);
        double squareValue = value * value;
        simpleSumOfSquare += squareValue;
        sumOfSquareWithCompensation(squareValue);
    }

    public DoubleStatistics combine(DoubleStatistics other) {
        super.combine(other);
        simpleSumOfSquare += other.simpleSumOfSquare;
        sumOfSquareWithCompensation(other.sumOfSquare);
        sumOfSquareWithCompensation(other.sumOfSquareCompensation);
        return this;
    }

    private void sumOfSquareWithCompensation(double value) {
        double tmp = value - sumOfSquareCompensation;
        double velvel = sumOfSquare + tmp; // Little wolf of rounding error
        sumOfSquareCompensation = (velvel - sumOfSquare) - tmp;
        sumOfSquare = velvel;
    }

    public double getSumOfSquare() {
        double tmp =  sumOfSquare + sumOfSquareCompensation;
        if (Double.isNaN(tmp) && Double.isInfinite(simpleSumOfSquare)) {
            return simpleSumOfSquare;
        }
        return tmp;
    }

    public final double getStandardDeviation() {
        return getCount() > 0 ? Math.sqrt((getSumOfSquare() / getCount()) - Math.pow(getAverage(), 2)) : 0.0d;
    }

}

Then, you can use this class with

Map<String, Double> standardDeviationMap =
    list.stream()
        .collect(Collectors.groupingBy(
            e -> e.getCar(),
            Collectors.mapping(
                e -> e.getHigh() - e.getLow(),
                Collector.of(
                    DoubleStatistics::new,
                    DoubleStatistics::accept,
                    DoubleStatistics::combine,
                    d -> d.getStandardDeviation()
                )
            )
        ));

This will collect the input list into a map where the values corresponds to the standard deviation of high - low for the same key.

You can use this custom Collector :

private static final Collector<Double, double[], Double> VARIANCE_COLLECTOR = Collector.of( // See https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
        () -> new double[3], // {count, mean, M2}
        (acu, d) -> { // See chapter about Welford's online algorithm and https://math.stackexchange.com/questions/198336/how-to-calculate-standard-deviation-with-streaming-inputs
            acu[0]++; // Count
            double delta = d - acu[1];
            acu[1] += delta / acu[0]; // Mean
            acu[2] += delta * (d - acu[1]); // M2
        },
        (acuA, acuB) -> { // See chapter about "Parallel algorithm" : only called if stream is parallel ...
            double delta = acuB[1] - acuA[1];
            double count = acuA[0] + acuB[0];
            acuA[2] = acuA[2] + acuB[2] + delta * delta * acuA[0] * acuB[0] / count; // M2
            acuA[1] += delta * acuB[0] / count;  // Mean
            acuA[0] = count; // Count
            return acuA;
        },
        acu -> acu[2] / (acu[0] - 1.0), // Var = M2 / (count - 1)
        UNORDERED);

Then simply call this collector on your stream :

double stdDev = Math.sqrt(outPut.stream().boxed().collect(VARIANCE_COLLECTOR));
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!