1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ 16 #define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ 17 18 #include <vector> 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" 21 #include "tensorflow/core/platform/macros.h" 22 #include "tensorflow/core/platform/mutex.h" 23 24 namespace tensorflow { 25 26 using QuantileStream = 27 boosted_trees::quantiles::WeightedQuantilesStream<float, float>; 28 29 // Quantile Stream Resource for a list of streams sharing the same number of 30 // quantiles, maximum elements, and epsilon. 31 class BoostedTreesQuantileStreamResource : public ResourceBase { 32 public: BoostedTreesQuantileStreamResource(const float epsilon,const int64 max_elements,const int64 num_streams)33 BoostedTreesQuantileStreamResource(const float epsilon, 34 const int64 max_elements, 35 const int64 num_streams) 36 : are_buckets_ready_(false), 37 epsilon_(epsilon), 38 num_streams_(num_streams), 39 max_elements_(max_elements) { 40 streams_.reserve(num_streams_); 41 boundaries_.reserve(num_streams_); 42 for (int64 idx = 0; idx < num_streams; ++idx) { 43 streams_.push_back(QuantileStream(epsilon, max_elements)); 44 boundaries_.push_back(std::vector<float>()); 45 } 46 } 47 DebugString()48 string DebugString() const override { return "QuantileStreamResource"; } 49 mutex()50 tensorflow::mutex* mutex() { return &mu_; } 51 stream(const int64 index)52 QuantileStream* stream(const int64 index) { return &streams_[index]; } 53 boundaries(const int64 index)54 const std::vector<float>& boundaries(const int64 index) { 55 return boundaries_[index]; 56 } 57 set_boundaries(const std::vector<float> & boundaries,const int64 index)58 void set_boundaries(const std::vector<float>& boundaries, const int64 index) { 59 boundaries_[index] = boundaries; 60 } 61 epsilon()62 float epsilon() const { return epsilon_; } num_streams()63 int64 num_streams() const { return num_streams_; } 64 are_buckets_ready()65 bool are_buckets_ready() const { return are_buckets_ready_; } set_buckets_ready(const bool are_buckets_ready)66 void set_buckets_ready(const bool are_buckets_ready) { 67 are_buckets_ready_ = are_buckets_ready; 68 } 69 70 private: ~BoostedTreesQuantileStreamResource()71 ~BoostedTreesQuantileStreamResource() override {} 72 73 // Mutex for the whole resource. 74 tensorflow::mutex mu_; 75 76 // Quantile streams. 77 std::vector<QuantileStream> streams_; 78 79 // Stores the boundaries. Same size as streams_. 80 std::vector<std::vector<float>> boundaries_; 81 82 // Whether boundaries are created. Initially boundaries are empty until 83 // set_boundaries are called. 84 bool are_buckets_ready_; 85 86 const float epsilon_; 87 const int64 num_streams_; 88 // An upper-bound for the number of elements. 89 int64 max_elements_; 90 91 TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource); 92 }; 93 94 } // namespace tensorflow 95 96 #endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ 97