1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
17 
18 #include <cstring>
19 #include <vector>
20 
21 #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h"
22 
23 namespace tensorflow {
24 namespace boosted_trees {
25 namespace quantiles {
26 
27 // Summary holding a sorted block of entries with upper bound guarantees
28 // over the approximation error.
29 template <typename ValueType, typename WeightType,
30           typename CompareFn = std::less<ValueType>>
31 class WeightedQuantilesSummary {
32  public:
33   using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
34   using BufferEntry = typename Buffer::BufferEntry;
35 
36   struct SummaryEntry {
SummaryEntrySummaryEntry37     SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
38                  const WeightType& max) {
39       value = v;
40       weight = w;
41       min_rank = min;
42       max_rank = max;
43     }
44 
SummaryEntrySummaryEntry45     SummaryEntry() {
46       value = ValueType();
47       weight = 0;
48       min_rank = 0;
49       max_rank = 0;
50     }
51 
52     bool operator==(const SummaryEntry& other) const {
53       return value == other.value && weight == other.weight &&
54              min_rank == other.min_rank && max_rank == other.max_rank;
55     }
56     friend std::ostream& operator<<(std::ostream& strm,
57                                     const SummaryEntry& entry) {
58       return strm << "{" << entry.value << ", " << entry.weight << ", "
59                   << entry.min_rank << ", " << entry.max_rank << "}";
60     }
61 
62     // Max rank estimate for previous smaller value.
PrevMaxRankSummaryEntry63     WeightType PrevMaxRank() const { return max_rank - weight; }
64 
65     // Min rank estimate for next larger value.
NextMinRankSummaryEntry66     WeightType NextMinRank() const { return min_rank + weight; }
67 
68     ValueType value;
69     WeightType weight;
70     WeightType min_rank;
71     WeightType max_rank;
72   };
73 
74   // Re-construct summary from the specified buffer.
BuildFromBufferEntries(const std::vector<BufferEntry> & buffer_entries)75   void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
76     entries_.clear();
77     entries_.reserve(buffer_entries.size());
78     WeightType cumulative_weight = 0;
79     for (const auto& entry : buffer_entries) {
80       WeightType current_weight = entry.weight;
81       entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
82                             cumulative_weight + current_weight);
83       cumulative_weight += current_weight;
84     }
85   }
86 
87   // Re-construct summary from the specified summary entries.
BuildFromSummaryEntries(const std::vector<SummaryEntry> & summary_entries)88   void BuildFromSummaryEntries(
89       const std::vector<SummaryEntry>& summary_entries) {
90     entries_.clear();
91     entries_.reserve(summary_entries.size());
92     entries_.insert(entries_.begin(), summary_entries.begin(),
93                     summary_entries.end());
94   }
95 
96   // Merges two summaries through an algorithm that's derived from MergeSort
97   // for summary entries while guaranteeing that the max approximation error
98   // of the final merged summary is no greater than the approximation errors
99   // of each individual summary.
100   // For example consider summaries where each entry is of the form
101   // (element, weight, min rank, max rank):
102   // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
103   // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
104   // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
Merge(const WeightedQuantilesSummary & other_summary)105   void Merge(const WeightedQuantilesSummary& other_summary) {
106     // Make sure we have something to merge.
107     const auto& other_entries = other_summary.entries_;
108     if (other_entries.empty()) {
109       return;
110     }
111     if (entries_.empty()) {
112       BuildFromSummaryEntries(other_summary.entries_);
113       return;
114     }
115 
116     // Move current entries to make room for a new buffer.
117     std::vector<SummaryEntry> base_entries(std::move(entries_));
118     entries_.clear();
119     entries_.reserve(base_entries.size() + other_entries.size());
120 
121     // Merge entries maintaining ranks. The idea is to stack values
122     // in order which we can do in linear time as the two summaries are
123     // already sorted. We keep track of the next lower rank from either
124     // summary and update it as we pop elements from the summaries.
125     // We handle the special case when the next two elements from either
126     // summary are equal, in which case we just merge the two elements
127     // and simultaneously update both ranks.
128     auto it1 = base_entries.cbegin();
129     auto it2 = other_entries.cbegin();
130     WeightType next_min_rank1 = 0;
131     WeightType next_min_rank2 = 0;
132     while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
133       if (kCompFn(it1->value, it2->value)) {  // value1 < value2
134         // Take value1 and use the last added value2 to compute
135         // the min rank and the current value2 to compute the max rank.
136         entries_.emplace_back(it1->value, it1->weight,
137                               it1->min_rank + next_min_rank2,
138                               it1->max_rank + it2->PrevMaxRank());
139         // Update next min rank 1.
140         next_min_rank1 = it1->NextMinRank();
141         ++it1;
142       } else if (kCompFn(it2->value, it1->value)) {  // value1 > value2
143         // Take value2 and use the last added value1 to compute
144         // the min rank and the current value1 to compute the max rank.
145         entries_.emplace_back(it2->value, it2->weight,
146                               it2->min_rank + next_min_rank1,
147                               it2->max_rank + it1->PrevMaxRank());
148         // Update next min rank 2.
149         next_min_rank2 = it2->NextMinRank();
150         ++it2;
151       } else {  // value1 == value2
152         // Straight additive merger of the two entries into one.
153         entries_.emplace_back(it1->value, it1->weight + it2->weight,
154                               it1->min_rank + it2->min_rank,
155                               it1->max_rank + it2->max_rank);
156         // Update next min ranks for both.
157         next_min_rank1 = it1->NextMinRank();
158         next_min_rank2 = it2->NextMinRank();
159         ++it1;
160         ++it2;
161       }
162     }
163 
164     // Fill in any residual.
165     while (it1 != base_entries.cend()) {
166       entries_.emplace_back(it1->value, it1->weight,
167                             it1->min_rank + next_min_rank2,
168                             it1->max_rank + other_entries.back().max_rank);
169       ++it1;
170     }
171     while (it2 != other_entries.cend()) {
172       entries_.emplace_back(it2->value, it2->weight,
173                             it2->min_rank + next_min_rank1,
174                             it2->max_rank + base_entries.back().max_rank);
175       ++it2;
176     }
177   }
178 
179   // Compresses buffer into desired size. The size specification is
180   // considered a hint as we always keep the first and last elements and
181   // maintain strict approximation error bounds.
182   // The approximation error delta is taken as the max of either the requested
183   // min error or 1 / size_hint.
184   // After compression, the approximation error is guaranteed to increase
185   // by no more than that error delta.
186   // This algorithm is linear in the original size of the summary and is
187   // designed to be cache-friendly.
188   void Compress(int64 size_hint, double min_eps = 0) {
189     // No-op if we're already within the size requirement.
190     size_hint = std::max(size_hint, int64{2});
191     if (entries_.size() <= size_hint) {
192       return;
193     }
194 
195     // First compute the max error bound delta resulting from this compression.
196     double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
197 
198     // Compress elements ensuring approximation bounds and elements diversity
199     // are both maintained.
200     int64 add_accumulator = 0, add_step = entries_.size();
201     auto write_it = entries_.begin() + 1, last_it = write_it;
202     for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
203       auto next_it = read_it + 1;
204       while (next_it != entries_.end() && add_accumulator < add_step &&
205              next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
206         add_accumulator += size_hint;
207         ++next_it;
208       }
209       if (read_it == next_it - 1) {
210         ++read_it;
211       } else {
212         read_it = next_it - 1;
213       }
214       (*write_it++) = (*read_it);
215       last_it = read_it;
216       add_accumulator -= add_step;
217     }
218     // Write last element and resize.
219     if (last_it + 1 != entries_.end()) {
220       (*write_it++) = entries_.back();
221     }
222     entries_.resize(write_it - entries_.begin());
223   }
224 
225   // To construct the boundaries we first run a soft compress over a copy
226   // of the summary and retrieve the values.
227   // The resulting boundaries are guaranteed to both contain at least
228   // num_boundaries unique elements and maintain approximation bounds.
GenerateBoundaries(int64 num_boundaries)229   std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
230     std::vector<ValueType> output;
231     if (entries_.empty()) {
232       return output;
233     }
234 
235     // Generate soft compressed summary.
236     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
237         compressed_summary;
238     compressed_summary.BuildFromSummaryEntries(entries_);
239     // Set an epsilon for compression that's at most 1.0 / num_boundaries
240     // more than epsilon of original our summary since the compression operation
241     // adds ~1.0/num_boundaries to final approximation error.
242     float compression_eps = ApproximationError() + (1.0 / num_boundaries);
243     compressed_summary.Compress(num_boundaries, compression_eps);
244 
245     // Return boundaries.
246     output.reserve(compressed_summary.entries_.size());
247     for (const auto& entry : compressed_summary.entries_) {
248       output.push_back(entry.value);
249     }
250     return output;
251   }
252 
253   // To construct the desired n-quantiles we repetitively query n ranks from the
254   // original summary. The following algorithm is an efficient cache-friendly
255   // O(n) implementation of that idea which avoids the cost of the repetitive
256   // full rank queries O(nlogn).
GenerateQuantiles(int64 num_quantiles)257   std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
258     std::vector<ValueType> output;
259     if (entries_.empty()) {
260       return output;
261     }
262     num_quantiles = std::max(num_quantiles, int64{2});
263     output.reserve(num_quantiles + 1);
264 
265     // Make successive rank queries to get boundaries.
266     // We always keep the first (min) and last (max) entries.
267     for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
268       // This step boils down to finding the next element sub-range defined by
269       // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
270       WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
271       size_t next_idx = cur_idx + 1;
272       while (next_idx < entries_.size() &&
273              d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
274         ++next_idx;
275       }
276       cur_idx = next_idx - 1;
277 
278       // Determine insertion order.
279       if (next_idx == entries_.size() ||
280           d_2 < entries_[cur_idx].NextMinRank() +
281                     entries_[next_idx].PrevMaxRank()) {
282         output.push_back(entries_[cur_idx].value);
283       } else {
284         output.push_back(entries_[next_idx].value);
285       }
286     }
287     return output;
288   }
289 
290   // Calculates current approximation error which should always be <= eps.
ApproximationError()291   double ApproximationError() const {
292     if (entries_.empty()) {
293       return 0;
294     }
295 
296     WeightType max_gap = 0;
297     for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
298       max_gap = std::max(max_gap,
299                          std::max(it->max_rank - it->min_rank - it->weight,
300                                   it->PrevMaxRank() - (it - 1)->NextMinRank()));
301     }
302     return static_cast<double>(max_gap) / TotalWeight();
303   }
304 
MinValue()305   ValueType MinValue() const {
306     return !entries_.empty() ? entries_.front().value
307                              : std::numeric_limits<ValueType>::max();
308   }
MaxValue()309   ValueType MaxValue() const {
310     return !entries_.empty() ? entries_.back().value
311                              : std::numeric_limits<ValueType>::max();
312   }
TotalWeight()313   WeightType TotalWeight() const {
314     return !entries_.empty() ? entries_.back().max_rank : 0;
315   }
Size()316   int64 Size() const { return entries_.size(); }
Clear()317   void Clear() { entries_.clear(); }
GetEntryList()318   const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
319 
320  private:
321   // Comparison function.
322   static constexpr decltype(CompareFn()) kCompFn = CompareFn();
323 
324   // Summary entries.
325   std::vector<SummaryEntry> entries_;
326 };
327 
328 template <typename ValueType, typename WeightType, typename CompareFn>
329 constexpr decltype(CompareFn())
330     WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
331 
332 }  // namespace quantiles
333 }  // namespace boosted_trees
334 }  // namespace tensorflow
335 
336 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
337