1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/resource_mgr.h"
21 #include "tensorflow/core/framework/tensor.h"
22 #include "tensorflow/core/framework/tensor_shape.h"
23 #include "tensorflow/core/framework/tensor_types.h"
24 #include "tensorflow/core/kernels/boosted_trees/resources.h"
25 
26 namespace tensorflow {
27 
28 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource);
29 
30 REGISTER_KERNEL_BUILDER(
31     Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU),
32     IsResourceInitialized<BoostedTreesEnsembleResource>);
33 
34 // Creates a tree ensemble resource.
35 class BoostedTreesCreateEnsembleOp : public OpKernel {
36  public:
BoostedTreesCreateEnsembleOp(OpKernelConstruction * context)37   explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context)
38       : OpKernel(context) {}
39 
Compute(OpKernelContext * context)40   void Compute(OpKernelContext* context) override {
41     // Get the stamp token.
42     const Tensor* stamp_token_t;
43     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
44     int64 stamp_token = stamp_token_t->scalar<int64>()();
45 
46     // Get the tree ensemble proto.
47     const Tensor* tree_ensemble_serialized_t;
48     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
49                                            &tree_ensemble_serialized_t));
50     std::unique_ptr<BoostedTreesEnsembleResource> result(
51         new BoostedTreesEnsembleResource());
52     if (!result->InitFromSerialized(
53             tree_ensemble_serialized_t->scalar<string>()(), stamp_token)) {
54       result->Unref();
55       OP_REQUIRES(
56           context, false,
57           errors::InvalidArgument("Unable to parse tree ensemble proto."));
58     }
59 
60     // Only create one, if one does not exist already. Report status for all
61     // other exceptions.
62     auto status =
63         CreateResource(context, HandleFromInput(context, 0), result.release());
64     if (status.code() != tensorflow::error::ALREADY_EXISTS) {
65       OP_REQUIRES_OK(context, status);
66     }
67   }
68 };
69 
70 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU),
71                         BoostedTreesCreateEnsembleOp);
72 
73 // Op for retrieving some model states (needed for training).
74 class BoostedTreesGetEnsembleStatesOp : public OpKernel {
75  public:
BoostedTreesGetEnsembleStatesOp(OpKernelConstruction * context)76   explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context)
77       : OpKernel(context) {}
78 
Compute(OpKernelContext * context)79   void Compute(OpKernelContext* context) override {
80     // Looks up the resource.
81     BoostedTreesEnsembleResource* tree_ensemble_resource;
82     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
83                                            &tree_ensemble_resource));
84     tf_shared_lock l(*tree_ensemble_resource->get_mutex());
85     core::ScopedUnref unref_me(tree_ensemble_resource);
86 
87     // Sets the outputs.
88     const int num_trees = tree_ensemble_resource->num_trees();
89     const int num_finalized_trees =
90         (num_trees <= 0 ||
91          tree_ensemble_resource->IsTreeFinalized(num_trees - 1))
92             ? num_trees
93             : num_trees - 1;
94     const int num_attempted_layers =
95         tree_ensemble_resource->GetNumLayersAttempted();
96 
97     // growing_metadata
98     Tensor* output_stamp_token_t = nullptr;
99     Tensor* output_num_trees_t = nullptr;
100     Tensor* output_num_finalized_trees_t = nullptr;
101     Tensor* output_num_attempted_layers_t = nullptr;
102     Tensor* output_last_layer_nodes_range_t = nullptr;
103 
104     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
105                                                      &output_stamp_token_t));
106     OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(),
107                                                      &output_num_trees_t));
108     OP_REQUIRES_OK(context,
109                    context->allocate_output(2, TensorShape(),
110                                             &output_num_finalized_trees_t));
111     OP_REQUIRES_OK(context,
112                    context->allocate_output(3, TensorShape(),
113                                             &output_num_attempted_layers_t));
114     OP_REQUIRES_OK(context, context->allocate_output(
115                                 4, {2}, &output_last_layer_nodes_range_t));
116 
117     output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
118     output_num_trees_t->scalar<int32>()() = num_trees;
119     output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees;
120     output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers;
121 
122     int32 range_start;
123     int32 range_end;
124     tree_ensemble_resource->GetLastLayerNodesRange(&range_start, &range_end);
125 
126     output_last_layer_nodes_range_t->vec<int32>()(0) = range_start;
127     // For a completely empty ensemble, this will be 0. To make it a valid range
128     // we add this max cond.
129     output_last_layer_nodes_range_t->vec<int32>()(1) = std::max(1, range_end);
130   }
131 };
132 
133 REGISTER_KERNEL_BUILDER(
134     Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU),
135     BoostedTreesGetEnsembleStatesOp);
136 
137 // Op for serializing a model.
138 class BoostedTreesSerializeEnsembleOp : public OpKernel {
139  public:
BoostedTreesSerializeEnsembleOp(OpKernelConstruction * context)140   explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context)
141       : OpKernel(context) {}
142 
Compute(OpKernelContext * context)143   void Compute(OpKernelContext* context) override {
144     BoostedTreesEnsembleResource* tree_ensemble_resource;
145     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
146                                            &tree_ensemble_resource));
147     tf_shared_lock l(*tree_ensemble_resource->get_mutex());
148     core::ScopedUnref unref_me(tree_ensemble_resource);
149     Tensor* output_stamp_token_t = nullptr;
150     OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(),
151                                                      &output_stamp_token_t));
152     output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp();
153     Tensor* output_proto_t = nullptr;
154     OP_REQUIRES_OK(context,
155                    context->allocate_output(1, TensorShape(), &output_proto_t));
156     output_proto_t->scalar<string>()() =
157         tree_ensemble_resource->SerializeAsString();
158   }
159 };
160 
161 REGISTER_KERNEL_BUILDER(
162     Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU),
163     BoostedTreesSerializeEnsembleOp);
164 
165 // Op for deserializing a tree ensemble variable from a checkpoint.
166 class BoostedTreesDeserializeEnsembleOp : public OpKernel {
167  public:
BoostedTreesDeserializeEnsembleOp(OpKernelConstruction * context)168   explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context)
169       : OpKernel(context) {}
170 
Compute(OpKernelContext * context)171   void Compute(OpKernelContext* context) override {
172     BoostedTreesEnsembleResource* tree_ensemble_resource;
173     OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0),
174                                            &tree_ensemble_resource));
175     mutex_lock l(*tree_ensemble_resource->get_mutex());
176     core::ScopedUnref unref_me(tree_ensemble_resource);
177 
178     // Get the stamp token.
179     const Tensor* stamp_token_t;
180     OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t));
181     int64 stamp_token = stamp_token_t->scalar<int64>()();
182 
183     // Get the tree ensemble proto.
184     const Tensor* tree_ensemble_serialized_t;
185     OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized",
186                                            &tree_ensemble_serialized_t));
187     // Deallocate all the previous objects on the resource.
188     tree_ensemble_resource->Reset();
189     OP_REQUIRES(
190         context,
191         tree_ensemble_resource->InitFromSerialized(
192             tree_ensemble_serialized_t->scalar<string>()(), stamp_token),
193         errors::InvalidArgument("Unable to parse tree ensemble proto."));
194   }
195 };
196 
197 REGISTER_KERNEL_BUILDER(
198     Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU),
199     BoostedTreesDeserializeEnsembleOp);
200 
201 }  // namespace tensorflow
202