1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
16 #define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
17 
18 #include <algorithm>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/platform/logging.h"
23 #include "tensorflow/core/platform/types.h"
24 
25 namespace tensorflow {
26 namespace boosted_trees {
27 namespace quantiles {
28 
29 // Buffering container ideally suited for scenarios where we need
30 // to sort and dedupe/compact fixed chunks of a stream of weighted elements.
31 template <typename ValueType, typename WeightType,
32           typename CompareFn = std::less<ValueType>>
33 class WeightedQuantilesBuffer {
34  public:
35   struct BufferEntry {
BufferEntryBufferEntry36     BufferEntry(ValueType v, WeightType w)
37         : value(std::move(v)), weight(std::move(w)) {}
BufferEntryBufferEntry38     BufferEntry() : value(), weight(0) {}
39 
40     bool operator<(const BufferEntry& other) const {
41       return kCompFn(value, other.value);
42     }
43     bool operator==(const BufferEntry& other) const {
44       return value == other.value && weight == other.weight;
45     }
46     friend std::ostream& operator<<(std::ostream& strm,
47                                     const BufferEntry& entry) {
48       return strm << "{" << entry.value << ", " << entry.weight << "}";
49     }
50     ValueType value;
51     WeightType weight;
52   };
53 
WeightedQuantilesBuffer(int64 block_size,int64 max_elements)54   explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
55       : max_size_(std::min(block_size << 1, max_elements)) {
56     QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
57                           << ", " << max_elements << ")";
58     vec_.reserve(max_size_);
59   }
60 
61   // Disallow copying as it's semantically non-sensical in the Squawd algorithm
62   // but enable move semantics.
63   WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
64   WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
65   WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
66   WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;
67 
68   // Push entry to buffer and maintain a compact representation within
69   // pre-defined size limit.
PushEntry(ValueType value,WeightType weight)70   void PushEntry(ValueType value, WeightType weight) {
71     // Callers are expected to act on a full compacted buffer after the
72     // PushEntry call returns.
73     QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
74 
75     // Ignore zero and negative weight entries.
76     if (weight <= 0) {
77       return;
78     }
79 
80     // Push back the entry to the buffer.
81     vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
82   }
83 
84   // Returns a sorted vector view of the base buffer and clears the buffer.
85   // Callers should minimize how often this is called, ideally only right after
86   // the buffer becomes full.
GenerateEntryList()87   std::vector<BufferEntry> GenerateEntryList() {
88     std::vector<BufferEntry> ret;
89     if (vec_.size() == 0) {
90       return ret;
91     }
92     ret.swap(vec_);
93     vec_.reserve(max_size_);
94     std::sort(ret.begin(), ret.end());
95     size_t num_entries = 0;
96     for (size_t i = 1; i < ret.size(); ++i) {
97       if (ret[i].value != ret[i - 1].value) {
98         BufferEntry tmp = ret[i];
99         ++num_entries;
100         ret[num_entries] = tmp;
101       } else {
102         ret[num_entries].weight += ret[i].weight;
103       }
104     }
105     ret.resize(num_entries + 1);
106     return ret;
107   }
108 
Size()109   int64 Size() const { return vec_.size(); }
IsFull()110   bool IsFull() const { return vec_.size() >= max_size_; }
Clear()111   void Clear() { vec_.clear(); }
112 
113  private:
114   using BufferVector = typename std::vector<BufferEntry>;
115 
116   // Comparison function.
117   static constexpr decltype(CompareFn()) kCompFn = CompareFn();
118 
119   // Base buffer.
120   size_t max_size_;
121   BufferVector vec_;
122 };
123 
124 template <typename ValueType, typename WeightType, typename CompareFn>
125 constexpr decltype(CompareFn())
126     WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;
127 
128 }  // namespace quantiles
129 }  // namespace boosted_trees
130 }  // namespace tensorflow
131 
132 #endif  // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
133