1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include <vector>
16 
17 #include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h"
18 #include "tensorflow/contrib/boosted_trees/proto/learner.pb.h"
19 #include "tensorflow/contrib/boosted_trees/proto/split_info.pb.h"
20 #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h"
21 #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h"
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 
25 namespace tensorflow {
26 using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig;
27 
28 namespace boosted_trees {
29 
30 namespace {
31 
32 using boosted_trees::learner::LearnerConfig;
33 using boosted_trees::learner::LearningRateConfig;
34 using boosted_trees::trees::Leaf;
35 using boosted_trees::trees::TreeNode;
36 using boosted_trees::trees::TreeNodeMetadata;
37 using boosted_trees::utils::DropoutUtils;
38 
39 // SplitCandidate holds the split candidate node along with the stats.
40 struct SplitCandidate {
41   // Id of handler that generated the split candidate.
42   int64 handler_id;
43 
44   // Split gain.
45   float gain;
46 
47   // Split info.
48   learner::SplitInfo split_info;
49 
50   // Oblivious split info.
51   learner::ObliviousSplitInfo oblivious_split_info;
52 };
53 
54 // Checks that the leaf is not empty.
IsLeafWellFormed(const Leaf & leaf)55 bool IsLeafWellFormed(const Leaf& leaf) {
56   return leaf.has_sparse_vector() || leaf.has_vector();
57 }
58 
59 // Helper method to update the best split per partition given
60 // a current candidate.
UpdateBestSplit(const boosted_trees::learner::LearnerConfig & learner_config,int32 partition_id,SplitCandidate * split,std::map<int32,SplitCandidate> * best_splits)61 void UpdateBestSplit(
62     const boosted_trees::learner::LearnerConfig& learner_config,
63     int32 partition_id, SplitCandidate* split,
64     std::map<int32, SplitCandidate>* best_splits) {
65   // Don't consider nodeless splits.
66   if (TF_PREDICT_FALSE(split->split_info.split_node().node_case() ==
67                        TreeNode::NODE_NOT_SET)) {
68     return;
69   }
70 
71   // Don't consider negative splits if we're pre-pruning the tree.
72   // Note that zero-gain splits are acceptable as they're mostly doing as well
73   // as what bias centering in that partition would do.
74   if (learner_config.pruning_mode() ==
75           boosted_trees::learner::LearnerConfig::PRE_PRUNE &&
76       split->gain < 0) {
77     return;
78   }
79 
80   // If the current node is pure, one of the leafs will be empty, so the split
81   // is meaningless and we should not split.
82   if (!(IsLeafWellFormed(split->split_info.right_child()) &&
83         IsLeafWellFormed(split->split_info.left_child()))) {
84     VLOG(1) << "Split does not actually split anything";
85     return;
86   }
87 
88   // Take the split if we don't have a candidate yet.
89   auto best_split_it = best_splits->find(partition_id);
90   if (best_split_it == best_splits->end()) {
91     best_splits->insert(std::make_pair(partition_id, std::move(*split)));
92     return;
93   }
94 
95   // Determine if best split so far needs to be replaced.
96   SplitCandidate& best_split = best_split_it->second;
97   if (TF_PREDICT_FALSE(split->gain == best_split.gain)) {
98     // Tie break on node case preferring simpler tree node types.
99     VLOG(2) << "Attempting to tie break with smaller node case. "
100             << "(current split: " << split->split_info.split_node().node_case()
101             << ", best split: "
102             << best_split.split_info.split_node().node_case() << ")";
103     if (split->split_info.split_node().node_case() <
104         best_split.split_info.split_node().node_case()) {
105       best_split = std::move(*split);
106     } else if (split->split_info.split_node().node_case() ==
107                best_split.split_info.split_node().node_case()) {
108       // Tie break on handler Id.
109       VLOG(2) << "Tie breaking with higher handler Id. "
110               << "(current split: " << split->handler_id
111               << ", best split: " << best_split.handler_id << ")";
112       if (split->handler_id > best_split.handler_id) {
113         best_split = std::move(*split);
114       }
115     }
116   } else if (split->gain > best_split.gain) {
117     best_split = std::move(*split);
118   }
119 }
120 
121 // Helper method to check whether a node is a terminal node in that it
122 // only has leaf nodes as children.
IsTerminalSplitNode(const size_t node_id,const std::vector<int32> & children,const std::vector<TreeNode> & nodes)123 bool IsTerminalSplitNode(const size_t node_id,
124                          const std::vector<int32>& children,
125                          const std::vector<TreeNode>& nodes) {
126   for (const int32 child_id : children) {
127     const auto& child_node = nodes[child_id];
128     CHECK(child_node.node_case() != TreeNode::NODE_NOT_SET);
129     if (child_node.node_case() != TreeNode::kLeaf) {
130       return false;
131     }
132   }
133   return true;
134 }
135 
136 // Helper method to recursively prune the tree in a depth-first fashion.
RecursivePruneTree(const size_t node_id,std::vector<TreeNode> * nodes)137 void RecursivePruneTree(const size_t node_id, std::vector<TreeNode>* nodes) {
138   // Base case when we reach a leaf.
139   TreeNode& tree_node = (*nodes)[node_id];
140   CHECK(tree_node.node_case() != TreeNode::NODE_NOT_SET);
141   if (tree_node.node_case() == TreeNode::kLeaf) {
142     return;
143   }
144 
145   // Traverse node children first and recursively prune their sub-trees.
146   const std::vector<int32> children =
147       boosted_trees::trees::DecisionTree::GetChildren(tree_node);
148   for (const int32 child_id : children) {
149     RecursivePruneTree(child_id, nodes);
150   }
151 
152   // Two conditions must be satisfied to prune the node:
153   // 1- The split gain is negative.
154   // 2- After depth-first pruning, the node only has leaf children.
155   TreeNodeMetadata* node_metadata = tree_node.mutable_node_metadata();
156   if (node_metadata->gain() < 0 &&
157       IsTerminalSplitNode(node_id, children, (*nodes))) {
158     // Clear node children.
159     for (const int32 child_id : children) {
160       auto& child_node = (*nodes)[child_id];
161       child_node.Clear();
162     }
163 
164     // Change node back into leaf.
165     (*tree_node.mutable_leaf()) = *node_metadata->mutable_original_leaf();
166 
167     // Clear gain for leaf node.
168     tree_node.clear_node_metadata();
169   } else {
170     // Clear original leaf as it's no longer needed for back-track pruning.
171     node_metadata->clear_original_leaf();
172   }
173 }
174 
175 }  // namespace
176 
177 class CenterTreeEnsembleBiasOp : public OpKernel {
178  public:
CenterTreeEnsembleBiasOp(OpKernelConstruction * const context)179   explicit CenterTreeEnsembleBiasOp(OpKernelConstruction* const context)
180       : OpKernel(context) {
181     // Read learner config.
182     string serialized_learner_config;
183     OP_REQUIRES_OK(context, context->GetAttr("learner_config",
184                                              &serialized_learner_config));
185     OP_REQUIRES(context,
186                 learner_config_.ParseFromString(serialized_learner_config),
187                 errors::InvalidArgument("Unable to parse learner config."));
188 
189     // Read centering epsilon.
190     OP_REQUIRES_OK(context,
191                    context->GetAttr("centering_epsilon", &centering_epsilon_));
192   }
193 
Compute(OpKernelContext * const context)194   void Compute(OpKernelContext* const context) override {
195     // Get decision tree ensemble.
196     boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
197     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
198                                            &ensemble_resource));
199     core::ScopedUnref unref_me(ensemble_resource);
200     mutex_lock l(*ensemble_resource->get_mutex());
201 
202     // Get the stamp token.
203     const Tensor* stamp_token_t;
204     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
205     int64 stamp_token = stamp_token_t->scalar<int64>()();
206 
207     // Only the Chief should run this Op and it is guaranteed to be in
208     // a consistent state so the stamps must always match.
209     CHECK(ensemble_resource->is_stamp_valid(stamp_token));
210 
211     // Get the next stamp token.
212     const Tensor* next_stamp_token_t;
213     OP_REQUIRES_OK(context,
214                    context->input("next_stamp_token", &next_stamp_token_t));
215     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
216     CHECK(stamp_token != next_stamp_token);
217 
218     // Update the ensemble stamp.
219     ensemble_resource->set_stamp(next_stamp_token);
220 
221     // Get the delta updates.
222     const Tensor* delta_updates_t;
223     OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t));
224     auto delta_updates = delta_updates_t->vec<float>();
225     const int64 logits_dimension = delta_updates_t->dim_size(0);
226 
227     // Get the bias.
228     boosted_trees::trees::Leaf* const bias =
229         RetrieveBias(ensemble_resource, logits_dimension);
230     CHECK(bias->has_vector());
231 
232     // Update the bias.
233     float total_delta = 0;
234     auto* bias_vec = bias->mutable_vector();
235     for (size_t idx = 0; idx < bias->vector().value_size(); ++idx) {
236       float delta = delta_updates(idx);
237       bias_vec->set_value(idx, bias_vec->value(idx) + delta);
238       total_delta += std::abs(delta);
239     }
240 
241     // Make a centering continuation decision based on current update.
242     bool continue_centering = total_delta > centering_epsilon_;
243     if (continue_centering) {
244       VLOG(1) << "Continuing to center bias, delta=" << total_delta;
245     } else {
246       VLOG(1) << "Done centering bias, delta=" << total_delta;
247       ensemble_resource->LastTreeMetadata()->set_is_finalized(true);
248     }
249     Tensor* continue_centering_t = nullptr;
250     OP_REQUIRES_OK(
251         context, context->allocate_output("continue_centering", TensorShape({}),
252                                           &continue_centering_t));
253     continue_centering_t->scalar<bool>()() = continue_centering;
254   }
255 
256  private:
257   // Helper method to retrieve the bias from the tree ensemble.
RetrieveBias(boosted_trees::models::DecisionTreeEnsembleResource * ensemble_resource,int64 logits_dimension)258   boosted_trees::trees::Leaf* RetrieveBias(
259       boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource,
260       int64 logits_dimension) {
261     const int32 num_trees = ensemble_resource->num_trees();
262     if (num_trees <= 0) {
263       // Add a new bias leaf.
264       ensemble_resource->IncrementAttempts();
265       boosted_trees::trees::DecisionTreeConfig* const tree_config =
266           ensemble_resource->AddNewTree(1.0);
267       auto* const leaf = tree_config->add_nodes()->mutable_leaf();
268       for (size_t idx = 0; idx < logits_dimension; ++idx) {
269         leaf->mutable_vector()->add_value(0.0);
270       }
271       return leaf;
272     } else if (num_trees == 1) {
273       // Confirms that the only tree is a bias and returns its leaf.
274       boosted_trees::trees::DecisionTreeConfig* const tree_config =
275           ensemble_resource->LastTree();
276       CHECK_EQ(tree_config->nodes_size(), 1);
277       CHECK_EQ(tree_config->nodes(0).node_case(), TreeNode::kLeaf);
278       return tree_config->mutable_nodes(0)->mutable_leaf();
279     } else {
280       LOG(FATAL) << "Unable to center bias on an already grown ensemble";
281     }
282   }
283 
284   boosted_trees::learner::LearnerConfig learner_config_;
285   float centering_epsilon_;
286 };
287 
288 REGISTER_KERNEL_BUILDER(Name("CenterTreeEnsembleBias").Device(DEVICE_CPU),
289                         CenterTreeEnsembleBiasOp);
290 
291 class GrowTreeEnsembleOp : public OpKernel {
292  public:
GrowTreeEnsembleOp(OpKernelConstruction * const context)293   explicit GrowTreeEnsembleOp(OpKernelConstruction* const context)
294       : OpKernel(context) {
295     // Read number of handlers, note that this is the static number of
296     // all handlers but any subset of these handlers may be active at a time.
297     OP_REQUIRES_OK(context, context->GetAttr("num_handlers", &num_handlers_));
298 
299     OP_REQUIRES_OK(context, context->GetAttr("center_bias", &center_bias_));
300 
301     // Read learner config.
302     string serialized_learner_config;
303     OP_REQUIRES_OK(context, context->GetAttr("learner_config",
304                                              &serialized_learner_config));
305     OP_REQUIRES(context,
306                 learner_config_.ParseFromString(serialized_learner_config),
307                 errors::InvalidArgument("Unable to parse learner config."));
308 
309     // Determine whether dropout was used when building this tree.
310     if (learner_config_.has_learning_rate_tuner() &&
311         learner_config_.learning_rate_tuner().tuner_case() ==
312             LearningRateConfig::kDropout) {
313       dropout_config_ = learner_config_.learning_rate_tuner().dropout();
314       dropout_was_applied_ = true;
315     } else {
316       dropout_was_applied_ = false;
317     }
318   }
319 
Compute(OpKernelContext * const context)320   void Compute(OpKernelContext* const context) override {
321     // Get decision tree ensemble.
322     boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
323     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
324                                            &ensemble_resource));
325     core::ScopedUnref unref_me(ensemble_resource);
326     mutex_lock l(*ensemble_resource->get_mutex());
327 
328     // Get the stamp token.
329     const Tensor* stamp_token_t;
330     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
331     int64 stamp_token = stamp_token_t->scalar<int64>()();
332 
333     // Only the Chief should run this Op and it is guaranteed to be in
334     // a consistent state so the stamps must always match.
335     CHECK(ensemble_resource->is_stamp_valid(stamp_token));
336 
337     // Get the next stamp token.
338     const Tensor* next_stamp_token_t;
339     OP_REQUIRES_OK(context,
340                    context->input("next_stamp_token", &next_stamp_token_t));
341     int64 next_stamp_token = next_stamp_token_t->scalar<int64>()();
342     CHECK(stamp_token != next_stamp_token);
343 
344     // Update the ensemble stamp regardless of whether a layer
345     // or tree is actually grown.
346     ensemble_resource->set_stamp(next_stamp_token);
347 
348     // Read the learning_rate.
349     const Tensor* learning_rate_t;
350     OP_REQUIRES_OK(context, context->input("learning_rate", &learning_rate_t));
351     float learning_rate = learning_rate_t->scalar<float>()();
352 
353     // Read the weak learner type to use.
354     const Tensor* weak_learner_type_t;
355     OP_REQUIRES_OK(context,
356                    context->input("weak_learner_type", &weak_learner_type_t));
357     const int32 weak_learner_type = weak_learner_type_t->scalar<int32>()();
358 
359     const Tensor* seed_t;
360     OP_REQUIRES_OK(context, context->input("dropout_seed", &seed_t));
361     // Cast seed to uint64.
362     const uint64 dropout_seed = seed_t->scalar<int64>()();
363 
364     // Read partition Ids, gains and split candidates.
365     OpInputList partition_ids_list;
366     OpInputList gains_list;
367     OpInputList splits_list;
368     OP_REQUIRES_OK(context,
369                    context->input_list("partition_ids", &partition_ids_list));
370     OP_REQUIRES_OK(context, context->input_list("gains", &gains_list));
371     OP_REQUIRES_OK(context, context->input_list("splits", &splits_list));
372 
373     // Increment attempt stats.
374     ensemble_resource->IncrementAttempts();
375 
376     // Find best splits for each active partition.
377     std::map<int32, SplitCandidate> best_splits;
378     switch (weak_learner_type) {
379       case LearnerConfig::NORMAL_DECISION_TREE: {
380         FindBestSplitsPerPartitionNormal(context, partition_ids_list,
381                                          gains_list, splits_list, &best_splits);
382         break;
383       }
384       case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
385         FindBestSplitOblivious(context, gains_list, splits_list, &best_splits);
386         break;
387       }
388     }
389     // No-op if no new splits can be considered.
390     if (best_splits.empty()) {
391       LOG(WARNING) << "Not growing tree ensemble as no good splits were found.";
392       return;
393     }
394 
395     // Get the max tree depth.
396     const Tensor* max_tree_depth_t;
397     OP_REQUIRES_OK(context,
398                    context->input("max_tree_depth", &max_tree_depth_t));
399     const int32 max_tree_depth = max_tree_depth_t->scalar<int32>()();
400     // Update and retrieve the growable tree.
401     // If the tree is fully built and dropout was applied, it also adjusts the
402     // weights of dropped and the last tree.
403     boosted_trees::trees::DecisionTreeConfig* const tree_config =
404         UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate,
405                                       dropout_seed, max_tree_depth,
406                                       weak_learner_type);
407     // Split tree nodes.
408     switch (weak_learner_type) {
409       case LearnerConfig::NORMAL_DECISION_TREE: {
410         for (auto& split_entry : best_splits) {
411           SplitTreeNode(split_entry.first, &split_entry.second, tree_config,
412                         ensemble_resource);
413         }
414         break;
415       }
416       case LearnerConfig::OBLIVIOUS_DECISION_TREE: {
417         SplitTreeLayer(&best_splits[0], tree_config, ensemble_resource);
418       }
419     }
420     // Post-prune finalized tree if needed.
421     if (learner_config_.pruning_mode() ==
422             boosted_trees::learner::LearnerConfig::POST_PRUNE &&
423         ensemble_resource->LastTreeMetadata()->is_finalized()) {
424       VLOG(2) << "Post-pruning finalized tree.";
425       if (weak_learner_type == LearnerConfig::OBLIVIOUS_DECISION_TREE) {
426         LOG(FATAL) << "Post-prunning is not implemented for Oblivious trees.";
427       }
428       PruneTree(tree_config);
429 
430       // If after post-pruning the whole tree has no gain, remove the tree
431       // altogether from the ensemble.
432       if (tree_config->nodes_size() <= 0) {
433         ensemble_resource->RemoveLastTree();
434       }
435     }
436   }
437 
438  private:
439   // Helper method which effectively does a reduce over all split candidates
440   // and finds the best split for each partition.
FindBestSplitsPerPartitionNormal(OpKernelContext * const context,const OpInputList & partition_ids_list,const OpInputList & gains_list,const OpInputList & splits_list,std::map<int32,SplitCandidate> * best_splits)441   void FindBestSplitsPerPartitionNormal(
442       OpKernelContext* const context, const OpInputList& partition_ids_list,
443       const OpInputList& gains_list, const OpInputList& splits_list,
444       std::map<int32, SplitCandidate>* best_splits) {
445     // Find best split per partition going through every feature candidate.
446     // TODO(salehay): Is this worth parallelizing?
447     for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
448       const auto& partition_ids = partition_ids_list[handler_id].vec<int32>();
449       const auto& gains = gains_list[handler_id].vec<float>();
450       const auto& splits = splits_list[handler_id].vec<string>();
451       OP_REQUIRES(context, partition_ids.size() == gains.size(),
452                   errors::InvalidArgument(
453                       "Inconsistent partition Ids and gains tensors: ",
454                       partition_ids.size(), " != ", gains.size()));
455       OP_REQUIRES(context, partition_ids.size() == splits.size(),
456                   errors::InvalidArgument(
457                       "Inconsistent partition Ids and splits tensors: ",
458                       partition_ids.size(), " != ", splits.size()));
459       for (size_t candidate_idx = 0; candidate_idx < splits.size();
460            ++candidate_idx) {
461         // Get current split candidate.
462         const auto& partition_id = partition_ids(candidate_idx);
463         const auto& gain = gains(candidate_idx);
464         const auto& serialized_split = splits(candidate_idx);
465         SplitCandidate split;
466         split.handler_id = handler_id;
467         split.gain = gain;
468         OP_REQUIRES(context, split.split_info.ParseFromString(serialized_split),
469                     errors::InvalidArgument("Unable to parse split info."));
470 
471         // Update best split for partition based on the current candidate.
472         UpdateBestSplit(learner_config_, partition_id, &split, best_splits);
473       }
474     }
475   }
476 
FindBestSplitOblivious(OpKernelContext * const context,const OpInputList & gains_list,const OpInputList & splits_list,std::map<int32,SplitCandidate> * best_splits)477   void FindBestSplitOblivious(OpKernelContext* const context,
478                               const OpInputList& gains_list,
479                               const OpInputList& splits_list,
480                               std::map<int32, SplitCandidate>* best_splits) {
481     // Find best split per partition going through every feature candidate.
482     for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) {
483       const auto& gains = gains_list[handler_id].vec<float>();
484       const auto& splits = splits_list[handler_id].vec<string>();
485       OP_REQUIRES(context, gains.size() == 1,
486                   errors::InvalidArgument(
487                       "Gains size must be one for oblivious weak learner: ",
488                       gains.size(), " != ", 1));
489       OP_REQUIRES(context, splits.size() == 1,
490                   errors::InvalidArgument(
491                       "Splits size must be one for oblivious weak learner: ",
492                       splits.size(), " != ", 1));
493       // Get current split candidate.
494       const auto& gain = gains(0);
495       const auto& serialized_split = splits(0);
496       SplitCandidate split;
497       split.handler_id = handler_id;
498       split.gain = gain;
499       OP_REQUIRES(
500           context, split.oblivious_split_info.ParseFromString(serialized_split),
501           errors::InvalidArgument("Unable to parse oblivious split info."));
502 
503       auto split_info = split.oblivious_split_info;
504       CHECK(split_info.children_size() % 2 == 0)
505           << "The oblivious split should generate an even number of children: "
506           << split_info.children_size();
507 
508       // If every node is pure, then we shouldn't split.
509       bool only_pure_nodes = true;
510       for (int idx = 0; idx < split_info.children_size(); idx += 2) {
511         if (IsLeafWellFormed(*split_info.mutable_children(idx)) &&
512             IsLeafWellFormed(*split_info.mutable_children(idx + 1))) {
513           only_pure_nodes = false;
514           break;
515         }
516       }
517       if (only_pure_nodes) {
518         VLOG(1) << "The oblivious split does not actually split anything.";
519         continue;
520       }
521 
522       // Don't consider negative splits if we're pre-pruning the tree.
523       if (learner_config_.pruning_mode() == learner::LearnerConfig::PRE_PRUNE &&
524           gain < 0) {
525         continue;
526       }
527 
528       // Take the split if we don't have a candidate yet.
529       auto best_split_it = best_splits->find(0);
530       if (best_split_it == best_splits->end()) {
531         best_splits->insert(std::make_pair(0, std::move(split)));
532         continue;
533       }
534 
535       // Determine if we should update best split.
536       SplitCandidate& best_split = best_split_it->second;
537       trees::TreeNode current_node = split_info.split_node();
538       trees::TreeNode best_node = best_split.oblivious_split_info.split_node();
539       if (TF_PREDICT_FALSE(gain == best_split.gain)) {
540         // Tie break on node case preferring simpler tree node types.
541         VLOG(2) << "Attempting to tie break with smaller node case. "
542                 << "(current split: " << current_node.node_case()
543                 << ", best split: " << best_node.node_case() << ")";
544         if (current_node.node_case() < best_node.node_case()) {
545           best_split = std::move(split);
546         } else if (current_node.node_case() == best_node.node_case()) {
547           // Tie break on handler Id.
548           VLOG(2) << "Tie breaking with higher handler Id. "
549                   << "(current split: " << handler_id
550                   << ", best split: " << best_split.handler_id << ")";
551           if (handler_id > best_split.handler_id) {
552             best_split = std::move(split);
553           }
554         }
555       } else if (gain > best_split.gain) {
556         best_split = std::move(split);
557       }
558     }
559   }
560 
UpdateTreeWeightsIfDropout(boosted_trees::models::DecisionTreeEnsembleResource * const ensemble_resource,const uint64 dropout_seed)561   void UpdateTreeWeightsIfDropout(
562       boosted_trees::models::DecisionTreeEnsembleResource* const
563           ensemble_resource,
564       const uint64 dropout_seed) {
565     // It is possible that the tree was built with dropout. If it is the case,
566     // we need to adjust the tree weight, or bail out.
567     if (!dropout_was_applied_ ||
568         !ensemble_resource->LastTreeMetadata()->is_finalized()) {
569       return;
570     }
571     const int32 num_trees = ensemble_resource->num_trees();
572 
573     // Based on seed, figure out what trees were dropped before.
574     std::unordered_set<int32> trees_not_to_drop;
575     if (center_bias_) {
576       trees_not_to_drop.insert(0);
577     }
578     // Last tree is the current tree that is built.
579     const int32 current_tree = num_trees - 1;
580     trees_not_to_drop.insert(current_tree);
581 
582     // Since only chief builds the trees, we are sure that the other tree
583     // weights didn't change.
584     std::vector<float> weights = ensemble_resource->GetTreeWeights();
585     std::vector<int32> dropped_trees;
586     std::vector<float> dropped_trees_weights;
587     const auto dropout_status = DropoutUtils::DropOutTrees(
588         dropout_seed, dropout_config_, trees_not_to_drop, weights,
589         &dropped_trees, &dropped_trees_weights);
590     CHECK(dropout_status.ok())
591         << "Can't figure out what trees were dropped out before, error is "
592         << dropout_status.error_message();
593 
594     // Now we have dropped trees, update their weights and the current tree
595     // weight.
596     if (!dropped_trees.empty()) {
597       std::vector<int32> increment_num_updates(num_trees, 0);
598       DropoutUtils::GetTreesWeightsForAddingTrees(
599           dropped_trees, dropped_trees_weights, current_tree,
600           1 /* only 1 tree was added */, &weights, &increment_num_updates);
601 
602       // Update the weights and num of updates for trees.
603       for (int i = 0; i < num_trees; ++i) {
604         ensemble_resource->SetTreeWeight(i, weights[i],
605                                          increment_num_updates[i]);
606       }
607     }
608   }
609 
610   // Helper method to update the growable tree which is by definition the last
611   // tree in the ensemble.
UpdateAndRetrieveGrowableTree(boosted_trees::models::DecisionTreeEnsembleResource * const ensemble_resource,const float learning_rate,const uint64 dropout_seed,const int32 max_tree_depth,const int32 weak_learner_type)612   boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree(
613       boosted_trees::models::DecisionTreeEnsembleResource* const
614           ensemble_resource,
615       const float learning_rate, const uint64 dropout_seed,
616       const int32 max_tree_depth, const int32 weak_learner_type) {
617     const auto num_trees = ensemble_resource->num_trees();
618     if (num_trees <= 0 ||
619         ensemble_resource->LastTreeMetadata()->is_finalized()) {
620       // Create a new tree with a no-op leaf.
621       boosted_trees::trees::DecisionTreeConfig* const tree_config =
622           ensemble_resource->AddNewTree(learning_rate);
623       VLOG(1) << "Adding layer #0 to tree #" << num_trees << " of ensemble of "
624               << num_trees + 1 << " trees.";
625       tree_config->add_nodes()->mutable_leaf();
626       boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
627           ensemble_resource->LastTreeMetadata();
628       tree_metadata->set_is_finalized(max_tree_depth <= 1);
629       tree_metadata->set_num_tree_weight_updates(1);
630     } else {
631       // The growable tree is by definition the last tree in the ensemble.
632       boosted_trees::trees::DecisionTreeMetadata* const tree_metadata =
633           ensemble_resource->LastTreeMetadata();
634       const auto new_num_layers = tree_metadata->num_layers_grown() + 1;
635       VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #"
636               << num_trees - 1 << " of ensemble of " << num_trees << " trees.";
637       // Update growable tree metadata.
638       tree_metadata->set_num_layers_grown(new_num_layers);
639       tree_metadata->set_is_finalized(new_num_layers >= max_tree_depth);
640     }
641     UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed);
642     return ensemble_resource->LastTree();
643   }
644 
645   // Helper method to merge leaf weights as the tree is being grown.
MergeLeafWeights(const boosted_trees::trees::Leaf & source,boosted_trees::trees::Leaf * dest)646   boosted_trees::trees::Leaf* MergeLeafWeights(
647       const boosted_trees::trees::Leaf& source,
648       boosted_trees::trees::Leaf* dest) {
649     // Resolve leaf merging method based on how the trees are being grown.
650     if (learner_config_.growing_mode() ==
651         boosted_trees::learner::LearnerConfig::WHOLE_TREE) {
652       // No merging occurs when building a whole tree at a time.
653       return dest;
654     }
655 
656     if (dest->leaf_case() == boosted_trees::trees::Leaf::LEAF_NOT_SET) {
657       // No merging is required. Just copy the source weights;
658       *dest = source;
659       return dest;
660     }
661 
662     // Handle leaf merging based on type.
663     switch (source.leaf_case()) {
664       case boosted_trees::trees::Leaf::kVector: {
665         // No-op if source is empty
666         const auto& src_vec = source.vector();
667         if (src_vec.value_size() == 0) {
668           break;
669         }
670         CHECK(source.leaf_case() == dest->leaf_case());
671 
672         // Dense add leaf vectors.
673         auto* dst_vec = dest->mutable_vector();
674         CHECK(src_vec.value_size() == dst_vec->value_size());
675         for (size_t idx = 0; idx < source.vector().value_size(); ++idx) {
676           (*dst_vec->mutable_value()->Mutable(idx)) += src_vec.value(idx);
677         }
678         break;
679       }
680       case boosted_trees::trees::Leaf::kSparseVector: {
681         // No-op if source is empty
682         const auto& src_vec = source.sparse_vector();
683         CHECK(src_vec.value_size() == src_vec.index_size());
684         if (src_vec.value_size() == 0) {
685           break;
686         }
687         CHECK(source.leaf_case() == dest->leaf_case());
688 
689         // Get mapping of dimension to value for destination.
690         std::unordered_map<int32, float> dst_map;
691         auto* dst_vec = dest->mutable_sparse_vector();
692         CHECK(dst_vec->value_size() == dst_vec->index_size());
693         dst_map.reserve(dst_vec->value_size());
694         for (size_t idx = 0; idx < dst_vec->value_size(); ++idx) {
695           dst_map[dst_vec->index(idx)] = dst_vec->value(idx);
696         }
697         // Sparse add source vector to destination vector.
698         for (size_t idx = 0; idx < src_vec.value_size(); ++idx) {
699           dst_map[src_vec.index(idx)] += src_vec.value(idx);
700         }
701         // Rebuild merged destination leaf.
702         dst_vec->clear_index();
703         dst_vec->clear_value();
704         for (const auto& entry : dst_map) {
705           dst_vec->add_index(entry.first);
706           dst_vec->add_value(entry.second);
707         }
708         break;
709       }
710       case boosted_trees::trees::Leaf::LEAF_NOT_SET: {
711         // No-op as there is nothing to merge.
712         break;
713       }
714     }
715     return dest;
716   }
717 
718   // Helper method to split a tree node and append its respective
719   // leaf children given the split candidate.
SplitTreeNode(const int32 node_id,SplitCandidate * split,boosted_trees::trees::DecisionTreeConfig * tree_config,boosted_trees::models::DecisionTreeEnsembleResource * ensemble_resource)720   void SplitTreeNode(
721       const int32 node_id, SplitCandidate* split,
722       boosted_trees::trees::DecisionTreeConfig* tree_config,
723       boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
724     // No-op if we have no real node.
725     CHECK(node_id < tree_config->nodes_size())
726         << "Invalid node " << node_id << " to split.";
727     // Ensure new split node is valid.
728     CHECK(split->split_info.split_node().node_case() != TreeNode::NODE_NOT_SET);
729     CHECK(tree_config->nodes(node_id).node_case() == TreeNode::kLeaf)
730         << "Unexpected node type to split "
731         << tree_config->nodes(node_id).node_case() << " for node_id " << node_id
732         << ". Tree config: " << tree_config->DebugString();
733 
734     // Add left leaf.
735     int32 left_id = tree_config->nodes_size();
736     (*tree_config->add_nodes()->mutable_leaf()) =
737         *MergeLeafWeights(tree_config->nodes(node_id).leaf(),
738                           split->split_info.mutable_left_child());
739 
740     // Add right leaf.
741     int32 right_id = tree_config->nodes_size();
742     (*tree_config->add_nodes()->mutable_leaf()) =
743         *MergeLeafWeights(tree_config->nodes(node_id).leaf(),
744                           split->split_info.mutable_right_child());
745 
746     // Link children and add them as new roots.
747     boosted_trees::trees::DecisionTree::LinkChildren(
748         {left_id, right_id}, split->split_info.mutable_split_node());
749 
750     // Add split gain and, if needed, original leaf to node metadata.
751     TreeNodeMetadata* node_metadata =
752         split->split_info.mutable_split_node()->mutable_node_metadata();
753     node_metadata->set_gain(split->gain);
754     if (learner_config_.pruning_mode() ==
755         boosted_trees::learner::LearnerConfig::POST_PRUNE) {
756       (*node_metadata->mutable_original_leaf()) =
757           *tree_config->mutable_nodes(node_id)->mutable_leaf();
758     }
759 
760     // Replace node in tree.
761     (*tree_config->mutable_nodes(node_id)) =
762         *split->split_info.mutable_split_node();
763     if (learner_config_.constraints().max_number_of_unique_feature_columns()) {
764       ensemble_resource->MaybeAddUsedHandler(split->handler_id);
765     }
766   }
767 
SplitTreeLayer(SplitCandidate * split,boosted_trees::trees::DecisionTreeConfig * tree_config,boosted_trees::models::DecisionTreeEnsembleResource * ensemble_resource)768   void SplitTreeLayer(
769       SplitCandidate* split,
770       boosted_trees::trees::DecisionTreeConfig* tree_config,
771       boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) {
772     int depth = 0;
773     while (depth < tree_config->nodes_size() &&
774            tree_config->nodes(depth).node_case() != TreeNode::kLeaf) {
775       depth++;
776     }
777     CHECK(tree_config->nodes_size() > 0)
778         << "A tree must have at least one dummy leaf.";
779     // The number of new children.
780     int num_children = 1 << (depth + 1);
781     auto split_info = split->oblivious_split_info;
782     CHECK(num_children >= split_info.children_size())
783         << "Too many new children, expected <= " << num_children << " and got "
784         << split_info.children_size();
785     std::vector<trees::Leaf> new_leaves;
786     new_leaves.reserve(num_children);
787     int next_id = 0;
788     for (int idx = 0; idx < num_children / 2; idx++) {
789       trees::Leaf old_leaf =
790           *tree_config->mutable_nodes(depth + idx)->mutable_leaf();
791       // Check if a split was made for this leaf.
792       if (next_id < split_info.children_parent_id_size() &&
793           depth + idx == split_info.children_parent_id(next_id)) {
794         // Add left leaf.
795         new_leaves.push_back(*MergeLeafWeights(
796             old_leaf, split_info.mutable_children(2 * next_id)));
797         // Add right leaf.
798         new_leaves.push_back(*MergeLeafWeights(
799             old_leaf, split_info.mutable_children(2 * next_id + 1)));
800         next_id++;
801       } else {
802         // If there is no split for this leaf, just duplicate it.
803         new_leaves.push_back(old_leaf);
804         new_leaves.push_back(old_leaf);
805       }
806     }
807     CHECK(next_id == split_info.children_parent_id_size());
808     TreeNodeMetadata* split_metadata =
809         split_info.mutable_split_node()->mutable_node_metadata();
810     split_metadata->set_gain(split->gain);
811 
812     TreeNode new_split = *split_info.mutable_split_node();
813     // Move old children to metadata.
814     for (int idx = depth; idx < tree_config->nodes_size(); idx++) {
815       *new_split.mutable_node_metadata()->add_original_oblivious_leaves() =
816           *tree_config->mutable_nodes(idx)->mutable_leaf();
817     }
818     // Add the new split to the tree_config in place before the children start.
819     *tree_config->mutable_nodes(depth) = new_split;
820     // Add the new children
821     int nodes_size = tree_config->nodes_size();
822     for (int idx = 0; idx < num_children; idx++) {
823       if (idx + depth + 1 < nodes_size) {
824         // Update leaves that were already there.
825         *tree_config->mutable_nodes(idx + depth + 1)->mutable_leaf() =
826             new_leaves[idx];
827       } else {
828         // Add new leaves.
829         *tree_config->add_nodes()->mutable_leaf() = new_leaves[idx];
830       }
831     }
832   }
PruneTree(boosted_trees::trees::DecisionTreeConfig * tree_config)833   void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) {
834     // No-op if tree is empty.
835     if (tree_config->nodes_size() <= 0) {
836       return;
837     }
838 
839     // Copy nodes to temp vector and clear original tree.
840     std::vector<TreeNode> tree_nodes;
841     tree_nodes.reserve(tree_config->nodes_size());
842     for (auto& node : (*tree_config->mutable_nodes())) {
843       tree_nodes.push_back(node);
844       node.Clear();
845     }
846     tree_config->clear_nodes();
847 
848     // Prune the tree recursively starting from the root.
849     RecursivePruneTree(0, &tree_nodes);
850 
851     // Rebuild compacted tree.
852     (*tree_config->add_nodes()) = tree_nodes[0];
853     std::unordered_map<size_t, size_t> nodes_map;
854     nodes_map[0] = 0;
855     for (size_t node_idx = 0; node_idx < tree_nodes.size(); ++node_idx) {
856       // Skip pruned nodes.
857       auto& original_node = tree_nodes[node_idx];
858       if (original_node.node_case() == TreeNode::NODE_NOT_SET) {
859         continue;
860       }
861 
862       // Find node mapped in tree ensemble.
863       auto mapped_node_it = nodes_map.find(node_idx);
864       CHECK(mapped_node_it != nodes_map.end());
865       auto& mapped_node = (*tree_config->mutable_nodes(mapped_node_it->second));
866 
867       // Get node children
868       auto children =
869           boosted_trees::trees::DecisionTree::GetChildren(original_node);
870       for (int32& child_idx : children) {
871         auto new_idx = tree_config->nodes_size();
872         (*tree_config->add_nodes()) = tree_nodes[child_idx];
873         nodes_map[child_idx] = new_idx;
874         child_idx = new_idx;
875       }
876       boosted_trees::trees::DecisionTree::LinkChildren(children, &mapped_node);
877     }
878 
879     // Check if there are any nodes with gain left.
880     if (tree_config->nodes_size() == 1 &&
881         tree_config->nodes(0).node_metadata().gain() <= 0) {
882       // The whole tree should be pruned.
883       VLOG(2) << "No useful nodes left after post-pruning tree.";
884       tree_config->clear_nodes();
885     }
886   }
887 
888  private:
889   boosted_trees::learner::LearnerConfig learner_config_;
890   int64 num_handlers_;
891   LearningRateDropoutDrivenConfig dropout_config_;
892   bool dropout_was_applied_;
893   bool center_bias_;
894 };
895 
896 REGISTER_KERNEL_BUILDER(Name("GrowTreeEnsemble").Device(DEVICE_CPU),
897                         GrowTreeEnsembleOp);
898 
899 class TreeEnsembleStatsOp : public OpKernel {
900  public:
TreeEnsembleStatsOp(OpKernelConstruction * const context)901   explicit TreeEnsembleStatsOp(OpKernelConstruction* const context)
902       : OpKernel(context) {}
903 
Compute(OpKernelContext * const context)904   void Compute(OpKernelContext* const context) override {
905     // Get decision tree ensemble.
906     boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource;
907     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
908                                            &ensemble_resource));
909     core::ScopedUnref unref_me(ensemble_resource);
910     tf_shared_lock l(*ensemble_resource->get_mutex());
911 
912     // Get the stamp token.
913     const Tensor* stamp_token_t;
914     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
915     int64 stamp_token = stamp_token_t->scalar<int64>()();
916 
917     // Only the Chief should run this Op and it is guaranteed to be in
918     // a consistent state so the stamps must always match.
919     CHECK(ensemble_resource->is_stamp_valid(stamp_token));
920     const boosted_trees::trees::DecisionTreeEnsembleConfig& ensemble_config =
921         ensemble_resource->decision_tree_ensemble();
922 
923     // Set tree stats.
924     Tensor* num_trees_t = nullptr;
925     OP_REQUIRES_OK(context, context->allocate_output(
926                                 "num_trees", TensorShape({}), &num_trees_t));
927     Tensor* active_tree_t = nullptr;
928     OP_REQUIRES_OK(context,
929                    context->allocate_output("active_tree", TensorShape({}),
930                                             &active_tree_t));
931     Tensor* attempted_tree_t = nullptr;
932     OP_REQUIRES_OK(context,
933                    context->allocate_output("attempted_trees", TensorShape({}),
934                                             &attempted_tree_t));
935 
936     const int num_trees = ensemble_resource->num_trees();
937     active_tree_t->scalar<int64>()() = num_trees;
938     num_trees_t->scalar<int64>()() =
939         (num_trees <= 0 ||
940          ensemble_resource->LastTreeMetadata()->is_finalized())
941             ? num_trees
942             : num_trees - 1;
943     attempted_tree_t->scalar<int64>()() =
944         ensemble_config.growing_metadata().num_trees_attempted();
945 
946     // Set layer stats.
947     Tensor* num_layers_t = nullptr;
948     OP_REQUIRES_OK(context, context->allocate_output(
949                                 "num_layers", TensorShape({}), &num_layers_t));
950     Tensor* active_layer_t = nullptr;
951     OP_REQUIRES_OK(context,
952                    context->allocate_output("active_layer", TensorShape({}),
953                                             &active_layer_t));
954     Tensor* attempted_layers_t = nullptr;
955     OP_REQUIRES_OK(context,
956                    context->allocate_output("attempted_layers", TensorShape({}),
957                                             &attempted_layers_t));
958 
959     int64 num_layers = 0;
960     for (const auto& tree_metadata : ensemble_config.tree_metadata()) {
961       num_layers += tree_metadata.num_layers_grown();
962     }
963     num_layers_t->scalar<int64>()() = num_layers;
964     int tree_metadata_size = ensemble_config.tree_metadata_size();
965     active_layer_t->scalar<int64>()() =
966         tree_metadata_size > 0
967             ? ensemble_config.tree_metadata(tree_metadata_size - 1)
968                   .num_layers_grown()
969             : 0;
970     attempted_layers_t->scalar<int64>()() =
971         ensemble_config.growing_metadata().num_layers_attempted();
972   }
973 };
974 
975 REGISTER_KERNEL_BUILDER(Name("TreeEnsembleStats").Device(DEVICE_CPU),
976                         TreeEnsembleStatsOp);
977 
978 }  // namespace boosted_trees
979 }  // namespace tensorflow
980