1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Contains the base ProcessingLayer and a subclass that uses Combiners."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import abc
21import collections
22
23import numpy as np
24import six
25
26from tensorflow.python.eager import context
27from tensorflow.python.eager import def_function
28from tensorflow.python.eager import monitoring
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import sparse_tensor
32from tensorflow.python.keras import backend as K
33from tensorflow.python.keras.engine import data_adapter
34from tensorflow.python.keras.engine.base_layer import Layer
35from tensorflow.python.keras.utils import tf_utils
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import sparse_ops
38from tensorflow.python.ops import variables
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.training.tracking import base as trackable
41from tensorflow.python.util.tf_export import keras_export
42
43
44keras_kpl_gauge = monitoring.BoolGauge(
45    '/tensorflow/api/keras/layers/preprocessing',
46    'keras preprocessing layers usage', 'method')
47
48
49@keras_export('keras.layers.experimental.preprocessing.PreprocessingLayer')
50@six.add_metaclass(abc.ABCMeta)
51class PreprocessingLayer(Layer):
52  """Base class for PreprocessingLayers.
53
54  Attributes:
55    stateful: Whether the layer contains state that needs to be adapted via
56      `PreprocessingLayer.adapt`.
57    streaming: Whether a layer can be adapted multiple times without resetting
58      the state of the layer.
59  """
60  _must_restore_from_config = True
61
62  def __init__(self, stateful=False, streaming=True, **kwargs):
63    super(PreprocessingLayer, self).__init__(**kwargs)
64    self._stateful = stateful
65    self._streaming = streaming
66    self._is_compiled = False
67    self._is_adapted = False
68
69    # Sets `is_adapted=False` when `reset_state` is called.
70    self._reset_state_impl = self.reset_state
71    self.reset_state = self._reset_state_wrapper
72
73    self._adapt_function = None
74
75  @property
76  def streaming(self):
77    """Whether `adapt` can be called twice without resetting the state."""
78    return self._streaming
79
80  @property
81  def is_adapted(self):
82    """Whether the layer has been fit to data already."""
83    return self._is_adapted
84
85  def update_state(self, data):
86    """Accumulates statistics for the preprocessing layer.
87
88    Arguments:
89      data: A mini-batch of inputs to the layer.
90    """
91    if self.stateful:
92      raise NotImplementedError
93
94  def reset_state(self):
95    """Resets the statistics of the preprocessing layer."""
96    if self.stateful:
97      raise NotImplementedError
98
99  def merge_state(self, layers):
100    """Merge the statistics of multiple preprocessing layers.
101
102    This layer will contain the merged state.
103
104    Arguments:
105      layers: Layers whose statistics should be merge with the statistics of
106        this layer.
107    """
108    if self.stateful:
109      raise NotImplementedError
110
111  def finalize_state(self):
112    """Finalize the statistics for the preprocessing layer.
113
114    This method is called at the end of `adapt`. This method
115    handles any one-time operations that should occur after all
116    data has been seen.
117    """
118    pass
119
120  def make_adapt_function(self):
121    """Creates a function to execute one step of `adapt`.
122
123    This method can be overridden to support custom adapt logic.
124    This method is called by `PreprocessingLayer.adapt`.
125
126    Typically, this method directly controls `tf.function` settings,
127    and delegates the actual state update logic to
128    `PreprocessingLayer.update_state`.
129
130    This function is cached the first time `PreprocessingLayer.adapt`
131    is called. The cache is cleared whenever `PreprocessingLayer.compile`
132    is called.
133
134    Returns:
135      Function. The function created by this method should accept a
136      `tf.data.Iterator`, retrieve a batch, and update the state of the
137      layer.
138    """
139    if self._adapt_function is not None:
140      return self._adapt_function
141
142    def adapt_step(iterator):
143      data = next(iterator)
144      self._adapt_maybe_build(data)
145      self.update_state(data)
146
147    if self._steps_per_execution.numpy().item() == 1:
148      adapt_fn = adapt_step
149    else:
150
151      def adapt_fn(iterator):
152        for _ in math_ops.range(self._steps_per_execution):
153          adapt_step(iterator)
154
155    if not self._run_eagerly:
156      adapt_fn = def_function.function(adapt_fn)
157
158    self._adapt_function = adapt_fn
159    return self._adapt_function
160
161  def compile(self, run_eagerly=None, steps_per_execution=None):
162    """Configures the layer for `adapt`.
163
164    Arguments:
165      run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s logic
166        will not be wrapped in a `tf.function`. Recommended to leave this as
167        `None` unless your `Model` cannot be run inside a `tf.function`.
168        steps_per_execution: Int. Defaults to 1. The number of batches to run
169          during each `tf.function` call. Running multiple batches inside a
170          single `tf.function` call can greatly improve performance on TPUs or
171          small models with a large Python overhead.
172    """
173    if steps_per_execution is None:
174      steps_per_execution = 1
175    self._configure_steps_per_execution(steps_per_execution)
176
177    if run_eagerly is None:
178      run_eagerly = self.dynamic
179    self._run_eagerly = run_eagerly
180
181    self._is_compiled = True
182
183  def adapt(self, data, batch_size=None, steps=None, reset_state=True):
184    """Fits the state of the preprocessing layer to the data being passed.
185
186    Arguments:
187        data: The data to train on. It can be passed either as a tf.data
188          Dataset, or as a numpy array.
189        batch_size: Integer or `None`.
190            Number of samples per state update.
191            If unspecified, `batch_size` will default to 32.
192            Do not specify the `batch_size` if your data is in the
193            form of datasets, generators, or `keras.utils.Sequence` instances
194            (since they generate batches).
195        steps: Integer or `None`.
196            Total number of steps (batches of samples)
197            When training with input tensors such as
198            TensorFlow data tensors, the default `None` is equal to
199            the number of samples in your dataset divided by
200            the batch size, or 1 if that cannot be determined. If x is a
201            `tf.data` dataset, and 'steps' is None, the epoch will run until
202            the input dataset is exhausted. When passing an infinitely
203            repeating dataset, you must specify the `steps` argument. This
204            argument is not supported with array inputs.
205        reset_state: Optional argument specifying whether to clear the state of
206          the layer at the start of the call to `adapt`, or whether to start
207          from the existing state. This argument may not be relevant to all
208          preprocessing layers: a subclass of PreprocessingLayer may choose to
209          throw if 'reset_state' is set to False.
210    """
211    _disallow_inside_tf_function('adapt')
212    if not self.stateful:
213      return
214    if not self.streaming and self._is_adapted and not reset_state:
215      raise ValueError('{} does not supporting calling `adapt` twice without '
216                       'resetting the state.'.format(self.__class__.__name__))
217    if not self._is_compiled:
218      self.compile()  # Compile with defaults.
219    if self.built and reset_state:
220      self.reset_state()
221    data_handler = data_adapter.DataHandler(
222        data,
223        batch_size=batch_size,
224        steps_per_epoch=steps,
225        epochs=1,
226        steps_per_execution=self._steps_per_execution,
227        distribute=False)
228    self._adapt_function = self.make_adapt_function()
229    for _, iterator in data_handler.enumerate_epochs():
230      with data_handler.catch_stop_iteration():
231        for _ in data_handler.steps():
232          self._adapt_function(iterator)
233          if data_handler.should_sync:
234            context.async_wait()
235    self.finalize_state()
236    self._is_adapted = True
237
238  def _reset_state_wrapper(self):
239    """Calls `reset_state` and sets `adapted` to `False`."""
240    self._reset_state_impl()
241    self._is_adapted = False
242
243  @trackable.no_automatic_dependency_tracking
244  def _configure_steps_per_execution(self, steps_per_execution):
245    self._steps_per_execution = variables.Variable(
246        steps_per_execution,
247        dtype='int64',
248        aggregation=variables.VariableAggregationV2.ONLY_FIRST_REPLICA)
249
250  # TODO(omalleyt): Unify this logic with `Layer._maybe_build`.
251  def _adapt_maybe_build(self, data):
252    if not self.built:
253      try:
254        # If this is a Numpy array or tensor, we can get shape from .shape.
255        # If not, an attribute error will be thrown.
256        data_shape = data.shape
257        data_shape_nones = tuple([None] * len(data.shape))
258      except AttributeError:
259        # The input has an unknown number of dimensions.
260        data_shape = None
261        data_shape_nones = None
262
263      # TODO (b/159261555): move this to base layer build.
264      batch_input_shape = getattr(self, '_batch_input_shape', None)
265      if batch_input_shape is None:
266        # Set the number of dimensions.
267        self._batch_input_shape = data_shape_nones
268      self.build(data_shape)
269      self.built = True
270
271
272# TODO(omalleyt): This class will be gradually replaced.
273class CombinerPreprocessingLayer(PreprocessingLayer):
274  """Base class for PreprocessingLayers that do computation using a Combiner.
275
276  This class provides several helper methods to make creating a
277  PreprocessingLayer easier. It assumes that the core of your computation will
278  be done via a Combiner object. Subclassing this class to create a
279  PreprocessingLayer allows your layer to be compatible with distributed
280  computation.
281
282  This class is compatible with Tensorflow 2.0+.
283  """
284
285  def __init__(self, combiner, **kwargs):
286    super(CombinerPreprocessingLayer, self).__init__(stateful=True, **kwargs)
287    self.state_variables = collections.OrderedDict()
288    self._combiner = combiner
289    self._adapt_accumulator = None
290
291  def reset_state(self):
292    self._adapt_accumulator = None
293
294  def update_state(self, data):
295    if self._adapt_accumulator is None:
296      self._adapt_accumulator = self._get_accumulator()
297    self._adapt_accumulator = self._combiner.compute(data,
298                                                     self._adapt_accumulator)
299
300  def merge_state(self, layers):
301    accumulators = ([self._get_accumulator()] +
302                    [l._get_accumulator() for l in layers])  # pylint: disable=protected-access
303    merged_accumulator = self._combiner.merge(accumulators)
304    self._set_accumulator(merged_accumulator)
305
306  def finalize_state(self):
307    self._set_accumulator(self._adapt_accumulator)
308
309  def compile(self, run_eagerly=None, steps_per_execution=None):
310    # TODO(omalleyt): Remove this once sublayers are switched to new APIs.
311    if run_eagerly is None:
312      run_eagerly = True
313    super(CombinerPreprocessingLayer, self).compile(
314        run_eagerly=run_eagerly, steps_per_execution=steps_per_execution)
315
316  def adapt(self, data, batch_size=None, steps=None, reset_state=True):
317    if not reset_state:
318      self._adapt_accumulator = self._combiner.restore(self._restore_updates())
319    super(CombinerPreprocessingLayer, self).adapt(
320        data, batch_size=batch_size, steps=steps, reset_state=reset_state)
321
322  def _add_state_variable(self,
323                          name,
324                          shape,
325                          dtype,
326                          initializer=None,
327                          partitioner=None,
328                          use_resource=None,
329                          **kwargs):
330    """Add a variable that can hold state which is updated during adapt().
331
332    Args:
333      name: Variable name.
334      shape: Variable shape. Defaults to scalar if unspecified.
335      dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
336      initializer: initializer instance (callable).
337      partitioner: Partitioner to be passed to the `Trackable` API.
338      use_resource: Whether to use `ResourceVariable`
339      **kwargs: Additional keyword arguments. Accepted values are `getter` and
340        `collections`.
341
342    Returns:
343      The created variable.
344    """
345    weight = self.add_weight(
346        name=name,
347        shape=shape,
348        dtype=dtype,
349        initializer=initializer,
350        regularizer=None,
351        trainable=False,
352        constraint=None,
353        partitioner=partitioner,
354        use_resource=use_resource,
355        **kwargs)
356    # TODO(momernick): Do not allow collisions here.
357    self.state_variables[name] = weight
358    return weight
359
360  def _restore_updates(self):
361    """Recreates a dict of updates from the layer's weights."""
362    data_dict = {}
363    for name, var in self.state_variables.items():
364      data_dict[name] = var.numpy()
365    return data_dict
366
367  def _get_accumulator(self):
368    if self._is_adapted:
369      return self._combiner.restore(self._restore_updates())
370    else:
371      return None
372
373  def _set_accumulator(self, accumulator):
374    updates = self._combiner.extract(accumulator)
375    self._set_state_variables(updates)
376    self._adapt_accumulator = None  # Reset accumulator from adapt.
377
378  def _set_state_variables(self, updates):
379    """Directly update the internal state of this Layer.
380
381    This method expects a string-keyed dict of {state_variable_name: state}. The
382    precise nature of the state, and the names associated, are describe by
383    the subclasses of CombinerPreprocessingLayer.
384
385    Args:
386      updates: A string keyed dict of weights to update.
387
388    Raises:
389      RuntimeError: if 'build()' was not called before 'set_processing_state'.
390    """
391    # TODO(momernick): Do we need to do any more input sanitization?
392    if not self.built:
393      raise RuntimeError('_set_state_variables() must be called after build().')
394
395    with ops.init_scope():
396      for var_name, value in updates.items():
397        self.state_variables[var_name].assign(value)
398
399
400def convert_to_list(values, sparse_default_value=None):
401  """Convert a TensorLike, CompositeTensor, or ndarray into a Python list."""
402  if tf_utils.is_ragged(values):
403    # There is a corner case when dealing with ragged tensors: if you get an
404    # actual RaggedTensor (not a RaggedTensorValue) passed in non-eager mode,
405    # you can't call to_list() on it without evaluating it first. However,
406    # because we don't yet fully support composite tensors across Keras,
407    # K.get_value() won't evaluate the tensor.
408    # TODO(momernick): Get Keras to recognize composite tensors as Tensors
409    # and then replace this with a call to K.get_value.
410    if (isinstance(values, ragged_tensor.RaggedTensor) and
411        not context.executing_eagerly()):
412      values = K.get_session(values).run(values)
413    values = values.to_list()
414
415  if isinstance(values,
416                (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
417    if sparse_default_value is None:
418      if dtypes.as_dtype(values.values.dtype) == dtypes.string:
419        sparse_default_value = ''
420      else:
421        sparse_default_value = -1
422    dense_tensor = sparse_ops.sparse_tensor_to_dense(
423        values, default_value=sparse_default_value)
424    values = K.get_value(dense_tensor)
425
426  if isinstance(values, ops.Tensor):
427    values = K.get_value(values)
428
429  # We may get passed a ndarray or the code above may give us a ndarray.
430  # In either case, we want to force it into a standard python list.
431  if isinstance(values, np.ndarray):
432    values = values.tolist()
433
434  return values
435
436
437# TODO(omalleyt): This class will be gradually replaced.
438class Combiner(object):
439  """Functional object that defines a shardable computation.
440
441  This object defines functions required to create and manipulate data objects.
442  These data objects, referred to below as 'accumulators', are computation-
443  specific and may be implemented alongside concrete subclasses of Combiner
444  (if necessary - some computations may be simple enough that standard Python
445  types can be used as accumulators).
446
447  The intent for this class is that by describing computations in this way, we
448  can arbitrarily shard a dataset, perform computations on a subset, and then
449  merge the computation into a final result. This enables distributed
450  computation.
451
452  The combiner itself does not own any state - all computational state is owned
453  by the accumulator objects. This is so that we can have an arbitrary number of
454  Combiners (thus sharding the computation N ways) without risking any change
455  to the underlying computation. These accumulator objects are uniquely
456  associated with each Combiner; a Combiner defines what the accumulator object
457  should be and will only work with accumulators of that type.
458  """
459  __metaclass__ = abc.ABCMeta
460
461  def __repr__(self):
462    return '<{}>'.format(self.__class__.__name__)
463
464  @abc.abstractmethod
465  def compute(self, batch_values, accumulator=None):
466    """Compute a step in this computation, returning a new accumulator.
467
468    This method computes a step of the computation described by this Combiner.
469    If an accumulator is passed, the data in that accumulator is also used; so
470    compute(batch_values) results in f(batch_values), while
471    compute(batch_values, accumulator) results in
472    merge(f(batch_values), accumulator).
473
474    Args:
475      batch_values: A list of ndarrays representing the values of the inputs for
476        this step of the computation.
477      accumulator: the current accumulator. Can be None.
478
479    Returns:
480      An accumulator that includes the passed batch of inputs.
481    """
482    pass
483
484  @abc.abstractmethod
485  def merge(self, accumulators):
486    """Merge several accumulators to a single accumulator.
487
488    This method takes the partial values in several accumulators and combines
489    them into a single accumulator. This computation must not be order-specific
490    (that is, merge([a, b]) must return the same result as merge([b, a]).
491
492    Args:
493      accumulators: the accumulators to merge, as a list.
494
495    Returns:
496      A merged accumulator.
497    """
498    pass
499
500  @abc.abstractmethod
501  def extract(self, accumulator):
502    """Convert an accumulator into a dict of output values.
503
504    Args:
505      accumulator: The accumulator to convert.
506
507    Returns:
508      A dict of ndarrays representing the data in this accumulator.
509    """
510    pass
511
512  @abc.abstractmethod
513  def restore(self, output):
514    """Create an accumulator based on 'output'.
515
516    This method creates a new accumulator with identical internal state to the
517    one used to create the data in 'output'. This means that if you do
518
519    output_data = combiner.extract(accumulator_1)
520    accumulator_2 = combiner.restore(output_data)
521
522    then accumulator_1 and accumulator_2 will have identical internal state, and
523    computations using either of them will be equivalent.
524
525    Args:
526      output: The data output from a previous computation. Should be in the same
527        form as provided by 'extract_output'.
528
529    Returns:
530      A new accumulator.
531    """
532    pass
533
534  @abc.abstractmethod
535  def serialize(self, accumulator):
536    """Serialize an accumulator for a remote call.
537
538    This function serializes an accumulator to be sent to a remote process.
539
540    Args:
541      accumulator: The accumulator to serialize.
542
543    Returns:
544      A byte string representing the passed accumulator.
545    """
546    pass
547
548  @abc.abstractmethod
549  def deserialize(self, encoded_accumulator):
550    """Deserialize an accumulator received from 'serialize()'.
551
552    This function deserializes an accumulator serialized by 'serialize()'.
553
554    Args:
555      encoded_accumulator: A byte string representing an accumulator.
556
557    Returns:
558      The accumulator represented by the passed byte_string.
559    """
560    pass
561
562
563def _disallow_inside_tf_function(method_name):
564  """Disallow calling a method inside a `tf.function`."""
565  if ops.inside_function():
566    error_msg = (
567        'Detected a call to `PreprocessingLayer.{method_name}` inside a '
568        '`tf.function`. `PreprocessingLayer.{method_name} is a high-level '
569        'endpoint that manages its own `tf.function`. Please move the call '
570        'to `PreprocessingLayer.{method_name}` outside of all enclosing '
571        '`tf.function`s. Note that you can call a `PreprocessingLayer` '
572        'directly on `Tensor`s inside a `tf.function` like: `layer(x)`, '
573        'or update its state like: `layer.update_state(x)`.').format(
574            method_name=method_name)
575    raise RuntimeError(error_msg)
576