1 // Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 17 18 #include <cstring> 19 #include <vector> 20 21 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h" 22 23 namespace tensorflow { 24 namespace boosted_trees { 25 namespace quantiles { 26 27 // Summary holding a sorted block of entries with upper bound guarantees 28 // over the approximation error. 29 template <typename ValueType, typename WeightType, 30 typename CompareFn = std::less<ValueType>> 31 class WeightedQuantilesSummary { 32 public: 33 using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>; 34 using BufferEntry = typename Buffer::BufferEntry; 35 36 struct SummaryEntry { SummaryEntrySummaryEntry37 SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, 38 const WeightType& max) { 39 // Explicitly initialize all of memory (including padding from memory 40 // alignment) to allow the struct to be msan-resistant "plain old data". 41 // 42 // POD = https://en.cppreference.com/w/cpp/named_req/PODType 43 memset(this, 0, sizeof(*this)); 44 45 value = v; 46 weight = w; 47 min_rank = min; 48 max_rank = max; 49 } 50 SummaryEntrySummaryEntry51 SummaryEntry() { 52 memset(this, 0, sizeof(*this)); 53 54 value = ValueType(); 55 weight = 0; 56 min_rank = 0; 57 max_rank = 0; 58 } 59 60 bool operator==(const SummaryEntry& other) const { 61 return value == other.value && weight == other.weight && 62 min_rank == other.min_rank && max_rank == other.max_rank; 63 } 64 friend std::ostream& operator<<(std::ostream& strm, 65 const SummaryEntry& entry) { 66 return strm << "{" << entry.value << ", " << entry.weight << ", " 67 << entry.min_rank << ", " << entry.max_rank << "}"; 68 } 69 70 // Max rank estimate for previous smaller value. PrevMaxRankSummaryEntry71 WeightType PrevMaxRank() const { return max_rank - weight; } 72 73 // Min rank estimate for next larger value. NextMinRankSummaryEntry74 WeightType NextMinRank() const { return min_rank + weight; } 75 76 ValueType value; 77 WeightType weight; 78 WeightType min_rank; 79 WeightType max_rank; 80 }; 81 82 // Re-construct summary from the specified buffer. BuildFromBufferEntries(const std::vector<BufferEntry> & buffer_entries)83 void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) { 84 entries_.clear(); 85 entries_.reserve(buffer_entries.size()); 86 WeightType cumulative_weight = 0; 87 for (const auto& entry : buffer_entries) { 88 WeightType current_weight = entry.weight; 89 entries_.emplace_back(entry.value, entry.weight, cumulative_weight, 90 cumulative_weight + current_weight); 91 cumulative_weight += current_weight; 92 } 93 } 94 95 // Re-construct summary from the specified summary entries. BuildFromSummaryEntries(const std::vector<SummaryEntry> & summary_entries)96 void BuildFromSummaryEntries( 97 const std::vector<SummaryEntry>& summary_entries) { 98 entries_.clear(); 99 entries_.reserve(summary_entries.size()); 100 entries_.insert(entries_.begin(), summary_entries.begin(), 101 summary_entries.end()); 102 } 103 104 // Merges two summaries through an algorithm that's derived from MergeSort 105 // for summary entries while guaranteeing that the max approximation error 106 // of the final merged summary is no greater than the approximation errors 107 // of each individual summary. 108 // For example consider summaries where each entry is of the form 109 // (element, weight, min rank, max rank): 110 // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5) 111 // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2) 112 // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7). Merge(const WeightedQuantilesSummary & other_summary)113 void Merge(const WeightedQuantilesSummary& other_summary) { 114 // Make sure we have something to merge. 115 const auto& other_entries = other_summary.entries_; 116 if (other_entries.empty()) { 117 return; 118 } 119 if (entries_.empty()) { 120 BuildFromSummaryEntries(other_summary.entries_); 121 return; 122 } 123 124 // Move current entries to make room for a new buffer. 125 std::vector<SummaryEntry> base_entries(std::move(entries_)); 126 entries_.clear(); 127 entries_.reserve(base_entries.size() + other_entries.size()); 128 129 // Merge entries maintaining ranks. The idea is to stack values 130 // in order which we can do in linear time as the two summaries are 131 // already sorted. We keep track of the next lower rank from either 132 // summary and update it as we pop elements from the summaries. 133 // We handle the special case when the next two elements from either 134 // summary are equal, in which case we just merge the two elements 135 // and simultaneously update both ranks. 136 auto it1 = base_entries.cbegin(); 137 auto it2 = other_entries.cbegin(); 138 WeightType next_min_rank1 = 0; 139 WeightType next_min_rank2 = 0; 140 while (it1 != base_entries.cend() && it2 != other_entries.cend()) { 141 if (kCompFn(it1->value, it2->value)) { // value1 < value2 142 // Take value1 and use the last added value2 to compute 143 // the min rank and the current value2 to compute the max rank. 144 entries_.emplace_back(it1->value, it1->weight, 145 it1->min_rank + next_min_rank2, 146 it1->max_rank + it2->PrevMaxRank()); 147 // Update next min rank 1. 148 next_min_rank1 = it1->NextMinRank(); 149 ++it1; 150 } else if (kCompFn(it2->value, it1->value)) { // value1 > value2 151 // Take value2 and use the last added value1 to compute 152 // the min rank and the current value1 to compute the max rank. 153 entries_.emplace_back(it2->value, it2->weight, 154 it2->min_rank + next_min_rank1, 155 it2->max_rank + it1->PrevMaxRank()); 156 // Update next min rank 2. 157 next_min_rank2 = it2->NextMinRank(); 158 ++it2; 159 } else { // value1 == value2 160 // Straight additive merger of the two entries into one. 161 entries_.emplace_back(it1->value, it1->weight + it2->weight, 162 it1->min_rank + it2->min_rank, 163 it1->max_rank + it2->max_rank); 164 // Update next min ranks for both. 165 next_min_rank1 = it1->NextMinRank(); 166 next_min_rank2 = it2->NextMinRank(); 167 ++it1; 168 ++it2; 169 } 170 } 171 172 // Fill in any residual. 173 while (it1 != base_entries.cend()) { 174 entries_.emplace_back(it1->value, it1->weight, 175 it1->min_rank + next_min_rank2, 176 it1->max_rank + other_entries.back().max_rank); 177 ++it1; 178 } 179 while (it2 != other_entries.cend()) { 180 entries_.emplace_back(it2->value, it2->weight, 181 it2->min_rank + next_min_rank1, 182 it2->max_rank + base_entries.back().max_rank); 183 ++it2; 184 } 185 } 186 187 // Compresses buffer into desired size. The size specification is 188 // considered a hint as we always keep the first and last elements and 189 // maintain strict approximation error bounds. 190 // The approximation error delta is taken as the max of either the requested 191 // min error or 1 / size_hint. 192 // After compression, the approximation error is guaranteed to increase 193 // by no more than that error delta. 194 // This algorithm is linear in the original size of the summary and is 195 // designed to be cache-friendly. 196 void Compress(int64 size_hint, double min_eps = 0) { 197 // No-op if we're already within the size requirement. 198 size_hint = std::max(size_hint, int64{2}); 199 if (entries_.size() <= size_hint) { 200 return; 201 } 202 203 // First compute the max error bound delta resulting from this compression. 204 double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps); 205 206 // Compress elements ensuring approximation bounds and elements diversity 207 // are both maintained. 208 int64 add_accumulator = 0, add_step = entries_.size(); 209 auto write_it = entries_.begin() + 1, last_it = write_it; 210 for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) { 211 auto next_it = read_it + 1; 212 while (next_it != entries_.end() && add_accumulator < add_step && 213 next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) { 214 add_accumulator += size_hint; 215 ++next_it; 216 } 217 if (read_it == next_it - 1) { 218 ++read_it; 219 } else { 220 read_it = next_it - 1; 221 } 222 (*write_it++) = (*read_it); 223 last_it = read_it; 224 add_accumulator -= add_step; 225 } 226 // Write last element and resize. 227 if (last_it + 1 != entries_.end()) { 228 (*write_it++) = entries_.back(); 229 } 230 entries_.resize(write_it - entries_.begin()); 231 } 232 233 // To construct the boundaries we first run a soft compress over a copy 234 // of the summary and retrieve the values. 235 // The resulting boundaries are guaranteed to both contain at least 236 // num_boundaries unique elements and maintain approximation bounds. GenerateBoundaries(int64 num_boundaries)237 std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const { 238 std::vector<ValueType> output; 239 if (entries_.empty()) { 240 return output; 241 } 242 243 // Generate soft compressed summary. 244 WeightedQuantilesSummary<ValueType, WeightType, CompareFn> 245 compressed_summary; 246 compressed_summary.BuildFromSummaryEntries(entries_); 247 // Set an epsilon for compression that's at most 1.0 / num_boundaries 248 // more than epsilon of original our summary since the compression operation 249 // adds ~1.0/num_boundaries to final approximation error. 250 float compression_eps = ApproximationError() + (1.0 / num_boundaries); 251 compressed_summary.Compress(num_boundaries, compression_eps); 252 253 // Return boundaries. 254 output.reserve(compressed_summary.entries_.size()); 255 for (const auto& entry : compressed_summary.entries_) { 256 output.push_back(entry.value); 257 } 258 return output; 259 } 260 261 // To construct the desired n-quantiles we repetitively query n ranks from the 262 // original summary. The following algorithm is an efficient cache-friendly 263 // O(n) implementation of that idea which avoids the cost of the repetitive 264 // full rank queries O(nlogn). GenerateQuantiles(int64 num_quantiles)265 std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const { 266 std::vector<ValueType> output; 267 if (entries_.empty()) { 268 return output; 269 } 270 num_quantiles = std::max(num_quantiles, int64{2}); 271 output.reserve(num_quantiles + 1); 272 273 // Make successive rank queries to get boundaries. 274 // We always keep the first (min) and last (max) entries. 275 for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) { 276 // This step boils down to finding the next element sub-range defined by 277 // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r. 278 WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles); 279 size_t next_idx = cur_idx + 1; 280 while (next_idx < entries_.size() && 281 d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) { 282 ++next_idx; 283 } 284 cur_idx = next_idx - 1; 285 286 // Determine insertion order. 287 if (next_idx == entries_.size() || 288 d_2 < entries_[cur_idx].NextMinRank() + 289 entries_[next_idx].PrevMaxRank()) { 290 output.push_back(entries_[cur_idx].value); 291 } else { 292 output.push_back(entries_[next_idx].value); 293 } 294 } 295 return output; 296 } 297 298 // Calculates current approximation error which should always be <= eps. ApproximationError()299 double ApproximationError() const { 300 if (entries_.empty()) { 301 return 0; 302 } 303 304 WeightType max_gap = 0; 305 for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) { 306 max_gap = std::max(max_gap, 307 std::max(it->max_rank - it->min_rank - it->weight, 308 it->PrevMaxRank() - (it - 1)->NextMinRank())); 309 } 310 return static_cast<double>(max_gap) / TotalWeight(); 311 } 312 MinValue()313 ValueType MinValue() const { 314 return !entries_.empty() ? entries_.front().value 315 : std::numeric_limits<ValueType>::max(); 316 } MaxValue()317 ValueType MaxValue() const { 318 return !entries_.empty() ? entries_.back().value 319 : std::numeric_limits<ValueType>::max(); 320 } TotalWeight()321 WeightType TotalWeight() const { 322 return !entries_.empty() ? entries_.back().max_rank : 0; 323 } Size()324 int64 Size() const { return entries_.size(); } Clear()325 void Clear() { entries_.clear(); } GetEntryList()326 const std::vector<SummaryEntry>& GetEntryList() const { return entries_; } 327 328 private: 329 // Comparison function. 330 static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 331 332 // Summary entries. 333 std::vector<SummaryEntry> entries_; 334 }; 335 336 template <typename ValueType, typename WeightType, typename CompareFn> 337 constexpr decltype(CompareFn()) 338 WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn; 339 340 } // namespace quantiles 341 } // namespace boosted_trees 342 } // namespace tensorflow 343 344 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ 345