1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the InputSpec class."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from six.moves import zip  # pylint: disable=redefined-builtin
22
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import tensor_shape
25from tensorflow.python.framework import tensor_spec
26from tensorflow.python.keras import backend
27from tensorflow.python.util import nest
28from tensorflow.python.util.tf_export import keras_export
29from tensorflow.python.util.tf_export import tf_export
30
31
32@keras_export('keras.layers.InputSpec')
33@tf_export(v1=['layers.InputSpec'])
34class InputSpec(object):
35  """Specifies the rank, dtype and shape of every input to a layer.
36
37  Layers can expose (if appropriate) an `input_spec` attribute:
38  an instance of `InputSpec`, or a nested structure of `InputSpec` instances
39  (one per input tensor). These objects enable the layer to run input
40  compatibility checks for input structure, input rank, input shape, and
41  input dtype.
42
43  A None entry in a shape is compatible with any dimension,
44  a None shape is compatible with any shape.
45
46  Args:
47    dtype: Expected DataType of the input.
48    shape: Shape tuple, expected shape of the input
49      (may include None for unchecked axes). Includes the batch size.
50    ndim: Integer, expected rank of the input.
51    max_ndim: Integer, maximum rank of the input.
52    min_ndim: Integer, minimum rank of the input.
53    axes: Dictionary mapping integer axes to
54      a specific dimension value.
55    allow_last_axis_squeeze: If True, then allow inputs of rank N+1 as long
56      as the last axis of the input is 1, as well as inputs of rank N-1
57      as long as the last axis of the spec is 1.
58    name: Expected key corresponding to this input when passing data as
59      a dictionary.
60
61  Example:
62
63  ```python
64  class MyLayer(Layer):
65      def __init__(self):
66          super(MyLayer, self).__init__()
67          # The layer will accept inputs with shape (?, 28, 28) & (?, 28, 28, 1)
68          # and raise an appropriate error message otherwise.
69          self.input_spec = InputSpec(
70              shape=(None, 28, 28, 1),
71              allow_last_axis_squeeze=True)
72  ```
73  """
74
75  def __init__(self,
76               dtype=None,
77               shape=None,
78               ndim=None,
79               max_ndim=None,
80               min_ndim=None,
81               axes=None,
82               allow_last_axis_squeeze=False,
83               name=None):
84    self.dtype = dtypes.as_dtype(dtype).name if dtype is not None else None
85    shape = tensor_shape.TensorShape(shape)
86    if shape.rank is None:
87      shape = None
88    else:
89      shape = tuple(shape.as_list())
90    if shape is not None:
91      self.ndim = len(shape)
92      self.shape = shape
93    else:
94      self.ndim = ndim
95      self.shape = None
96    self.max_ndim = max_ndim
97    self.min_ndim = min_ndim
98    self.name = name
99    self.allow_last_axis_squeeze = allow_last_axis_squeeze
100    try:
101      axes = axes or {}
102      self.axes = {int(k): axes[k] for k in axes}
103    except (ValueError, TypeError):
104      raise TypeError('The keys in axes must be integers.')
105
106    if self.axes and (self.ndim is not None or self.max_ndim is not None):
107      max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
108      max_axis = max(self.axes)
109      if max_axis > max_dim:
110        raise ValueError('Axis {} is greater than the maximum allowed value: {}'
111                         .format(max_axis, max_dim))
112
113  def __repr__(self):
114    spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
115            ('shape=' + str(self.shape)) if self.shape else '',
116            ('ndim=' + str(self.ndim)) if self.ndim else '',
117            ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
118            ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
119            ('axes=' + str(self.axes)) if self.axes else '']
120    return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
121
122  def get_config(self):
123    return {
124        'dtype': self.dtype,
125        'shape': self.shape,
126        'ndim': self.ndim,
127        'max_ndim': self.max_ndim,
128        'min_ndim': self.min_ndim,
129        'axes': self.axes}
130
131  @classmethod
132  def from_config(cls, config):
133    return cls(**config)
134
135
136def to_tensor_shape(spec):
137  """Returns a tf.TensorShape object that matches the shape specifications.
138
139  If the InputSpec's shape or ndim is defined, this method will return a fully
140  or partially-known shape. Otherwise, the returned TensorShape is None.
141
142  Args:
143    spec: an InputSpec object.
144
145  Returns:
146    a tf.TensorShape object
147  """
148  if spec.ndim is None and spec.shape is None:
149    return tensor_shape.TensorShape(None)
150  elif spec.shape is not None:
151    return tensor_shape.TensorShape(spec.shape)
152  else:
153    shape = [None] * spec.ndim
154    for a in spec.axes:
155      shape[a] = spec.axes[a]  # Assume that axes is defined
156    return tensor_shape.TensorShape(shape)
157
158
159def assert_input_compatibility(input_spec, inputs, layer_name):
160  """Checks compatibility between the layer and provided inputs.
161
162  This checks that the tensor(s) `inputs` verify the input assumptions
163  of a layer (if any). If not, a clear and actional exception gets raised.
164
165  Args:
166      input_spec: An InputSpec instance, list of InputSpec instances, a nested
167          structure of InputSpec instances, or None.
168      inputs: Input tensor, list of input tensors, or a nested structure of
169          input tensors.
170      layer_name: String, name of the layer (for error message formatting).
171
172  Raises:
173      ValueError: in case of mismatch between
174          the provided inputs and the expectations of the layer.
175  """
176  if not input_spec:
177    return
178
179  input_spec = nest.flatten(input_spec)
180  if isinstance(inputs, dict):
181    # Flatten `inputs` by reference order if input spec names are provided
182    names = [spec.name for spec in input_spec]
183    if all(names):
184      list_inputs = []
185      for name in names:
186        if name not in inputs:
187          raise ValueError('Missing data for input "%s". '
188                           'You passed a data dictionary with keys %s. '
189                           'Expected the following keys: %s' %
190                           (name, list(inputs.keys()), names))
191        list_inputs.append(inputs[name])
192      inputs = list_inputs
193
194  inputs = nest.flatten(inputs)
195  for x in inputs:
196    # Having a shape/dtype is the only commonality of the various tensor-like
197    # objects that may be passed. The most common kind of invalid type we are
198    # guarding for is a Layer instance (Functional API), which does not
199    # have a `shape` attribute.
200    if not hasattr(x, 'shape'):
201      raise TypeError('Inputs to a layer should be tensors. Got: %s' % (x,))
202
203  if len(inputs) != len(input_spec):
204    raise ValueError('Layer ' + layer_name + ' expects ' +
205                     str(len(input_spec)) + ' input(s), '
206                     'but it received ' + str(len(inputs)) +
207                     ' input tensors. Inputs received: ' + str(inputs))
208  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
209    if spec is None:
210      continue
211
212    shape = tensor_shape.TensorShape(x.shape)
213    if shape.rank is None:
214      return
215    # Check ndim.
216    if spec.ndim is not None and not spec.allow_last_axis_squeeze:
217      ndim = shape.rank
218      if ndim != spec.ndim:
219        raise ValueError('Input ' + str(input_index) + ' of layer ' +
220                         layer_name + ' is incompatible with the layer: '
221                         'expected ndim=' + str(spec.ndim) + ', found ndim=' +
222                         str(ndim) + '. Full shape received: ' +
223                         str(tuple(shape)))
224    if spec.max_ndim is not None:
225      ndim = x.shape.rank
226      if ndim is not None and ndim > spec.max_ndim:
227        raise ValueError('Input ' + str(input_index) + ' of layer ' +
228                         layer_name + ' is incompatible with the layer: '
229                         'expected max_ndim=' + str(spec.max_ndim) +
230                         ', found ndim=' + str(ndim))
231    if spec.min_ndim is not None:
232      ndim = x.shape.rank
233      if ndim is not None and ndim < spec.min_ndim:
234        raise ValueError('Input ' + str(input_index) + ' of layer ' +
235                         layer_name + ' is incompatible with the layer: '
236                         ': expected min_ndim=' + str(spec.min_ndim) +
237                         ', found ndim=' + str(ndim) +
238                         '. Full shape received: ' +
239                         str(tuple(shape)))
240    # Check dtype.
241    if spec.dtype is not None:
242      if x.dtype.name != spec.dtype:
243        raise ValueError('Input ' + str(input_index) + ' of layer ' +
244                         layer_name + ' is incompatible with the layer: '
245                         'expected dtype=' + str(spec.dtype) +
246                         ', found dtype=' + str(x.dtype))
247
248    # Check specific shape axes.
249    shape_as_list = shape.as_list()
250    if spec.axes:
251      for axis, value in spec.axes.items():
252        if hasattr(value, 'value'):
253          value = value.value
254        if value is not None and shape_as_list[int(axis)] not in {value, None}:
255          raise ValueError(
256              'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
257              ' incompatible with the layer: expected axis ' + str(axis) +
258              ' of input shape to have value ' + str(value) +
259              ' but received input with shape ' + display_shape(x.shape))
260    # Check shape.
261    if spec.shape is not None and shape.rank is not None:
262      spec_shape = spec.shape
263      if spec.allow_last_axis_squeeze:
264        if shape_as_list and shape_as_list[-1] == 1:
265          shape_as_list = shape_as_list[:-1]
266        if spec_shape and spec_shape[-1] == 1:
267          spec_shape = spec_shape[:-1]
268      for spec_dim, dim in zip(spec_shape, shape_as_list):
269        if spec_dim is not None and dim is not None:
270          if spec_dim != dim:
271            raise ValueError('Input ' + str(input_index) +
272                             ' is incompatible with layer ' + layer_name +
273                             ': expected shape=' + str(spec.shape) +
274                             ', found shape=' + display_shape(x.shape))
275
276
277def display_shape(shape):
278  return str(tuple(shape.as_list()))
279
280
281def to_tensor_spec(input_spec, default_dtype=None):
282  """Converts a Keras InputSpec object to a TensorSpec."""
283  default_dtype = default_dtype or backend.floatx()
284  if isinstance(input_spec, InputSpec):
285    dtype = input_spec.dtype or default_dtype
286    return tensor_spec.TensorSpec(to_tensor_shape(input_spec), dtype)
287  return tensor_spec.TensorSpec(None, default_dtype)
288