1 /* 2 * Copyright (C) 2012 The Guava Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except 5 * in compliance with the License. You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software distributed under the License 10 * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express 11 * or implied. See the License for the specific language governing permissions and limitations under 12 * the License. 13 */ 14 15 package com.google.common.math; 16 17 import static com.google.common.base.Preconditions.checkState; 18 import static com.google.common.primitives.Doubles.isFinite; 19 import static java.lang.Double.NaN; 20 import static java.lang.Double.isNaN; 21 22 import com.google.common.annotations.Beta; 23 import com.google.common.annotations.GwtIncompatible; 24 import com.google.common.primitives.Doubles; 25 26 /** 27 * A mutable object which accumulates paired double values (e.g. points on a plane) and tracks some 28 * basic statistics over all the values added so far. This class is not thread safe. 29 * 30 * @author Pete Gillin 31 * @since 20.0 32 */ 33 @Beta 34 @GwtIncompatible 35 public final class PairedStatsAccumulator { 36 37 // These fields must satisfy the requirements of PairedStats' constructor as well as those of the 38 // stat methods of this class. 39 private final StatsAccumulator xStats = new StatsAccumulator(); 40 private final StatsAccumulator yStats = new StatsAccumulator(); 41 private double sumOfProductsOfDeltas = 0.0; 42 43 /** Adds the given pair of values to the dataset. */ add(double x, double y)44 public void add(double x, double y) { 45 // We extend the recursive expression for the one-variable case at Art of Computer Programming 46 // vol. 2, Knuth, 4.2.2, (16) to the two-variable case. We have two value series x_i and y_i. 47 // We define the arithmetic means X_n = 1/n \sum_{i=1}^n x_i, and Y_n = 1/n \sum_{i=1}^n y_i. 48 // We also define the sum of the products of the differences from the means 49 // C_n = \sum_{i=1}^n x_i y_i - n X_n Y_n 50 // for all n >= 1. Then for all n > 1: 51 // C_{n-1} = \sum_{i=1}^{n-1} x_i y_i - (n-1) X_{n-1} Y_{n-1} 52 // C_n - C_{n-1} = x_n y_n - n X_n Y_n + (n-1) X_{n-1} Y_{n-1} 53 // = x_n y_n - X_n [ y_n + (n-1) Y_{n-1} ] + [ n X_n - x_n ] Y_{n-1} 54 // = x_n y_n - X_n y_n - x_n Y_{n-1} + X_n Y_{n-1} 55 // = (x_n - X_n) (y_n - Y_{n-1}) 56 xStats.add(x); 57 if (isFinite(x) && isFinite(y)) { 58 if (xStats.count() > 1) { 59 sumOfProductsOfDeltas += (x - xStats.mean()) * (y - yStats.mean()); 60 } 61 } else { 62 sumOfProductsOfDeltas = NaN; 63 } 64 yStats.add(y); 65 } 66 67 /** 68 * Adds the given statistics to the dataset, as if the individual values used to compute the 69 * statistics had been added directly. 70 */ addAll(PairedStats values)71 public void addAll(PairedStats values) { 72 if (values.count() == 0) { 73 return; 74 } 75 76 xStats.addAll(values.xStats()); 77 if (yStats.count() == 0) { 78 sumOfProductsOfDeltas = values.sumOfProductsOfDeltas(); 79 } else { 80 // This is a generalized version of the calculation in add(double, double) above. Note that 81 // non-finite inputs will have sumOfProductsOfDeltas = NaN, so non-finite values will result 82 // in NaN naturally. 83 sumOfProductsOfDeltas += 84 values.sumOfProductsOfDeltas() 85 + (values.xStats().mean() - xStats.mean()) 86 * (values.yStats().mean() - yStats.mean()) 87 * values.count(); 88 } 89 yStats.addAll(values.yStats()); 90 } 91 92 /** Returns an immutable snapshot of the current statistics. */ snapshot()93 public PairedStats snapshot() { 94 return new PairedStats(xStats.snapshot(), yStats.snapshot(), sumOfProductsOfDeltas); 95 } 96 97 /** Returns the number of pairs in the dataset. */ count()98 public long count() { 99 return xStats.count(); 100 } 101 102 /** Returns an immutable snapshot of the statistics on the {@code x} values alone. */ xStats()103 public Stats xStats() { 104 return xStats.snapshot(); 105 } 106 107 /** Returns an immutable snapshot of the statistics on the {@code y} values alone. */ yStats()108 public Stats yStats() { 109 return yStats.snapshot(); 110 } 111 112 /** 113 * Returns the population covariance of the values. The count must be non-zero. 114 * 115 * <p>This is guaranteed to return zero if the dataset contains a single pair of finite values. It 116 * is not guaranteed to return zero when the dataset consists of the same pair of values multiple 117 * times, due to numerical errors. 118 * 119 * <h3>Non-finite values</h3> 120 * 121 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, {@link 122 * Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link Double#NaN}. 123 * 124 * @throws IllegalStateException if the dataset is empty 125 */ populationCovariance()126 public double populationCovariance() { 127 checkState(count() != 0); 128 return sumOfProductsOfDeltas / count(); 129 } 130 131 /** 132 * Returns the sample covariance of the values. The count must be greater than one. 133 * 134 * <p>This is not guaranteed to return zero when the dataset consists of the same pair of values 135 * multiple times, due to numerical errors. 136 * 137 * <h3>Non-finite values</h3> 138 * 139 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, {@link 140 * Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link Double#NaN}. 141 * 142 * @throws IllegalStateException if the dataset is empty or contains a single pair of values 143 */ sampleCovariance()144 public final double sampleCovariance() { 145 checkState(count() > 1); 146 return sumOfProductsOfDeltas / (count() - 1); 147 } 148 149 /** 150 * Returns the <a href="http://mathworld.wolfram.com/CorrelationCoefficient.html">Pearson's or 151 * product-moment correlation coefficient</a> of the values. The count must greater than one, and 152 * the {@code x} and {@code y} values must both have non-zero population variance (i.e. {@code 153 * xStats().populationVariance() > 0.0 && yStats().populationVariance() > 0.0}). The result is not 154 * guaranteed to be exactly +/-1 even when the data are perfectly (anti-)correlated, due to 155 * numerical errors. However, it is guaranteed to be in the inclusive range [-1, +1]. 156 * 157 * <h3>Non-finite values</h3> 158 * 159 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, {@link 160 * Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link Double#NaN}. 161 * 162 * @throws IllegalStateException if the dataset is empty or contains a single pair of values, or 163 * either the {@code x} and {@code y} dataset has zero population variance 164 */ pearsonsCorrelationCoefficient()165 public final double pearsonsCorrelationCoefficient() { 166 checkState(count() > 1); 167 if (isNaN(sumOfProductsOfDeltas)) { 168 return NaN; 169 } 170 double xSumOfSquaresOfDeltas = xStats.sumOfSquaresOfDeltas(); 171 double ySumOfSquaresOfDeltas = yStats.sumOfSquaresOfDeltas(); 172 checkState(xSumOfSquaresOfDeltas > 0.0); 173 checkState(ySumOfSquaresOfDeltas > 0.0); 174 // The product of two positive numbers can be zero if the multiplication underflowed. We 175 // force a positive value by effectively rounding up to MIN_VALUE. 176 double productOfSumsOfSquaresOfDeltas = 177 ensurePositive(xSumOfSquaresOfDeltas * ySumOfSquaresOfDeltas); 178 return ensureInUnitRange(sumOfProductsOfDeltas / Math.sqrt(productOfSumsOfSquaresOfDeltas)); 179 } 180 181 /** 182 * Returns a linear transformation giving the best fit to the data according to <a 183 * href="http://mathworld.wolfram.com/LeastSquaresFitting.html">Ordinary Least Squares linear 184 * regression</a> of {@code y} as a function of {@code x}. The count must be greater than one, and 185 * either the {@code x} or {@code y} data must have a non-zero population variance (i.e. {@code 186 * xStats().populationVariance() > 0.0 || yStats().populationVariance() > 0.0}). The result is 187 * guaranteed to be horizontal if there is variance in the {@code x} data but not the {@code y} 188 * data, and vertical if there is variance in the {@code y} data but not the {@code x} data. 189 * 190 * <p>This fit minimizes the root-mean-square error in {@code y} as a function of {@code x}. This 191 * error is defined as the square root of the mean of the squares of the differences between the 192 * actual {@code y} values of the data and the values predicted by the fit for the {@code x} 193 * values (i.e. it is the square root of the mean of the squares of the vertical distances between 194 * the data points and the best fit line). For this fit, this error is a fraction {@code sqrt(1 - 195 * R*R)} of the population standard deviation of {@code y}, where {@code R} is the Pearson's 196 * correlation coefficient (as given by {@link #pearsonsCorrelationCoefficient()}). 197 * 198 * <p>The corresponding root-mean-square error in {@code x} as a function of {@code y} is a 199 * fraction {@code sqrt(1/(R*R) - 1)} of the population standard deviation of {@code x}. This fit 200 * does not normally minimize that error: to do that, you should swap the roles of {@code x} and 201 * {@code y}. 202 * 203 * <h3>Non-finite values</h3> 204 * 205 * <p>If the dataset contains any non-finite values ({@link Double#POSITIVE_INFINITY}, {@link 206 * Double#NEGATIVE_INFINITY}, or {@link Double#NaN}) then the result is {@link 207 * LinearTransformation#forNaN()}. 208 * 209 * @throws IllegalStateException if the dataset is empty or contains a single pair of values, or 210 * both the {@code x} and {@code y} dataset have zero population variance 211 */ leastSquaresFit()212 public final LinearTransformation leastSquaresFit() { 213 checkState(count() > 1); 214 if (isNaN(sumOfProductsOfDeltas)) { 215 return LinearTransformation.forNaN(); 216 } 217 double xSumOfSquaresOfDeltas = xStats.sumOfSquaresOfDeltas(); 218 if (xSumOfSquaresOfDeltas > 0.0) { 219 if (yStats.sumOfSquaresOfDeltas() > 0.0) { 220 return LinearTransformation.mapping(xStats.mean(), yStats.mean()) 221 .withSlope(sumOfProductsOfDeltas / xSumOfSquaresOfDeltas); 222 } else { 223 return LinearTransformation.horizontal(yStats.mean()); 224 } 225 } else { 226 checkState(yStats.sumOfSquaresOfDeltas() > 0.0); 227 return LinearTransformation.vertical(xStats.mean()); 228 } 229 } 230 ensurePositive(double value)231 private double ensurePositive(double value) { 232 if (value > 0.0) { 233 return value; 234 } else { 235 return Double.MIN_VALUE; 236 } 237 } 238 ensureInUnitRange(double value)239 private static double ensureInUnitRange(double value) { 240 return Doubles.constrainToRange(value, -1.0, 1.0); 241 } 242 } 243