1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <queue>
16 
17 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
18 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
19 #include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
22 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
23 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/work_sharder.h"
35 
36 namespace tensorflow {
37 namespace tensorforest {
38 
39 using gtl::FindOrNull;
40 
41 // Creates a stats variable.
42 class CreateFertileStatsVariableOp : public OpKernel {
43  public:
CreateFertileStatsVariableOp(OpKernelConstruction * context)44   explicit CreateFertileStatsVariableOp(OpKernelConstruction* context)
45       : OpKernel(context) {
46     string serialized_params;
47     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
48     ParseProtoUnlimited(&param_proto_, serialized_params);
49   }
50 
Compute(OpKernelContext * context)51   void Compute(OpKernelContext* context) override {
52     const Tensor* stats_config_t;
53     OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
54     OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()),
55                 errors::InvalidArgument("Stats config must be a scalar."));
56     auto* result = new FertileStatsResource(param_proto_);
57     FertileStats stats;
58     if (!ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()())) {
59       result->Unref();
60       OP_REQUIRES(context, false,
61                   errors::InvalidArgument("Unable to parse stats config."));
62     }
63 
64     result->ExtractFromProto(stats);
65     result->MaybeInitialize();
66 
67     // Only create one, if one does not exist already. Report status for all
68     // other exceptions.
69     auto status = CreateResource(context, HandleFromInput(context, 0), result);
70     if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
71       OP_REQUIRES(context, false, status);
72     }
73   }
74 
75  private:
76   TensorForestParams param_proto_;
77 };
78 
79 // Op for serializing a model.
80 class FertileStatsSerializeOp : public OpKernel {
81  public:
FertileStatsSerializeOp(OpKernelConstruction * context)82   explicit FertileStatsSerializeOp(OpKernelConstruction* context)
83       : OpKernel(context) {
84     string serialized_params;
85     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
86     ParseProtoUnlimited(&param_proto_, serialized_params);
87   }
88 
Compute(OpKernelContext * context)89   void Compute(OpKernelContext* context) override {
90     FertileStatsResource* fertile_stats_resource;
91     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
92                                            &fertile_stats_resource));
93     mutex_lock l(*fertile_stats_resource->get_mutex());
94     core::ScopedUnref unref_me(fertile_stats_resource);
95     Tensor* output_config_t = nullptr;
96     OP_REQUIRES_OK(
97         context, context->allocate_output(0, TensorShape(), &output_config_t));
98 
99     FertileStats stats;
100     fertile_stats_resource->PackToProto(&stats);
101     output_config_t->scalar<string>()() = stats.SerializeAsString();
102   }
103 
104  private:
105   TensorForestParams param_proto_;
106 };
107 
108 // Op for deserializing a stats variable from a checkpoint.
109 class FertileStatsDeserializeOp : public OpKernel {
110  public:
FertileStatsDeserializeOp(OpKernelConstruction * context)111   explicit FertileStatsDeserializeOp(OpKernelConstruction* context)
112       : OpKernel(context) {
113     string serialized_params;
114     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
115     ParseProtoUnlimited(&param_proto_, serialized_params);
116   }
117 
Compute(OpKernelContext * context)118   void Compute(OpKernelContext* context) override {
119     FertileStatsResource* fertile_stats_resource;
120     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
121                                            &fertile_stats_resource));
122     mutex_lock l(*fertile_stats_resource->get_mutex());
123     core::ScopedUnref unref_me(fertile_stats_resource);
124 
125     const Tensor* stats_config_t;
126     OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
127     OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()),
128                 errors::InvalidArgument("Stats config must be a scalar."));
129     // Deallocate all the previous objects on the resource.
130     fertile_stats_resource->Reset();
131     FertileStats stats;
132     OP_REQUIRES(context,
133                 ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()()),
134                 errors::InvalidArgument("Unable to parse stats config."));
135 
136     fertile_stats_resource->ExtractFromProto(stats);
137     fertile_stats_resource->MaybeInitialize();
138   }
139 
140  private:
141   TensorForestParams param_proto_;
142 };
143 
144 // Try to update a leaf's stats by acquiring its lock.  If it can't be
145 // acquired, put it in a waiting queue to come back to later and try the next
146 // one.  Once all leaf_ids have been visited, cycle through the waiting ids
147 // until they're gone.
UpdateStats(FertileStatsResource * fertile_stats_resource,const std::unique_ptr<TensorDataSet> & data,const TensorInputTarget & target,int num_targets,const Tensor & leaf_ids_tensor,std::unordered_map<int32,std::unique_ptr<mutex>> * locks,mutex * set_lock,int32 start,int32 end,std::unordered_set<int32> * ready_to_split)148 void UpdateStats(FertileStatsResource* fertile_stats_resource,
149                  const std::unique_ptr<TensorDataSet>& data,
150                  const TensorInputTarget& target, int num_targets,
151                  const Tensor& leaf_ids_tensor,
152                  std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
153                  mutex* set_lock, int32 start, int32 end,
154                  std::unordered_set<int32>* ready_to_split) {
155   const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
156 
157   // Stores leaf_id, leaf_depth, example_id for examples that are waiting
158   // on another to finish.
159   std::queue<std::tuple<int32, int32>> waiting;
160 
161   int32 i = start;
162   while (i < end || !waiting.empty()) {
163     int32 leaf_id;
164     int32 example_id;
165     bool was_waiting = false;
166     if (i >= end) {
167       std::tie(leaf_id, example_id) = waiting.front();
168       waiting.pop();
169       was_waiting = true;
170     } else {
171       leaf_id = leaf_ids(i);
172       example_id = i;
173       ++i;
174     }
175     const std::unique_ptr<mutex>& leaf_lock = (*locks)[leaf_id];
176     if (was_waiting) {
177       leaf_lock->lock();
178     } else {
179       if (!leaf_lock->try_lock()) {
180         waiting.emplace(leaf_id, example_id);
181         continue;
182       }
183     }
184 
185     bool is_finished;
186     fertile_stats_resource->AddExampleToStatsAndInitialize(
187         data, &target, {example_id}, leaf_id, &is_finished);
188     leaf_lock->unlock();
189     if (is_finished) {
190       set_lock->lock();
191       ready_to_split->insert(leaf_id);
192       set_lock->unlock();
193     }
194   }
195 }
196 
197 // Update leaves from start through end in the leaf_examples iterator.
UpdateStatsCollated(FertileStatsResource * fertile_stats_resource,DecisionTreeResource * tree_resource,const std::unique_ptr<TensorDataSet> & data,const TensorInputTarget & target,int num_targets,const std::unordered_map<int32,std::vector<int>> & leaf_examples,mutex * set_lock,int32 start,int32 end,std::unordered_set<int32> * ready_to_split)198 void UpdateStatsCollated(
199     FertileStatsResource* fertile_stats_resource,
200     DecisionTreeResource* tree_resource,
201     const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
202     int num_targets,
203     const std::unordered_map<int32, std::vector<int>>& leaf_examples,
204     mutex* set_lock, int32 start, int32 end,
205     std::unordered_set<int32>* ready_to_split) {
206   auto it = leaf_examples.begin();
207   std::advance(it, start);
208   auto end_it = leaf_examples.begin();
209   std::advance(end_it, end);
210   while (it != end_it) {
211     int32 leaf_id = it->first;
212     bool is_finished;
213     fertile_stats_resource->AddExampleToStatsAndInitialize(
214         data, &target, it->second, leaf_id, &is_finished);
215     if (is_finished) {
216       set_lock->lock();
217       ready_to_split->insert(leaf_id);
218       set_lock->unlock();
219     }
220     ++it;
221   }
222 }
223 
224 // Op for traversing the tree with each example, accumulating statistics, and
225 // outputting node ids that are ready to split.
226 class ProcessInputOp : public OpKernel {
227  public:
ProcessInputOp(OpKernelConstruction * context)228   explicit ProcessInputOp(OpKernelConstruction* context) : OpKernel(context) {
229     string serialized_params;
230     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
231     ParseProtoUnlimited(&param_proto_, serialized_params);
232 
233     OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
234 
235     string serialized_proto;
236     OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
237     input_spec_.ParseFromString(serialized_proto);
238   }
239 
Compute(OpKernelContext * context)240   void Compute(OpKernelContext* context) override {
241     const Tensor& input_data = context->input(2);
242     const Tensor& sparse_input_indices = context->input(3);
243     const Tensor& sparse_input_values = context->input(4);
244     const Tensor& sparse_input_shape = context->input(5);
245     const Tensor& input_labels = context->input(6);
246     const Tensor& input_weights = context->input(7);
247     const Tensor& leaf_ids_tensor = context->input(8);
248 
249     std::unique_ptr<TensorDataSet> data_set(
250         new TensorDataSet(input_spec_, random_seed_));
251     data_set->set_input_tensors(input_data, sparse_input_indices,
252                                 sparse_input_values, sparse_input_shape);
253 
254     FertileStatsResource* fertile_stats_resource;
255     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
256                                            &fertile_stats_resource));
257     DecisionTreeResource* tree_resource;
258     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
259                                            &tree_resource));
260     mutex_lock l1(*fertile_stats_resource->get_mutex());
261     mutex_lock l2(*tree_resource->get_mutex());
262 
263     core::ScopedUnref unref_stats(fertile_stats_resource);
264     core::ScopedUnref unref_tree(tree_resource);
265 
266     const int32 num_data = data_set->NumItems();
267     auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
268     int num_threads = worker_threads->num_threads;
269 
270     const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
271 
272     // Create one mutex per leaf. We need to protect access to leaf pointers,
273     // so instead of grouping examples by leaf, we spread examples out among
274     // threads to provide uniform work for each of them and protect access
275     // with mutexes.
276     std::unordered_map<int, std::unique_ptr<mutex>> locks;
277     std::unordered_map<int32, std::vector<int>> leaf_examples;
278     if (param_proto_.collate_examples()) {
279       for (int i = 0; i < num_data; ++i) {
280         leaf_examples[leaf_ids(i)].push_back(i);
281       }
282     } else {
283       for (int i = 0; i < num_data; ++i) {
284         const int32 id = leaf_ids(i);
285         if (FindOrNull(locks, id) == nullptr) {
286           // TODO(gilberth): Consider using a memory pool for these.
287           locks[id] = std::unique_ptr<mutex>(new mutex);
288         }
289       }
290     }
291 
292     const int32 num_leaves = leaf_examples.size();
293     const int32 label_dim =
294         input_labels.shape().dims() <= 1
295             ? 0
296             : static_cast<int>(input_labels.shape().dim_size(1));
297     const int32 num_targets =
298         param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
299 
300     // Ids of leaves that can split.
301     std::unordered_set<int32> ready_to_split;
302     mutex set_lock;
303 
304     TensorInputTarget target(input_labels, input_weights, num_targets);
305 
306     // TODO(gilberth): This is a rough approximation based on measurements
307     // from a digits run on local desktop.  Heuristics might be necessary
308     // if it really matters that much.
309     const int64 costPerUpdate = 1000;
310     auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
311                    fertile_stats_resource, &locks, &set_lock, &ready_to_split,
312                    num_data](int64 start, int64 end) {
313       CHECK(start <= end);
314       CHECK(end <= num_data);
315       UpdateStats(fertile_stats_resource, data_set, target, num_targets,
316                   leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start),
317                   static_cast<int32>(end), &ready_to_split);
318     };
319 
320     auto update_collated = [&target, &num_targets, fertile_stats_resource,
321                             tree_resource, &leaf_examples, &set_lock,
322                             &ready_to_split, &data_set,
323                             num_leaves](int64 start, int64 end) {
324       CHECK(start <= end);
325       CHECK(end <= num_leaves);
326       UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set,
327                           target, num_targets, leaf_examples, &set_lock,
328                           static_cast<int32>(start), static_cast<int32>(end),
329                           &ready_to_split);
330     };
331 
332     if (param_proto_.collate_examples()) {
333       Shard(num_threads, worker_threads->workers, num_leaves, costPerUpdate,
334             update_collated);
335     } else {
336       Shard(num_threads, worker_threads->workers, num_data, costPerUpdate,
337             update);
338     }
339 
340     Tensor* output_finished_t = nullptr;
341     TensorShape output_shape;
342     output_shape.AddDim(ready_to_split.size());
343     OP_REQUIRES_OK(
344         context, context->allocate_output(0, output_shape, &output_finished_t));
345     auto output = output_finished_t->unaligned_flat<int32>();
346     std::copy(ready_to_split.begin(), ready_to_split.end(), output.data());
347   }
348 
349  private:
350   int32 random_seed_;
351   tensorforest::TensorForestDataSpec input_spec_;
352   TensorForestParams param_proto_;
353 };
354 
355 // Op for growing finished nodes.
356 class GrowTreeOp : public OpKernel {
357  public:
GrowTreeOp(OpKernelConstruction * context)358   explicit GrowTreeOp(OpKernelConstruction* context) : OpKernel(context) {
359     string serialized_params;
360     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
361     ParseProtoUnlimited(&param_proto_, serialized_params);
362   }
363 
Compute(OpKernelContext * context)364   void Compute(OpKernelContext* context) override {
365     FertileStatsResource* fertile_stats_resource;
366     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
367                                            &fertile_stats_resource));
368     DecisionTreeResource* tree_resource;
369     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
370                                            &tree_resource));
371     mutex_lock l1(*fertile_stats_resource->get_mutex());
372     mutex_lock l2(*tree_resource->get_mutex());
373 
374     core::ScopedUnref unref_stats(fertile_stats_resource);
375     core::ScopedUnref unref_tree(tree_resource);
376 
377     const Tensor& finished_nodes = context->input(2);
378 
379     const auto finished = finished_nodes.unaligned_flat<int32>();
380 
381     const int32 num_nodes =
382         static_cast<int32>(finished_nodes.shape().dim_size(0));
383 
384     // This op takes so little of the time for one batch that it isn't worth
385     // threading this.
386     for (int i = 0;
387          i < num_nodes &&
388          tree_resource->decision_tree().decision_tree().nodes_size() <
389              param_proto_.max_nodes();
390          ++i) {
391       const int32 node = finished(i);
392       std::unique_ptr<SplitCandidate> best(new SplitCandidate);
393       int32 parent_depth;
394       // TODO(gilberth): Pushing these to an output would allow the complete
395       // decoupling of tree from resource.
396       bool found =
397           fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
398       if (found) {
399         std::vector<int32> new_children;
400         tree_resource->SplitNode(node, best.get(), &new_children);
401         fertile_stats_resource->Allocate(parent_depth, new_children);
402         // We are done with best, so it is now safe to clear node.
403         fertile_stats_resource->Clear(node);
404         CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
405       } else {  // reset
406         fertile_stats_resource->ResetSplitStats(node, parent_depth);
407       }
408     }
409   }
410 
411  private:
412   tensorforest::TensorForestDataSpec input_spec_;
413   TensorForestParams param_proto_;
414 };
415 
FinalizeLeaf(bool is_regression,bool drop_final_class,const std::unique_ptr<LeafModelOperator> & leaf_op,decision_trees::Leaf * leaf)416 void FinalizeLeaf(bool is_regression, bool drop_final_class,
417                   const std::unique_ptr<LeafModelOperator>& leaf_op,
418                   decision_trees::Leaf* leaf) {
419   // regression models are already stored in leaf in normalized form.
420   if (is_regression) {
421     return;
422   }
423 
424   // TODO(gilberth): Calculate the leaf's sum.
425   float sum = 0;
426   LOG(FATAL) << "FinalizeTreeOp is disabled for now.";
427   if (sum <= 0.0) {
428     LOG(WARNING) << "Leaf with sum " << sum << " has stats "
429                  << leaf->ShortDebugString();
430     return;
431   }
432 
433   if (leaf->has_vector()) {
434     for (int i = 0; i < leaf->vector().value_size(); i++) {
435       auto* v = leaf->mutable_vector()->mutable_value(i);
436       v->set_float_value(v->float_value() / sum);
437     }
438     if (drop_final_class) {
439       leaf->mutable_vector()->mutable_value()->RemoveLast();
440     }
441     return;
442   }
443 
444   if (leaf->has_sparse_vector()) {
445     for (auto& it : *leaf->mutable_sparse_vector()->mutable_sparse_value()) {
446       it.second.set_float_value(it.second.float_value() / sum);
447     }
448     return;
449   }
450 
451   LOG(FATAL) << "Unknown leaf type in " << leaf->DebugString();
452 }
453 
454 // Op for finalizing a tree at the end of training.
455 class FinalizeTreeOp : public OpKernel {
456  public:
FinalizeTreeOp(OpKernelConstruction * context)457   explicit FinalizeTreeOp(OpKernelConstruction* context) : OpKernel(context) {
458     string serialized_params;
459     OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
460     ParseProtoUnlimited(&param_proto_, serialized_params);
461 
462     model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
463   }
464 
Compute(OpKernelContext * context)465   void Compute(OpKernelContext* context) override {
466     DecisionTreeResource* tree_resource;
467     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
468                                            &tree_resource));
469     FertileStatsResource* fertile_stats_resource;
470     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
471                                            &fertile_stats_resource));
472 
473     mutex_lock l1(*fertile_stats_resource->get_mutex());
474     mutex_lock l2(*tree_resource->get_mutex());
475 
476     core::ScopedUnref unref_me(tree_resource);
477     core::ScopedUnref unref_stats(fertile_stats_resource);
478 
479     // TODO(thomaswc): Add threads
480     int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
481     for (int i = 0; i < num_nodes; i++) {
482       auto* node = tree_resource->mutable_decision_tree()
483                        ->mutable_decision_tree()
484                        ->mutable_nodes(i);
485       if (node->has_leaf()) {
486         FinalizeLeaf(param_proto_.is_regression(),
487                      param_proto_.drop_final_class(), model_op_,
488                      node->mutable_leaf());
489       }
490     }
491   }
492 
493  private:
494   std::unique_ptr<LeafModelOperator> model_op_;
495   TensorForestParams param_proto_;
496 };
497 
498 REGISTER_RESOURCE_HANDLE_KERNEL(FertileStatsResource);
499 
500 REGISTER_KERNEL_BUILDER(Name("FertileStatsIsInitializedOp").Device(DEVICE_CPU),
501                         IsResourceInitialized<FertileStatsResource>);
502 
503 REGISTER_KERNEL_BUILDER(Name("CreateFertileStatsVariable").Device(DEVICE_CPU),
504                         CreateFertileStatsVariableOp);
505 
506 REGISTER_KERNEL_BUILDER(Name("FertileStatsSerialize").Device(DEVICE_CPU),
507                         FertileStatsSerializeOp);
508 
509 REGISTER_KERNEL_BUILDER(Name("FertileStatsDeserialize").Device(DEVICE_CPU),
510                         FertileStatsDeserializeOp);
511 
512 REGISTER_KERNEL_BUILDER(Name("ProcessInputV4").Device(DEVICE_CPU),
513                         ProcessInputOp);
514 
515 REGISTER_KERNEL_BUILDER(Name("GrowTreeV4").Device(DEVICE_CPU), GrowTreeOp);
516 
517 REGISTER_KERNEL_BUILDER(Name("FinalizeTree").Device(DEVICE_CPU),
518                         FinalizeTreeOp);
519 
520 }  // namespace tensorforest
521 }  // namespace tensorflow
522