1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ 17 #include <ctime> 18 #include <unordered_map> 19 #include "google/protobuf/any.pb.h" 20 #include "google/protobuf/wrappers.pb.h" 21 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h" 22 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h" 23 #include "tensorflow/core/framework/tensor.h" 24 #include "tensorflow/core/framework/tensor_types.h" 25 #include "tensorflow/core/lib/random/philox_random.h" 26 #include "tensorflow/core/lib/random/random.h" 27 #include "tensorflow/core/lib/random/simple_philox.h" 28 #include "tensorflow/core/platform/mutex.h" 29 30 namespace tensorflow { 31 namespace tensorforest { 32 33 typedef TTypes<const float, 2>::ConstTensor DenseStorageType; 34 typedef TTypes<const int64, 2>::ConstTensor SparseIndicesStorageType; 35 typedef TTypes<const float, 1>::ConstTensor SparseValuesStorageType; 36 37 class TensorDataSet { 38 public: TensorDataSet(const tensorforest::TensorForestDataSpec & input_spec,int32 seed)39 TensorDataSet(const tensorforest::TensorForestDataSpec& input_spec, 40 int32 seed) 41 : dense_data_(nullptr), 42 sparse_indices_(nullptr), 43 sparse_values_(nullptr), 44 input_spec_(input_spec), 45 split_sampling_random_seed_(seed) { 46 int column_count = 0; 47 for (int i = 0; i < input_spec_.dense_size(); ++i) { 48 for (int j = 0; j < input_spec_.dense(i).size(); ++j) { 49 ++column_count; 50 } 51 } 52 available_features_.reserve(column_count); 53 decision_trees::FeatureId id; 54 for (int i = 0; i < column_count; i++) { 55 id.mutable_id()->set_value(strings::StrCat(i)); 56 available_features_.emplace_back(id); 57 } 58 59 // Set up the random number generator. 60 if (split_sampling_random_seed_ == 0) { 61 single_rand_ = std::unique_ptr<random::PhiloxRandom>( 62 new random::PhiloxRandom(random::New64())); 63 } else { 64 single_rand_ = std::unique_ptr<random::PhiloxRandom>( 65 new random::PhiloxRandom(split_sampling_random_seed_)); 66 } 67 68 rng_ = std::unique_ptr<random::SimplePhilox>( 69 new random::SimplePhilox(single_rand_.get())); 70 } ~TensorDataSet()71 virtual ~TensorDataSet() {} 72 73 void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices, 74 const Tensor& sparse_values, 75 const Tensor& sparse_shape); 76 get_input_value(int offset,int col)77 float get_input_value(int offset, int col) { 78 return (*dense_data_)(offset, col); 79 } 80 NumItems()81 int NumItems() const { 82 if (dense_data_ != nullptr) { 83 return dense_data_->dimensions()[0]; 84 } else if (sparse_indices_ != nullptr) { 85 return sparse_batch_size_; 86 } else { 87 return 0; 88 } 89 } 90 91 // This looks up a value by example and int32_id, which is much faster than 92 // GetFeature. 93 float GetExampleValue(int example, 94 const decision_trees::FeatureId& feature_id) const; 95 96 // Same as overload with FeatureId, but if you already have the feature as 97 // an int32 you can avoid the atoi32. 98 virtual float GetExampleValue(int example, int32 feature_id) const; 99 num_features()100 int num_features() { return available_features_.size(); } 101 original_tensor()102 const Tensor& original_tensor() const { return original_dense_tensor_; } 103 104 bool Decide(const decision_trees::BinaryNode& node, int example) const; 105 106 // Randomly samples a feature from example, returns its id in feature_name, 107 // the value in bias, and it's type from input_spec in type. 108 void RandomSample(int example, decision_trees::FeatureId* feature_name, 109 float* bias, int* type) const; 110 111 private: 112 std::unique_ptr<DenseStorageType> dense_data_; 113 std::unique_ptr<SparseIndicesStorageType> sparse_indices_; 114 std::unique_ptr<SparseValuesStorageType> sparse_values_; 115 int sparse_batch_size_; 116 117 Tensor original_dense_tensor_; 118 const tensorforest::TensorForestDataSpec input_spec_; 119 std::vector<decision_trees::FeatureId> available_features_; 120 121 int32 split_sampling_random_seed_; 122 std::unique_ptr<random::PhiloxRandom> single_rand_; 123 std::unique_ptr<random::SimplePhilox> rng_; 124 // Mutex for using random number generator. 125 mutable mutex mu_; 126 }; 127 } // namespace tensorforest 128 } // namespace tensorflow 129 130 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_ 131