1 // Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <stdlib.h>
16 #include <time.h>
17 #include <algorithm>
18 #include <cmath>
19 #include <memory>
20 #include <unordered_map>
21 #include <unordered_set>
22 #include <utility>
23 #include <vector>
24 
25 #include "tensorflow/contrib/tensor_forest/hybrid/core/ops/utils.h"
26 #include "tensorflow/contrib/tensor_forest/kernels/tree_utils.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/shape_inference.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/lib/gtl/top_n.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/work_sharder.h"
34 
35 namespace tensorflow {
36 
37 using shape_inference::InferenceContext;
38 using shape_inference::ShapeHandle;
39 
40 using tensorforest::LeftProbability;
41 
42 // This op computes the derivative of the routing loss with respect to each
43 // decision node.
44 REGISTER_OP("StochasticHardRoutingGradient")
45     .Attr("tree_depth: int")
46     .Input("input_data: float")
47     .Input("tree_parameters: float")
48     .Input("tree_biases: float")
49     .Input("path_probability: float")
50     .Input("path: int32")
51     .Output("routing_gradient: float")
52     .Output("data_gradient: float")
53     .Output("parameter_gradient: float")
54     .Output("bias_gradient: float")
__anon2a31cb540102(InferenceContext* c) 55     .SetShapeFn([](InferenceContext* c) {
56       ShapeHandle input, params;
57       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 2, &input));
58       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &params));
59 
60       auto num_points = c->Dim(input, 0);
61       auto num_features = c->Dim(input, 1);
62       auto num_nodes = c->Dim(params, 0);
63 
64       c->set_output(0, c->Matrix(num_points, num_nodes));
65       c->set_output(1, c->Matrix(num_nodes, num_features));
66       c->set_output(2, c->MakeShape({num_points, num_nodes, num_features}));
67       c->set_output(3, c->Vector(num_nodes));
68       return Status::OK();
69     })
70     .Doc(R"doc(
71   Computes the derivative of the routing loss with respect to each decision
72   node.
73 
74   tree_depth: The depth of the decision tree.
75 
76   input_data: The training batch's features as a 2-d tensor; `input_data[i][j]`
77    gives the j-th feature of the i-th input
78   tree_parameters: `tree_parameters[i]` gives the weight of
79    the logistic regression model that translates from node features to
80    probabilities.
81   tree_biases: `tree_biases[i]` gives the bias of the logistic
82    regression model that translates from node features to
83    probabilities.
84   path_probability: `path_probability[i]` gives the probability of reaching each
85    node in `path[i]`.
86   path: `path[i][j]` gives the jth node in the path taken by the ith data
87    instance.
88 
89   routing_gradient: `routing_gradient` provides du / df, where u is the routing
90    function and f is the (vector of) decision functions.  A decision function
91    f_i computes the routing decision at node i.
92   data_gradient: `data_gradient` provides df / dx, where f is the (vector
93    of) decision functions and x is a batch of data.
94   parameter_gradient: `parameter_gradient` provides df / dw, where f is the
95    (vector of) decision functions and w is the matrix of parameters that
96    determine how instances are routed through a tree.
97   bias_gradient: `bias_gradient` provides df / db, where f is the
98    (vector of) decision functions and b is the vector of bias parameters that
99    determine how instances are routed through a tree.
100 
101   f_i is parameterized by t_i (parameters) and b_i (bias) and takes data x as
102   input.  This op is called in training_ops.py to compute du / df, and we use
103   that to compute
104 
105      du / dx = du / df * df / dx,
106      du / dt = du / df * df / dt, and
107      du / db = du / df * df / db.
108 )doc");
109 
110 class StochasticHardRoutingGradient : public OpKernel {
111  public:
StochasticHardRoutingGradient(OpKernelConstruction * context)112   explicit StochasticHardRoutingGradient(OpKernelConstruction* context)
113       : OpKernel(context) {
114     OP_REQUIRES_OK(context, context->GetAttr("tree_depth", &tree_depth_));
115   }
116 
Compute(OpKernelContext * context)117   void Compute(OpKernelContext* context) override {
118     VLOG(1) << "stochastic gradient start";
119     const Tensor& input_data = context->input(0);
120     const Tensor& tree_parameters_tensor = context->input(1);
121     const Tensor& tree_biases_tensor = context->input(2);
122 
123     const Tensor& path_probability_tensor = context->input(3);
124     const Tensor& path_tensor = context->input(4);
125 
126     const int32 num_data = static_cast<int32>(input_data.shape().dim_size(0));
127     const int32 num_features =
128         static_cast<int32>(input_data.shape().dim_size(1));
129     const int32 num_nodes =
130         static_cast<int32>(tree_parameters_tensor.shape().dim_size(0));
131 
132     Tensor* output_routing = nullptr;
133     TensorShape output_routing_shape;
134     output_routing_shape.AddDim(num_data);
135     output_routing_shape.AddDim(num_nodes);
136 
137     Tensor* output_data = nullptr;
138     TensorShape output_data_shape;
139     output_data_shape.AddDim(num_nodes);
140     output_data_shape.AddDim(num_features);
141 
142     Tensor* output_parameters = nullptr;
143     TensorShape output_parameters_shape;
144     output_parameters_shape.AddDim(num_data);
145     output_parameters_shape.AddDim(num_nodes);
146     output_parameters_shape.AddDim(num_features);
147 
148     Tensor* output_bias = nullptr;
149     TensorShape output_bias_shape;
150     output_bias_shape.AddDim(num_data);
151 
152     OP_REQUIRES_OK(context, context->allocate_output(0, output_routing_shape,
153                                                      &output_routing));
154     OP_REQUIRES_OK(
155         context, context->allocate_output(1, output_data_shape, &output_data));
156     OP_REQUIRES_OK(context, context->allocate_output(2, output_parameters_shape,
157                                                      &output_parameters));
158     OP_REQUIRES_OK(
159         context, context->allocate_output(3, output_bias_shape, &output_bias));
160 
161     tensorforest::Initialize(*output_routing, 0.0);
162     tensorforest::Initialize(*output_data, 0.0);
163     tensorforest::Initialize(*output_parameters, 0.0);
164     tensorforest::Initialize(*output_bias, 0.0);
165 
166     auto out_routing = output_routing->tensor<float, 2>();
167     auto out_data = output_data->tensor<float, 2>();
168     auto out_parameters = output_parameters->tensor<float, 3>();
169     auto out_bias = output_bias->tensor<float, 1>();
170 
171     const auto data = input_data.tensor<float, 2>();
172     const auto tree_parameters = tree_parameters_tensor.tensor<float, 2>();
173     const auto tree_biases = tree_biases_tensor.tensor<float, 1>();
174     const auto path_probability = path_probability_tensor.tensor<float, 2>();
175     const auto path = path_tensor.tensor<int32, 2>();
176 
177     for (int i = 0; i < num_data; i++) {
178       const Tensor point = input_data.Slice(i, i + 1);
179 
180       // Traverses the tree from the bottom up.
181       for (int j = tree_depth_ - 1; j > -1; j--) {
182         int32 node = path(i, j);
183 
184         CHECK_LT(node, num_nodes);
185         CHECK_GT(node, -1);
186 
187         // Compute data, parameter, and bias gradients.
188         // TODO(atwoodj): Should these be normalized?  Loss looks pretty large.
189         for (int k = 0; k < num_features; k++) {
190           out_data(node, k) = tree_parameters(node, k);
191           out_parameters(i, node, k) = out_parameters(i, node, k) + data(i, k);
192         }
193         out_bias(node) = out_bias(node) + 1.0;
194 
195         // Compute decision gradient.
196         // node is a leaf
197         if (node >= num_nodes / 2) {
198           CHECK_LT(node, num_nodes);
199           out_routing(i, node) = path_probability(i, j);
200         } else {  // node is not a leaf
201           int32 left_child = 2 * j + 1;
202 
203           float left_prob =
204               LeftProbability(point, tree_parameters_tensor.Slice(j, j + 1),
205                               tree_biases(j), num_features);
206 
207           float right_prob = 1 - left_prob;
208 
209           CHECK_GT(j - 1, -1);
210           if (path(i, j - 1) == left_child) {
211             CHECK_LT(node, num_nodes);
212             out_routing(i, node) = right_prob * path_probability(i, j - 1);
213           } else {
214             CHECK_LT(node, num_nodes);
215             out_routing(i, node) = left_prob * path_probability(i, j - 1);
216           }
217         }
218       }
219     }
220     VLOG(1) << "stochastic gradient end";
221   }
222 
223  private:
224   int32 tree_depth_;
225 };
226 
227 REGISTER_KERNEL_BUILDER(
228     Name("StochasticHardRoutingGradient").Device(DEVICE_CPU),
229     StochasticHardRoutingGradient);
230 }  // namespace tensorflow
231