1 // Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
17
18 #include <cmath>
19 #include <memory>
20 #include <vector>
21
22 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
23 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
24 #include "tensorflow/core/platform/types.h"
25
26 namespace tensorflow {
27 namespace boosted_trees {
28 namespace quantiles {
29
30 // Class to compute approximate quantiles with error bound guarantees for
31 // weighted data sets.
32 // This implementation is an adaptation of techniques from the following papers:
33 // * (2001) Space-efficient online computation of quantile summaries.
34 // * (2004) Power-conserving computation of order-statistics over
35 // sensor networks.
36 // * (2007) A fast algorithm for approximate quantiles in high speed
37 // data streams.
38 // * (2016) XGBoost: A Scalable Tree Boosting System.
39 //
40 // The key ideas at play are the following:
41 // - Maintain an in-memory multi-level quantile summary in a way to guarantee
42 // a maximum approximation error of eps * W per bucket where W is the total
43 // weight across all points in the input dataset.
44 // - Two base operations are defined: MERGE and COMPRESS. MERGE combines two
45 // summaries guaranteeing a epsNew = max(eps1, eps2). COMPRESS compresses
46 // a summary to b + 1 elements guaranteeing epsNew = epsOld + 1/b.
47 // - b * sizeof(summary entry) must ideally be small enough to fit in an
48 // average CPU L2 cache.
49 // - To distribute this algorithm with maintaining error bounds, we need
50 // the worker-computed summaries to have no more than eps / h error
51 // where h is the height of the distributed computation graph which
52 // is 2 for an MR with no combiner.
53 //
54 // We mainly want to max out IO bw by ensuring we're not compute-bound and
55 // using a reasonable amount of RAM.
56 //
57 // Complexity:
58 // Compute: O(n * log(1/eps * log(eps * n))).
59 // Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the
60 // entire dataset.
61 // An epsilon value of zero would make the algorithm extremely inefficent and
62 // therefore, is disallowed.
63 template <typename ValueType, typename WeightType,
64 typename CompareFn = std::less<ValueType>>
65 class WeightedQuantilesStream {
66 public:
67 using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
68 using BufferEntry = typename Buffer::BufferEntry;
69 using Summary = WeightedQuantilesSummary<ValueType, WeightType, CompareFn>;
70 using SummaryEntry = typename Summary::SummaryEntry;
71
WeightedQuantilesStream(double eps,int64 max_elements)72 explicit WeightedQuantilesStream(double eps, int64 max_elements)
73 : eps_(eps), buffer_(1LL, 2LL), finalized_(false) {
74 // See the class documentation. An epsilon value of zero could cause
75 // perfoamance issues.
76 QCHECK(eps > 0) << "An epsilon value of zero is not allowed.";
77 std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements);
78 buffer_ = Buffer(block_size_, max_elements);
79 summary_levels_.reserve(max_levels_);
80 }
81
82 // Disallow copy and assign but enable move semantics for the stream.
83 WeightedQuantilesStream(const WeightedQuantilesStream& other) = delete;
84 WeightedQuantilesStream& operator=(const WeightedQuantilesStream&) = delete;
85 WeightedQuantilesStream(WeightedQuantilesStream&& other) = default;
86 WeightedQuantilesStream& operator=(WeightedQuantilesStream&& other) = default;
87
88 // Pushes one entry while maintaining approximation error invariants.
PushEntry(const ValueType & value,const WeightType & weight)89 void PushEntry(const ValueType& value, const WeightType& weight) {
90 // Validate state.
91 QCHECK(!finalized_) << "Finalize() already called.";
92
93 // Push element to base buffer.
94 buffer_.PushEntry(value, weight);
95
96 // When compacted buffer is full we need to compress
97 // and push weighted quantile summary up the level chain.
98 if (buffer_.IsFull()) {
99 PushBuffer(buffer_);
100 }
101 }
102
103 // Pushes full buffer while maintaining approximation error invariants.
PushBuffer(Buffer & buffer)104 void PushBuffer(Buffer& buffer) {
105 // Validate state.
106 QCHECK(!finalized_) << "Finalize() already called.";
107
108 // Create local compressed summary and propagate.
109 local_summary_.BuildFromBufferEntries(buffer.GenerateEntryList());
110 local_summary_.Compress(block_size_, eps_);
111 PropagateLocalSummary();
112 }
113
114 // Pushes full summary while maintaining approximation error invariants.
PushSummary(const std::vector<SummaryEntry> & summary)115 void PushSummary(const std::vector<SummaryEntry>& summary) {
116 // Validate state.
117 QCHECK(!finalized_) << "Finalize() already called.";
118
119 // Create local compressed summary and propagate.
120 local_summary_.BuildFromSummaryEntries(summary);
121 local_summary_.Compress(block_size_, eps_);
122 PropagateLocalSummary();
123 }
124
125 // Flushes approximator and finalizes state.
Finalize()126 void Finalize() {
127 // Validate state.
128 QCHECK(!finalized_) << "Finalize() may only be called once.";
129
130 // Flush any remaining buffer elements.
131 PushBuffer(buffer_);
132
133 // Create final merged summary.
134 local_summary_.Clear();
135 for (auto& summary : summary_levels_) {
136 local_summary_.Merge(summary);
137 summary.Clear();
138 }
139 summary_levels_.clear();
140 summary_levels_.shrink_to_fit();
141 finalized_ = true;
142 }
143
144 // Generates requested number of quantiles after finalizing stream.
145 // The returned quantiles can be queried using std::lower_bound to get
146 // the bucket for a given value.
GenerateQuantiles(int64 num_quantiles)147 std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
148 // Validate state.
149 QCHECK(finalized_)
150 << "Finalize() must be called before generating quantiles.";
151 return local_summary_.GenerateQuantiles(num_quantiles);
152 }
153
154 // Generates requested number of boundaries after finalizing stream.
155 // The returned boundaries can be queried using std::lower_bound to get
156 // the bucket for a given value.
157 // The boundaries, while still guaranteeing approximation bounds, don't
158 // necessarily represent the actual quantiles of the distribution.
159 // Boundaries are preferable over quantiles when the caller is less
160 // interested in the actual quantiles distribution and more interested in
161 // getting a representative sample of boundary values.
GenerateBoundaries(int64 num_boundaries)162 std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
163 // Validate state.
164 QCHECK(finalized_)
165 << "Finalize() must be called before generating boundaries.";
166 return local_summary_.GenerateBoundaries(num_boundaries);
167 }
168
169 // Calculates approximation error for the specified level.
170 // If the passed level is negative, the approximation error for the entire
171 // summary is returned. Note that after Finalize is called, only the overall
172 // error is available.
173 WeightType ApproximationError(int64 level = -1) const {
174 if (finalized_) {
175 QCHECK(level <= 0) << "Only overall error is available after Finalize()";
176 return local_summary_.ApproximationError();
177 }
178
179 if (summary_levels_.empty()) {
180 // No error even if base buffer isn't empty.
181 return 0;
182 }
183
184 // If level is negative, we get the approximation error
185 // for the top-most level which is the max approximation error
186 // in all summaries by construction.
187 if (level < 0) {
188 level = summary_levels_.size() - 1;
189 }
190 QCHECK(level < summary_levels_.size()) << "Invalid level.";
191 return summary_levels_[level].ApproximationError();
192 }
193
MaxDepth()194 size_t MaxDepth() const { return summary_levels_.size(); }
195
196 // Generates requested number of quantiles after finalizing stream.
GetFinalSummary()197 const Summary& GetFinalSummary() const {
198 // Validate state.
199 QCHECK(finalized_)
200 << "Finalize() must be called before requesting final summary.";
201 return local_summary_;
202 }
203
204 // Helper method which, given the desired approximation error
205 // and an upper bound on the number of elements, computes the optimal
206 // number of levels and block size and returns them in the tuple.
207 static std::tuple<int64, int64> GetQuantileSpecs(double eps,
208 int64 max_elements);
209
210 // Serializes the internal state of the stream.
SerializeInternalSummaries()211 std::vector<Summary> SerializeInternalSummaries() const {
212 // The buffer should be empty for serialize to work.
213 QCHECK_EQ(buffer_.Size(), 0);
214 std::vector<Summary> result;
215 result.reserve(summary_levels_.size() + 1);
216 for (const Summary& summary : summary_levels_) {
217 result.push_back(summary);
218 }
219 result.push_back(local_summary_);
220 return result;
221 }
222
223 // Resets the state of the stream with a serialized state.
DeserializeInternalSummaries(const std::vector<Summary> & summaries)224 void DeserializeInternalSummaries(const std::vector<Summary>& summaries) {
225 // Clear the state before deserializing.
226 buffer_.Clear();
227 summary_levels_.clear();
228 local_summary_.Clear();
229 QCHECK_GT(max_levels_, summaries.size() - 1);
230 for (int i = 0; i < summaries.size() - 1; ++i) {
231 summary_levels_.push_back(summaries[i]);
232 }
233 local_summary_ = summaries[summaries.size() - 1];
234 }
235
236 private:
237 // Propagates local summary through summary levels while maintaining
238 // approximation error invariants.
PropagateLocalSummary()239 void PropagateLocalSummary() {
240 // Validate state.
241 QCHECK(!finalized_) << "Finalize() already called.";
242
243 // No-op if there's nothing to add.
244 if (local_summary_.Size() <= 0) {
245 return;
246 }
247
248 // Propagate summary through levels.
249 size_t level = 0;
250 for (bool settled = false; !settled; ++level) {
251 // Ensure we have enough depth.
252 if (summary_levels_.size() <= level) {
253 summary_levels_.emplace_back();
254 }
255
256 // Merge summaries.
257 Summary& current_summary = summary_levels_[level];
258 local_summary_.Merge(current_summary);
259
260 // Check if we need to compress and propagate summary higher.
261 if (current_summary.Size() == 0 ||
262 local_summary_.Size() <= block_size_ + 1) {
263 current_summary = std::move(local_summary_);
264 settled = true;
265 } else {
266 // Compress, empty current level and propagate.
267 local_summary_.Compress(block_size_, eps_);
268 current_summary.Clear();
269 }
270 }
271 }
272
273 // Desired approximation precision.
274 double eps_;
275 // Maximum number of levels.
276 int64 max_levels_;
277 // Max block size per level.
278 int64 block_size_;
279 // Base buffer.
280 Buffer buffer_;
281 // Local summary used to minimize memory allocation and cache misses.
282 // After the stream is finalized, this summary holds the final quantile
283 // estimates.
284 Summary local_summary_;
285 // Summary levels;
286 std::vector<Summary> summary_levels_;
287 // Flag indicating whether the stream is finalized.
288 bool finalized_;
289 };
290
291 template <typename ValueType, typename WeightType, typename CompareFn>
292 inline std::tuple<int64, int64>
GetQuantileSpecs(double eps,int64 max_elements)293 WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs(
294 double eps, int64 max_elements) {
295 int64 max_level = 1LL;
296 int64 block_size = 2LL;
297 QCHECK(eps >= 0 && eps < 1);
298 QCHECK_GT(max_elements, 0);
299
300 if (eps <= std::numeric_limits<double>::epsilon()) {
301 // Exact quantile computation at the expense of RAM.
302 max_level = 1;
303 block_size = std::max(max_elements, int64{2});
304 } else {
305 // The bottom-most level will become full at most
306 // (max_elements / block_size) times, the level above will become full
307 // (max_elements / 2 * block_size) times and generally level l becomes
308 // full (max_elements / 2^l * block_size) times until the last
309 // level max_level becomes full at most once meaning when the inequality
310 // (2^max_level * block_size >= max_elements) is satisfied.
311 // In what follows, we jointly solve for max_level and block_size by
312 // gradually increasing the level until the inequality above is satisfied.
313 // We could alternatively set max_level = ceil(log2(eps * max_elements));
314 // and block_size = ceil(max_level / eps) + 1 but that tends to give more
315 // pessimistic bounds and wastes RAM needlessly.
316 for (max_level = 1, block_size = 2;
317 (1LL << max_level) * block_size < max_elements; ++max_level) {
318 // Update upper bound on block size at current level, we always
319 // increase the estimate by 2 to hold the min/max elements seen so far.
320 block_size = static_cast<size_t>(ceil(max_level / eps)) + 1;
321 }
322 }
323 return std::make_tuple(max_level, std::max(block_size, int64{2}));
324 }
325
326 } // namespace quantiles
327 } // namespace boosted_trees
328 } // namespace tensorflow
329
330 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
331