1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #ifndef TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ 16 #define TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ 17 #include "tensorflow/core/framework/tensor.h" 18 #include "tensorflow/core/framework/tensor_types.h" 19 20 namespace tensorflow { 21 namespace tensorforest { 22 23 typedef TTypes<float, 1>::UnalignedConstTensor SingleDimStorageType; 24 25 // Base class for classes that hold labels and weights. Mostly for testing 26 // purposes, because it's inconvenient to construct nasty Eigen::things. 27 class InputTarget { 28 public: ~InputTarget()29 virtual ~InputTarget() {} 30 virtual int32 GetTargetAsClassIndex(int example_index, 31 int target_index) const = 0; 32 33 virtual float GetTargetWeight(int example_index) const = 0; 34 35 virtual float GetTargetAsContinuous(int example_index, 36 int target_index) const = 0; 37 }; 38 39 template <typename T> 40 class StoredInputTarget : public InputTarget { 41 protected: 42 // Takes ownership of t and w with a std::unique_ptr. StoredInputTarget(const T * t,const T * w,int num_targets)43 StoredInputTarget(const T* t, const T* w, int num_targets) 44 : target_(t), weight_(w), num_targets_(num_targets) {} 45 46 const std::unique_ptr<const T> target_; 47 const std::unique_ptr<const T> weight_; 48 int num_targets_; 49 }; 50 51 // Holds labels/targets and weights. Assumes that tensors are passed as 52 // t.unaligned_flat<float>(). For multi-output, specifying the number of 53 // outputs will correctly index the flattened data. 54 class TensorInputTarget : public StoredInputTarget<SingleDimStorageType> { 55 public: TensorInputTarget(const Tensor & target,const Tensor & weight,int num_targets)56 TensorInputTarget(const Tensor& target, const Tensor& weight, int num_targets) 57 : StoredInputTarget( 58 new SingleDimStorageType(target.unaligned_flat<float>()), 59 new SingleDimStorageType(weight.unaligned_flat<float>()), 60 num_targets), 61 original_tensor_(target) {} 62 GetTargetAsClassIndex(int example_index,int target_index)63 int32 GetTargetAsClassIndex(int example_index, 64 int target_index) const override { 65 return static_cast<int32>( 66 GetTargetAsContinuous(example_index, target_index)); 67 } 68 GetTargetWeight(int example_index)69 float GetTargetWeight(int example_index) const override { 70 const size_t num_weights = weight_->size(); 71 return num_weights > 0 && example_index < num_weights 72 ? (*weight_)(example_index) 73 : 1.0; 74 } 75 GetTargetAsContinuous(int example_index,int target_index)76 float GetTargetAsContinuous(int example_index, 77 int target_index) const override { 78 QCHECK_LT(target_index, num_targets_); 79 return (*target_)(example_index * num_targets_ + target_index); 80 } 81 original_tensor()82 const Tensor& original_tensor() const { return original_tensor_; } 83 84 protected: 85 Tensor original_tensor_; 86 }; 87 } // namespace tensorforest 88 } // namespace tensorflow 89 90 #endif // TENSORFLOW_CONTRIB_TENSOR_FOREST_KERNELS_V4_INPUT_TARGET_H_ 91