1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.adservices.service.common;
18 
19 import static java.lang.Math.abs;
20 
21 import com.android.internal.util.Preconditions;
22 
23 /** Utility class to perform stochastic rounding. */
24 public final class StochasticRoundingUtil {
25     private static final int MIN_8_BIT_INT_VALUE = -128;
26     private static final int MAX_8_BIT_INT_VALUE = 127;
27 
28     // Number of bits in a double is 64, we can't round down to less than that.
29     private static final int MAX_POSSIBLE_ROUNDING_BITS = 64;
30 
31     private static final String NUM_BITS_MIN_EXCEEDED =
32             "Number of bits to round to must be greater than 0.";
33     private static final String NUM_BITS_MAX_EXCEEDED =
34             "Number of bits to round to must be less than the number of bits in a double.";
35 
StochasticRoundingUtil()36     private StochasticRoundingUtil() throws IllegalAccessException {
37         throw new IllegalAccessException("This class cannot be instantiated!");
38     }
39 
40     /**
41      * Rounds a double stochastically to a given number of bits.
42      *
43      * <p>In Stochastic rounding, we have two options. If the number cannot be represented in {@code
44      * numBits}, we either round the double up or down to a number that can be represented in {@code
45      * numBits}.
46      *
47      * <p>Whether we round up or down is dependent on how far {@code value} is from the closest
48      * smaller and closest larger numbers that can be represented in {@code numBits}.
49      *
50      * <p>Essentially, the closer {@code value} is to the smaller value, the more likely it is to be
51      * rounded down. Conversely, the closer it is to the larger value, the more likely it is to be
52      * rounded up.
53      *
54      * <p>This algorithm can be implemented in many ways, but we have based the implementation on
55      * Chrome's example <a
56      * href="https://source.chromium.org/chromium/chromium/src/+/main:content/browser
57      * /interest_group/interest_group_auction_reporter.cc;l=259;bpv=0;bpt=1">...</a> to maintain
58      * consistency.
59      */
roundStochastically(double value, int numBits)60     public static double roundStochastically(double value, int numBits) {
61         Preconditions.checkArgument(numBits > 0, NUM_BITS_MIN_EXCEEDED);
62         Preconditions.checkArgument(numBits <= MAX_POSSIBLE_ROUNDING_BITS, NUM_BITS_MAX_EXCEEDED);
63 
64         if (!Double.isFinite(value)) {
65             return value;
66         }
67         int exponent = log2(abs(value));
68         double mantissa = value * Math.pow(2.0, -exponent);
69 
70         if (exponent < MIN_8_BIT_INT_VALUE) {
71             return Math.copySign(0, value);
72         }
73 
74         if (exponent > MAX_8_BIT_INT_VALUE) {
75             return Math.copySign(Double.POSITIVE_INFINITY, value);
76         }
77 
78         double precisionScaledValue = ldexp(mantissa, numBits);
79         double noisyScaledValue = precisionScaledValue + (0.5f * Math.random());
80         double truncatedScaledValue = Math.floor(noisyScaledValue);
81 
82         return ldexp(truncatedScaledValue, exponent - numBits);
83     }
84 
85     /**
86      * Mimics the C Library function <a
87      * href="https://www.tutorialspoint.com/c_standard_library/c_function_ldexp.htm">...</a>.
88      */
ldexp(double x, int exponent)89     private static double ldexp(double x, int exponent) {
90         return x * Math.pow(2, exponent);
91     }
92 
log2(double x)93     private static int log2(double x) {
94         return (int) (Math.log(x) / Math.log(2));
95     }
96 }
97