1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Ops for boosted_trees.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import ops 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import gen_boosted_trees_ops 23from tensorflow.python.ops import resources 24 25# Re-exporting ops used by other modules. 26# pylint: disable=unused-import 27from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize 28from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature 29from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias 30from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource 31from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs 32from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries 33from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary 34from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict 35from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries 36from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize 37from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush 38from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries 39from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op 40from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict 41from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble 42from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized 43# pylint: enable=unused-import 44 45from tensorflow.python.training import saver 46from tensorflow.python.training.tracking import tracking 47 48 49class PruningMode(object): 50 """Class for working with Pruning modes.""" 51 NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3) 52 53 _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING} 54 55 @classmethod 56 def from_str(cls, mode): 57 if mode in cls._map: 58 return cls._map[mode] 59 else: 60 raise ValueError('pruning_mode mode must be one of: {}'.format(', '.join( 61 sorted(cls._map)))) 62 63 64class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject): 65 """SaveableObject implementation for QuantileAccumulator.""" 66 67 def __init__(self, resource_handle, create_op, num_streams, name): 68 self._resource_handle = resource_handle 69 self._num_streams = num_streams 70 self._create_op = create_op 71 bucket_boundaries = get_bucket_boundaries(self._resource_handle, 72 self._num_streams) 73 slice_spec = '' 74 specs = [] 75 76 def make_save_spec(tensor, suffix): 77 return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix) 78 79 for i in range(self._num_streams): 80 specs += [ 81 make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i)) 82 ] 83 super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle, 84 specs, name) 85 86 def restore(self, restored_tensors, unused_tensor_shapes): 87 bucket_boundaries = restored_tensors 88 with ops.control_dependencies([self._create_op]): 89 return quantile_resource_deserialize( 90 self._resource_handle, bucket_boundaries=bucket_boundaries) 91 92 93class QuantileAccumulator(tracking.TrackableResource): 94 """SaveableObject implementation for QuantileAccumulator. 95 96 The bucket boundaries are serialized and deserialized from checkpointing. 97 """ 98 99 def __init__(self, 100 epsilon, 101 num_streams, 102 num_quantiles, 103 name=None, 104 max_elements=None): 105 self._eps = epsilon 106 self._num_streams = num_streams 107 self._num_quantiles = num_quantiles 108 super(QuantileAccumulator, self).__init__() 109 110 with ops.name_scope(name, 'QuantileAccumulator') as name: 111 self._name = name 112 self._resource_handle = self._create_resource() 113 self._init_op = self._initialize() 114 is_initialized_op = self.is_initialized() 115 resources.register_resource(self.resource_handle, self._init_op, 116 is_initialized_op) 117 self._saveable = QuantileAccumulatorSaveable( 118 self.resource_handle, self._init_op, self._num_streams, 119 self.resource_handle.name) 120 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) 121 122 def _create_resource(self): 123 return quantile_resource_handle_op( 124 container='', shared_name=self._name, name=self._name) 125 126 def _initialize(self): 127 return create_quantile_stream_resource(self.resource_handle, self._eps, 128 self._num_streams) 129 130 @property 131 def initializer(self): 132 if self._init_op is None: 133 self._init_op = self._initialize() 134 return self._init_op 135 136 def is_initialized(self): 137 return is_quantile_resource_initialized(self.resource_handle) 138 139 @property 140 def saveable(self): 141 return self._saveable 142 143 def _gather_saveables_for_checkpoint(self): 144 return {'quantile_accumulator', self._saveable} 145 146 def add_summaries(self, float_columns, example_weights): 147 summaries = make_quantile_summaries(float_columns, example_weights, 148 self._eps) 149 summary_op = quantile_add_summaries(self.resource_handle, summaries) 150 return summary_op 151 152 def flush(self): 153 return quantile_flush(self.resource_handle, self._num_quantiles) 154 155 def get_bucket_boundaries(self): 156 return get_bucket_boundaries(self.resource_handle, self._num_streams) 157 158 159class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject): 160 """SaveableObject implementation for TreeEnsemble.""" 161 162 def __init__(self, resource_handle, create_op, name): 163 """Creates a _TreeEnsembleSavable object. 164 165 Args: 166 resource_handle: handle to the decision tree ensemble variable. 167 create_op: the op to initialize the variable. 168 name: the name to save the tree ensemble variable under. 169 """ 170 stamp_token, serialized = ( 171 gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle)) 172 # slice_spec is useful for saving a slice from a variable. 173 # It's not meaningful the tree ensemble variable. So we just pass an empty 174 # value. 175 slice_spec = '' 176 specs = [ 177 saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec, 178 name + '_stamp'), 179 saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec, 180 name + '_serialized'), 181 ] 182 super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name) 183 self._resource_handle = resource_handle 184 self._create_op = create_op 185 186 def restore(self, restored_tensors, unused_restored_shapes): 187 """Restores the associated tree ensemble from 'restored_tensors'. 188 189 Args: 190 restored_tensors: the tensors that were loaded from a checkpoint. 191 unused_restored_shapes: the shapes this object should conform to after 192 restore. Not meaningful for trees. 193 194 Returns: 195 The operation that restores the state of the tree ensemble variable. 196 """ 197 with ops.control_dependencies([self._create_op]): 198 return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( 199 self._resource_handle, 200 stamp_token=restored_tensors[0], 201 tree_ensemble_serialized=restored_tensors[1]) 202 203 204class TreeEnsemble(tracking.TrackableResource): 205 """Creates TreeEnsemble resource.""" 206 207 def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''): 208 self._stamp_token = stamp_token 209 self._serialized_proto = serialized_proto 210 self._is_local = is_local 211 with ops.name_scope(name, 'TreeEnsemble') as name: 212 self._name = name 213 self._resource_handle = self._create_resource() 214 self._init_op = self._initialize() 215 is_initialized_op = self.is_initialized() 216 # Adds the variable to the savable list. 217 if not is_local: 218 self._saveable = _TreeEnsembleSavable( 219 self.resource_handle, self.initializer, self.resource_handle.name) 220 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) 221 resources.register_resource( 222 self.resource_handle, 223 self.initializer, 224 is_initialized_op, 225 is_shared=not is_local) 226 227 def _create_resource(self): 228 return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op( 229 container='', shared_name=self._name, name=self._name) 230 231 def _initialize(self): 232 return gen_boosted_trees_ops.boosted_trees_create_ensemble( 233 self.resource_handle, 234 self._stamp_token, 235 tree_ensemble_serialized=self._serialized_proto) 236 237 @property 238 def initializer(self): 239 if self._init_op is None: 240 self._init_op = self._initialize() 241 return self._init_op 242 243 def is_initialized(self): 244 return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized( 245 self.resource_handle) 246 247 def _gather_saveables_for_checkpoint(self): 248 if not self._is_local: 249 return {'tree_ensemble': self._saveable} 250 251 def get_stamp_token(self): 252 """Returns the current stamp token of the resource.""" 253 stamp_token, _, _, _, _ = ( 254 gen_boosted_trees_ops.boosted_trees_get_ensemble_states( 255 self.resource_handle)) 256 return stamp_token 257 258 def get_states(self): 259 """Returns states of the tree ensemble. 260 261 Returns: 262 stamp_token, num_trees, num_finalized_trees, num_attempted_layers and 263 range of the nodes in the latest layer. 264 """ 265 (stamp_token, num_trees, num_finalized_trees, num_attempted_layers, 266 nodes_range) = ( 267 gen_boosted_trees_ops.boosted_trees_get_ensemble_states( 268 self.resource_handle)) 269 # Use identity to give names. 270 return (array_ops.identity(stamp_token, name='stamp_token'), 271 array_ops.identity(num_trees, name='num_trees'), 272 array_ops.identity(num_finalized_trees, name='num_finalized_trees'), 273 array_ops.identity( 274 num_attempted_layers, name='num_attempted_layers'), 275 array_ops.identity(nodes_range, name='last_layer_nodes_range')) 276 277 def serialize(self): 278 """Serializes the ensemble into proto and returns the serialized proto. 279 280 Returns: 281 stamp_token: int64 scalar Tensor to denote the stamp of the resource. 282 serialized_proto: string scalar Tensor of the serialized proto. 283 """ 284 return gen_boosted_trees_ops.boosted_trees_serialize_ensemble( 285 self.resource_handle) 286 287 def deserialize(self, stamp_token, serialized_proto): 288 """Deserialize the input proto and resets the ensemble from it. 289 290 Args: 291 stamp_token: int64 scalar Tensor to denote the stamp of the resource. 292 serialized_proto: string scalar Tensor of the serialized proto. 293 294 Returns: 295 Operation (for dependencies). 296 """ 297 return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble( 298 self.resource_handle, stamp_token, serialized_proto) 299