1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // ============================================================================= 15 #include <stdlib.h> 16 #include <time.h> 17 #include <algorithm> 18 #include <cmath> 19 #include <memory> 20 #include <unordered_map> 21 #include <unordered_set> 22 #include <utility> 23 #include <vector> 24 25 #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h" 26 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h" 27 #include "tensorflow/core/framework/op.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/shape_inference.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/lib/gtl/top_n.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/work_sharder.h" 34 35 namespace tensorflow { 36 37 using shape_inference::InferenceContext; 38 using shape_inference::ShapeHandle; 39 40 using tensorforest::LeftProbability; 41 42 // This op computes the derivative of the routing loss with respect to each 43 // decision node. 44 REGISTER_OP("StochasticHardRoutingGradient") 45 .Attr("tree_depth: int") 46 .Input("input_data: float") 47 .Input("tree_parameters: float") 48 .Input("tree_biases: float") 49 .Input("path_probability: float") 50 .Input("path: int32") 51 .Output("routing_gradient: float") 52 .Output("data_gradient: float") 53 .Output("parameter_gradient: float") 54 .Output("bias_gradient: float") __anon2a31cb540102(InferenceContext* c) 55 .SetShapeFn([](InferenceContext* c) { 56 ShapeHandle input, params; 57 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input)); 58 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, ¶ms)); 59 60 auto num_points = c->Dim(input, 0); 61 auto num_features = c->Dim(input, 1); 62 auto num_nodes = c->Dim(params, 0); 63 64 c->set_output(0, c->Matrix(num_points, num_nodes)); 65 c->set_output(1, c->Matrix(num_nodes, num_features)); 66 c->set_output(2, c->MakeShape({num_points, num_nodes, num_features})); 67 c->set_output(3, c->Vector(num_nodes)); 68 return Status::OK(); 69 }) 70 .Doc(R"doc( 71 Computes the derivative of the routing loss with respect to each decision 72 node. 73 74 tree_depth: The depth of the decision tree. 75 76 input_data: The training batch's features as a 2-d tensor; `input_data[i][j]` 77 gives the j-th feature of the i-th input 78 tree_parameters: `tree_parameters[i]` gives the weight of 79 the logistic regression model that translates from node features to 80 probabilities. 81 tree_biases: `tree_biases[i]` gives the bias of the logistic 82 regression model that translates from node features to 83 probabilities. 84 path_probability: `path_probability[i]` gives the probability of reaching each 85 node in `path[i]`. 86 path: `path[i][j]` gives the jth node in the path taken by the ith data 87 instance. 88 89 routing_gradient: `routing_gradient` provides du / df, where u is the routing 90 function and f is the (vector of) decision functions. A decision function 91 f_i computes the routing decision at node i. 92 data_gradient: `data_gradient` provides df / dx, where f is the (vector 93 of) decision functions and x is a batch of data. 94 parameter_gradient: `parameter_gradient` provides df / dw, where f is the 95 (vector of) decision functions and w is the matrix of parameters that 96 determine how instances are routed through a tree. 97 bias_gradient: `bias_gradient` provides df / db, where f is the 98 (vector of) decision functions and b is the vector of bias parameters that 99 determine how instances are routed through a tree. 100 101 f_i is parameterized by t_i (parameters) and b_i (bias) and takes data x as 102 input. This op is called in training_ops.py to compute du / df, and we use 103 that to compute 104 105 du / dx = du / df * df / dx, 106 du / dt = du / df * df / dt, and 107 du / db = du / df * df / db. 108 )doc"); 109 110 class StochasticHardRoutingGradient : public OpKernel { 111 public: StochasticHardRoutingGradient(OpKernelConstruction * context)112 explicit StochasticHardRoutingGradient(OpKernelConstruction* context) 113 : OpKernel(context) { 114 OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_)); 115 } 116 Compute(OpKernelContext * context)117 void Compute(OpKernelContext* context) override { 118 VLOG(1) << "stochastic gradient start"; 119 const Tensor& input_data = context->input(0); 120 const Tensor& tree_parameters_tensor = context->input(1); 121 const Tensor& tree_biases_tensor = context->input(2); 122 123 const Tensor& path_probability_tensor = context->input(3); 124 const Tensor& path_tensor = context->input(4); 125 126 const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0)); 127 const int32 num_features = 128 static_cast<int32>(input_data.shape().dim_size(1)); 129 const int32 num_nodes = 130 static_cast<int32>(tree_parameters_tensor.shape().dim_size(0)); 131 132 Tensor* output_routing = nullptr; 133 TensorShape output_routing_shape; 134 output_routing_shape.AddDim(num_data); 135 output_routing_shape.AddDim(num_nodes); 136 137 Tensor* output_data = nullptr; 138 TensorShape output_data_shape; 139 output_data_shape.AddDim(num_nodes); 140 output_data_shape.AddDim(num_features); 141 142 Tensor* output_parameters = nullptr; 143 TensorShape output_parameters_shape; 144 output_parameters_shape.AddDim(num_data); 145 output_parameters_shape.AddDim(num_nodes); 146 output_parameters_shape.AddDim(num_features); 147 148 Tensor* output_bias = nullptr; 149 TensorShape output_bias_shape; 150 output_bias_shape.AddDim(num_data); 151 152 OP_REQUIRES_OK(context, context->allocate_output(0, output_routing_shape, 153 &output_routing)); 154 OP_REQUIRES_OK( 155 context, context->allocate_output(1, output_data_shape, &output_data)); 156 OP_REQUIRES_OK(context, context->allocate_output(2, output_parameters_shape, 157 &output_parameters)); 158 OP_REQUIRES_OK( 159 context, context->allocate_output(3, output_bias_shape, &output_bias)); 160 161 tensorforest::Initialize(*output_routing, 0.0); 162 tensorforest::Initialize(*output_data, 0.0); 163 tensorforest::Initialize(*output_parameters, 0.0); 164 tensorforest::Initialize(*output_bias, 0.0); 165 166 auto out_routing = output_routing->tensor<float, 2>(); 167 auto out_data = output_data->tensor<float, 2>(); 168 auto out_parameters = output_parameters->tensor<float, 3>(); 169 auto out_bias = output_bias->tensor<float, 1>(); 170 171 const auto data = input_data.tensor<float, 2>(); 172 const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>(); 173 const auto tree_biases = tree_biases_tensor.tensor<float, 1>(); 174 const auto path_probability = path_probability_tensor.tensor<float, 2>(); 175 const auto path = path_tensor.tensor<int32, 2>(); 176 177 for (int i = 0; i < num_data; i++) { 178 const Tensor point = input_data.Slice(i, i + 1); 179 180 // Traverses the tree from the bottom up. 181 for (int j = tree_depth_ - 1; j > -1; j--) { 182 int32 node = path(i, j); 183 184 CHECK_LT(node, num_nodes); 185 CHECK_GT(node, -1); 186 187 // Compute data, parameter, and bias gradients. 188 // TODO(atwoodj): Should these be normalized? Loss looks pretty large. 189 for (int k = 0; k < num_features; k++) { 190 out_data(node, k) = tree_parameters(node, k); 191 out_parameters(i, node, k) = out_parameters(i, node, k) + data(i, k); 192 } 193 out_bias(node) = out_bias(node) + 1.0; 194 195 // Compute decision gradient. 196 // node is a leaf 197 if (node >= num_nodes / 2) { 198 CHECK_LT(node, num_nodes); 199 out_routing(i, node) = path_probability(i, j); 200 } else { // node is not a leaf 201 int32 left_child = 2 * j + 1; 202 203 float left_prob = 204 LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1), 205 tree_biases(j), num_features); 206 207 float right_prob = 1 - left_prob; 208 209 CHECK_GT(j - 1, -1); 210 if (path(i, j - 1) == left_child) { 211 CHECK_LT(node, num_nodes); 212 out_routing(i, node) = right_prob * path_probability(i, j - 1); 213 } else { 214 CHECK_LT(node, num_nodes); 215 out_routing(i, node) = left_prob * path_probability(i, j - 1); 216 } 217 } 218 } 219 } 220 VLOG(1) << "stochastic gradient end"; 221 } 222 223 private: 224 int32 tree_depth_; 225 }; 226 227 REGISTER_KERNEL_BUILDER( 228 Name("StochasticHardRoutingGradient").Device(DEVICE_CPU), 229 StochasticHardRoutingGradient); 230 } // namespace tensorflow 231