1 // Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 // =============================================================================
15 #include "tensorflow/contrib/tensor_forest/kernels/v4/fertile-stats-resource.h"
16
17 #include <cfloat>
18
19 namespace tensorflow {
20 namespace tensorforest {
21
AddExampleToStatsAndInitialize(const std::unique_ptr<TensorDataSet> & input_data,const InputTarget * target,const std::vector<int> & examples,int32 node_id,bool * is_finished)22 void FertileStatsResource::AddExampleToStatsAndInitialize(
23 const std::unique_ptr<TensorDataSet>& input_data, const InputTarget* target,
24 const std::vector<int>& examples, int32 node_id, bool* is_finished) {
25 // Update stats or initialize if needed.
26 if (collection_op_->IsInitialized(node_id)) {
27 collection_op_->AddExample(input_data, target, examples, node_id);
28 } else {
29 // This throws away any extra examples, which is more inefficient towards
30 // the top but gradually becomes less of an issue as the tree grows.
31 for (int example : examples) {
32 collection_op_->CreateAndInitializeCandidateWithExample(
33 input_data, target, example, node_id);
34 if (collection_op_->IsInitialized(node_id)) {
35 break;
36 }
37 }
38 }
39
40 *is_finished = collection_op_->IsFinished(node_id);
41 }
42
AllocateNode(int32 node_id,int32 depth)43 void FertileStatsResource::AllocateNode(int32 node_id, int32 depth) {
44 collection_op_->InitializeSlot(node_id, depth);
45 }
46
Allocate(int32 parent_depth,const std::vector<int32> & new_children)47 void FertileStatsResource::Allocate(int32 parent_depth,
48 const std::vector<int32>& new_children) {
49 const int32 children_depth = parent_depth + 1;
50 for (const int32 child : new_children) {
51 AllocateNode(child, children_depth);
52 }
53 }
54
Clear(int32 node)55 void FertileStatsResource::Clear(int32 node) {
56 collection_op_->ClearSlot(node);
57 }
58
BestSplit(int32 node_id,SplitCandidate * best,int32 * depth)59 bool FertileStatsResource::BestSplit(int32 node_id, SplitCandidate* best,
60 int32* depth) {
61 return collection_op_->BestSplit(node_id, best, depth);
62 }
63
MaybeInitialize()64 void FertileStatsResource::MaybeInitialize() {
65 collection_op_->MaybeInitialize();
66 }
67
ExtractFromProto(const FertileStats & stats)68 void FertileStatsResource::ExtractFromProto(const FertileStats& stats) {
69 collection_op_ =
70 SplitCollectionOperatorFactory::CreateSplitCollectionOperator(params_);
71 collection_op_->ExtractFromProto(stats);
72 }
73
PackToProto(FertileStats * stats) const74 void FertileStatsResource::PackToProto(FertileStats* stats) const {
75 collection_op_->PackToProto(stats);
76 }
77 } // namespace tensorforest
78 } // namespace tensorflow
79