1 /*
2  * Copyright (C) 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.service.measurement.noising;
18 
19 import com.android.adservices.service.measurement.PrivacyParams;
20 import com.android.internal.annotations.VisibleForTesting;
21 
22 import com.google.common.math.DoubleMath;
23 import com.google.common.math.LongMath;
24 
25 import java.math.BigInteger;
26 import java.util.ArrayList;
27 import java.util.Arrays;
28 import java.util.List;
29 import java.util.Objects;
30 import java.util.function.LongToDoubleFunction;
31 
32 /**
33  * Combinatorics utilities used for randomization.
34  */
35 public class Combinatorics {
36 
37     /**
38      * Returns the k-combination associated with the number {@code combinationIndex}. In
39      * other words, returns the combination of {@code k} integers uniquely indexed by
40      * {@code combinationIndex} in the combinatorial number system.
41      * https://en.wikipedia.org/wiki/Combinatorial_number_system
42      *
43      * @return combinationIndex-th lexicographically smallest k-combination.
44      * @throws ArithmeticException in case of int overflow
45      */
getKCombinationAtIndex(long combinationIndex, int k)46     static long[] getKCombinationAtIndex(long combinationIndex, int k) {
47         // Computes the combinationIndex-th lexicographically smallest k-combination.
48         // https://en.wikipedia.org/wiki/Combinatorial_number_system
49         //
50         // A k-combination is a sequence of k non-negative integers in decreasing order.
51         // a_k > a_{k-1} > ... > a_2 > a_1 >= 0.
52         // k-combinations can be ordered lexicographically, with the smallest
53         // k-combination being a_k=k-1, a_{k-1}=k-2, .., a_1=0. Given an index
54         // combinationIndex>=0, and an order k, this method returns the
55         // combinationIndex-th smallest k-combination.
56         //
57         // Given an index combinationIndex, the combinationIndex-th k-combination
58         // is the unique set of k non-negative integers
59         // a_k > a_{k-1} > ... > a_2 > a_1 >= 0
60         // such that combinationIndex = \sum_{i=1}^k {a_i}\choose{i}
61         //
62         // We find this set via a simple greedy algorithm.
63         // http://math0.wvstateu.edu/~baker/cs405/code/Combinadics.html
64         long[] result = new long[k];
65         if (k == 0) {
66             return result;
67         }
68         // To find a_k, iterate candidates upwards from 0 until we've found the
69         // maximum a such that (a choose k) <= combinationIndex. Let a_k = a. Use
70         // the previous binomial coefficient to compute the next one. Note: possible
71         // to speed this up via something other than incremental search.
72         long target = combinationIndex;
73         long candidate = (long) k - 1L;
74         long binomialCoefficient = 0L;
75         long nextBinomialCoefficient = 1L;
76         while (nextBinomialCoefficient <= target) {
77             candidate++;
78             binomialCoefficient = nextBinomialCoefficient;
79             // (n + 1 choose k) = (n choose k) * (n + 1) / (n + 1 - k)
80             nextBinomialCoefficient = Math.multiplyExact(binomialCoefficient, (candidate + 1));
81             nextBinomialCoefficient /= candidate + 1 - k;
82         }
83         // We know from the k-combination definition, all subsequent values will be
84         // strictly decreasing. Find them all by decrementing candidate.
85         // Use the previous binomial coefficient to compute the next one.
86         long currentK = (long) k;
87         int currentIndex = 0;
88         while (true) {
89             if (binomialCoefficient <= target) {
90                 result[currentIndex] = candidate;
91                 currentIndex++;
92                 target -= binomialCoefficient;
93                 if (currentIndex == k) {
94                     return result;
95                 }
96                 // (n - 1 choose k - 1) = (n choose k) * k / n
97                 binomialCoefficient = binomialCoefficient * currentK / candidate;
98                 currentK--;
99             } else {
100                 // (n - 1 choose k) = (n choose k) * (n - k) / n
101                 binomialCoefficient = binomialCoefficient * (candidate - currentK) / candidate;
102             }
103             candidate--;
104         }
105     }
106 
107     /**
108      * Returns the number of possible sequences of "stars and bars" sequences
109      * https://en.wikipedia.org/wiki/Stars_and_bars_(combinatorics),
110      * which is equivalent to (numStars + numBars choose numStars).
111      *
112      * @param numStars number of stars
113      * @param numBars  number of bars
114      * @return number of possible sequences
115      */
getNumberOfStarsAndBarsSequences(int numStars, int numBars)116     public static long getNumberOfStarsAndBarsSequences(int numStars, int numBars) {
117         // Note, LongMath::binomial returns Long.MAX_VALUE rather than overflow.
118         return LongMath.binomial(numStars + numBars, numStars);
119     }
120 
121     /**
122      * Returns an array of the indices of every star in the stars and bars sequence indexed by
123      * {@code sequenceIndex}.
124      *
125      * @param numStars number of stars in the sequence
126      * @param sequenceIndex index of the sequence
127      * @return list of indices of every star in stars & bars sequence
128      */
getStarIndices(int numStars, long sequenceIndex)129     public static long[] getStarIndices(int numStars, long sequenceIndex) {
130         return getKCombinationAtIndex(sequenceIndex, numStars);
131     }
132 
133     /**
134      * From an array with the index of every star in a stars and bars sequence, returns an array
135      * which, for every star, counts the number of bars preceding it.
136      *
137      * @param starIndices indices of the stars in descending order
138      * @return count of bars preceding every star
139      */
getBarsPrecedingEachStar(long[] starIndices)140     public static long[] getBarsPrecedingEachStar(long[] starIndices) {
141         for (int i = 0; i < starIndices.length; i++) {
142             long starIndex = starIndices[i];
143             // There are {@code starIndex} prior positions in the sequence, and `i` prior
144             // stars, so there are {@code starIndex - i} prior bars.
145             starIndices[i] = starIndex - ((long) starIndices.length - 1L - (long) i);
146         }
147         return starIndices;
148     }
149 
150     /**
151      * Compute number of states from the trigger specification
152      *
153      * @param numBucketIncrements number of bucket increments (equivalent to number of triggers)
154      * @param numTriggerData number of trigger data. (equivalent to number of metadata)
155      * @param numWindows number of reporting windows
156      * @return number of states
157      */
getNumStatesArithmetic( int numBucketIncrements, int numTriggerData, int numWindows)158     public static long getNumStatesArithmetic(
159             int numBucketIncrements, int numTriggerData, int numWindows) {
160         int numStars = numBucketIncrements;
161         int numBars = Math.multiplyExact(numTriggerData, numWindows);
162         return getNumberOfStarsAndBarsSequences(numStars, numBars);
163     }
164 
165     /**
166      * Using dynamic programming to compute number of states. Returns Long.MAX_VALUE if the result
167      * is greater than {@code bound}.
168      *
169      * @param totalCap total incremental cap
170      * @param perTypeNumWindowList reporting window per trigger data
171      * @param perTypeCapList cap per trigger data
172      * @param bound the highest state count allowed
173      * @return number of states
174      * @throws ArithmeticException in case of long overflow
175      */
getNumStatesIterative( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)176     private static long getNumStatesIterative(
177             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) {
178         // Assumes perTypeCapList cannot sum to more than int value. Overflowing int here can lead
179         // to an exception when declaring the array size later, based on the min value.
180         int sum = 0;
181         for (int cap : perTypeCapList) {
182             sum += cap;
183         }
184         int leastTotalCap = Math.min(totalCap, sum);
185         long[][] dp = new long[2][leastTotalCap + 1];
186         int prev = 0;
187         int curr = 1;
188 
189         dp[prev][0] = 1L;
190         long result = 0L;
191 
192         for (int i = 0; i < perTypeNumWindowList.length && perTypeNumWindowList[i] > 0; i++) {
193             int winCount = perTypeNumWindowList[i];
194             int capCount = perTypeCapList[i];
195             result = 0L;
196 
197             for (int cap = 0; cap < leastTotalCap + 1; cap++) {
198                 dp[curr][cap] = 0L;
199 
200                 for (int capVal = 0; capVal < Math.min(cap, capCount) + 1; capVal++) {
201                     dp[curr][cap] = Math.addExact(
202                             dp[curr][cap],
203                             Math.multiplyExact(
204                                     dp[prev][cap - capVal],
205                                     getNumberOfStarsAndBarsSequences(capVal, winCount - 1)));
206                 }
207 
208                 result = Math.addExact(result, dp[curr][cap]);
209 
210                 if (result > bound) {
211                     return Long.MAX_VALUE;
212                 }
213             }
214 
215             curr ^= 1;
216             prev ^= 1;
217         }
218 
219         return Math.max(result, 1L);
220     }
221 
222     /**
223      * Compute number of states for flexible event report API. Returns Long.MAX_VALUE if the result
224      * exceeds {@code bound}.
225      *
226      * @param totalCap number of total increments
227      * @param perTypeNumWindowList reporting window for each trigger data
228      * @param perTypeCapList limit of the increment of each trigger data
229      * @param bound the highest state count allowed
230      * @return number of states
231      * @throws ArithmeticException in case of long overflow during the iterative procedure
232      */
getNumStatesFlexApi( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound)233     public static long getNumStatesFlexApi(
234             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long bound) {
235         if (perTypeNumWindowList.length == 0 || perTypeCapList.length == 0) {
236             return 1;
237         }
238         for (int i = 1; i < perTypeNumWindowList.length; i++) {
239             if (perTypeNumWindowList[i] != perTypeNumWindowList[i - 1]) {
240                 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound);
241             }
242         }
243         for (int n : perTypeCapList) {
244             if (n < totalCap) {
245                 return getNumStatesIterative(totalCap, perTypeNumWindowList, perTypeCapList, bound);
246             }
247         }
248 
249         long result = getNumStatesArithmetic(
250                 totalCap, perTypeCapList.length, perTypeNumWindowList[0]);
251 
252         return result > bound ? Long.MAX_VALUE : result;
253     }
254 
255     /**
256      * @param numOfStates Number of States
257      * @return the probability to use fake reports
258      */
getFlipProbability(long numOfStates, double privacyEpsilon)259     public static double getFlipProbability(long numOfStates, double privacyEpsilon) {
260         return (numOfStates) / (numOfStates + Math.exp(privacyEpsilon) - 1D);
261     }
262 
getBinaryEntropy(double x)263     private static double getBinaryEntropy(double x) {
264         if (DoubleMath.fuzzyEquals(x, 0.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)
265                 || DoubleMath.fuzzyEquals(x, 1.0d, PrivacyParams.NUMBER_EQUAL_THRESHOLD)) {
266             return 0.0D;
267         }
268         return (-1.0D) * x * DoubleMath.log2(x) - (1 - x) * DoubleMath.log2(1 - x);
269     }
270 
271     /**
272      * @param numOfStates Number of States
273      * @param flipProbability Flip Probability
274      * @return the information gain
275      */
getInformationGain(long numOfStates, double flipProbability)276     public static double getInformationGain(long numOfStates, double flipProbability) {
277         if (numOfStates <= 1L) {
278             return 0d;
279         }
280         double log2Q = DoubleMath.log2(numOfStates);
281         double fakeProbability = flipProbability * (numOfStates - 1L) / numOfStates;
282         return log2Q
283                 - getBinaryEntropy(fakeProbability)
284                 - fakeProbability * DoubleMath.log2(numOfStates - 1);
285     }
286 
getFakeProbability( long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon)287     private static double getFakeProbability(
288             long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon) {
289         double pickRateForSource =
290                 getFlipProbability(numOfStates, privacyEpsilon) * (numOfStates - 1L) / numOfStates;
291         double pickRateForEvent =
292                 getFlipProbability(numEventStates, privacyEpsilon)
293                         * (numEventStates - 1L)
294                         / numEventStates;
295         return 1 - (1 - pickRateForSource) * Math.pow(1 - pickRateForEvent, (double) numUsedScopes);
296     }
297 
298     @VisibleForTesting
calculateInformationGainWithAttributionScope( long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon)299     static double calculateInformationGainWithAttributionScope(
300             long numOfStates, long numUsedScopes, long numEventStates, double privacyEpsilon) {
301         BigInteger totalNumStates =
302                 BigInteger.valueOf(numOfStates)
303                         .add(
304                                 BigInteger.valueOf(numEventStates)
305                                         .multiply(BigInteger.valueOf(numUsedScopes)));
306         if (totalNumStates.compareTo(BigInteger.ONE) <= 0) {
307             return 0d;
308         }
309         double log2Q = DoubleMath.log2(totalNumStates.doubleValue());
310         double fakeProbability =
311                 getFakeProbability(numOfStates, numUsedScopes, numEventStates, privacyEpsilon);
312         return log2Q
313                 - getBinaryEntropy(fakeProbability)
314                 - fakeProbability * DoubleMath.log2(totalNumStates.doubleValue() - 1);
315     }
316 
317     /**
318      * Returns the max information gain given the num of trigger states, attribution scope limit and
319      * max num event states.
320      *
321      * @param numTriggerStates The number of trigger states.
322      * @param attributionScopeLimit The attribution scope limit.
323      * @param maxEventStates The maximum number of event states (expected to be positive).
324      * @return The max information gain.
325      */
getMaxInformationGainWithAttributionScope( long numTriggerStates, long attributionScopeLimit, long maxEventStates, double privacyEpsilon)326     public static double getMaxInformationGainWithAttributionScope(
327             long numTriggerStates,
328             long attributionScopeLimit,
329             long maxEventStates,
330             double privacyEpsilon) {
331         if (numTriggerStates <= 0 || maxEventStates <= 0) {
332             throw new IllegalArgumentException(
333                     "numTriggerStates and maxEventStates must be positive");
334         }
335         double maxInformationGain = 0;
336         // Choosing the smaller dimension for iteration.
337         if (attributionScopeLimit > maxEventStates) {
338             long start = 0;
339             long end = attributionScopeLimit - 1;
340             for (long numEventStates = 1; numEventStates <= maxEventStates; ++numEventStates) {
341                 final long currentNumEventStates =
342                         numEventStates; // Make a final copy of the variable
343                 LongToDoubleFunction infoGainFunction =
344                         (numUsedScopes) ->
345                                 calculateInformationGainWithAttributionScope(
346                                         numTriggerStates,
347                                         numUsedScopes,
348                                         currentNumEventStates,
349                                         privacyEpsilon);
350                 maxInformationGain =
351                         Math.max(
352                                 maxInformationGain,
353                                 findMaxValueUniModal(start, end, infoGainFunction));
354             }
355         } else {
356             long start = 1;
357             long end = maxEventStates;
358             for (long numUsedScopes = 0; numUsedScopes < attributionScopeLimit; ++numUsedScopes) {
359                 final long currentNumUsedScopes =
360                         numUsedScopes; // Make a final copy of the variable
361                 LongToDoubleFunction infoGainFunction =
362                         (numEventStates) ->
363                                 calculateInformationGainWithAttributionScope(
364                                         numTriggerStates,
365                                         currentNumUsedScopes,
366                                         numEventStates,
367                                         privacyEpsilon);
368                 maxInformationGain =
369                         Math.max(
370                                 maxInformationGain,
371                                 findMaxValueUniModal(start, end, infoGainFunction));
372             }
373         }
374         return maxInformationGain;
375     }
376 
377     /**
378      * Generate fake report set given a trigger specification and the rank order number
379      *
380      * @param totalCap total_cap
381      * @param perTypeNumWindowList per type number of window list
382      * @param perTypeCapList per type cap list
383      * @param rank the rank of the report state within all the report states
384      * @return a report set based on the input rank
385      */
getReportSetBasedOnRank( int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank)386     public static List<AtomReportState> getReportSetBasedOnRank(
387             int totalCap, int[] perTypeNumWindowList, int[] perTypeCapList, long rank) {
388         int triggerTypeIndex = perTypeNumWindowList.length - 1;
389         return getReportSetBasedOnRankRecursive(
390                 totalCap,
391                 triggerTypeIndex,
392                 perTypeNumWindowList[triggerTypeIndex],
393                 perTypeCapList[triggerTypeIndex],
394                 rank,
395                 perTypeNumWindowList,
396                 perTypeCapList);
397     }
398 
399     // Function to find the maximum value of a function f(x) where f(x) satisfies the following
400     // condition: for some value m, it is strictly increasing for x ≤ m and strictly decreasing
401     // for x ≥ m.
findMaxValueUniModal(long start, long end, LongToDoubleFunction f)402     private static double findMaxValueUniModal(long start, long end, LongToDoubleFunction f) {
403         long left = start;
404         long right = end;
405 
406         while (left < right) {
407             long mid = left + (right - left) / 2;
408 
409             // Calculate f(mid) and f(mid + 1).
410             double fMid = f.applyAsDouble(mid);
411             double fMidPlus1 = f.applyAsDouble(mid + 1);
412 
413             // If f(mid) < f(mid + 1), then the maximum value is to the right of mid.
414             if (fMid < fMidPlus1) {
415                 left = mid + 1;
416             } else {
417                 // If f(mid) >= f(mid + 1), then the maximum value is to the left of or at mid.
418                 // In cases where f(mid) = f(mid + 1) due to precision loss of double
419                 // and info gain is effectively 0, which also means we've passed the peak so
420                 // continue searching left.
421                 right = mid;
422             }
423         }
424         // At the end of the loop, left and right will converge to the maximum value of f(x).
425         return f.applyAsDouble(left);
426     }
427 
getReportSetBasedOnRankRecursive( int totalCap, int triggerTypeIndex, int winVal, int capVal, long rank, int[] perTypeNumWindowList, int[] perTypeCapList)428     private static List<AtomReportState> getReportSetBasedOnRankRecursive(
429             int totalCap,
430             int triggerTypeIndex,
431             int winVal,
432             int capVal,
433             long rank,
434             int[] perTypeNumWindowList,
435             int[] perTypeCapList) {
436 
437         if (winVal == 0 && triggerTypeIndex == 0) {
438             return new ArrayList<>();
439         } else if (winVal == 0) {
440             return getReportSetBasedOnRankRecursive(
441                     totalCap,
442                     triggerTypeIndex - 1,
443                     perTypeNumWindowList[triggerTypeIndex - 1],
444                     perTypeCapList[triggerTypeIndex - 1],
445                     rank,
446                     perTypeNumWindowList,
447                     perTypeCapList);
448         }
449         for (int i = 0; i <= Math.min(totalCap, capVal); i++) {
450             int[] perTypeNumWindowListClone = Arrays.copyOfRange(
451                     perTypeNumWindowList, 0, triggerTypeIndex + 1);
452             perTypeNumWindowListClone[triggerTypeIndex] = winVal - 1;
453             int[] perTypeCapListClone = Arrays.copyOfRange(
454                     perTypeCapList, 0, triggerTypeIndex + 1);
455             perTypeCapListClone[triggerTypeIndex] = capVal - i;
456             long currentNumStates =
457                     getNumStatesIterative(
458                             totalCap - i,
459                             perTypeNumWindowListClone,
460                             perTypeCapListClone,
461                             Long.MAX_VALUE);
462             if (currentNumStates > rank) {
463                 // The triggers to be appended.
464                 List<AtomReportState> toAppend = new ArrayList<>();
465                 for (int k = 0; k < i; k++) {
466                     toAppend.add(new AtomReportState(triggerTypeIndex, winVal - 1));
467                 }
468                 List<AtomReportState> otherReports =
469                         getReportSetBasedOnRankRecursive(
470                                 totalCap - i,
471                                 triggerTypeIndex,
472                                 winVal - 1,
473                                 capVal - i,
474                                 rank,
475                                 perTypeNumWindowList,
476                                 perTypeCapList);
477                 toAppend.addAll(otherReports);
478                 return toAppend;
479             } else {
480                 rank -= currentNumStates;
481             }
482         }
483         // will not reach here
484         return new ArrayList<>();
485     }
486 
487     /** A single report including triggerDataType and window index for the fake report generation */
488     public static class AtomReportState {
489         private final int mTriggerDataType;
490         private final int mWindowIndex;
491 
AtomReportState(int triggerDataType, int windowIndex)492         public AtomReportState(int triggerDataType, int windowIndex) {
493             this.mTriggerDataType = triggerDataType;
494             this.mWindowIndex = windowIndex;
495         }
496 
getTriggerDataType()497         public int getTriggerDataType() {
498             return mTriggerDataType;
499         }
500         ;
501 
getWindowIndex()502         public final int getWindowIndex() {
503             return mWindowIndex;
504         }
505         ;
506 
507         @Override
equals(Object obj)508         public boolean equals(Object obj) {
509             if (!(obj instanceof AtomReportState)) {
510                 return false;
511             }
512             AtomReportState t = (AtomReportState) obj;
513             return mTriggerDataType == t.mTriggerDataType && mWindowIndex == t.mWindowIndex;
514         }
515 
516         @Override
hashCode()517         public int hashCode() {
518             return Objects.hash(mWindowIndex, mTriggerDataType);
519         }
520     }
521 }
522