1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <memory> 17 #include <string> 18 19 #include "tensorflow/core/framework/op_kernel.h" 20 #include "tensorflow/core/framework/resource_mgr.h" 21 #include "tensorflow/core/framework/tensor.h" 22 #include "tensorflow/core/framework/tensor_shape.h" 23 #include "tensorflow/core/framework/tensor_types.h" 24 #include "tensorflow/core/kernels/boosted_trees/resources.h" 25 26 namespace tensorflow { 27 28 REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesEnsembleResource); 29 30 REGISTER_KERNEL_BUILDER( 31 Name("IsBoostedTreesEnsembleInitialized").Device(DEVICE_CPU), 32 IsResourceInitialized<BoostedTreesEnsembleResource>); 33 34 // Creates a tree ensemble resource. 35 class BoostedTreesCreateEnsembleOp : public OpKernel { 36 public: BoostedTreesCreateEnsembleOp(OpKernelConstruction * context)37 explicit BoostedTreesCreateEnsembleOp(OpKernelConstruction* context) 38 : OpKernel(context) {} 39 Compute(OpKernelContext * context)40 void Compute(OpKernelContext* context) override { 41 // Get the stamp token. 42 const Tensor* stamp_token_t; 43 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 44 int64 stamp_token = stamp_token_t->scalar<int64>()(); 45 46 // Get the tree ensemble proto. 47 const Tensor* tree_ensemble_serialized_t; 48 OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized", 49 &tree_ensemble_serialized_t)); 50 std::unique_ptr<BoostedTreesEnsembleResource> result( 51 new BoostedTreesEnsembleResource()); 52 if (!result->InitFromSerialized( 53 tree_ensemble_serialized_t->scalar<string>()(), stamp_token)) { 54 result->Unref(); 55 OP_REQUIRES( 56 context, false, 57 errors::InvalidArgument("Unable to parse tree ensemble proto.")); 58 } 59 60 // Only create one, if one does not exist already. Report status for all 61 // other exceptions. 62 auto status = 63 CreateResource(context, HandleFromInput(context, 0), result.release()); 64 if (status.code() != tensorflow::error::ALREADY_EXISTS) { 65 OP_REQUIRES_OK(context, status); 66 } 67 } 68 }; 69 70 REGISTER_KERNEL_BUILDER(Name("BoostedTreesCreateEnsemble").Device(DEVICE_CPU), 71 BoostedTreesCreateEnsembleOp); 72 73 // Op for retrieving some model states (needed for training). 74 class BoostedTreesGetEnsembleStatesOp : public OpKernel { 75 public: BoostedTreesGetEnsembleStatesOp(OpKernelConstruction * context)76 explicit BoostedTreesGetEnsembleStatesOp(OpKernelConstruction* context) 77 : OpKernel(context) {} 78 Compute(OpKernelContext * context)79 void Compute(OpKernelContext* context) override { 80 // Looks up the resource. 81 BoostedTreesEnsembleResource* tree_ensemble_resource; 82 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 83 &tree_ensemble_resource)); 84 tf_shared_lock l(*tree_ensemble_resource->get_mutex()); 85 core::ScopedUnref unref_me(tree_ensemble_resource); 86 87 // Sets the outputs. 88 const int num_trees = tree_ensemble_resource->num_trees(); 89 const int num_finalized_trees = 90 (num_trees <= 0 || 91 tree_ensemble_resource->IsTreeFinalized(num_trees - 1)) 92 ? num_trees 93 : num_trees - 1; 94 const int num_attempted_layers = 95 tree_ensemble_resource->GetNumLayersAttempted(); 96 97 // growing_metadata 98 Tensor* output_stamp_token_t = nullptr; 99 Tensor* output_num_trees_t = nullptr; 100 Tensor* output_num_finalized_trees_t = nullptr; 101 Tensor* output_num_attempted_layers_t = nullptr; 102 Tensor* output_last_layer_nodes_range_t = nullptr; 103 104 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), 105 &output_stamp_token_t)); 106 OP_REQUIRES_OK(context, context->allocate_output(1, TensorShape(), 107 &output_num_trees_t)); 108 OP_REQUIRES_OK(context, 109 context->allocate_output(2, TensorShape(), 110 &output_num_finalized_trees_t)); 111 OP_REQUIRES_OK(context, 112 context->allocate_output(3, TensorShape(), 113 &output_num_attempted_layers_t)); 114 OP_REQUIRES_OK(context, context->allocate_output( 115 4, {2}, &output_last_layer_nodes_range_t)); 116 117 output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp(); 118 output_num_trees_t->scalar<int32>()() = num_trees; 119 output_num_finalized_trees_t->scalar<int32>()() = num_finalized_trees; 120 output_num_attempted_layers_t->scalar<int32>()() = num_attempted_layers; 121 122 int32 range_start; 123 int32 range_end; 124 tree_ensemble_resource->GetLastLayerNodesRange(&range_start, &range_end); 125 126 output_last_layer_nodes_range_t->vec<int32>()(0) = range_start; 127 // For a completely empty ensemble, this will be 0. To make it a valid range 128 // we add this max cond. 129 output_last_layer_nodes_range_t->vec<int32>()(1) = std::max(1, range_end); 130 } 131 }; 132 133 REGISTER_KERNEL_BUILDER( 134 Name("BoostedTreesGetEnsembleStates").Device(DEVICE_CPU), 135 BoostedTreesGetEnsembleStatesOp); 136 137 // Op for serializing a model. 138 class BoostedTreesSerializeEnsembleOp : public OpKernel { 139 public: BoostedTreesSerializeEnsembleOp(OpKernelConstruction * context)140 explicit BoostedTreesSerializeEnsembleOp(OpKernelConstruction* context) 141 : OpKernel(context) {} 142 Compute(OpKernelContext * context)143 void Compute(OpKernelContext* context) override { 144 BoostedTreesEnsembleResource* tree_ensemble_resource; 145 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 146 &tree_ensemble_resource)); 147 tf_shared_lock l(*tree_ensemble_resource->get_mutex()); 148 core::ScopedUnref unref_me(tree_ensemble_resource); 149 Tensor* output_stamp_token_t = nullptr; 150 OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), 151 &output_stamp_token_t)); 152 output_stamp_token_t->scalar<int64>()() = tree_ensemble_resource->stamp(); 153 Tensor* output_proto_t = nullptr; 154 OP_REQUIRES_OK(context, 155 context->allocate_output(1, TensorShape(), &output_proto_t)); 156 output_proto_t->scalar<string>()() = 157 tree_ensemble_resource->SerializeAsString(); 158 } 159 }; 160 161 REGISTER_KERNEL_BUILDER( 162 Name("BoostedTreesSerializeEnsemble").Device(DEVICE_CPU), 163 BoostedTreesSerializeEnsembleOp); 164 165 // Op for deserializing a tree ensemble variable from a checkpoint. 166 class BoostedTreesDeserializeEnsembleOp : public OpKernel { 167 public: BoostedTreesDeserializeEnsembleOp(OpKernelConstruction * context)168 explicit BoostedTreesDeserializeEnsembleOp(OpKernelConstruction* context) 169 : OpKernel(context) {} 170 Compute(OpKernelContext * context)171 void Compute(OpKernelContext* context) override { 172 BoostedTreesEnsembleResource* tree_ensemble_resource; 173 OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), 174 &tree_ensemble_resource)); 175 mutex_lock l(*tree_ensemble_resource->get_mutex()); 176 core::ScopedUnref unref_me(tree_ensemble_resource); 177 178 // Get the stamp token. 179 const Tensor* stamp_token_t; 180 OP_REQUIRES_OK(context, context->input("stamp_token", &stamp_token_t)); 181 int64 stamp_token = stamp_token_t->scalar<int64>()(); 182 183 // Get the tree ensemble proto. 184 const Tensor* tree_ensemble_serialized_t; 185 OP_REQUIRES_OK(context, context->input("tree_ensemble_serialized", 186 &tree_ensemble_serialized_t)); 187 // Deallocate all the previous objects on the resource. 188 tree_ensemble_resource->Reset(); 189 OP_REQUIRES( 190 context, 191 tree_ensemble_resource->InitFromSerialized( 192 tree_ensemble_serialized_t->scalar<string>()(), stamp_token), 193 errors::InvalidArgument("Unable to parse tree ensemble proto.")); 194 } 195 }; 196 197 REGISTER_KERNEL_BUILDER( 198 Name("BoostedTreesDeserializeEnsemble").Device(DEVICE_CPU), 199 BoostedTreesDeserializeEnsembleOp); 200 201 } // namespace tensorflow 202