1 /* 2 * Copyright (C) 2022 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.adservices.service.measurement.noising; 18 19 import com.android.adservices.service.measurement.PrivacyParams; 20 import com.android.internal.annotations.VisibleForTesting; 21 22 import com.google.common.math.DoubleMath; 23 import com.google.common.math.LongMath; 24 25 import java.math.BigInteger; 26 import java.util.ArrayList; 27 import java.util.Arrays; 28 import java.util.List; 29 import java.util.Objects; 30 import java.util.function.LongToDoubleFunction; 31 32 /** 33 * Combinatorics utilities used for randomization. 34 */ 35 public class Combinatorics { 36 37 /** 38 * Returns the k-combination associated with the number {@code combinationIndex}. In 39 * other words, returns the combination of {@code k} integers uniquely indexed by 40 * {@code combinationIndex} in the combinatorial number system. 41 * https://en.wikipedia.org/wiki/Combinatorial_number_system 42 * 43 * @return combinationIndex-th lexicographically smallest k-combination. 44 * @throws ArithmeticException in case of int overflow 45 */ getKCombinationAtIndex(long combinationIndex, int k)46 static long[] getKCombinationAtIndex(long combinationIndex, int k) { 47 // Computes the combinationIndex-th lexicographically smallest k-combination. 48 // https://en.wikipedia.org/wiki/Combinatorial_number_system 49 // 50 // A k-combination is a sequence of k non-negative integers in decreasing order. 51 // a_k > a_{k-1} > ... > a_2 > a_1 >= 0. 52 // k-combinations can be ordered lexicographically, with the smallest 53 // k-combination being a_k=k-1, a_{k-1}=k-2, .., a_1=0. Given an index 54 // combinationIndex>=0, and an order k, this method returns the 55 // combinationIndex-th smallest k-combination. 56 // 57 // Given an index combinationIndex, the combinationIndex-th k-combination 58 // is the unique set of k non-negative integers 59 // a_k > a_{k-1} > ... > a_2 > a_1 >= 0 60 // such that combinationIndex = \sum_{i=1}^k {a_i}\choose{i} 61 // 62 // We find this set via a simple greedy algorithm. 63 // http://math0.wvstateu.edu/~baker/cs405/code/Combinadics.html 64 long[] result = new long[k]; 65 if (k == 0) { 66 return result; 67 } 68 // To find a_k, iterate candidates upwards from 0 until we've found the 69 // maximum a such that (a choose k) <= combinationIndex. Let a_k = a. Use 70 // the previous binomial coefficient to compute the next one. Note: possible 71 // to speed this up via something other than incremental search. 72 long target = combinationIndex; 73 long candidate = (long) k - 1L; 74 long binomialCoefficient = 0L; 75 long nextBinomialCoefficient = 1L; 76 while (nextBinomialCoefficient <= target) { 77 candidate++; 78 binomialCoefficient = nextBinomialCoefficient; 79 // (n + 1 choose k) = (n choose k) * (n + 1) / (n + 1 - k) 80 nextBinomialCoefficient = Math.multiplyExact(binomialCoefficient, (candidate + 1)); 81 nextBinomialCoefficient /= candidate + 1 - k; 82 } 83 // We know from the k-combination definition, all subsequent values will be 84 // strictly decreasing. Find them all by decrementing candidate. 85 // Use the previous binomial coefficient to compute the next one. 86 long currentK = (long) k; 87 int currentIndex = 0; 88 while (true) { 89 if (binomialCoefficient <= target) { 90 result[currentIndex] = candidate; 91 currentIndex++; 92 target -= binomialCoefficient; 93 if (currentIndex == k) { 94 return result; 95 } 96 // (n - 1 choose k - 1) = (n choose k) * k / n 97 binomialCoefficient = binomialCoefficient * currentK / candidate; 98 currentK--; 99 } else { 100 // (n - 1 choose k) = (n choose k) * (n - k) / n 101 binomialCoefficient = binomialCoefficient * (candidate - currentK) / candidate; 102 } 103 candidate--; 104 } 105 } 106 107 /** 108 * Returns the number of possible sequences of "stars and bars" sequences 109 * https://en.wikipedia.org/wiki/Stars_and_bars_(combinatorics), 110 * which is equivalent to (numStars + numBars choose numStars). 111 * 112 * @param numStars number of stars 113 * @param numBars number of bars 114 * @return number of possible sequences 115 */ getNumberOfStarsAndBarsSequences(int numStars, int numBars)116 public static long getNumberOfStarsAndBarsSequences(int numStars, int numBars) { 117 // Note, LongMath::binomial returns Long.MAX_VALUE rather than overflow. 118 return LongMath.binomial(numStars + numBars, numStars); 119 } 120 121 /** 122 * Returns an array of the indices of every star in the stars and bars sequence indexed by 123 * {@code sequenceIndex}. 124 * 125 * @param numStars number of stars in the sequence 126 * @param sequenceIndex index of the sequence 127 * @return list of indices of every star in stars & bars sequence 128 */ getStarIndices(int numStars, long sequenceIndex)129 public static long[] getStarIndices(int numStars, long sequenceIndex) { 130 return getKCombinationAtIndex(sequenceIndex, numStars); 131 } 132 133 /** 134 * From an array with the index of every star in a stars and bars sequence, returns an array 135 * which, for every star, counts the number of bars preceding it. 136 * 137 * @param starIndices indices of the stars in descending order 138 * @return count of bars preceding every star 139 */ getBarsPrecedingEachStar(long[] starIndices)140 public static long[] getBarsPrecedingEachStar(long[] starIndices) { 141 for (int i = 0; i < starIndices.length; i++) { 142 long starIndex = starIndices[i]; 143 // There are {@code starIndex} prior positions in the sequence, and `i` prior 144 // stars, so there are {@code starIndex - i} prior bars. 145 starIndices[i] = starIndex - ((long) starIndices.length - 1L - (long) i); 146 } 147 return starIndices; 148 } 149 150 /** 151 * Compute number of states from the trigger specification 152 * 153 * @param numBucketIncrements number of bucket increments (equivalent to number of triggers) 154 * @param numTriggerData number of trigger data. (equivalent to number of metadata) 155 * @param numWindows number of reporting windows 156 * @return number of states 157 */ getNumStatesArithmetic( int numBucketIncrements, int numTriggerData, int numWindows)158 public static long getNumStatesArithmetic( 159 int numBucketIncrements, int numTriggerData, int numWindows) { 160 int numStars = numBucketIncrements; 161 int numBars = Math.multiplyExact(numTriggerData, numWindows); 162 return getNumberOfStarsAndBarsSequences(numStars, numBars); 163 } 164 165 /** 166 * Using dynamic programming to compute number of states. Returns Long.MAX_VALUE if the result 167 * is greater than {@code bound}. 168 * 169 * @param totalCap total incremental cap 170 * @param perTypeNumWindowList reporting window per trigger data 171 * @param perTypeCapList cap per trigger data 172 * @param bound the highest state count allowed 173 * @return number of states 174 * @throws ArithmeticException in case of long overflow 175 */ getNumStatesIterative( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)176 private static long getNumStatesIterative( 177 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) { 178 // Assumes perTypeCapList cannot sum to more than int value. Overflowing int here can lead 179 // to an exception when declaring the array size later, based on the min value. 180 int sum = 0; 181 for (int cap : perTypeCapList) { 182 sum += cap; 183 } 184 int leastTotalCap = Math.min(totalCap, sum); 185 long[][] dp = new long[2][leastTotalCap + 1]; 186 int prev = 0; 187 int curr = 1; 188 189 dp[prev][0] = 1L; 190 long result = 0L; 191 192 for (int i = 0; i < perTypeNumWindowList.length && perTypeNumWindowList[i] > 0; i++) { 193 int winCount = perTypeNumWindowList[i]; 194 int capCount = perTypeCapList[i]; 195 result = 0L; 196 197 for (int cap = 0; cap < leastTotalCap + 1; cap++) { 198 dp[curr][cap] = 0L; 199 200 for (int capVal = 0; capVal < Math.min(cap, capCount) + 1; capVal++) { 201 dp[curr][cap] = Math.addExact( 202 dp[curr][cap], 203 Math.multiplyExact( 204 dp[prev][cap - capVal], 205 getNumberOfStarsAndBarsSequences(capVal, winCount - 1))); 206 } 207 208 result = Math.addExact(result, dp[curr][cap]); 209 210 if (result > bound) { 211 return Long.MAX_VALUE; 212 } 213 } 214 215 curr ^= 1; 216 prev ^= 1; 217 } 218 219 return Math.max(result, 1L); 220 } 221 222 /** 223 * Compute number of states for flexible event report API. Returns Long.MAX_VALUE if the result 224 * exceeds {@code bound}. 225 * 226 * @param totalCap number of total increments 227 * @param perTypeNumWindowList reporting window for each trigger data 228 * @param perTypeCapList limit of the increment of each trigger data 229 * @param bound the highest state count allowed 230 * @return number of states 231 * @throws ArithmeticException in case of long overflow during the iterative procedure 232 */ getNumStatesFlexApi( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)233 public static long getNumStatesFlexApi( 234 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) { 235 if (perTypeNumWindowList.length == 0 || perTypeCapList.length == 0) { 236 return 1; 237 } 238 for (int i = 1; i < perTypeNumWindowList.length; i++) { 239 if (perTypeNumWindowList[i] != perTypeNumWindowList[i - 1]) { 240 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound); 241 } 242 } 243 for (int n : perTypeCapList) { 244 if (n < totalCap) { 245 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound); 246 } 247 } 248 249 long result = getNumStatesArithmetic( 250 totalCap, perTypeCapList.length, perTypeNumWindowList[0]); 251 252 return result > bound ? Long.MAX_VALUE : result; 253 } 254 255 /** 256 * @param numOfStates Number of States 257 * @return the probability to use fake reports 258 */ getFlipProbability(long numOfStates, double privacyEpsilon)259 public static double getFlipProbability(long numOfStates, double privacyEpsilon) { 260 return (numOfStates) / (numOfStates + Math.exp(privacyEpsilon) - 1D); 261 } 262 getBinaryEntropy(double x)263 private static double getBinaryEntropy(double x) { 264 if (DoubleMath.fuzzyEquals(x, 0.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD) 265 || DoubleMath.fuzzyEquals(x, 1.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)) { 266 return 0.0D; 267 } 268 return (-1.0D) * x * DoubleMath.log2(x) - (1 - x) * DoubleMath.log2(1 - x); 269 } 270 271 /** 272 * @param numOfStates Number of States 273 * @param flipProbability Flip Probability 274 * @return the information gain 275 */ getInformationGain(long numOfStates, double flipProbability)276 public static double getInformationGain(long numOfStates, double flipProbability) { 277 if (numOfStates <= 1L) { 278 return 0d; 279 } 280 double log2Q = DoubleMath.log2(numOfStates); 281 double fakeProbability = flipProbability * (numOfStates - 1L) / numOfStates; 282 return log2Q 283 - getBinaryEntropy(fakeProbability) 284 - fakeProbability * DoubleMath.log2(numOfStates - 1); 285 } 286 getFakeProbability( long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon)287 private static double getFakeProbability( 288 long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon) { 289 double pickRateForSource = 290 getFlipProbability(numOfStates, privacyEpsilon) * (numOfStates - 1L) / numOfStates; 291 double pickRateForEvent = 292 getFlipProbability(numEventStates, privacyEpsilon) 293 * (numEventStates - 1L) 294 / numEventStates; 295 return 1 - (1 - pickRateForSource) * Math.pow(1 - pickRateForEvent, (double) numUsedScopes); 296 } 297 298 @VisibleForTesting calculateInformationGainWithAttributionScope( long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon)299 static double calculateInformationGainWithAttributionScope( 300 long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon) { 301 BigInteger totalNumStates = 302 BigInteger.valueOf(numOfStates) 303 .add( 304 BigInteger.valueOf(numEventStates) 305 .multiply(BigInteger.valueOf(numUsedScopes))); 306 if (totalNumStates.compareTo(BigInteger.ONE) <= 0) { 307 return 0d; 308 } 309 double log2Q = DoubleMath.log2(totalNumStates.doubleValue()); 310 double fakeProbability = 311 getFakeProbability(numOfStates, numUsedScopes, numEventStates, privacyEpsilon); 312 return log2Q 313 - getBinaryEntropy(fakeProbability) 314 - fakeProbability * DoubleMath.log2(totalNumStates.doubleValue() - 1); 315 } 316 317 /** 318 * Returns the max information gain given the num of trigger states, attribution scope limit and 319 * max num event states. 320 * 321 * @param numTriggerStates The number of trigger states. 322 * @param attributionScopeLimit The attribution scope limit. 323 * @param maxEventStates The maximum number of event states (expected to be positive). 324 * @return The max information gain. 325 */ getMaxInformationGainWithAttributionScope( long numTriggerStates, long attributionScopeLimit, long maxEventStates, double privacyEpsilon)326 public static double getMaxInformationGainWithAttributionScope( 327 long numTriggerStates, 328 long attributionScopeLimit, 329 long maxEventStates, 330 double privacyEpsilon) { 331 if (numTriggerStates <= 0 || maxEventStates <= 0) { 332 throw new IllegalArgumentException( 333 "numTriggerStates and maxEventStates must be positive"); 334 } 335 double maxInformationGain = 0; 336 // Choosing the smaller dimension for iteration. 337 if (attributionScopeLimit > maxEventStates) { 338 long start = 0; 339 long end = attributionScopeLimit - 1; 340 for (long numEventStates = 1; numEventStates <= maxEventStates; ++numEventStates) { 341 final long currentNumEventStates = 342 numEventStates; // Make a final copy of the variable 343 LongToDoubleFunction infoGainFunction = 344 (numUsedScopes) -> 345 calculateInformationGainWithAttributionScope( 346 numTriggerStates, 347 numUsedScopes, 348 currentNumEventStates, 349 privacyEpsilon); 350 maxInformationGain = 351 Math.max( 352 maxInformationGain, 353 findMaxValueUniModal(start, end, infoGainFunction)); 354 } 355 } else { 356 long start = 1; 357 long end = maxEventStates; 358 for (long numUsedScopes = 0; numUsedScopes < attributionScopeLimit; ++numUsedScopes) { 359 final long currentNumUsedScopes = 360 numUsedScopes; // Make a final copy of the variable 361 LongToDoubleFunction infoGainFunction = 362 (numEventStates) -> 363 calculateInformationGainWithAttributionScope( 364 numTriggerStates, 365 currentNumUsedScopes, 366 numEventStates, 367 privacyEpsilon); 368 maxInformationGain = 369 Math.max( 370 maxInformationGain, 371 findMaxValueUniModal(start, end, infoGainFunction)); 372 } 373 } 374 return maxInformationGain; 375 } 376 377 /** 378 * Generate fake report set given a trigger specification and the rank order number 379 * 380 * @param totalCap total_cap 381 * @param perTypeNumWindowList per type number of window list 382 * @param perTypeCapList per type cap list 383 * @param rank the rank of the report state within all the report states 384 * @return a report set based on the input rank 385 */ getReportSetBasedOnRank( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank)386 public static List<AtomReportState> getReportSetBasedOnRank( 387 int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank) { 388 int triggerTypeIndex = perTypeNumWindowList.length - 1; 389 return getReportSetBasedOnRankRecursive( 390 totalCap, 391 triggerTypeIndex, 392 perTypeNumWindowList[triggerTypeIndex], 393 perTypeCapList[triggerTypeIndex], 394 rank, 395 perTypeNumWindowList, 396 perTypeCapList); 397 } 398 399 // Function to find the maximum value of a function f(x) where f(x) satisfies the following 400 // condition: for some value m, it is strictly increasing for x ≤ m and strictly decreasing 401 // for x ≥ m. findMaxValueUniModal(long start, long end, LongToDoubleFunction f)402 private static double findMaxValueUniModal(long start, long end, LongToDoubleFunction f) { 403 long left = start; 404 long right = end; 405 406 while (left < right) { 407 long mid = left + (right - left) / 2; 408 409 // Calculate f(mid) and f(mid + 1). 410 double fMid = f.applyAsDouble(mid); 411 double fMidPlus1 = f.applyAsDouble(mid + 1); 412 413 // If f(mid) < f(mid + 1), then the maximum value is to the right of mid. 414 if (fMid < fMidPlus1) { 415 left = mid + 1; 416 } else { 417 // If f(mid) >= f(mid + 1), then the maximum value is to the left of or at mid. 418 // In cases where f(mid) = f(mid + 1) due to precision loss of double 419 // and info gain is effectively 0, which also means we've passed the peak so 420 // continue searching left. 421 right = mid; 422 } 423 } 424 // At the end of the loop, left and right will converge to the maximum value of f(x). 425 return f.applyAsDouble(left); 426 } 427 getReportSetBasedOnRankRecursive( int totalCap, int triggerTypeIndex, int winVal, int capVal, long rank, int[] perTypeNumWindowList, int[] perTypeCapList)428 private static List<AtomReportState> getReportSetBasedOnRankRecursive( 429 int totalCap, 430 int triggerTypeIndex, 431 int winVal, 432 int capVal, 433 long rank, 434 int[] perTypeNumWindowList, 435 int[] perTypeCapList) { 436 437 if (winVal == 0 && triggerTypeIndex == 0) { 438 return new ArrayList<>(); 439 } else if (winVal == 0) { 440 return getReportSetBasedOnRankRecursive( 441 totalCap, 442 triggerTypeIndex - 1, 443 perTypeNumWindowList[triggerTypeIndex - 1], 444 perTypeCapList[triggerTypeIndex - 1], 445 rank, 446 perTypeNumWindowList, 447 perTypeCapList); 448 } 449 for (int i = 0; i <= Math.min(totalCap, capVal); i++) { 450 int[] perTypeNumWindowListClone = Arrays.copyOfRange( 451 perTypeNumWindowList, 0, triggerTypeIndex + 1); 452 perTypeNumWindowListClone[triggerTypeIndex] = winVal - 1; 453 int[] perTypeCapListClone = Arrays.copyOfRange( 454 perTypeCapList, 0, triggerTypeIndex + 1); 455 perTypeCapListClone[triggerTypeIndex] = capVal - i; 456 long currentNumStates = 457 getNumStatesIterative( 458 totalCap - i, 459 perTypeNumWindowListClone, 460 perTypeCapListClone, 461 Long.MAX_VALUE); 462 if (currentNumStates > rank) { 463 // The triggers to be appended. 464 List<AtomReportState> toAppend = new ArrayList<>(); 465 for (int k = 0; k < i; k++) { 466 toAppend.add(new AtomReportState(triggerTypeIndex, winVal - 1)); 467 } 468 List<AtomReportState> otherReports = 469 getReportSetBasedOnRankRecursive( 470 totalCap - i, 471 triggerTypeIndex, 472 winVal - 1, 473 capVal - i, 474 rank, 475 perTypeNumWindowList, 476 perTypeCapList); 477 toAppend.addAll(otherReports); 478 return toAppend; 479 } else { 480 rank -= currentNumStates; 481 } 482 } 483 // will not reach here 484 return new ArrayList<>(); 485 } 486 487 /** A single report including triggerDataType and window index for the fake report generation */ 488 public static class AtomReportState { 489 private final int mTriggerDataType; 490 private final int mWindowIndex; 491 AtomReportState(int triggerDataType, int windowIndex)492 public AtomReportState(int triggerDataType, int windowIndex) { 493 this.mTriggerDataType = triggerDataType; 494 this.mWindowIndex = windowIndex; 495 } 496 getTriggerDataType()497 public int getTriggerDataType() { 498 return mTriggerDataType; 499 } 500 ; 501 getWindowIndex()502 public final int getWindowIndex() { 503 return mWindowIndex; 504 } 505 ; 506 507 @Override equals(Object obj)508 public boolean equals(Object obj) { 509 if (!(obj instanceof AtomReportState)) { 510 return false; 511 } 512 AtomReportState t = (AtomReportState) obj; 513 return mTriggerDataType == t.mTriggerDataType && mWindowIndex == t.mWindowIndex; 514 } 515 516 @Override hashCode()517 public int hashCode() { 518 return Objects.hash(mWindowIndex, mTriggerDataType); 519 } 520 } 521 } 522