1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=not-callable
16# pylint: disable=redefined-builtin
17"""Layers that can merge several inputs into one.
18"""
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras.engine import base_layer_utils
25from tensorflow.python.keras.engine.base_layer import Layer
26from tensorflow.python.keras.utils import tf_utils
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn
30from tensorflow.python.util.tf_export import keras_export
31
32
33class _Merge(Layer):
34  """Generic merge layer for elementwise merge functions.
35
36  Used to implement `Sum`, `Average`, etc.
37  """
38
39  def __init__(self, **kwargs):
40    """Intializes a Merge layer.
41
42    Args:
43      **kwargs: standard layer keyword arguments.
44    """
45    super(_Merge, self).__init__(**kwargs)
46    self.supports_masking = True
47
48  def _merge_function(self, inputs):
49    raise NotImplementedError
50
51  def _compute_elemwise_op_output_shape(self, shape1, shape2):
52    """Computes the shape of the resultant of an elementwise operation.
53
54    Args:
55        shape1: tuple or None. Shape of the first tensor
56        shape2: tuple or None. Shape of the second tensor
57
58    Returns:
59        expected output shape when an element-wise operation is
60        carried out on 2 tensors with shapes shape1 and shape2.
61        tuple or None.
62
63    Raises:
64        ValueError: if shape1 and shape2 are not compatible for
65            element-wise operations.
66    """
67    if None in [shape1, shape2]:
68      return None
69    elif len(shape1) < len(shape2):
70      return self._compute_elemwise_op_output_shape(shape2, shape1)
71    elif not shape2:
72      return shape1
73    output_shape = list(shape1[:-len(shape2)])
74    for i, j in zip(shape1[-len(shape2):], shape2):
75      if i is None or j is None:
76        output_shape.append(None)
77      elif i == 1:
78        output_shape.append(j)
79      elif j == 1:
80        output_shape.append(i)
81      else:
82        if i != j:
83          raise ValueError(
84              'Operands could not be broadcast '
85              'together with shapes ' + str(shape1) + ' ' + str(shape2))
86        output_shape.append(i)
87    return tuple(output_shape)
88
89  @tf_utils.shape_type_conversion
90  def build(self, input_shape):
91    # Used purely for shape validation.
92    if not isinstance(input_shape[0], tuple):
93      raise ValueError('A merge layer should be called on a list of inputs.')
94    if len(input_shape) < 2:
95      raise ValueError('A merge layer should be called '
96                       'on a list of at least 2 inputs. '
97                       'Got ' + str(len(input_shape)) + ' inputs.')
98    batch_sizes = {s[0] for s in input_shape if s} - {None}
99    if len(batch_sizes) > 1:
100      raise ValueError(
101          'Can not merge tensors with different '
102          'batch sizes. Got tensors with shapes : ' + str(input_shape))
103    if input_shape[0] is None:
104      output_shape = None
105    else:
106      output_shape = input_shape[0][1:]
107    for i in range(1, len(input_shape)):
108      if input_shape[i] is None:
109        shape = None
110      else:
111        shape = input_shape[i][1:]
112      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
113    # If the inputs have different ranks, we have to reshape them
114    # to make them broadcastable.
115    if None not in input_shape and len(set(map(len, input_shape))) == 1:
116      self._reshape_required = False
117    else:
118      self._reshape_required = True
119
120  def call(self, inputs):
121    if not isinstance(inputs, (list, tuple)):
122      raise ValueError('A merge layer should be called on a list of inputs.')
123    if self._reshape_required:
124      reshaped_inputs = []
125      input_ndims = list(map(K.ndim, inputs))
126      if None not in input_ndims:
127        # If ranks of all inputs are available,
128        # we simply expand each of them at axis=1
129        # until all of them have the same rank.
130        max_ndim = max(input_ndims)
131        for x in inputs:
132          x_ndim = K.ndim(x)
133          for _ in range(max_ndim - x_ndim):
134            x = array_ops.expand_dims(x, axis=1)
135          reshaped_inputs.append(x)
136        return self._merge_function(reshaped_inputs)
137      else:
138        # Transpose all inputs so that batch size is the last dimension.
139        # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
140        transposed = False
141        for x in inputs:
142          x_ndim = K.ndim(x)
143          if x_ndim is None:
144            x_shape = array_ops.shape(x)
145            batch_size = x_shape[0]
146            new_shape = K.concatenate(
147                [x_shape[1:],
148                 array_ops.expand_dims(batch_size, axis=-1)])
149            x_transposed = array_ops.reshape(
150                x,
151                array_ops.stack(
152                    [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0))
153            x_transposed = array_ops.transpose(x_transposed, perm=(1, 0))
154            x_transposed = array_ops.reshape(x_transposed, new_shape)
155            reshaped_inputs.append(x_transposed)
156            transposed = True
157          elif x_ndim > 1:
158            dims = list(range(1, x_ndim)) + [0]
159            reshaped_inputs.append(array_ops.transpose(x, perm=dims))
160            transposed = True
161          else:
162            # We don't transpose inputs if they are 1D vectors or scalars.
163            reshaped_inputs.append(x)
164        y = self._merge_function(reshaped_inputs)
165        y_ndim = K.ndim(y)
166        if transposed:
167          # If inputs have been transposed, we have to transpose the output too.
168          if y_ndim is None:
169            y_shape = array_ops.shape(y)
170            y_ndim = array_ops.shape(y_shape)[0]
171            batch_size = y_shape[y_ndim - 1]
172            new_shape = K.concatenate([
173                array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
174            ])
175            y = array_ops.reshape(y, (-1, batch_size))
176            y = array_ops.transpose(y, perm=(1, 0))
177            y = array_ops.reshape(y, new_shape)
178          elif y_ndim > 1:
179            dims = [y_ndim - 1] + list(range(y_ndim - 1))
180            y = array_ops.transpose(y, perm=dims)
181        return y
182    else:
183      return self._merge_function(inputs)
184
185  @tf_utils.shape_type_conversion
186  def compute_output_shape(self, input_shape):
187    if input_shape[0] is None:
188      output_shape = None
189    else:
190      output_shape = input_shape[0][1:]
191    for i in range(1, len(input_shape)):
192      if input_shape[i] is None:
193        shape = None
194      else:
195        shape = input_shape[i][1:]
196      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
197    batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
198    if len(batch_sizes) == 1:
199      output_shape = (list(batch_sizes)[0],) + output_shape
200    else:
201      output_shape = (None,) + output_shape
202    return output_shape
203
204  def compute_mask(self, inputs, mask=None):
205    if mask is None:
206      return None
207    if not isinstance(mask, (tuple, list)):
208      raise ValueError('`mask` should be a list.')
209    if not isinstance(inputs, (tuple, list)):
210      raise ValueError('`inputs` should be a list.')
211    if len(mask) != len(inputs):
212      raise ValueError('The lists `inputs` and `mask` '
213                       'should have the same length.')
214    if all(m is None for m in mask):
215      return None
216    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
217    return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
218
219
220@keras_export('keras.layers.Add')
221class Add(_Merge):
222  """Layer that adds a list of inputs.
223
224  It takes as input a list of tensors,
225  all of the same shape, and returns
226  a single tensor (also of the same shape).
227
228  Examples:
229
230  >>> input_shape = (2, 3, 4)
231  >>> x1 = tf.random.normal(input_shape)
232  >>> x2 = tf.random.normal(input_shape)
233  >>> y = tf.keras.layers.Add()([x1, x2])
234  >>> print(y.shape)
235  (2, 3, 4)
236
237  Used in a functional model:
238
239  >>> input1 = tf.keras.layers.Input(shape=(16,))
240  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
241  >>> input2 = tf.keras.layers.Input(shape=(32,))
242  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
243  >>> # equivalent to `added = tf.keras.layers.add([x1, x2])`
244  >>> added = tf.keras.layers.Add()([x1, x2])
245  >>> out = tf.keras.layers.Dense(4)(added)
246  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
247
248  """
249
250  def _merge_function(self, inputs):
251    output = inputs[0]
252    for i in range(1, len(inputs)):
253      output += inputs[i]
254    return output
255
256
257@keras_export('keras.layers.Subtract')
258class Subtract(_Merge):
259  """Layer that subtracts two inputs.
260
261  It takes as input a list of tensors of size 2,
262  both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
263  also of the same shape.
264
265  Examples:
266
267  ```python
268      import keras
269
270      input1 = keras.layers.Input(shape=(16,))
271      x1 = keras.layers.Dense(8, activation='relu')(input1)
272      input2 = keras.layers.Input(shape=(32,))
273      x2 = keras.layers.Dense(8, activation='relu')(input2)
274      # Equivalent to subtracted = keras.layers.subtract([x1, x2])
275      subtracted = keras.layers.Subtract()([x1, x2])
276
277      out = keras.layers.Dense(4)(subtracted)
278      model = keras.models.Model(inputs=[input1, input2], outputs=out)
279  ```
280  """
281
282  @tf_utils.shape_type_conversion
283  def build(self, input_shape):
284    super(Subtract, self).build(input_shape)
285    if len(input_shape) != 2:
286      raise ValueError('A `Subtract` layer should be called '
287                       'on exactly 2 inputs')
288
289  def _merge_function(self, inputs):
290    if len(inputs) != 2:
291      raise ValueError('A `Subtract` layer should be called '
292                       'on exactly 2 inputs')
293    return inputs[0] - inputs[1]
294
295
296@keras_export('keras.layers.Multiply')
297class Multiply(_Merge):
298  """Layer that multiplies (element-wise) a list of inputs.
299
300  It takes as input a list of tensors, all of the same shape, and returns
301  a single tensor (also of the same shape).
302
303  >>> tf.keras.layers.Multiply()([np.arange(5).reshape(5, 1),
304  ...                             np.arange(5, 10).reshape(5, 1)])
305  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
306  array([[ 0],
307       [ 6],
308       [14],
309       [24],
310       [36]])>
311
312  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
313  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
314  >>> multiplied = tf.keras.layers.Multiply()([x1, x2])
315  >>> multiplied.shape
316  TensorShape([5, 8])
317  """
318
319  def _merge_function(self, inputs):
320    output = inputs[0]
321    for i in range(1, len(inputs)):
322      output = output * inputs[i]
323    return output
324
325
326@keras_export('keras.layers.Average')
327class Average(_Merge):
328  """Layer that averages a list of inputs element-wise.
329
330  It takes as input a list of tensors, all of the same shape, and returns
331  a single tensor (also of the same shape).
332
333  Example:
334
335  >>> x1 = np.ones((2, 2))
336  >>> x2 = np.zeros((2, 2))
337  >>> y = tf.keras.layers.Average()([x1, x2])
338  >>> y.numpy().tolist()
339  [[0.5, 0.5], [0.5, 0.5]]
340
341  Usage in a functional model:
342
343  >>> input1 = tf.keras.layers.Input(shape=(16,))
344  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
345  >>> input2 = tf.keras.layers.Input(shape=(32,))
346  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
347  >>> avg = tf.keras.layers.Average()([x1, x2])
348  >>> out = tf.keras.layers.Dense(4)(avg)
349  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
350
351  Raises:
352    ValueError: If there is a shape mismatch between the inputs and the shapes
353      cannot be broadcasted to match.
354  """
355
356  def _merge_function(self, inputs):
357    output = inputs[0]
358    for i in range(1, len(inputs)):
359      output += inputs[i]
360    return output / len(inputs)
361
362
363@keras_export('keras.layers.Maximum')
364class Maximum(_Merge):
365  """Layer that computes the maximum (element-wise) a list of inputs.
366
367  It takes as input a list of tensors, all of the same shape, and returns
368  a single tensor (also of the same shape).
369
370  >>> tf.keras.layers.Maximum()([np.arange(5).reshape(5, 1),
371  ...                            np.arange(5, 10).reshape(5, 1)])
372  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
373  array([[5],
374       [6],
375       [7],
376       [8],
377       [9]])>
378
379  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
380  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
381  >>> maxed = tf.keras.layers.Maximum()([x1, x2])
382  >>> maxed.shape
383  TensorShape([5, 8])
384  """
385
386  def _merge_function(self, inputs):
387    output = inputs[0]
388    for i in range(1, len(inputs)):
389      output = math_ops.maximum(output, inputs[i])
390    return output
391
392
393@keras_export('keras.layers.Minimum')
394class Minimum(_Merge):
395  """Layer that computes the minimum (element-wise) a list of inputs.
396
397  It takes as input a list of tensors, all of the same shape, and returns
398  a single tensor (also of the same shape).
399
400  >>> tf.keras.layers.Minimum()([np.arange(5).reshape(5, 1),
401  ...                            np.arange(5, 10).reshape(5, 1)])
402  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
403  array([[0],
404       [1],
405       [2],
406       [3],
407       [4]])>
408
409  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
410  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
411  >>> minned = tf.keras.layers.Minimum()([x1, x2])
412  >>> minned.shape
413  TensorShape([5, 8])
414  """
415
416  def _merge_function(self, inputs):
417    output = inputs[0]
418    for i in range(1, len(inputs)):
419      output = math_ops.minimum(output, inputs[i])
420    return output
421
422
423@keras_export('keras.layers.Concatenate')
424class Concatenate(_Merge):
425  """Layer that concatenates a list of inputs.
426
427  It takes as input a list of tensors, all of the same shape except
428  for the concatenation axis, and returns a single tensor that is the
429  concatenation of all inputs.
430
431  >>> x = np.arange(20).reshape(2, 2, 5)
432  >>> print(x)
433  [[[ 0  1  2  3  4]
434    [ 5  6  7  8  9]]
435   [[10 11 12 13 14]
436    [15 16 17 18 19]]]
437  >>> y = np.arange(20, 30).reshape(2, 1, 5)
438  >>> print(y)
439  [[[20 21 22 23 24]]
440   [[25 26 27 28 29]]]
441  >>> tf.keras.layers.Concatenate(axis=1)([x, y])
442  <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
443  array([[[ 0,  1,  2,  3,  4],
444          [ 5,  6,  7,  8,  9],
445          [20, 21, 22, 23, 24]],
446         [[10, 11, 12, 13, 14],
447          [15, 16, 17, 18, 19],
448          [25, 26, 27, 28, 29]]])>
449
450  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
451  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
452  >>> concatted = tf.keras.layers.Concatenate()([x1, x2])
453  >>> concatted.shape
454  TensorShape([5, 16])
455
456  """
457
458  def __init__(self, axis=-1, **kwargs):
459    """Instantiates a Concatenate layer.
460
461    >>> x = np.arange(20).reshape(2, 2, 5)
462    >>> print(x)
463    [[[ 0  1  2  3  4]
464      [ 5  6  7  8  9]]
465     [[10 11 12 13 14]
466      [15 16 17 18 19]]]
467    >>> y = np.arange(20, 30).reshape(2, 1, 5)
468    >>> print(y)
469    [[[20 21 22 23 24]]
470     [[25 26 27 28 29]]]
471    >>> tf.keras.layers.Concatenate(axis=1)([x, y])
472    <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
473    array([[[ 0,  1,  2,  3,  4],
474            [ 5,  6,  7,  8,  9],
475            [20, 21, 22, 23, 24]],
476           [[10, 11, 12, 13, 14],
477            [15, 16, 17, 18, 19],
478            [25, 26, 27, 28, 29]]])>
479
480    Args:
481      axis: Axis along which to concatenate.
482      **kwargs: standard layer keyword arguments.
483    """
484    super(Concatenate, self).__init__(**kwargs)
485    self.axis = axis
486    self.supports_masking = True
487    self._reshape_required = False
488
489  @tf_utils.shape_type_conversion
490  def build(self, input_shape):
491    # Used purely for shape validation.
492    if not isinstance(input_shape[0], tuple) or len(input_shape) < 1:
493      raise ValueError('A `Concatenate` layer should be called '
494                       'on a list of at least 1 input.')
495    if all(shape is None for shape in input_shape):
496      return
497    reduced_inputs_shapes = [list(shape) for shape in input_shape]
498    shape_set = set()
499    for i in range(len(reduced_inputs_shapes)):
500      del reduced_inputs_shapes[i][self.axis]
501      shape_set.add(tuple(reduced_inputs_shapes[i]))
502
503    if len(shape_set) != 1:
504      err_msg = ('A `Concatenate` layer requires inputs with matching shapes '
505                 'except for the concat axis. Got inputs shapes: %s' %
506                 input_shape)
507      # Make sure all the shapes have same ranks.
508      ranks = set(len(shape) for shape in shape_set)
509      if len(ranks) != 1:
510        raise ValueError(err_msg)
511      # Get the only rank for the set.
512      (rank,) = ranks
513      for axis in range(rank):
514        # Skip the Nones in the shape since they are dynamic, also the axis for
515        # concat has been removed above.
516        unique_dims = set(
517            shape[axis] for shape in shape_set if shape[axis] is not None)
518        if len(unique_dims) > 1:
519          raise ValueError(err_msg)
520
521  def _merge_function(self, inputs):
522    return K.concatenate(inputs, axis=self.axis)
523
524  @tf_utils.shape_type_conversion
525  def compute_output_shape(self, input_shape):
526    if ((not isinstance(input_shape, (tuple, list))) or
527        (not isinstance(input_shape[0], (tuple, list)))):
528      # The tf_utils.shape_type_conversion decorator turns tensorshapes
529      # into tuples, so we need to verify that `input_shape` is a list/tuple,
530      # *and* that the individual elements are themselves shape tuples.
531      raise ValueError('A `Concatenate` layer should be called '
532                       'on a list of inputs.')
533    input_shapes = input_shape
534    output_shape = list(input_shapes[0])
535    for shape in input_shapes[1:]:
536      if output_shape[self.axis] is None or shape[self.axis] is None:
537        output_shape[self.axis] = None
538        break
539      output_shape[self.axis] += shape[self.axis]
540    return tuple(output_shape)
541
542  def compute_mask(self, inputs, mask=None):
543    if mask is None:
544      return None
545    if not isinstance(mask, (tuple, list)):
546      raise ValueError('`mask` should be a list.')
547    if not isinstance(inputs, (tuple, list)):
548      raise ValueError('`inputs` should be a list.')
549    if len(mask) != len(inputs):
550      raise ValueError('The lists `inputs` and `mask` '
551                       'should have the same length.')
552    if all(m is None for m in mask):
553      return None
554    # Make a list of masks while making sure
555    # the dimensionality of each mask
556    # is the same as the corresponding input.
557    masks = []
558    for input_i, mask_i in zip(inputs, mask):
559      if mask_i is None:
560        # Input is unmasked. Append all 1s to masks,
561        masks.append(array_ops.ones_like(input_i, dtype='bool'))
562      elif K.ndim(mask_i) < K.ndim(input_i):
563        # Mask is smaller than the input, expand it
564        masks.append(array_ops.expand_dims(mask_i, axis=-1))
565      else:
566        masks.append(mask_i)
567    concatenated = K.concatenate(masks, axis=self.axis)
568    return K.all(concatenated, axis=-1, keepdims=False)
569
570  def get_config(self):
571    config = {
572        'axis': self.axis,
573    }
574    base_config = super(Concatenate, self).get_config()
575    return dict(list(base_config.items()) + list(config.items()))
576
577
578@keras_export('keras.layers.Dot')
579class Dot(_Merge):
580  """Layer that computes a dot product between samples in two tensors.
581
582  E.g. if applied to a list of two tensors `a` and `b` of shape
583  `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
584  where each entry `i` will be the dot product between
585  `a[i]` and `b[i]`.
586
587  >>> x = np.arange(10).reshape(1, 5, 2)
588  >>> print(x)
589  [[[0 1]
590    [2 3]
591    [4 5]
592    [6 7]
593    [8 9]]]
594  >>> y = np.arange(10, 20).reshape(1, 2, 5)
595  >>> print(y)
596  [[[10 11 12 13 14]
597    [15 16 17 18 19]]]
598  >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
599  <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
600  array([[[260, 360],
601          [320, 445]]])>
602
603  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
604  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
605  >>> dotted = tf.keras.layers.Dot(axes=1)([x1, x2])
606  >>> dotted.shape
607  TensorShape([5, 1])
608
609
610  """
611
612  def __init__(self, axes, normalize=False, **kwargs):
613    """Initializes a layer that computes the element-wise dot product.
614
615      >>> x = np.arange(10).reshape(1, 5, 2)
616      >>> print(x)
617      [[[0 1]
618        [2 3]
619        [4 5]
620        [6 7]
621        [8 9]]]
622      >>> y = np.arange(10, 20).reshape(1, 2, 5)
623      >>> print(y)
624      [[[10 11 12 13 14]
625        [15 16 17 18 19]]]
626      >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
627      <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
628      array([[[260, 360],
629              [320, 445]]])>
630
631    Args:
632      axes: Integer or tuple of integers,
633        axis or axes along which to take the dot product. If a tuple, should
634        be two integers corresponding to the desired axis from the first input
635        and the desired axis from the second input, respectively. Note that the
636        size of the two selected axes must match.
637      normalize: Whether to L2-normalize samples along the
638        dot product axis before taking the dot product.
639        If set to True, then the output of the dot product
640        is the cosine proximity between the two samples.
641      **kwargs: Standard layer keyword arguments.
642    """
643    super(Dot, self).__init__(**kwargs)
644    if not isinstance(axes, int):
645      if not isinstance(axes, (list, tuple)):
646        raise TypeError('Invalid type for `axes` - '
647                        'should be a list or an int.')
648      if len(axes) != 2:
649        raise ValueError('Invalid format for `axes` - '
650                         'should contain two elements.')
651      if not isinstance(axes[0], int) or not isinstance(axes[1], int):
652        raise ValueError('Invalid format for `axes` - '
653                         'list elements should be "int".')
654    self.axes = axes
655    self.normalize = normalize
656    self.supports_masking = True
657    self._reshape_required = False
658
659  @tf_utils.shape_type_conversion
660  def build(self, input_shape):
661    # Used purely for shape validation.
662    if not isinstance(input_shape[0], tuple) or len(input_shape) != 2:
663      raise ValueError('A `Dot` layer should be called '
664                       'on a list of 2 inputs.')
665    shape1 = input_shape[0]
666    shape2 = input_shape[1]
667    if shape1 is None or shape2 is None:
668      return
669    if isinstance(self.axes, int):
670      if self.axes < 0:
671        axes = [self.axes % len(shape1), self.axes % len(shape2)]
672      else:
673        axes = [self.axes] * 2
674    else:
675      axes = self.axes
676    if shape1[axes[0]] != shape2[axes[1]]:
677      raise ValueError('Dimension incompatibility '
678                       '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
679                       'Layer shapes: %s, %s. ' % (shape1, shape2) +
680                       'Chosen axes: %s, %s' % (axes[0], axes[1]))
681
682  def _merge_function(self, inputs):
683    base_layer_utils.no_ragged_support(inputs, self.name)
684    if len(inputs) != 2:
685      raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
686    x1 = inputs[0]
687    x2 = inputs[1]
688    if isinstance(self.axes, int):
689      if self.axes < 0:
690        axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)]
691      else:
692        axes = [self.axes] * 2
693    else:
694      axes = []
695      for i in range(len(self.axes)):
696        if self.axes[i] < 0:
697          axes.append(self.axes[i] % K.ndim(inputs[i]))
698        else:
699          axes.append(self.axes[i])
700    if self.normalize:
701      x1 = nn.l2_normalize(x1, axis=axes[0])
702      x2 = nn.l2_normalize(x2, axis=axes[1])
703    output = K.batch_dot(x1, x2, axes)
704    return output
705
706  @tf_utils.shape_type_conversion
707  def compute_output_shape(self, input_shape):
708    if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:
709      raise ValueError('A `Dot` layer should be called '
710                       'on a list of 2 inputs.')
711    shape1 = list(input_shape[0])
712    shape2 = list(input_shape[1])
713    if isinstance(self.axes, int):
714      if self.axes < 0:
715        axes = [self.axes % len(shape1), self.axes % len(shape2)]
716      else:
717        axes = [self.axes] * 2
718    else:
719      axes = self.axes
720    shape1.pop(axes[0])
721    shape2.pop(axes[1])
722    shape2.pop(0)
723    output_shape = shape1 + shape2
724    if len(output_shape) == 1:
725      output_shape += [1]
726    return tuple(output_shape)
727
728  def compute_mask(self, inputs, mask=None):
729    return None
730
731  def get_config(self):
732    config = {
733        'axes': self.axes,
734        'normalize': self.normalize,
735    }
736    base_config = super(Dot, self).get_config()
737    return dict(list(base_config.items()) + list(config.items()))
738
739
740@keras_export('keras.layers.add')
741def add(inputs, **kwargs):
742  """Functional interface to the `tf.keras.layers.Add` layer.
743
744  Args:
745      inputs: A list of input tensors (at least 2) with the same shape.
746      **kwargs: Standard layer keyword arguments.
747
748  Returns:
749      A tensor as the sum of the inputs. It has the same shape as the inputs.
750
751  Examples:
752
753  >>> input_shape = (2, 3, 4)
754  >>> x1 = tf.random.normal(input_shape)
755  >>> x2 = tf.random.normal(input_shape)
756  >>> y = tf.keras.layers.add([x1, x2])
757  >>> print(y.shape)
758  (2, 3, 4)
759
760  Used in a functional model:
761
762  >>> input1 = tf.keras.layers.Input(shape=(16,))
763  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
764  >>> input2 = tf.keras.layers.Input(shape=(32,))
765  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
766  >>> added = tf.keras.layers.add([x1, x2])
767  >>> out = tf.keras.layers.Dense(4)(added)
768  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
769
770  """
771  return Add(**kwargs)(inputs)
772
773
774@keras_export('keras.layers.subtract')
775def subtract(inputs, **kwargs):
776  """Functional interface to the `Subtract` layer.
777
778  Args:
779      inputs: A list of input tensors (exactly 2).
780      **kwargs: Standard layer keyword arguments.
781
782  Returns:
783      A tensor, the difference of the inputs.
784
785  Examples:
786
787  ```python
788      import keras
789
790      input1 = keras.layers.Input(shape=(16,))
791      x1 = keras.layers.Dense(8, activation='relu')(input1)
792      input2 = keras.layers.Input(shape=(32,))
793      x2 = keras.layers.Dense(8, activation='relu')(input2)
794      subtracted = keras.layers.subtract([x1, x2])
795
796      out = keras.layers.Dense(4)(subtracted)
797      model = keras.models.Model(inputs=[input1, input2], outputs=out)
798  ```
799  """
800  return Subtract(**kwargs)(inputs)
801
802
803@keras_export('keras.layers.multiply')
804def multiply(inputs, **kwargs):
805  """Functional interface to the `Multiply` layer.
806
807  Args:
808      inputs: A list of input tensors (at least 2).
809      **kwargs: Standard layer keyword arguments.
810
811  Returns:
812      A tensor, the element-wise product of the inputs.
813  """
814  return Multiply(**kwargs)(inputs)
815
816
817@keras_export('keras.layers.average')
818def average(inputs, **kwargs):
819  """Functional interface to the `tf.keras.layers.Average` layer.
820
821  Example:
822
823  >>> x1 = np.ones((2, 2))
824  >>> x2 = np.zeros((2, 2))
825  >>> y = tf.keras.layers.Average()([x1, x2])
826  >>> y.numpy().tolist()
827  [[0.5, 0.5], [0.5, 0.5]]
828
829  Usage in a functional model:
830
831  >>> input1 = tf.keras.layers.Input(shape=(16,))
832  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
833  >>> input2 = tf.keras.layers.Input(shape=(32,))
834  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
835  >>> avg = tf.keras.layers.Average()([x1, x2])
836  >>> out = tf.keras.layers.Dense(4)(avg)
837  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
838
839  Args:
840      inputs: A list of input tensors (at least 2).
841      **kwargs: Standard layer keyword arguments.
842
843  Returns:
844      A tensor, the average of the inputs.
845
846  Raises:
847    ValueError: If there is a shape mismatch between the inputs and the shapes
848      cannot be broadcasted to match.
849  """
850  return Average(**kwargs)(inputs)
851
852
853@keras_export('keras.layers.maximum')
854def maximum(inputs, **kwargs):
855  """Functional interface to compute maximum (element-wise) list of `inputs`.
856
857  This is equivalent to the `tf.keras.layers.Maximum` layer.
858
859  For example:
860
861  ```python
862  input1 = tf.keras.layers.Input(shape=(16,))
863  x1 = tf.keras.layers.Dense(8, activation='relu')(input1) #shape=(None, 8)
864  input2 = tf.keras.layers.Input(shape=(32,))
865  x2 = tf.keras.layers.Dense(8, activation='relu')(input2) #shape=(None, 8)
866  max_inp=tf.keras.layers.maximum([x1,x2]) #shape=(None, 8)
867  out = tf.keras.layers.Dense(4)(max_inp)
868  model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
869  ```
870
871  Args:
872      inputs: A list of input tensors (at least 2) of same shape.
873      **kwargs: Standard layer keyword arguments.
874
875  Returns:
876      A tensor (of same shape as input tensor) with the element-wise
877      maximum of the inputs.
878
879  Raises:
880      ValueError: If input tensors are of different shape.
881  """
882  return Maximum(**kwargs)(inputs)
883
884
885@keras_export('keras.layers.minimum')
886def minimum(inputs, **kwargs):
887  """Functional interface to the `Minimum` layer.
888
889  Args:
890      inputs: A list of input tensors (at least 2).
891      **kwargs: Standard layer keyword arguments.
892
893  Returns:
894      A tensor, the element-wise minimum of the inputs.
895  """
896  return Minimum(**kwargs)(inputs)
897
898
899@keras_export('keras.layers.concatenate')
900def concatenate(inputs, axis=-1, **kwargs):
901  """Functional interface to the `Concatenate` layer.
902
903  >>> x = np.arange(20).reshape(2, 2, 5)
904  >>> print(x)
905  [[[ 0  1  2  3  4]
906    [ 5  6  7  8  9]]
907   [[10 11 12 13 14]
908    [15 16 17 18 19]]]
909  >>> y = np.arange(20, 30).reshape(2, 1, 5)
910  >>> print(y)
911  [[[20 21 22 23 24]]
912   [[25 26 27 28 29]]]
913  >>> tf.keras.layers.concatenate([x, y],
914  ...                             axis=1)
915  <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
916  array([[[ 0,  1,  2,  3,  4],
917        [ 5,  6,  7,  8,  9],
918        [20, 21, 22, 23, 24]],
919       [[10, 11, 12, 13, 14],
920        [15, 16, 17, 18, 19],
921        [25, 26, 27, 28, 29]]])>
922
923  Args:
924      inputs: A list of input tensors (at least 2).
925      axis: Concatenation axis.
926      **kwargs: Standard layer keyword arguments.
927
928  Returns:
929      A tensor, the concatenation of the inputs alongside axis `axis`.
930  """
931  return Concatenate(axis=axis, **kwargs)(inputs)
932
933
934@keras_export('keras.layers.dot')
935def dot(inputs, axes, normalize=False, **kwargs):
936  """Functional interface to the `Dot` layer.
937
938  Args:
939      inputs: A list of input tensors (at least 2).
940      axes: Integer or tuple of integers,
941          axis or axes along which to take the dot product.
942      normalize: Whether to L2-normalize samples along the
943          dot product axis before taking the dot product.
944          If set to True, then the output of the dot product
945          is the cosine proximity between the two samples.
946      **kwargs: Standard layer keyword arguments.
947
948  Returns:
949      A tensor, the dot product of the samples from the inputs.
950  """
951  return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)
952