1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Input layer code (`Input` and `InputLayer`).
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.python.distribute import distribution_strategy_context
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.framework import tensor_spec
25from tensorflow.python.keras import backend
26from tensorflow.python.keras.distribute import distributed_training_utils
27from tensorflow.python.keras.engine import base_layer
28from tensorflow.python.keras.engine import keras_tensor
29from tensorflow.python.keras.engine import node as node_module
30from tensorflow.python.keras.saving.saved_model import layer_serialization
31from tensorflow.python.keras.utils import tf_utils
32from tensorflow.python.util.tf_export import keras_export
33
34
35def _assert_other_arg_none(arg_name, arg):
36  if arg is not None:
37    raise ValueError('When `type_spec` is not None, all other args '
38                     'except `name` must be None, '
39                     'but %s is not None.' % arg_name)
40
41
42@keras_export('keras.layers.InputLayer')
43class InputLayer(base_layer.Layer):
44  """Layer to be used as an entry point into a Network (a graph of layers).
45
46  It can either wrap an existing tensor (pass an `input_tensor` argument)
47  or create a placeholder tensor (pass arguments `input_shape`, and
48  optionally, `dtype`).
49
50  It is generally recommend to use the functional layer API via `Input`,
51  (which creates an `InputLayer`) without directly using `InputLayer`.
52
53  When using InputLayer with Keras Sequential model, it can be skipped by
54  moving the input_shape parameter to the first layer after the InputLayer.
55
56  This class can create placeholders for tf.Tensors, tf.SparseTensors, and
57  tf.RaggedTensors by choosing 'sparse=True' or 'ragged=True'. Note that
58  'sparse' and 'ragged' can't be configured to True at same time.
59  Usage:
60
61  ```python
62  # With explicit InputLayer.
63  model = tf.keras.Sequential([
64    tf.keras.layers.InputLayer(input_shape=(4,)),
65    tf.keras.layers.Dense(8)])
66  model.compile(tf.optimizers.RMSprop(0.001), loss='mse')
67  model.fit(np.zeros((10, 4)),
68            np.ones((10, 8)))
69
70  # Without InputLayer and let the first layer to have the input_shape.
71  # Keras will add a input for the model behind the scene.
72  model = tf.keras.Sequential([
73    tf.keras.layers.Dense(8, input_shape=(4,))])
74  model.compile(tf.optimizers.RMSprop(0.001), loss='mse')
75  model.fit(np.zeros((10, 4)),
76            np.ones((10, 8)))
77  ```
78
79  Args:
80      input_shape: Shape tuple (not including the batch axis), or `TensorShape`
81        instance (not including the batch axis).
82      batch_size: Optional input batch size (integer or None).
83      dtype: Optional datatype of the input. When not provided, the Keras
84          default float type will be used.
85      input_tensor: Optional tensor to use as layer input. If set, the layer
86          will use the `tf.TypeSpec` of this tensor rather
87          than creating a new placeholder tensor.
88      sparse: Boolean, whether the placeholder created is meant to be sparse.
89          Default to False.
90      ragged: Boolean, whether the placeholder created is meant to be ragged.
91          In this case, values of 'None' in the 'shape' argument represent
92          ragged dimensions. For more information about RaggedTensors, see
93          [this guide](https://www.tensorflow.org/guide/ragged_tensors).
94          Default to False.
95      type_spec: A `tf.TypeSpec` object to create Input from. This `tf.TypeSpec`
96          represents the entire batch. When provided, all other args except
97          name must be None.
98      name: Optional name of the layer (string).
99  """
100
101  def __init__(self,
102               input_shape=None,
103               batch_size=None,
104               dtype=None,
105               input_tensor=None,
106               sparse=None,
107               name=None,
108               ragged=None,
109               type_spec=None,
110               **kwargs):
111    self._init_input_shape = input_shape
112    self._init_batch_size = batch_size
113    self._init_dtype = dtype
114    self._init_sparse = sparse
115    self._init_ragged = ragged
116    self._init_type_spec = type_spec
117
118    strategy = distribution_strategy_context.get_strategy()
119    if strategy and batch_size is not None and \
120        distributed_training_utils.global_batch_size_supported(strategy):
121      if batch_size % strategy.num_replicas_in_sync != 0:
122        raise ValueError('The `batch_size` argument ({}) must be divisible by '
123                         'the number of replicas ({})'.format(
124                             batch_size, strategy.num_replicas_in_sync))
125      batch_size = batch_size // strategy.num_replicas_in_sync
126
127    if 'batch_input_shape' in kwargs:
128      batch_input_shape = kwargs.pop('batch_input_shape')
129      if input_shape and batch_input_shape:
130        raise ValueError('Only provide the input_shape OR '
131                         'batch_input_shape argument to '
132                         'InputLayer, not both at the same time.')
133      batch_size = batch_input_shape[0]
134      input_shape = batch_input_shape[1:]
135    if kwargs:
136      raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
137
138    if sparse and ragged:
139      raise ValueError(
140          'Cannot set both sparse and ragged to True in a Keras input.')
141
142    if not name:
143      prefix = 'input'
144      name = prefix + '_' + str(backend.get_uid(prefix))
145
146    if not dtype:
147      if input_tensor is None:
148        dtype = backend.floatx()
149      else:
150        dtype = backend.dtype(input_tensor)
151    elif input_tensor is not None and input_tensor.dtype != dtype:
152      raise ValueError('`input_tensor.dtype` differs from `dtype`: %s vs. %s' %
153                       (input_tensor.dtype, dtype))
154    super(InputLayer, self).__init__(dtype=dtype, name=name)
155    self.built = True
156    self.sparse = True if sparse else False
157    self.ragged = True if ragged else False
158    self.batch_size = batch_size
159    self.supports_masking = True
160
161    if isinstance(input_shape, tensor_shape.TensorShape):
162      input_shape = tuple(input_shape.as_list())
163    elif isinstance(input_shape, int):
164      input_shape = (input_shape,)
165
166    if type_spec is not None:
167      args_that_must_be_none = [
168          ('(input_)shape', self._init_input_shape),
169          ('batch_size', self._init_batch_size),
170          ('dtype', self._init_dtype),
171          ('input_tensor', input_tensor),
172          ('sparse', self._init_sparse),
173          ('ragged', self._init_ragged),
174      ]
175      for arg_name, arg in args_that_must_be_none:
176        _assert_other_arg_none(arg_name, arg)
177      if not keras_tensor.keras_tensors_enabled():
178        raise ValueError('Creating Keras inputs from a type_spec is only '
179                         'supported when eager execution is enabled.')
180      input_tensor = keras_tensor.keras_tensor_from_type_spec(type_spec)
181      if isinstance(input_tensor, keras_tensor.SparseKerasTensor):
182        self.sparse = True
183      if isinstance(input_tensor, keras_tensor.RaggedKerasTensor):
184        self.ragged = True
185      self.is_placeholder = True
186      try:
187        self._batch_input_shape = tuple(input_tensor.shape.as_list())
188      except ValueError:
189        # If the shape cannot be represented as a tuple (e.g. unknown rank)
190        self._batch_input_shape = None
191    elif input_tensor is None:
192      if input_shape is not None:
193        batch_input_shape = (batch_size,) + tuple(input_shape)
194      else:
195        batch_input_shape = None
196      graph = backend.get_graph()
197      with graph.as_default():
198        input_tensor = backend.placeholder(
199            shape=batch_input_shape,
200            dtype=dtype,
201            name=self.name,
202            sparse=sparse,
203            ragged=ragged)
204
205      self.is_placeholder = True
206      self._batch_input_shape = batch_input_shape
207    else:
208      if keras_tensor.keras_tensors_enabled():
209        if not isinstance(input_tensor, keras_tensor.KerasTensor):
210          input_tensor = keras_tensor.keras_tensor_from_tensor(input_tensor)
211      else:
212        if not tf_utils.is_symbolic_tensor(input_tensor):
213          raise ValueError('You should not pass an EagerTensor to `Input`. '
214                           'For example, instead of creating an '
215                           'InputLayer, you should instantiate your model and '
216                           'directly call it on your input.')
217      self.is_placeholder = False
218      try:
219        self._batch_input_shape = tuple(input_tensor.shape.as_list())
220      except ValueError:
221        # If the shape cannot be represented as a tuple (e.g. unknown rank)
222        self._batch_input_shape = None
223    # Create an input node.
224    input_tensor._keras_mask = None
225    node_module.Node(layer=self, outputs=input_tensor)
226
227    # Store type spec
228    if isinstance(input_tensor, keras_tensor.KerasTensor) or (
229        tf_utils.is_extension_type(input_tensor)):
230      self._type_spec = input_tensor._type_spec  # pylint: disable=protected-access
231    else:
232      self._type_spec = tensor_spec.TensorSpec(
233          shape=input_tensor.shape, dtype=input_tensor.dtype, name=self.name)
234
235  def get_config(self):
236    if self._init_type_spec is not None:
237      config = {
238          'name': self.name,
239          'type_spec': self._init_type_spec
240      }
241    else:
242      config = {
243          'batch_input_shape': self._batch_input_shape,
244          'dtype': self.dtype,
245          'sparse': self.sparse,
246          'ragged': self.ragged,
247          'name': self.name,
248      }
249    return config
250
251  @property
252  def _trackable_saved_model_saver(self):
253    return layer_serialization.InputLayerSavedModelSaver(self)
254
255
256@keras_export('keras.Input', 'keras.layers.Input')
257def Input(  # pylint: disable=invalid-name
258    shape=None,
259    batch_size=None,
260    name=None,
261    dtype=None,
262    sparse=None,
263    tensor=None,
264    ragged=None,
265    type_spec=None,
266    **kwargs):
267  """`Input()` is used to instantiate a Keras tensor.
268
269  A Keras tensor is a symbolic tensor-like object,
270  which we augment with certain attributes that allow us to build a Keras model
271  just by knowing the inputs and outputs of the model.
272
273  For instance, if `a`, `b` and `c` are Keras tensors,
274  it becomes possible to do:
275  `model = Model(input=[a, b], output=c)`
276
277  Args:
278      shape: A shape tuple (integers), not including the batch size.
279          For instance, `shape=(32,)` indicates that the expected input
280          will be batches of 32-dimensional vectors. Elements of this tuple
281          can be None; 'None' elements represent dimensions where the shape is
282          not known.
283      batch_size: optional static batch size (integer).
284      name: An optional name string for the layer.
285          Should be unique in a model (do not reuse the same name twice).
286          It will be autogenerated if it isn't provided.
287      dtype: The data type expected by the input, as a string
288          (`float32`, `float64`, `int32`...)
289      sparse: A boolean specifying whether the placeholder to be created is
290          sparse. Only one of 'ragged' and 'sparse' can be True. Note that,
291          if `sparse` is False, sparse tensors can still be passed into the
292          input - they will be densified with a default value of 0.
293      tensor: Optional existing tensor to wrap into the `Input` layer.
294          If set, the layer will use the `tf.TypeSpec` of this tensor rather
295          than creating a new placeholder tensor.
296      ragged: A boolean specifying whether the placeholder to be created is
297          ragged. Only one of 'ragged' and 'sparse' can be True. In this case,
298          values of 'None' in the 'shape' argument represent ragged dimensions.
299          For more information about RaggedTensors, see
300          [this guide](https://www.tensorflow.org/guide/ragged_tensors).
301      type_spec: A `tf.TypeSpec` object to create the input placeholder from.
302          When provided, all other args except name must be None.
303      **kwargs: deprecated arguments support. Supports `batch_shape` and
304          `batch_input_shape`.
305
306  Returns:
307    A `tensor`.
308
309  Example:
310
311  ```python
312  # this is a logistic regression in Keras
313  x = Input(shape=(32,))
314  y = Dense(16, activation='softmax')(x)
315  model = Model(x, y)
316  ```
317
318  Note that even if eager execution is enabled,
319  `Input` produces a symbolic tensor-like object (i.e. a placeholder).
320  This symbolic tensor-like object can be used with lower-level
321  TensorFlow ops that take tensors as inputs, as such:
322
323  ```python
324  x = Input(shape=(32,))
325  y = tf.square(x)  # This op will be treated like a layer
326  model = Model(x, y)
327  ```
328
329  (This behavior does not work for higher-order TensorFlow APIs such as
330  control flow and being directly watched by a `tf.GradientTape`).
331
332  However, the resulting model will not track any variables that were
333  used as inputs to TensorFlow ops. All variable usages must happen within
334  Keras layers to make sure they will be tracked by the model's weights.
335
336  The Keras Input can also create a placeholder from an arbitrary `tf.TypeSpec`,
337  e.g:
338
339  ```python
340  x = Input(type_spec=tf.RaggedTensorSpec(shape=[None, None],
341                                          dtype=tf.float32, ragged_rank=1))
342  y = x.values
343  model = Model(x, y)
344  ```
345  When passing an arbitrary `tf.TypeSpec`, it must represent the signature of an
346  entire batch instead of just one example.
347
348  Raises:
349    ValueError: If both `sparse` and `ragged` are provided.
350    ValueError: If both `shape` and (`batch_input_shape` or `batch_shape`) are
351      provided.
352    ValueError: If `shape`, `tensor` and `type_spec` are None.
353    ValueError: If arguments besides `type_spec` are non-None while `type_spec`
354                is passed.
355    ValueError: if any unrecognized parameters are provided.
356  """
357  if sparse and ragged:
358    raise ValueError(
359        'Cannot set both sparse and ragged to True in a Keras input.')
360
361  input_layer_config = {'name': name, 'dtype': dtype, 'sparse': sparse,
362                        'ragged': ragged, 'input_tensor': tensor,
363                        'type_spec': type_spec}
364
365  batch_input_shape = kwargs.pop('batch_input_shape',
366                                 kwargs.pop('batch_shape', None))
367  if shape is not None and batch_input_shape is not None:
368    raise ValueError('Only provide the `shape` OR `batch_input_shape` argument '
369                     'to Input, not both at the same time.')
370  if (batch_input_shape is None and shape is None and tensor is None
371      and type_spec is None):
372    raise ValueError('Please provide to Input a `shape`'
373                     ' or a `tensor` or a `type_spec` argument. Note that '
374                     '`shape` does not include the batch '
375                     'dimension.')
376  if kwargs:
377    raise ValueError('Unrecognized keyword arguments:', kwargs.keys())
378
379  if batch_input_shape:
380    shape = batch_input_shape[1:]
381    input_layer_config.update({'batch_input_shape': batch_input_shape})
382  else:
383    input_layer_config.update(
384        {'batch_size': batch_size, 'input_shape': shape})
385  input_layer = InputLayer(**input_layer_config)
386
387  # Return tensor including `_keras_history`.
388  # Note that in this case train_output and test_output are the same pointer.
389  outputs = input_layer._inbound_nodes[0].outputs
390  if isinstance(outputs, list) and len(outputs) == 1:
391    return outputs[0]
392  else:
393    return outputs
394