1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the InputSpec class."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from six.moves import zip  # pylint: disable=redefined-builtin
22
23from tensorflow.python.util import nest
24from tensorflow.python.util.tf_export import keras_export
25from tensorflow.python.util.tf_export import tf_export
26
27
28@keras_export('keras.layers.InputSpec', v1=['keras.layers.InputSpec'])
29@tf_export(v1=['layers.InputSpec'])
30class InputSpec(object):
31  """Specifies the ndim, dtype and shape of every input to a layer.
32
33  Every layer should expose (if appropriate) an `input_spec` attribute:
34  a list of instances of InputSpec (one per input tensor).
35
36  A None entry in a shape is compatible with any dimension,
37  a None shape is compatible with any shape.
38
39  Arguments:
40      dtype: Expected DataType of the input.
41      shape: Shape tuple, expected shape of the input
42          (may include None for unchecked axes).
43      ndim: Integer, expected rank of the input.
44      max_ndim: Integer, maximum rank of the input.
45      min_ndim: Integer, minimum rank of the input.
46      axes: Dictionary mapping integer axes to
47          a specific dimension value.
48  """
49
50  def __init__(self,
51               dtype=None,
52               shape=None,
53               ndim=None,
54               max_ndim=None,
55               min_ndim=None,
56               axes=None):
57    self.dtype = dtype
58    self.shape = shape
59    if shape is not None:
60      self.ndim = len(shape)
61    else:
62      self.ndim = ndim
63    self.max_ndim = max_ndim
64    self.min_ndim = min_ndim
65    self.axes = axes or {}
66
67  def __repr__(self):
68    spec = [('dtype=' + str(self.dtype)) if self.dtype else '',
69            ('shape=' + str(self.shape)) if self.shape else '',
70            ('ndim=' + str(self.ndim)) if self.ndim else '',
71            ('max_ndim=' + str(self.max_ndim)) if self.max_ndim else '',
72            ('min_ndim=' + str(self.min_ndim)) if self.min_ndim else '',
73            ('axes=' + str(self.axes)) if self.axes else '']
74    return 'InputSpec(%s)' % ', '.join(x for x in spec if x)
75
76
77def assert_input_compatibility(input_spec, inputs, layer_name):
78  """Checks compatibility between the layer and provided inputs.
79
80  This checks that the tensor(s) `inputs` verify the input assumptions
81  of a layer (if any). If not, a clear and actional exception gets raised.
82
83  Arguments:
84      input_spec: An InputSpec instance, or None.
85      inputs: Input tensor or list of input tensors.
86      layer_name: String, name of the layer (for error message formatting).
87
88  Raises:
89      ValueError: in case of mismatch between
90          the provided inputs and the expectations of the layer.
91  """
92  if not input_spec:
93    return
94  if not isinstance(input_spec, (list, tuple)):
95    input_spec = nest.flatten(input_spec)
96
97  inputs = nest.flatten(inputs)
98  if len(inputs) != len(input_spec):
99    raise ValueError('Layer ' + layer_name + ' expects ' +
100                     str(len(input_spec)) + ' inputs, '
101                     'but it received ' + str(len(inputs)) +
102                     ' input tensors. Inputs received: ' + str(inputs))
103  for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
104    if spec is None:
105      continue
106
107    if (spec.ndim is not None or
108        spec.min_ndim is not None or
109        spec.max_ndim is not None):
110      if x.shape.ndims is None:
111        raise ValueError('Input ' + str(input_index) + ' of layer ' +
112                         layer_name + ' is incompatible with the layer: '
113                         'its rank is undefined, but the layer requires a '
114                         'defined rank.')
115
116    # Check ndim.
117    if spec.ndim is not None:
118      ndim = x.shape.ndims
119      if ndim != spec.ndim:
120        raise ValueError('Input ' + str(input_index) + ' of layer ' +
121                         layer_name + ' is incompatible with the layer: '
122                         'expected ndim=' + str(spec.ndim) + ', found ndim=' +
123                         str(ndim) + '. Full shape received: ' +
124                         str(x.shape.as_list()))
125    if spec.max_ndim is not None:
126      ndim = x.shape.ndims
127      if ndim is not None and ndim > spec.max_ndim:
128        raise ValueError('Input ' + str(input_index) + ' of layer ' +
129                         layer_name + ' is incompatible with the layer: '
130                         'expected max_ndim=' + str(spec.max_ndim) +
131                         ', found ndim=' + str(ndim))
132    if spec.min_ndim is not None:
133      ndim = x.shape.ndims
134      if ndim is not None and ndim < spec.min_ndim:
135        raise ValueError('Input ' + str(input_index) + ' of layer ' +
136                         layer_name + ' is incompatible with the layer: '
137                         ': expected min_ndim=' + str(spec.min_ndim) +
138                         ', found ndim=' + str(ndim) +
139                         '. Full shape received: ' +
140                         str(x.shape.as_list()))
141    # Check dtype.
142    if spec.dtype is not None:
143      if x.dtype != spec.dtype:
144        raise ValueError('Input ' + str(input_index) + ' of layer ' +
145                         layer_name + ' is incompatible with the layer: '
146                         'expected dtype=' + str(spec.dtype) +
147                         ', found dtype=' + str(x.dtype))
148    # Check specific shape axes.
149    if spec.axes:
150      shape = x.shape.as_list()
151      if shape is not None:
152        for axis, value in spec.axes.items():
153          if hasattr(value, 'value'):
154            value = value.value
155          if value is not None and shape[int(axis)] not in {value, None}:
156            raise ValueError(
157                'Input ' + str(input_index) + ' of layer ' + layer_name + ' is'
158                ' incompatible with the layer: expected axis ' + str(axis) +
159                ' of input shape to have value ' + str(value) +
160                ' but received input with shape ' + str(shape))
161    # Check shape.
162    if spec.shape is not None:
163      shape = x.shape.as_list()
164      if shape is not None:
165        for spec_dim, dim in zip(spec.shape, shape):
166          if spec_dim is not None and dim is not None:
167            if spec_dim != dim:
168              raise ValueError('Input ' + str(input_index) +
169                               ' is incompatible with layer ' + layer_name +
170                               ': expected shape=' + str(spec.shape) +
171                               ', found shape=' + str(shape))
172