1 // Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 17 18 #include <algorithm> 19 #include <unordered_map> 20 #include <vector> 21 22 #include "tensorflow/core/platform/logging.h" 23 #include "tensorflow/core/platform/types.h" 24 25 namespace tensorflow { 26 namespace boosted_trees { 27 namespace quantiles { 28 29 // Buffering container ideally suited for scenarios where we need 30 // to sort and dedupe/compact fixed chunks of a stream of weighted elements. 31 template <typename ValueType, typename WeightType, 32 typename CompareFn = std::less<ValueType>> 33 class WeightedQuantilesBuffer { 34 public: 35 struct BufferEntry { BufferEntryBufferEntry36 BufferEntry(ValueType v, WeightType w) 37 : value(std::move(v)), weight(std::move(w)) {} BufferEntryBufferEntry38 BufferEntry() : value(), weight(0) {} 39 40 bool operator<(const BufferEntry& other) const { 41 return kCompFn(value, other.value); 42 } 43 bool operator==(const BufferEntry& other) const { 44 return value == other.value && weight == other.weight; 45 } 46 friend std::ostream& operator<<(std::ostream& strm, 47 const BufferEntry& entry) { 48 return strm << "{" << entry.value << ", " << entry.weight << "}"; 49 } 50 ValueType value; 51 WeightType weight; 52 }; 53 WeightedQuantilesBuffer(int64 block_size,int64 max_elements)54 explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements) 55 : max_size_(std::min(block_size << 1, max_elements)) { 56 QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size 57 << ", " << max_elements << ")"; 58 vec_.reserve(max_size_); 59 } 60 61 // Disallow copying as it's semantically non-sensical in the Squawd algorithm 62 // but enable move semantics. 63 WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete; 64 WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete; 65 WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default; 66 WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default; 67 68 // Push entry to buffer and maintain a compact representation within 69 // pre-defined size limit. PushEntry(ValueType value,WeightType weight)70 void PushEntry(ValueType value, WeightType weight) { 71 // Callers are expected to act on a full compacted buffer after the 72 // PushEntry call returns. 73 QCHECK(!IsFull()) << "Buffer already full: " << max_size_; 74 75 // Ignore zero and negative weight entries. 76 if (weight <= 0) { 77 return; 78 } 79 80 // Push back the entry to the buffer. 81 vec_.push_back(BufferEntry(std::move(value), std::move(weight))); 82 } 83 84 // Returns a sorted vector view of the base buffer and clears the buffer. 85 // Callers should minimize how often this is called, ideally only right after 86 // the buffer becomes full. GenerateEntryList()87 std::vector<BufferEntry> GenerateEntryList() { 88 std::vector<BufferEntry> ret; 89 if (vec_.size() == 0) { 90 return ret; 91 } 92 ret.swap(vec_); 93 vec_.reserve(max_size_); 94 std::sort(ret.begin(), ret.end()); 95 size_t num_entries = 0; 96 for (size_t i = 1; i < ret.size(); ++i) { 97 if (ret[i].value != ret[i - 1].value) { 98 BufferEntry tmp = ret[i]; 99 ++num_entries; 100 ret[num_entries] = tmp; 101 } else { 102 ret[num_entries].weight += ret[i].weight; 103 } 104 } 105 ret.resize(num_entries + 1); 106 return ret; 107 } 108 Size()109 int64 Size() const { return vec_.size(); } IsFull()110 bool IsFull() const { return vec_.size() >= max_size_; } Clear()111 void Clear() { vec_.clear(); } 112 113 private: 114 using BufferVector = typename std::vector<BufferEntry>; 115 116 // Comparison function. 117 static constexpr decltype(CompareFn()) kCompFn = CompareFn(); 118 119 // Base buffer. 120 size_t max_size_; 121 BufferVector vec_; 122 }; 123 124 template <typename ValueType, typename WeightType, typename CompareFn> 125 constexpr decltype(CompareFn()) 126 WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn; 127 128 } // namespace quantiles 129 } // namespace boosted_trees 130 } // namespace tensorflow 131 132 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ 133