1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h"
16 
17 #include <cfloat>
18 
19 namespace tensorflow {
20 namespace tensorforest {
21 
AddExampleToStatsAndInitialize(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,const std::vector<int> & examples,int32 node_id,bool * is_finished)22 void FertileStatsResource::AddExampleToStatsAndInitialize(
23     const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
24     const std::vector<int>& examples, int32 node_id, bool* is_finished) {
25   // Update stats or initialize if needed.
26   if (collection_op_->IsInitialized(node_id)) {
27     collection_op_->AddExample(input_data, target, examples, node_id);
28   } else {
29     // This throws away any extra examples, which is more inefficient towards
30     // the top but gradually becomes less of an issue as the tree grows.
31     for (int example : examples) {
32       collection_op_->CreateAndInitializeCandidateWithExample(
33           input_data, target, example, node_id);
34       if (collection_op_->IsInitialized(node_id)) {
35         break;
36       }
37     }
38   }
39 
40   *is_finished = collection_op_->IsFinished(node_id);
41 }
42 
AllocateNode(int32 node_id,int32 depth)43 void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) {
44   collection_op_->InitializeSlot(node_id, depth);
45 }
46 
Allocate(int32 parent_depth,const std::vector<int32> & new_children)47 void FertileStatsResource::Allocate(int32 parent_depth,
48                                     const std::vector<int32>& new_children) {
49   const int32 children_depth = parent_depth + 1;
50   for (const int32 child : new_children) {
51     AllocateNode(child, children_depth);
52   }
53 }
54 
Clear(int32 node)55 void FertileStatsResource::Clear(int32 node) {
56   collection_op_->ClearSlot(node);
57 }
58 
BestSplit(int32 node_id,SplitCandidate * best,int32 * depth)59 bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
60                                      int32* depth) {
61   return collection_op_->BestSplit(node_id, best, depth);
62 }
63 
MaybeInitialize()64 void FertileStatsResource::MaybeInitialize() {
65   collection_op_->MaybeInitialize();
66 }
67 
ExtractFromProto(const FertileStats & stats)68 void FertileStatsResource::ExtractFromProto(const FertileStats& stats) {
69   collection_op_ =
70       SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_);
71   collection_op_->ExtractFromProto(stats);
72 }
73 
PackToProto(FertileStats * stats) const74 void FertileStatsResource::PackToProto(FertileStats* stats) const {
75   collection_op_->PackToProto(stats);
76 }
77 }  // namespace tensorforest
78 }  // namespace tensorflow
79