1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <queue>
16
17 #include "tensorflow/contrib/tensor_forest/kernels/data_spec.h"
18 #include "tensorflow/contrib/tensor_forest/kernels/v4/decision-tree-resource.h"
19 #include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h"
20 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_data.h"
21 #include "tensorflow/contrib/tensor_forest/kernels/v4/input_target.h"
22 #include "tensorflow/contrib/tensor_forest/kernels/v4/params.h"
23 #include "tensorflow/contrib/tensor_forest/proto/fertile_stats.pb.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_types.h"
29 #include "tensorflow/core/lib/gtl/map_util.h"
30 #include "tensorflow/core/lib/strings/strcat.h"
31 #include "tensorflow/core/platform/mutex.h"
32 #include "tensorflow/core/platform/thread_annotations.h"
33 #include "tensorflow/core/platform/types.h"
34 #include "tensorflow/core/util/work_sharder.h"
35
36 namespace tensorflow {
37 namespace tensorforest {
38
39 using gtl::FindOrNull;
40
41 // Creates a stats variable.
42 class CreateFertileStatsVariableOp : public OpKernel {
43 public:
CreateFertileStatsVariableOp(OpKernelConstruction * context)44 explicit CreateFertileStatsVariableOp(OpKernelConstruction* context)
45 : OpKernel(context) {
46 string serialized_params;
47 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
48 ParseProtoUnlimited(¶m_proto_, serialized_params);
49 }
50
Compute(OpKernelContext * context)51 void Compute(OpKernelContext* context) override {
52 const Tensor* stats_config_t;
53 OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
54 OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()),
55 errors::InvalidArgument("Stats config must be a scalar."));
56 auto* result = new FertileStatsResource(param_proto_);
57 FertileStats stats;
58 if (!ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()())) {
59 result->Unref();
60 OP_REQUIRES(context, false,
61 errors::InvalidArgument("Unable to parse stats config."));
62 }
63
64 result->ExtractFromProto(stats);
65 result->MaybeInitialize();
66
67 // Only create one, if one does not exist already. Report status for all
68 // other exceptions.
69 auto status = CreateResource(context, HandleFromInput(context, 0), result);
70 if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
71 OP_REQUIRES(context, false, status);
72 }
73 }
74
75 private:
76 TensorForestParams param_proto_;
77 };
78
79 // Op for serializing a model.
80 class FertileStatsSerializeOp : public OpKernel {
81 public:
FertileStatsSerializeOp(OpKernelConstruction * context)82 explicit FertileStatsSerializeOp(OpKernelConstruction* context)
83 : OpKernel(context) {
84 string serialized_params;
85 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
86 ParseProtoUnlimited(¶m_proto_, serialized_params);
87 }
88
Compute(OpKernelContext * context)89 void Compute(OpKernelContext* context) override {
90 FertileStatsResource* fertile_stats_resource;
91 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
92 &fertile_stats_resource));
93 mutex_lock l(*fertile_stats_resource->get_mutex());
94 core::ScopedUnref unref_me(fertile_stats_resource);
95 Tensor* output_config_t = nullptr;
96 OP_REQUIRES_OK(
97 context, context->allocate_output(0, TensorShape(), &output_config_t));
98
99 FertileStats stats;
100 fertile_stats_resource->PackToProto(&stats);
101 output_config_t->scalar<string>()() = stats.SerializeAsString();
102 }
103
104 private:
105 TensorForestParams param_proto_;
106 };
107
108 // Op for deserializing a stats variable from a checkpoint.
109 class FertileStatsDeserializeOp : public OpKernel {
110 public:
FertileStatsDeserializeOp(OpKernelConstruction * context)111 explicit FertileStatsDeserializeOp(OpKernelConstruction* context)
112 : OpKernel(context) {
113 string serialized_params;
114 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
115 ParseProtoUnlimited(¶m_proto_, serialized_params);
116 }
117
Compute(OpKernelContext * context)118 void Compute(OpKernelContext* context) override {
119 FertileStatsResource* fertile_stats_resource;
120 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
121 &fertile_stats_resource));
122 mutex_lock l(*fertile_stats_resource->get_mutex());
123 core::ScopedUnref unref_me(fertile_stats_resource);
124
125 const Tensor* stats_config_t;
126 OP_REQUIRES_OK(context, context->input("stats_config", &stats_config_t));
127 OP_REQUIRES(context, TensorShapeUtils::IsScalar(stats_config_t->shape()),
128 errors::InvalidArgument("Stats config must be a scalar."));
129 // Deallocate all the previous objects on the resource.
130 fertile_stats_resource->Reset();
131 FertileStats stats;
132 OP_REQUIRES(context,
133 ParseProtoUnlimited(&stats, stats_config_t->scalar<string>()()),
134 errors::InvalidArgument("Unable to parse stats config."));
135
136 fertile_stats_resource->ExtractFromProto(stats);
137 fertile_stats_resource->MaybeInitialize();
138 }
139
140 private:
141 TensorForestParams param_proto_;
142 };
143
144 // Try to update a leaf's stats by acquiring its lock. If it can't be
145 // acquired, put it in a waiting queue to come back to later and try the next
146 // one. Once all leaf_ids have been visited, cycle through the waiting ids
147 // until they're gone.
UpdateStats(FertileStatsResource * fertile_stats_resource,const std::unique_ptr<TensorDataSet> & data,const TensorInputTarget & target,int num_targets,const Tensor & leaf_ids_tensor,std::unordered_map<int32,std::unique_ptr<mutex>> * locks,mutex * set_lock,int32 start,int32 end,std::unordered_set<int32> * ready_to_split)148 void UpdateStats(FertileStatsResource* fertile_stats_resource,
149 const std::unique_ptr<TensorDataSet>& data,
150 const TensorInputTarget& target, int num_targets,
151 const Tensor& leaf_ids_tensor,
152 std::unordered_map<int32, std::unique_ptr<mutex>>* locks,
153 mutex* set_lock, int32 start, int32 end,
154 std::unordered_set<int32>* ready_to_split) {
155 const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
156
157 // Stores leaf_id, leaf_depth, example_id for examples that are waiting
158 // on another to finish.
159 std::queue<std::tuple<int32, int32>> waiting;
160
161 int32 i = start;
162 while (i < end || !waiting.empty()) {
163 int32 leaf_id;
164 int32 example_id;
165 bool was_waiting = false;
166 if (i >= end) {
167 std::tie(leaf_id, example_id) = waiting.front();
168 waiting.pop();
169 was_waiting = true;
170 } else {
171 leaf_id = leaf_ids(i);
172 example_id = i;
173 ++i;
174 }
175 const std::unique_ptr<mutex>& leaf_lock = (*locks)[leaf_id];
176 if (was_waiting) {
177 leaf_lock->lock();
178 } else {
179 if (!leaf_lock->try_lock()) {
180 waiting.emplace(leaf_id, example_id);
181 continue;
182 }
183 }
184
185 bool is_finished;
186 fertile_stats_resource->AddExampleToStatsAndInitialize(
187 data, &target, {example_id}, leaf_id, &is_finished);
188 leaf_lock->unlock();
189 if (is_finished) {
190 set_lock->lock();
191 ready_to_split->insert(leaf_id);
192 set_lock->unlock();
193 }
194 }
195 }
196
197 // Update leaves from start through end in the leaf_examples iterator.
UpdateStatsCollated(FertileStatsResource * fertile_stats_resource,DecisionTreeResource * tree_resource,const std::unique_ptr<TensorDataSet> & data,const TensorInputTarget & target,int num_targets,const std::unordered_map<int32,std::vector<int>> & leaf_examples,mutex * set_lock,int32 start,int32 end,std::unordered_set<int32> * ready_to_split)198 void UpdateStatsCollated(
199 FertileStatsResource* fertile_stats_resource,
200 DecisionTreeResource* tree_resource,
201 const std::unique_ptr<TensorDataSet>& data, const TensorInputTarget& target,
202 int num_targets,
203 const std::unordered_map<int32, std::vector<int>>& leaf_examples,
204 mutex* set_lock, int32 start, int32 end,
205 std::unordered_set<int32>* ready_to_split) {
206 auto it = leaf_examples.begin();
207 std::advance(it, start);
208 auto end_it = leaf_examples.begin();
209 std::advance(end_it, end);
210 while (it != end_it) {
211 int32 leaf_id = it->first;
212 bool is_finished;
213 fertile_stats_resource->AddExampleToStatsAndInitialize(
214 data, &target, it->second, leaf_id, &is_finished);
215 if (is_finished) {
216 set_lock->lock();
217 ready_to_split->insert(leaf_id);
218 set_lock->unlock();
219 }
220 ++it;
221 }
222 }
223
224 // Op for traversing the tree with each example, accumulating statistics, and
225 // outputting node ids that are ready to split.
226 class ProcessInputOp : public OpKernel {
227 public:
ProcessInputOp(OpKernelConstruction * context)228 explicit ProcessInputOp(OpKernelConstruction* context) : OpKernel(context) {
229 string serialized_params;
230 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
231 ParseProtoUnlimited(¶m_proto_, serialized_params);
232
233 OP_REQUIRES_OK(context, context->GetAttr("random_seed", &random_seed_));
234
235 string serialized_proto;
236 OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto));
237 input_spec_.ParseFromString(serialized_proto);
238 }
239
Compute(OpKernelContext * context)240 void Compute(OpKernelContext* context) override {
241 const Tensor& input_data = context->input(2);
242 const Tensor& sparse_input_indices = context->input(3);
243 const Tensor& sparse_input_values = context->input(4);
244 const Tensor& sparse_input_shape = context->input(5);
245 const Tensor& input_labels = context->input(6);
246 const Tensor& input_weights = context->input(7);
247 const Tensor& leaf_ids_tensor = context->input(8);
248
249 std::unique_ptr<TensorDataSet> data_set(
250 new TensorDataSet(input_spec_, random_seed_));
251 data_set->set_input_tensors(input_data, sparse_input_indices,
252 sparse_input_values, sparse_input_shape);
253
254 FertileStatsResource* fertile_stats_resource;
255 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
256 &fertile_stats_resource));
257 DecisionTreeResource* tree_resource;
258 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
259 &tree_resource));
260 mutex_lock l1(*fertile_stats_resource->get_mutex());
261 mutex_lock l2(*tree_resource->get_mutex());
262
263 core::ScopedUnref unref_stats(fertile_stats_resource);
264 core::ScopedUnref unref_tree(tree_resource);
265
266 const int32 num_data = data_set->NumItems();
267 auto worker_threads = context->device()->tensorflow_cpu_worker_threads();
268 int num_threads = worker_threads->num_threads;
269
270 const auto leaf_ids = leaf_ids_tensor.unaligned_flat<int32>();
271
272 // Create one mutex per leaf. We need to protect access to leaf pointers,
273 // so instead of grouping examples by leaf, we spread examples out among
274 // threads to provide uniform work for each of them and protect access
275 // with mutexes.
276 std::unordered_map<int, std::unique_ptr<mutex>> locks;
277 std::unordered_map<int32, std::vector<int>> leaf_examples;
278 if (param_proto_.collate_examples()) {
279 for (int i = 0; i < num_data; ++i) {
280 leaf_examples[leaf_ids(i)].push_back(i);
281 }
282 } else {
283 for (int i = 0; i < num_data; ++i) {
284 const int32 id = leaf_ids(i);
285 if (FindOrNull(locks, id) == nullptr) {
286 // TODO(gilberth): Consider using a memory pool for these.
287 locks[id] = std::unique_ptr<mutex>(new mutex);
288 }
289 }
290 }
291
292 const int32 num_leaves = leaf_examples.size();
293 const int32 label_dim =
294 input_labels.shape().dims() <= 1
295 ? 0
296 : static_cast<int>(input_labels.shape().dim_size(1));
297 const int32 num_targets =
298 param_proto_.is_regression() ? (std::max(1, label_dim)) : 1;
299
300 // Ids of leaves that can split.
301 std::unordered_set<int32> ready_to_split;
302 mutex set_lock;
303
304 TensorInputTarget target(input_labels, input_weights, num_targets);
305
306 // TODO(gilberth): This is a rough approximation based on measurements
307 // from a digits run on local desktop. Heuristics might be necessary
308 // if it really matters that much.
309 const int64 costPerUpdate = 1000;
310 auto update = [&target, &leaf_ids_tensor, &num_targets, &data_set,
311 fertile_stats_resource, &locks, &set_lock, &ready_to_split,
312 num_data](int64 start, int64 end) {
313 CHECK(start <= end);
314 CHECK(end <= num_data);
315 UpdateStats(fertile_stats_resource, data_set, target, num_targets,
316 leaf_ids_tensor, &locks, &set_lock, static_cast<int32>(start),
317 static_cast<int32>(end), &ready_to_split);
318 };
319
320 auto update_collated = [&target, &num_targets, fertile_stats_resource,
321 tree_resource, &leaf_examples, &set_lock,
322 &ready_to_split, &data_set,
323 num_leaves](int64 start, int64 end) {
324 CHECK(start <= end);
325 CHECK(end <= num_leaves);
326 UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set,
327 target, num_targets, leaf_examples, &set_lock,
328 static_cast<int32>(start), static_cast<int32>(end),
329 &ready_to_split);
330 };
331
332 if (param_proto_.collate_examples()) {
333 Shard(num_threads, worker_threads->workers, num_leaves, costPerUpdate,
334 update_collated);
335 } else {
336 Shard(num_threads, worker_threads->workers, num_data, costPerUpdate,
337 update);
338 }
339
340 Tensor* output_finished_t = nullptr;
341 TensorShape output_shape;
342 output_shape.AddDim(ready_to_split.size());
343 OP_REQUIRES_OK(
344 context, context->allocate_output(0, output_shape, &output_finished_t));
345 auto output = output_finished_t->unaligned_flat<int32>();
346 std::copy(ready_to_split.begin(), ready_to_split.end(), output.data());
347 }
348
349 private:
350 int32 random_seed_;
351 tensorforest::TensorForestDataSpec input_spec_;
352 TensorForestParams param_proto_;
353 };
354
355 // Op for growing finished nodes.
356 class GrowTreeOp : public OpKernel {
357 public:
GrowTreeOp(OpKernelConstruction * context)358 explicit GrowTreeOp(OpKernelConstruction* context) : OpKernel(context) {
359 string serialized_params;
360 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
361 ParseProtoUnlimited(¶m_proto_, serialized_params);
362 }
363
Compute(OpKernelContext * context)364 void Compute(OpKernelContext* context) override {
365 FertileStatsResource* fertile_stats_resource;
366 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
367 &fertile_stats_resource));
368 DecisionTreeResource* tree_resource;
369 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
370 &tree_resource));
371 mutex_lock l1(*fertile_stats_resource->get_mutex());
372 mutex_lock l2(*tree_resource->get_mutex());
373
374 core::ScopedUnref unref_stats(fertile_stats_resource);
375 core::ScopedUnref unref_tree(tree_resource);
376
377 const Tensor& finished_nodes = context->input(2);
378
379 const auto finished = finished_nodes.unaligned_flat<int32>();
380
381 const int32 num_nodes =
382 static_cast<int32>(finished_nodes.shape().dim_size(0));
383
384 // This op takes so little of the time for one batch that it isn't worth
385 // threading this.
386 for (int i = 0;
387 i < num_nodes &&
388 tree_resource->decision_tree().decision_tree().nodes_size() <
389 param_proto_.max_nodes();
390 ++i) {
391 const int32 node = finished(i);
392 std::unique_ptr<SplitCandidate> best(new SplitCandidate);
393 int32 parent_depth;
394 // TODO(gilberth): Pushing these to an output would allow the complete
395 // decoupling of tree from resource.
396 bool found =
397 fertile_stats_resource->BestSplit(node, best.get(), &parent_depth);
398 if (found) {
399 std::vector<int32> new_children;
400 tree_resource->SplitNode(node, best.get(), &new_children);
401 fertile_stats_resource->Allocate(parent_depth, new_children);
402 // We are done with best, so it is now safe to clear node.
403 fertile_stats_resource->Clear(node);
404 CHECK(tree_resource->get_mutable_tree_node(node)->has_leaf() == false);
405 } else { // reset
406 fertile_stats_resource->ResetSplitStats(node, parent_depth);
407 }
408 }
409 }
410
411 private:
412 tensorforest::TensorForestDataSpec input_spec_;
413 TensorForestParams param_proto_;
414 };
415
FinalizeLeaf(bool is_regression,bool drop_final_class,const std::unique_ptr<LeafModelOperator> & leaf_op,decision_trees::Leaf * leaf)416 void FinalizeLeaf(bool is_regression, bool drop_final_class,
417 const std::unique_ptr<LeafModelOperator>& leaf_op,
418 decision_trees::Leaf* leaf) {
419 // regression models are already stored in leaf in normalized form.
420 if (is_regression) {
421 return;
422 }
423
424 // TODO(gilberth): Calculate the leaf's sum.
425 float sum = 0;
426 LOG(FATAL) << "FinalizeTreeOp is disabled for now.";
427 if (sum <= 0.0) {
428 LOG(WARNING) << "Leaf with sum " << sum << " has stats "
429 << leaf->ShortDebugString();
430 return;
431 }
432
433 if (leaf->has_vector()) {
434 for (int i = 0; i < leaf->vector().value_size(); i++) {
435 auto* v = leaf->mutable_vector()->mutable_value(i);
436 v->set_float_value(v->float_value() / sum);
437 }
438 if (drop_final_class) {
439 leaf->mutable_vector()->mutable_value()->RemoveLast();
440 }
441 return;
442 }
443
444 if (leaf->has_sparse_vector()) {
445 for (auto& it : *leaf->mutable_sparse_vector()->mutable_sparse_value()) {
446 it.second.set_float_value(it.second.float_value() / sum);
447 }
448 return;
449 }
450
451 LOG(FATAL) << "Unknown leaf type in " << leaf->DebugString();
452 }
453
454 // Op for finalizing a tree at the end of training.
455 class FinalizeTreeOp : public OpKernel {
456 public:
FinalizeTreeOp(OpKernelConstruction * context)457 explicit FinalizeTreeOp(OpKernelConstruction* context) : OpKernel(context) {
458 string serialized_params;
459 OP_REQUIRES_OK(context, context->GetAttr("params", &serialized_params));
460 ParseProtoUnlimited(¶m_proto_, serialized_params);
461
462 model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_);
463 }
464
Compute(OpKernelContext * context)465 void Compute(OpKernelContext* context) override {
466 DecisionTreeResource* tree_resource;
467 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
468 &tree_resource));
469 FertileStatsResource* fertile_stats_resource;
470 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1),
471 &fertile_stats_resource));
472
473 mutex_lock l1(*fertile_stats_resource->get_mutex());
474 mutex_lock l2(*tree_resource->get_mutex());
475
476 core::ScopedUnref unref_me(tree_resource);
477 core::ScopedUnref unref_stats(fertile_stats_resource);
478
479 // TODO(thomaswc): Add threads
480 int num_nodes = tree_resource->decision_tree().decision_tree().nodes_size();
481 for (int i = 0; i < num_nodes; i++) {
482 auto* node = tree_resource->mutable_decision_tree()
483 ->mutable_decision_tree()
484 ->mutable_nodes(i);
485 if (node->has_leaf()) {
486 FinalizeLeaf(param_proto_.is_regression(),
487 param_proto_.drop_final_class(), model_op_,
488 node->mutable_leaf());
489 }
490 }
491 }
492
493 private:
494 std::unique_ptr<LeafModelOperator> model_op_;
495 TensorForestParams param_proto_;
496 };
497
498 REGISTER_RESOURCE_HANDLE_KERNEL(FertileStatsResource);
499
500 REGISTER_KERNEL_BUILDER(Name("FertileStatsIsInitializedOp").Device(DEVICE_CPU),
501 IsResourceInitialized<FertileStatsResource>);
502
503 REGISTER_KERNEL_BUILDER(Name("CreateFertileStatsVariable").Device(DEVICE_CPU),
504 CreateFertileStatsVariableOp);
505
506 REGISTER_KERNEL_BUILDER(Name("FertileStatsSerialize").Device(DEVICE_CPU),
507 FertileStatsSerializeOp);
508
509 REGISTER_KERNEL_BUILDER(Name("FertileStatsDeserialize").Device(DEVICE_CPU),
510 FertileStatsDeserializeOp);
511
512 REGISTER_KERNEL_BUILDER(Name("ProcessInputV4").Device(DEVICE_CPU),
513 ProcessInputOp);
514
515 REGISTER_KERNEL_BUILDER(Name("GrowTreeV4").Device(DEVICE_CPU), GrowTreeOp);
516
517 REGISTER_KERNEL_BUILDER(Name("FinalizeTree").Device(DEVICE_CPU),
518 FinalizeTreeOp);
519
520 } // namespace tensorforest
521 } // namespace tensorflow
522