1 // Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
17 
18 #include <cstring>
19 #include <vector>
20 
21 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
22 
23 namespace tensorflow {
24 namespace boosted_trees {
25 namespace quantiles {
26 
27 // Summary holding a sorted block of entries with upper bound guarantees
28 // over the approximation error.
29 template <typename ValueType, typename WeightType,
30           typename CompareFn = std::less<ValueType>>
31 class WeightedQuantilesSummary {
32  public:
33   using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
34   using BufferEntry = typename Buffer::BufferEntry;
35 
36   struct SummaryEntry {
SummaryEntrySummaryEntry37     SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
38                  const WeightType& max) {
39       // Explicitly initialize all of memory (including padding from memory
40       // alignment) to allow the struct to be msan-resistant "plain old data".
41       //
42       // POD = https://en.cppreference.com/w/cpp/named_req/PODType
43       memset(this, 0, sizeof(*this));
44 
45       value = v;
46       weight = w;
47       min_rank = min;
48       max_rank = max;
49     }
50 
SummaryEntrySummaryEntry51     SummaryEntry() {
52       memset(this, 0, sizeof(*this));
53 
54       value = ValueType();
55       weight = 0;
56       min_rank = 0;
57       max_rank = 0;
58     }
59 
60     bool operator==(const SummaryEntry& other) const {
61       return value == other.value && weight == other.weight &&
62              min_rank == other.min_rank && max_rank == other.max_rank;
63     }
64     friend std::ostream& operator<<(std::ostream& strm,
65                                     const SummaryEntry& entry) {
66       return strm << "{" << entry.value << ", " << entry.weight << ", "
67                   << entry.min_rank << ", " << entry.max_rank << "}";
68     }
69 
70     // Max rank estimate for previous smaller value.
PrevMaxRankSummaryEntry71     WeightType PrevMaxRank() const { return max_rank - weight; }
72 
73     // Min rank estimate for next larger value.
NextMinRankSummaryEntry74     WeightType NextMinRank() const { return min_rank + weight; }
75 
76     ValueType value;
77     WeightType weight;
78     WeightType min_rank;
79     WeightType max_rank;
80   };
81 
82   // Re-construct summary from the specified buffer.
BuildFromBufferEntries(const std::vector<BufferEntry> & buffer_entries)83   void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
84     entries_.clear();
85     entries_.reserve(buffer_entries.size());
86     WeightType cumulative_weight = 0;
87     for (const auto& entry : buffer_entries) {
88       WeightType current_weight = entry.weight;
89       entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
90                             cumulative_weight + current_weight);
91       cumulative_weight += current_weight;
92     }
93   }
94 
95   // Re-construct summary from the specified summary entries.
BuildFromSummaryEntries(const std::vector<SummaryEntry> & summary_entries)96   void BuildFromSummaryEntries(
97       const std::vector<SummaryEntry>& summary_entries) {
98     entries_.clear();
99     entries_.reserve(summary_entries.size());
100     entries_.insert(entries_.begin(), summary_entries.begin(),
101                     summary_entries.end());
102   }
103 
104   // Merges two summaries through an algorithm that's derived from MergeSort
105   // for summary entries while guaranteeing that the max approximation error
106   // of the final merged summary is no greater than the approximation errors
107   // of each individual summary.
108   // For example consider summaries where each entry is of the form
109   // (element, weight, min rank, max rank):
110   // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
111   // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
112   // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
Merge(const WeightedQuantilesSummary & other_summary)113   void Merge(const WeightedQuantilesSummary& other_summary) {
114     // Make sure we have something to merge.
115     const auto& other_entries = other_summary.entries_;
116     if (other_entries.empty()) {
117       return;
118     }
119     if (entries_.empty()) {
120       BuildFromSummaryEntries(other_summary.entries_);
121       return;
122     }
123 
124     // Move current entries to make room for a new buffer.
125     std::vector<SummaryEntry> base_entries(std::move(entries_));
126     entries_.clear();
127     entries_.reserve(base_entries.size() + other_entries.size());
128 
129     // Merge entries maintaining ranks. The idea is to stack values
130     // in order which we can do in linear time as the two summaries are
131     // already sorted. We keep track of the next lower rank from either
132     // summary and update it as we pop elements from the summaries.
133     // We handle the special case when the next two elements from either
134     // summary are equal, in which case we just merge the two elements
135     // and simultaneously update both ranks.
136     auto it1 = base_entries.cbegin();
137     auto it2 = other_entries.cbegin();
138     WeightType next_min_rank1 = 0;
139     WeightType next_min_rank2 = 0;
140     while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
141       if (kCompFn(it1->value, it2->value)) {  // value1 < value2
142         // Take value1 and use the last added value2 to compute
143         // the min rank and the current value2 to compute the max rank.
144         entries_.emplace_back(it1->value, it1->weight,
145                               it1->min_rank + next_min_rank2,
146                               it1->max_rank + it2->PrevMaxRank());
147         // Update next min rank 1.
148         next_min_rank1 = it1->NextMinRank();
149         ++it1;
150       } else if (kCompFn(it2->value, it1->value)) {  // value1 > value2
151         // Take value2 and use the last added value1 to compute
152         // the min rank and the current value1 to compute the max rank.
153         entries_.emplace_back(it2->value, it2->weight,
154                               it2->min_rank + next_min_rank1,
155                               it2->max_rank + it1->PrevMaxRank());
156         // Update next min rank 2.
157         next_min_rank2 = it2->NextMinRank();
158         ++it2;
159       } else {  // value1 == value2
160         // Straight additive merger of the two entries into one.
161         entries_.emplace_back(it1->value, it1->weight + it2->weight,
162                               it1->min_rank + it2->min_rank,
163                               it1->max_rank + it2->max_rank);
164         // Update next min ranks for both.
165         next_min_rank1 = it1->NextMinRank();
166         next_min_rank2 = it2->NextMinRank();
167         ++it1;
168         ++it2;
169       }
170     }
171 
172     // Fill in any residual.
173     while (it1 != base_entries.cend()) {
174       entries_.emplace_back(it1->value, it1->weight,
175                             it1->min_rank + next_min_rank2,
176                             it1->max_rank + other_entries.back().max_rank);
177       ++it1;
178     }
179     while (it2 != other_entries.cend()) {
180       entries_.emplace_back(it2->value, it2->weight,
181                             it2->min_rank + next_min_rank1,
182                             it2->max_rank + base_entries.back().max_rank);
183       ++it2;
184     }
185   }
186 
187   // Compresses buffer into desired size. The size specification is
188   // considered a hint as we always keep the first and last elements and
189   // maintain strict approximation error bounds.
190   // The approximation error delta is taken as the max of either the requested
191   // min error or 1 / size_hint.
192   // After compression, the approximation error is guaranteed to increase
193   // by no more than that error delta.
194   // This algorithm is linear in the original size of the summary and is
195   // designed to be cache-friendly.
196   void Compress(int64 size_hint, double min_eps = 0) {
197     // No-op if we're already within the size requirement.
198     size_hint = std::max(size_hint, int64{2});
199     if (entries_.size() <= size_hint) {
200       return;
201     }
202 
203     // First compute the max error bound delta resulting from this compression.
204     double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
205 
206     // Compress elements ensuring approximation bounds and elements diversity
207     // are both maintained.
208     int64 add_accumulator = 0, add_step = entries_.size();
209     auto write_it = entries_.begin() + 1, last_it = write_it;
210     for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
211       auto next_it = read_it + 1;
212       while (next_it != entries_.end() && add_accumulator < add_step &&
213              next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
214         add_accumulator += size_hint;
215         ++next_it;
216       }
217       if (read_it == next_it - 1) {
218         ++read_it;
219       } else {
220         read_it = next_it - 1;
221       }
222       (*write_it++) = (*read_it);
223       last_it = read_it;
224       add_accumulator -= add_step;
225     }
226     // Write last element and resize.
227     if (last_it + 1 != entries_.end()) {
228       (*write_it++) = entries_.back();
229     }
230     entries_.resize(write_it - entries_.begin());
231   }
232 
233   // To construct the boundaries we first run a soft compress over a copy
234   // of the summary and retrieve the values.
235   // The resulting boundaries are guaranteed to both contain at least
236   // num_boundaries unique elements and maintain approximation bounds.
GenerateBoundaries(int64 num_boundaries)237   std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
238     std::vector<ValueType> output;
239     if (entries_.empty()) {
240       return output;
241     }
242 
243     // Generate soft compressed summary.
244     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
245         compressed_summary;
246     compressed_summary.BuildFromSummaryEntries(entries_);
247     // Set an epsilon for compression that's at most 1.0 / num_boundaries
248     // more than epsilon of original our summary since the compression operation
249     // adds ~1.0/num_boundaries to final approximation error.
250     float compression_eps = ApproximationError() + (1.0 / num_boundaries);
251     compressed_summary.Compress(num_boundaries, compression_eps);
252 
253     // Return boundaries.
254     output.reserve(compressed_summary.entries_.size());
255     for (const auto& entry : compressed_summary.entries_) {
256       output.push_back(entry.value);
257     }
258     return output;
259   }
260 
261   // To construct the desired n-quantiles we repetitively query n ranks from the
262   // original summary. The following algorithm is an efficient cache-friendly
263   // O(n) implementation of that idea which avoids the cost of the repetitive
264   // full rank queries O(nlogn).
GenerateQuantiles(int64 num_quantiles)265   std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
266     std::vector<ValueType> output;
267     if (entries_.empty()) {
268       return output;
269     }
270     num_quantiles = std::max(num_quantiles, int64{2});
271     output.reserve(num_quantiles + 1);
272 
273     // Make successive rank queries to get boundaries.
274     // We always keep the first (min) and last (max) entries.
275     for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
276       // This step boils down to finding the next element sub-range defined by
277       // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
278       WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
279       size_t next_idx = cur_idx + 1;
280       while (next_idx < entries_.size() &&
281              d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
282         ++next_idx;
283       }
284       cur_idx = next_idx - 1;
285 
286       // Determine insertion order.
287       if (next_idx == entries_.size() ||
288           d_2 < entries_[cur_idx].NextMinRank() +
289                     entries_[next_idx].PrevMaxRank()) {
290         output.push_back(entries_[cur_idx].value);
291       } else {
292         output.push_back(entries_[next_idx].value);
293       }
294     }
295     return output;
296   }
297 
298   // Calculates current approximation error which should always be <= eps.
ApproximationError()299   double ApproximationError() const {
300     if (entries_.empty()) {
301       return 0;
302     }
303 
304     WeightType max_gap = 0;
305     for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
306       max_gap = std::max(max_gap,
307                          std::max(it->max_rank - it->min_rank - it->weight,
308                                   it->PrevMaxRank() - (it - 1)->NextMinRank()));
309     }
310     return static_cast<double>(max_gap) / TotalWeight();
311   }
312 
MinValue()313   ValueType MinValue() const {
314     return !entries_.empty() ? entries_.front().value
315                              : std::numeric_limits<ValueType>::max();
316   }
MaxValue()317   ValueType MaxValue() const {
318     return !entries_.empty() ? entries_.back().value
319                              : std::numeric_limits<ValueType>::max();
320   }
TotalWeight()321   WeightType TotalWeight() const {
322     return !entries_.empty() ? entries_.back().max_rank : 0;
323   }
Size()324   int64 Size() const { return entries_.size(); }
Clear()325   void Clear() { entries_.clear(); }
GetEntryList()326   const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
327 
328  private:
329   // Comparison function.
330   static constexpr decltype(CompareFn()) kCompFn = CompareFn();
331 
332   // Summary entries.
333   std::vector<SummaryEntry> entries_;
334 };
335 
336 template <typename ValueType, typename WeightType, typename CompareFn>
337 constexpr decltype(CompareFn())
338     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
339 
340 }  // namespace quantiles
341 }  // namespace boosted_trees
342 }  // namespace tensorflow
343 
344 #endif  // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
345