1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ops for boosted_trees."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import gen_boosted_trees_ops
23from tensorflow.python.ops import resources
24
25# Re-exporting ops used by other modules.
26# pylint: disable=unused-import
27from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_bucketize
28from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_calculate_best_gains_per_feature as calculate_best_gains_per_feature
29from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_center_bias as center_bias
30from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_create_quantile_stream_resource as create_quantile_stream_resource
31from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_example_debug_outputs as example_debug_outputs
32from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_quantile_summaries as make_quantile_summaries
33from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_make_stats_summary as make_stats_summary
34from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_predict as predict
35from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_add_summaries as quantile_add_summaries
36from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_deserialize as quantile_resource_deserialize
37from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_flush as quantile_flush
38from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_get_bucket_boundaries as get_bucket_boundaries
39from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_quantile_stream_resource_handle_op as quantile_resource_handle_op
40from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_training_predict as training_predict
41from tensorflow.python.ops.gen_boosted_trees_ops import boosted_trees_update_ensemble as update_ensemble
42from tensorflow.python.ops.gen_boosted_trees_ops import is_boosted_trees_quantile_stream_resource_initialized as is_quantile_resource_initialized
43# pylint: enable=unused-import
44
45from tensorflow.python.training import saver
46from tensorflow.python.training.tracking import tracking
47
48
49class PruningMode(object):
50  """Class for working with Pruning modes."""
51  NO_PRUNING, PRE_PRUNING, POST_PRUNING = range(0, 3)
52
53  _map = {'none': NO_PRUNING, 'pre': PRE_PRUNING, 'post': POST_PRUNING}
54
55  @classmethod
56  def from_str(cls, mode):
57    if mode in cls._map:
58      return cls._map[mode]
59    else:
60      raise ValueError('pruning_mode mode must be one of: {}'.format(', '.join(
61          sorted(cls._map))))
62
63
64class QuantileAccumulatorSaveable(saver.BaseSaverBuilder.SaveableObject):
65  """SaveableObject implementation for QuantileAccumulator."""
66
67  def __init__(self, resource_handle, create_op, num_streams, name):
68    self._resource_handle = resource_handle
69    self._num_streams = num_streams
70    self._create_op = create_op
71    bucket_boundaries = get_bucket_boundaries(self._resource_handle,
72                                              self._num_streams)
73    slice_spec = ''
74    specs = []
75
76    def make_save_spec(tensor, suffix):
77      return saver.BaseSaverBuilder.SaveSpec(tensor, slice_spec, name + suffix)
78
79    for i in range(self._num_streams):
80      specs += [
81          make_save_spec(bucket_boundaries[i], '_bucket_boundaries_' + str(i))
82      ]
83    super(QuantileAccumulatorSaveable, self).__init__(self._resource_handle,
84                                                      specs, name)
85
86  def restore(self, restored_tensors, unused_tensor_shapes):
87    bucket_boundaries = restored_tensors
88    with ops.control_dependencies([self._create_op]):
89      return quantile_resource_deserialize(
90          self._resource_handle, bucket_boundaries=bucket_boundaries)
91
92
93class QuantileAccumulator(tracking.TrackableResource):
94  """SaveableObject implementation for QuantileAccumulator.
95
96     The bucket boundaries are serialized and deserialized from checkpointing.
97  """
98
99  def __init__(self,
100               epsilon,
101               num_streams,
102               num_quantiles,
103               name=None,
104               max_elements=None):
105    self._eps = epsilon
106    self._num_streams = num_streams
107    self._num_quantiles = num_quantiles
108    super(QuantileAccumulator, self).__init__()
109
110    with ops.name_scope(name, 'QuantileAccumulator') as name:
111      self._name = name
112      self._resource_handle = self._create_resource()
113      self._init_op = self._initialize()
114      is_initialized_op = self.is_initialized()
115    resources.register_resource(self.resource_handle, self._init_op,
116                                is_initialized_op)
117    self._saveable = QuantileAccumulatorSaveable(
118        self.resource_handle, self._init_op, self._num_streams,
119        self.resource_handle.name)
120    ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
121
122  def _create_resource(self):
123    return quantile_resource_handle_op(
124        container='', shared_name=self._name, name=self._name)
125
126  def _initialize(self):
127    return create_quantile_stream_resource(self.resource_handle, self._eps,
128                                           self._num_streams)
129
130  @property
131  def initializer(self):
132    if self._init_op is None:
133      self._init_op = self._initialize()
134    return self._init_op
135
136  def is_initialized(self):
137    return is_quantile_resource_initialized(self.resource_handle)
138
139  @property
140  def saveable(self):
141    return self._saveable
142
143  def _gather_saveables_for_checkpoint(self):
144    return {'quantile_accumulator', self._saveable}
145
146  def add_summaries(self, float_columns, example_weights):
147    summaries = make_quantile_summaries(float_columns, example_weights,
148                                        self._eps)
149    summary_op = quantile_add_summaries(self.resource_handle, summaries)
150    return summary_op
151
152  def flush(self):
153    return quantile_flush(self.resource_handle, self._num_quantiles)
154
155  def get_bucket_boundaries(self):
156    return get_bucket_boundaries(self.resource_handle, self._num_streams)
157
158
159class _TreeEnsembleSavable(saver.BaseSaverBuilder.SaveableObject):
160  """SaveableObject implementation for TreeEnsemble."""
161
162  def __init__(self, resource_handle, create_op, name):
163    """Creates a _TreeEnsembleSavable object.
164
165    Args:
166      resource_handle: handle to the decision tree ensemble variable.
167      create_op: the op to initialize the variable.
168      name: the name to save the tree ensemble variable under.
169    """
170    stamp_token, serialized = (
171        gen_boosted_trees_ops.boosted_trees_serialize_ensemble(resource_handle))
172    # slice_spec is useful for saving a slice from a variable.
173    # It's not meaningful the tree ensemble variable. So we just pass an empty
174    # value.
175    slice_spec = ''
176    specs = [
177        saver.BaseSaverBuilder.SaveSpec(stamp_token, slice_spec,
178                                        name + '_stamp'),
179        saver.BaseSaverBuilder.SaveSpec(serialized, slice_spec,
180                                        name + '_serialized'),
181    ]
182    super(_TreeEnsembleSavable, self).__init__(resource_handle, specs, name)
183    self._resource_handle = resource_handle
184    self._create_op = create_op
185
186  def restore(self, restored_tensors, unused_restored_shapes):
187    """Restores the associated tree ensemble from 'restored_tensors'.
188
189    Args:
190      restored_tensors: the tensors that were loaded from a checkpoint.
191      unused_restored_shapes: the shapes this object should conform to after
192        restore. Not meaningful for trees.
193
194    Returns:
195      The operation that restores the state of the tree ensemble variable.
196    """
197    with ops.control_dependencies([self._create_op]):
198      return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
199          self._resource_handle,
200          stamp_token=restored_tensors[0],
201          tree_ensemble_serialized=restored_tensors[1])
202
203
204class TreeEnsemble(tracking.TrackableResource):
205  """Creates TreeEnsemble resource."""
206
207  def __init__(self, name, stamp_token=0, is_local=False, serialized_proto=''):
208    self._stamp_token = stamp_token
209    self._serialized_proto = serialized_proto
210    self._is_local = is_local
211    with ops.name_scope(name, 'TreeEnsemble') as name:
212      self._name = name
213      self._resource_handle = self._create_resource()
214      self._init_op = self._initialize()
215      is_initialized_op = self.is_initialized()
216      # Adds the variable to the savable list.
217      if not is_local:
218        self._saveable = _TreeEnsembleSavable(
219            self.resource_handle, self.initializer, self.resource_handle.name)
220        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable)
221      resources.register_resource(
222          self.resource_handle,
223          self.initializer,
224          is_initialized_op,
225          is_shared=not is_local)
226
227  def _create_resource(self):
228    return gen_boosted_trees_ops.boosted_trees_ensemble_resource_handle_op(
229        container='', shared_name=self._name, name=self._name)
230
231  def _initialize(self):
232    return gen_boosted_trees_ops.boosted_trees_create_ensemble(
233        self.resource_handle,
234        self._stamp_token,
235        tree_ensemble_serialized=self._serialized_proto)
236
237  @property
238  def initializer(self):
239    if self._init_op is None:
240      self._init_op = self._initialize()
241    return self._init_op
242
243  def is_initialized(self):
244    return gen_boosted_trees_ops.is_boosted_trees_ensemble_initialized(
245        self.resource_handle)
246
247  def _gather_saveables_for_checkpoint(self):
248    if not self._is_local:
249      return {'tree_ensemble': self._saveable}
250
251  def get_stamp_token(self):
252    """Returns the current stamp token of the resource."""
253    stamp_token, _, _, _, _ = (
254        gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
255            self.resource_handle))
256    return stamp_token
257
258  def get_states(self):
259    """Returns states of the tree ensemble.
260
261    Returns:
262      stamp_token, num_trees, num_finalized_trees, num_attempted_layers and
263      range of the nodes in the latest layer.
264    """
265    (stamp_token, num_trees, num_finalized_trees, num_attempted_layers,
266     nodes_range) = (
267         gen_boosted_trees_ops.boosted_trees_get_ensemble_states(
268             self.resource_handle))
269    # Use identity to give names.
270    return (array_ops.identity(stamp_token, name='stamp_token'),
271            array_ops.identity(num_trees, name='num_trees'),
272            array_ops.identity(num_finalized_trees, name='num_finalized_trees'),
273            array_ops.identity(
274                num_attempted_layers, name='num_attempted_layers'),
275            array_ops.identity(nodes_range, name='last_layer_nodes_range'))
276
277  def serialize(self):
278    """Serializes the ensemble into proto and returns the serialized proto.
279
280    Returns:
281      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
282      serialized_proto: string scalar Tensor of the serialized proto.
283    """
284    return gen_boosted_trees_ops.boosted_trees_serialize_ensemble(
285        self.resource_handle)
286
287  def deserialize(self, stamp_token, serialized_proto):
288    """Deserialize the input proto and resets the ensemble from it.
289
290    Args:
291      stamp_token: int64 scalar Tensor to denote the stamp of the resource.
292      serialized_proto: string scalar Tensor of the serialized proto.
293
294    Returns:
295      Operation (for dependencies).
296    """
297    return gen_boosted_trees_ops.boosted_trees_deserialize_ensemble(
298        self.resource_handle, stamp_token, serialized_proto)
299