1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
17 #include <ctime>
18 #include <unordered_map>
19 #include "google/protobuf/any.pb.h"
20 #include "google/protobuf/wrappers.pb.h"
21 #include "tensorflow/contrib/decision_trees/proto/generic_tree_model.pb.h"
22 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
23 #include "tensorflow/core/framework/tensor.h"
24 #include "tensorflow/core/framework/tensor_types.h"
25 #include "tensorflow/core/lib/random/philox_random.h"
26 #include "tensorflow/core/lib/random/random.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 #include "tensorflow/core/platform/mutex.h"
29 
30 namespace tensorflow {
31 namespace tensorforest {
32 
33 typedef TTypes<const float, 2>::ConstTensor DenseStorageType;
34 typedef TTypes<const int64, 2>::ConstTensor SparseIndicesStorageType;
35 typedef TTypes<const float, 1>::ConstTensor SparseValuesStorageType;
36 
37 class TensorDataSet {
38  public:
TensorDataSet(const tensorforest::TensorForestDataSpec & input_spec,int32 seed)39   TensorDataSet(const tensorforest::TensorForestDataSpec& input_spec,
40                 int32 seed)
41       : dense_data_(nullptr),
42         sparse_indices_(nullptr),
43         sparse_values_(nullptr),
44         input_spec_(input_spec),
45         split_sampling_random_seed_(seed) {
46     int column_count = 0;
47     for (int i = 0; i < input_spec_.dense_size(); ++i) {
48       for (int j = 0; j < input_spec_.dense(i).size(); ++j) {
49         ++column_count;
50       }
51     }
52     available_features_.reserve(column_count);
53     decision_trees::FeatureId id;
54     for (int i = 0; i < column_count; i++) {
55       id.mutable_id()->set_value(strings::StrCat(i));
56       available_features_.emplace_back(id);
57     }
58 
59     // Set up the random number generator.
60     if (split_sampling_random_seed_ == 0) {
61       single_rand_ = std::unique_ptr<random::PhiloxRandom>(
62           new random::PhiloxRandom(random::New64()));
63     } else {
64       single_rand_ = std::unique_ptr<random::PhiloxRandom>(
65           new random::PhiloxRandom(split_sampling_random_seed_));
66     }
67 
68     rng_ = std::unique_ptr<random::SimplePhilox>(
69         new random::SimplePhilox(single_rand_.get()));
70   }
~TensorDataSet()71   virtual ~TensorDataSet() {}
72 
73   void set_input_tensors(const Tensor& dense, const Tensor& sparse_indices,
74                          const Tensor& sparse_values,
75                          const Tensor& sparse_shape);
76 
get_input_value(int offset,int col)77   float get_input_value(int offset, int col) {
78     return (*dense_data_)(offset, col);
79   }
80 
NumItems()81   int NumItems() const {
82     if (dense_data_ != nullptr) {
83       return dense_data_->dimensions()[0];
84     } else if (sparse_indices_ != nullptr) {
85       return sparse_batch_size_;
86     } else {
87       return 0;
88     }
89   }
90 
91   // This looks up a value by example and int32_id, which is much faster than
92   // GetFeature.
93   float GetExampleValue(int example,
94                         const decision_trees::FeatureId& feature_id) const;
95 
96   // Same as overload with FeatureId, but if you already have the feature as
97   // an int32 you can avoid the atoi32.
98   virtual float GetExampleValue(int example, int32 feature_id) const;
99 
num_features()100   int num_features() { return available_features_.size(); }
101 
original_tensor()102   const Tensor& original_tensor() const { return original_dense_tensor_; }
103 
104   bool Decide(const decision_trees::BinaryNode& node, int example) const;
105 
106   // Randomly samples a feature from example, returns its id in feature_name,
107   // the value in bias, and it's type from input_spec in type.
108   void RandomSample(int example, decision_trees::FeatureId* feature_name,
109                     float* bias, int* type) const;
110 
111  private:
112   std::unique_ptr<DenseStorageType> dense_data_;
113   std::unique_ptr<SparseIndicesStorageType> sparse_indices_;
114   std::unique_ptr<SparseValuesStorageType> sparse_values_;
115   int sparse_batch_size_;
116 
117   Tensor original_dense_tensor_;
118   const tensorforest::TensorForestDataSpec input_spec_;
119   std::vector<decision_trees::FeatureId> available_features_;
120 
121   int32 split_sampling_random_seed_;
122   std::unique_ptr<random::PhiloxRandom> single_rand_;
123   std::unique_ptr<random::SimplePhilox> rng_;
124   // Mutex for using random number generator.
125   mutable mutex mu_;
126 };
127 }  // namespace tensorforest
128 }  // namespace tensorflow
129 
130 #endif  // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_DATA_H_
131