1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Keras functions required by TensorFlow Lite.
17
18The functions defined in this library have been copied over from Keras in order
19to remove the dependency from TensorFlow Lite to Keras. The functions which
20could not be copied over are accessed using the dependency inversion principle.
21(for details, refer to tensorflow/python/util/keras_deps.py).
22"""
23
24from __future__ import absolute_import
25from __future__ import division
26from __future__ import print_function
27
28import copy
29
30from tensorflow.python.eager import def_function
31from tensorflow.python.util import keras_deps
32from tensorflow.python.util import nest
33from tensorflow.python.util.compat import collections_abc
34
35
36def _enforce_names_consistency(specs):
37  """Enforces that either all specs have names or none do."""
38
39  def _has_name(spec):
40    return hasattr(spec, 'name') and spec.name is not None
41
42  def _clear_name(spec):
43    spec = copy.deepcopy(spec)
44    if hasattr(spec, 'name'):
45      spec._name = None  # pylint:disable=protected-access
46    return spec
47
48  flat_specs = nest.flatten(specs)
49  name_inconsistency = (
50      any(_has_name(s) for s in flat_specs) and
51      not all(_has_name(s) for s in flat_specs))
52
53  if name_inconsistency:
54    specs = nest.map_structure(_clear_name, specs)
55  return specs
56
57
58def model_input_signature(model, keep_original_batch_size=False):
59  """Inspect model to get its input signature.
60
61  The model's input signature is a list with a single (possibly-nested) object.
62  This is due to the Keras-enforced restriction that tensor inputs must be
63  passed in as the first argument.
64
65  For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
66  will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
67
68  Args:
69    model: Keras Model object.
70    keep_original_batch_size: A boolean indicating whether we want to keep using
71      the original batch size or set it to None. Default is `False`, which means
72      that the batch dim of the returned input signature will always be set to
73      `None`.
74
75  Returns:
76    A list containing either a single TensorSpec or an object with nested
77    TensorSpecs. This list does not contain the `training` argument.
78  """
79  input_specs = model._get_save_spec(dynamic_batch=not keep_original_batch_size)  # pylint: disable=protected-access
80  if input_specs is None:
81    return None
82  input_specs = _enforce_names_consistency(input_specs)
83  # Return a list with a single element as the model's input signature.
84  if isinstance(input_specs,
85                collections_abc.Sequence) and len(input_specs) == 1:
86    # Note that the isinstance check filters out single-element dictionaries,
87    # which should also be wrapped as a single-element list.
88    return input_specs
89  else:
90    return [input_specs]
91
92
93def raise_model_input_error(model):
94  raise ValueError(
95      'Model {} cannot be saved because the input shapes have not been '
96      'set. Usually, input shapes are automatically determined from calling'
97      ' `.fit()` or `.predict()`. To manually set the shapes, call '
98      '`model.build(input_shape)`.'.format(model))
99
100
101def _create_pseudo_names(tensors, prefix):
102  """Creates pseudo {input | output} names for subclassed Models.
103
104  Warning: this function should only be used to define default
105  names for `Metics` and `SavedModel`. No other use cases should
106  rely on a `Model`'s input or output names.
107
108  Example with dict:
109
110  `{'a': [x1, x2], 'b': x3}` becomes:
111  `['a_1', 'a_2', 'b']`
112
113  Example with list:
114
115  `[x, y]` becomes:
116  `['output_1', 'output_2']`
117
118  Args:
119    tensors: `Model`'s outputs or inputs.
120    prefix: 'output_' for outputs, 'input_' for inputs.
121
122  Returns:
123    Flattened list of pseudo names.
124  """
125
126  def one_index(ele):
127    # Start with "output_1" instead of "output_0".
128    if isinstance(ele, int):
129      return ele + 1
130    return ele
131
132  flat_paths = list(nest.yield_flat_paths(tensors))
133  flat_paths = nest.map_structure(one_index, flat_paths)
134  names = []
135  for path in flat_paths:
136    if not path:
137      name = prefix + '1'  # Single output.
138    else:
139      name = '_'.join(str(p) for p in path)
140      if isinstance(path[0], int):
141        name = prefix + name
142    names.append(name)
143  return names
144
145
146def create_pseudo_output_names(outputs):
147  """Create pseudo output names for a subclassed Model."""
148  return _create_pseudo_names(outputs, prefix='output_')
149
150
151def trace_model_call(model, input_signature=None):
152  """Trace the model call to create a tf.function for exporting a Keras model.
153
154  Args:
155    model: A Keras model.
156    input_signature: optional, a list of tf.TensorSpec objects specifying the
157      inputs to the model.
158
159  Returns:
160    A tf.function wrapping the model's call function with input signatures set.
161
162  Raises:
163    ValueError: if input signature cannot be inferred from the model.
164  """
165  if input_signature is None:
166    if isinstance(model.call, def_function.Function):
167      input_signature = model.call.input_signature
168
169  if input_signature is None:
170    input_signature = model_input_signature(model)
171
172  if input_signature is None:
173    raise_model_input_error(model)
174
175  @def_function.function(input_signature=input_signature, autograph=False)
176  def _wrapped_model(*args):
177    """A concrete tf.function that wraps the model's call function."""
178    # When given a single input, Keras models will call the model on the tensor
179    # rather than a list consisting of the single tensor.
180    inputs = args[0] if len(input_signature) == 1 else list(args)
181
182    with keras_deps.get_call_context_function()().enter(
183        model, inputs=inputs, build_graph=False, training=False, saving=True):
184      outputs = model(inputs, training=False)
185
186    return outputs
187
188  return _wrapped_model
189