1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=not-callable
16# pylint: disable=redefined-builtin
17"""Layers that can merge several inputs into one.
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
23from tensorflow.python.keras import backend as K
24from tensorflow.python.keras.engine import base_layer_utils
25from tensorflow.python.keras.engine.base_layer import Layer
26from tensorflow.python.keras.utils import tf_utils
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn
30from tensorflow.python.util.tf_export import keras_export
33class _Merge(Layer):
34  """Generic merge layer for elementwise merge functions.
36  Used to implement `Sum`, `Average`, etc.
37  """
39  def __init__(self, **kwargs):
40    """Intializes a Merge layer.
42    Args:
43      **kwargs: standard layer keyword arguments.
44    """
45    super(_Merge, self).__init__(**kwargs)
46    self.supports_masking = True
48  def _merge_function(self, inputs):
49    raise NotImplementedError
51  def _compute_elemwise_op_output_shape(self, shape1, shape2):
52    """Computes the shape of the resultant of an elementwise operation.
54    Args:
55        shape1: tuple or None. Shape of the first tensor
56        shape2: tuple or None. Shape of the second tensor
58    Returns:
59        expected output shape when an element-wise operation is
60        carried out on 2 tensors with shapes shape1 and shape2.
61        tuple or None.
63    Raises:
64        ValueError: if shape1 and shape2 are not compatible for
65            element-wise operations.
66    """
67    if None in [shape1, shape2]:
68      return None
69    elif len(shape1) < len(shape2):
70      return self._compute_elemwise_op_output_shape(shape2, shape1)
71    elif not shape2:
72      return shape1
73    output_shape = list(shape1[:-len(shape2)])
74    for i, j in zip(shape1[-len(shape2):], shape2):
75      if i is None or j is None:
76        output_shape.append(None)
77      elif i == 1:
78        output_shape.append(j)
79      elif j == 1:
80        output_shape.append(i)
81      else:
82        if i != j:
83          raise ValueError(
84              'Operands could not be broadcast '
85              'together with shapes ' + str(shape1) + ' ' + str(shape2))
86        output_shape.append(i)
87    return tuple(output_shape)
89  @tf_utils.shape_type_conversion
90  def build(self, input_shape):
91    # Used purely for shape validation.
92    if not isinstance(input_shape[0], tuple):
93      raise ValueError('A merge layer should be called on a list of inputs.')
94    if len(input_shape) < 2:
95      raise ValueError('A merge layer should be called '
96                       'on a list of at least 2 inputs. '
97                       'Got ' + str(len(input_shape)) + ' inputs.')
98    batch_sizes = {s[0] for s in input_shape if s} - {None}
99    if len(batch_sizes) > 1:
100      raise ValueError(
101          'Can not merge tensors with different '
102          'batch sizes. Got tensors with shapes : ' + str(input_shape))
103    if input_shape[0] is None:
104      output_shape = None
105    else:
106      output_shape = input_shape[0][1:]
107    for i in range(1, len(input_shape)):
108      if input_shape[i] is None:
109        shape = None
110      else:
111        shape = input_shape[i][1:]
112      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
113    # If the inputs have different ranks, we have to reshape them
114    # to make them broadcastable.
115    if None not in input_shape and len(set(map(len, input_shape))) == 1:
116      self._reshape_required = False
117    else:
118      self._reshape_required = True
120  def call(self, inputs):
121    if not isinstance(inputs, (list, tuple)):
122      raise ValueError('A merge layer should be called on a list of inputs.')
123    if self._reshape_required:
124      reshaped_inputs = []
125      input_ndims = list(map(K.ndim, inputs))
126      if None not in input_ndims:
127        # If ranks of all inputs are available,
128        # we simply expand each of them at axis=1
129        # until all of them have the same rank.
130        max_ndim = max(input_ndims)
131        for x in inputs:
132          x_ndim = K.ndim(x)
133          for _ in range(max_ndim - x_ndim):
134            x = array_ops.expand_dims(x, axis=1)
135          reshaped_inputs.append(x)
136        return self._merge_function(reshaped_inputs)
137      else:
138        # Transpose all inputs so that batch size is the last dimension.
139        # (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
140        transposed = False
141        for x in inputs:
142          x_ndim = K.ndim(x)
143          if x_ndim is None:
144            x_shape = array_ops.shape(x)
145            batch_size = x_shape[0]
146            new_shape = K.concatenate(
147                [x_shape[1:],
148                 array_ops.expand_dims(batch_size, axis=-1)])
149            x_transposed = array_ops.reshape(
150                x,
151                array_ops.stack(
152                    [batch_size, math_ops.reduce_prod(x_shape[1:])], axis=0))
153            x_transposed = array_ops.transpose(x_transposed, perm=(1, 0))
154            x_transposed = array_ops.reshape(x_transposed, new_shape)
155            reshaped_inputs.append(x_transposed)
156            transposed = True
157          elif x_ndim > 1:
158            dims = list(range(1, x_ndim)) + [0]
159            reshaped_inputs.append(array_ops.transpose(x, perm=dims))
160            transposed = True
161          else:
162            # We don't transpose inputs if they are 1D vectors or scalars.
163            reshaped_inputs.append(x)
164        y = self._merge_function(reshaped_inputs)
165        y_ndim = K.ndim(y)
166        if transposed:
167          # If inputs have been transposed, we have to transpose the output too.
168          if y_ndim is None:
169            y_shape = array_ops.shape(y)
170            y_ndim = array_ops.shape(y_shape)[0]
171            batch_size = y_shape[y_ndim - 1]
172            new_shape = K.concatenate([
173                array_ops.expand_dims(batch_size, axis=-1), y_shape[:y_ndim - 1]
174            ])
175            y = array_ops.reshape(y, (-1, batch_size))
176            y = array_ops.transpose(y, perm=(1, 0))
177            y = array_ops.reshape(y, new_shape)
178          elif y_ndim > 1:
179            dims = [y_ndim - 1] + list(range(y_ndim - 1))
180            y = array_ops.transpose(y, perm=dims)
181        return y
182    else:
183      return self._merge_function(inputs)
185  @tf_utils.shape_type_conversion
186  def compute_output_shape(self, input_shape):
187    if input_shape[0] is None:
188      output_shape = None
189    else:
190      output_shape = input_shape[0][1:]
191    for i in range(1, len(input_shape)):
192      if input_shape[i] is None:
193        shape = None
194      else:
195        shape = input_shape[i][1:]
196      output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
197    batch_sizes = {s[0] for s in input_shape if s is not None} - {None}
198    if len(batch_sizes) == 1:
199      output_shape = (list(batch_sizes)[0],) + output_shape
200    else:
201      output_shape = (None,) + output_shape
202    return output_shape
204  def compute_mask(self, inputs, mask=None):
205    if mask is None:
206      return None
207    if not isinstance(mask, (tuple, list)):
208      raise ValueError('`mask` should be a list.')
209    if not isinstance(inputs, (tuple, list)):
210      raise ValueError('`inputs` should be a list.')
211    if len(mask) != len(inputs):
212      raise ValueError('The lists `inputs` and `mask` '
213                       'should have the same length.')
214    if all(m is None for m in mask):
215      return None
216    masks = [array_ops.expand_dims(m, axis=0) for m in mask if m is not None]
217    return K.all(K.concatenate(masks, axis=0), axis=0, keepdims=False)
221class Add(_Merge):
222  """Layer that adds a list of inputs.
224  It takes as input a list of tensors,
225  all of the same shape, and returns
226  a single tensor (also of the same shape).
228  Examples:
230  >>> input_shape = (2, 3, 4)
231  >>> x1 = tf.random.normal(input_shape)
232  >>> x2 = tf.random.normal(input_shape)
233  >>> y = tf.keras.layers.Add()([x1, x2])
234  >>> print(y.shape)
235  (2, 3, 4)
237  Used in a functional model:
239  >>> input1 = tf.keras.layers.Input(shape=(16,))
240  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
241  >>> input2 = tf.keras.layers.Input(shape=(32,))
242  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
243  >>> # equivalent to `added = tf.keras.layers.add([x1, x2])`
244  >>> added = tf.keras.layers.Add()([x1, x2])
245  >>> out = tf.keras.layers.Dense(4)(added)
246  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
248  """
250  def _merge_function(self, inputs):
251    output = inputs[0]
252    for i in range(1, len(inputs)):
253      output += inputs[i]
254    return output
258class Subtract(_Merge):
259  """Layer that subtracts two inputs.
261  It takes as input a list of tensors of size 2,
262  both of the same shape, and returns a single tensor, (inputs[0] - inputs[1]),
263  also of the same shape.
265  Examples:
267  ```python
268      import keras
270      input1 = keras.layers.Input(shape=(16,))
271      x1 = keras.layers.Dense(8, activation='relu')(input1)
272      input2 = keras.layers.Input(shape=(32,))
273      x2 = keras.layers.Dense(8, activation='relu')(input2)
274      # Equivalent to subtracted = keras.layers.subtract([x1, x2])
275      subtracted = keras.layers.Subtract()([x1, x2])
277      out = keras.layers.Dense(4)(subtracted)
278      model = keras.models.Model(inputs=[input1, input2], outputs=out)
279  ```
280  """
282  @tf_utils.shape_type_conversion
283  def build(self, input_shape):
284    super(Subtract, self).build(input_shape)
285    if len(input_shape) != 2:
286      raise ValueError('A `Subtract` layer should be called '
287                       'on exactly 2 inputs')
289  def _merge_function(self, inputs):
290    if len(inputs) != 2:
291      raise ValueError('A `Subtract` layer should be called '
292                       'on exactly 2 inputs')
293    return inputs[0] - inputs[1]
297class Multiply(_Merge):
298  """Layer that multiplies (element-wise) a list of inputs.
300  It takes as input a list of tensors, all of the same shape, and returns
301  a single tensor (also of the same shape).
303  >>> tf.keras.layers.Multiply()([np.arange(5).reshape(5, 1),
304  ...                             np.arange(5, 10).reshape(5, 1)])
305  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
306  array([[ 0],
307       [ 6],
308       [14],
309       [24],
310       [36]])>
312  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
313  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
314  >>> multiplied = tf.keras.layers.Multiply()([x1, x2])
315  >>> multiplied.shape
316  TensorShape([5, 8])
317  """
319  def _merge_function(self, inputs):
320    output = inputs[0]
321    for i in range(1, len(inputs)):
322      output = output * inputs[i]
323    return output
327class Average(_Merge):
328  """Layer that averages a list of inputs element-wise.
330  It takes as input a list of tensors, all of the same shape, and returns
331  a single tensor (also of the same shape).
333  Example:
335  >>> x1 = np.ones((2, 2))
336  >>> x2 = np.zeros((2, 2))
337  >>> y = tf.keras.layers.Average()([x1, x2])
338  >>> y.numpy().tolist()
339  [[0.5, 0.5], [0.5, 0.5]]
341  Usage in a functional model:
343  >>> input1 = tf.keras.layers.Input(shape=(16,))
344  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
345  >>> input2 = tf.keras.layers.Input(shape=(32,))
346  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
347  >>> avg = tf.keras.layers.Average()([x1, x2])
348  >>> out = tf.keras.layers.Dense(4)(avg)
349  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
351  Raises:
352    ValueError: If there is a shape mismatch between the inputs and the shapes
353      cannot be broadcasted to match.
354  """
356  def _merge_function(self, inputs):
357    output = inputs[0]
358    for i in range(1, len(inputs)):
359      output += inputs[i]
360    return output / len(inputs)
364class Maximum(_Merge):
365  """Layer that computes the maximum (element-wise) a list of inputs.
367  It takes as input a list of tensors, all of the same shape, and returns
368  a single tensor (also of the same shape).
370  >>> tf.keras.layers.Maximum()([np.arange(5).reshape(5, 1),
371  ...                            np.arange(5, 10).reshape(5, 1)])
372  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
373  array([[5],
374       [6],
375       [7],
376       [8],
377       [9]])>
379  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
380  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
381  >>> maxed = tf.keras.layers.Maximum()([x1, x2])
382  >>> maxed.shape
383  TensorShape([5, 8])
384  """
386  def _merge_function(self, inputs):
387    output = inputs[0]
388    for i in range(1, len(inputs)):
389      output = math_ops.maximum(output, inputs[i])
390    return output
394class Minimum(_Merge):
395  """Layer that computes the minimum (element-wise) a list of inputs.
397  It takes as input a list of tensors, all of the same shape, and returns
398  a single tensor (also of the same shape).
400  >>> tf.keras.layers.Minimum()([np.arange(5).reshape(5, 1),
401  ...                            np.arange(5, 10).reshape(5, 1)])
402  <tf.Tensor: shape=(5, 1), dtype=int64, numpy=
403  array([[0],
404       [1],
405       [2],
406       [3],
407       [4]])>
409  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
410  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
411  >>> minned = tf.keras.layers.Minimum()([x1, x2])
412  >>> minned.shape
413  TensorShape([5, 8])
414  """
416  def _merge_function(self, inputs):
417    output = inputs[0]
418    for i in range(1, len(inputs)):
419      output = math_ops.minimum(output, inputs[i])
420    return output
424class Concatenate(_Merge):
425  """Layer that concatenates a list of inputs.
427  It takes as input a list of tensors, all of the same shape except
428  for the concatenation axis, and returns a single tensor that is the
429  concatenation of all inputs.
431  >>> x = np.arange(20).reshape(2, 2, 5)
432  >>> print(x)
433  [[[ 0  1  2  3  4]
434    [ 5  6  7  8  9]]
435   [[10 11 12 13 14]
436    [15 16 17 18 19]]]
437  >>> y = np.arange(20, 30).reshape(2, 1, 5)
438  >>> print(y)
439  [[[20 21 22 23 24]]
440   [[25 26 27 28 29]]]
441  >>> tf.keras.layers.Concatenate(axis=1)([x, y])
442  <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
443  array([[[ 0,  1,  2,  3,  4],
444          [ 5,  6,  7,  8,  9],
445          [20, 21, 22, 23, 24]],
446         [[10, 11, 12, 13, 14],
447          [15, 16, 17, 18, 19],
448          [25, 26, 27, 28, 29]]])>
450  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
451  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
452  >>> concatted = tf.keras.layers.Concatenate()([x1, x2])
453  >>> concatted.shape
454  TensorShape([5, 16])
456  """
458  def __init__(self, axis=-1, **kwargs):
459    """Instantiates a Concatenate layer.
461    >>> x = np.arange(20).reshape(2, 2, 5)
462    >>> print(x)
463    [[[ 0  1  2  3  4]
464      [ 5  6  7  8  9]]
465     [[10 11 12 13 14]
466      [15 16 17 18 19]]]
467    >>> y = np.arange(20, 30).reshape(2, 1, 5)
468    >>> print(y)
469    [[[20 21 22 23 24]]
470     [[25 26 27 28 29]]]
471    >>> tf.keras.layers.Concatenate(axis=1)([x, y])
472    <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
473    array([[[ 0,  1,  2,  3,  4],
474            [ 5,  6,  7,  8,  9],
475            [20, 21, 22, 23, 24]],
476           [[10, 11, 12, 13, 14],
477            [15, 16, 17, 18, 19],
478            [25, 26, 27, 28, 29]]])>
480    Args:
481      axis: Axis along which to concatenate.
482      **kwargs: standard layer keyword arguments.
483    """
484    super(Concatenate, self).__init__(**kwargs)
485    self.axis = axis
486    self.supports_masking = True
487    self._reshape_required = False
489  @tf_utils.shape_type_conversion
490  def build(self, input_shape):
491    # Used purely for shape validation.
492    if not isinstance(input_shape[0], tuple) or len(input_shape) < 1:
493      raise ValueError('A `Concatenate` layer should be called '
494                       'on a list of at least 1 input.')
495    if all(shape is None for shape in input_shape):
496      return
497    reduced_inputs_shapes = [list(shape) for shape in input_shape]
498    shape_set = set()
499    for i in range(len(reduced_inputs_shapes)):
500      del reduced_inputs_shapes[i][self.axis]
501      shape_set.add(tuple(reduced_inputs_shapes[i]))
503    if len(shape_set) != 1:
504      err_msg = ('A `Concatenate` layer requires inputs with matching shapes '
505                 'except for the concat axis. Got inputs shapes: %s' %
506                 input_shape)
507      # Make sure all the shapes have same ranks.
508      ranks = set(len(shape) for shape in shape_set)
509      if len(ranks) != 1:
510        raise ValueError(err_msg)
511      # Get the only rank for the set.
512      (rank,) = ranks
513      for axis in range(rank):
514        # Skip the Nones in the shape since they are dynamic, also the axis for
515        # concat has been removed above.
516        unique_dims = set(
517            shape[axis] for shape in shape_set if shape[axis] is not None)
518        if len(unique_dims) > 1:
519          raise ValueError(err_msg)
521  def _merge_function(self, inputs):
522    return K.concatenate(inputs, axis=self.axis)
524  @tf_utils.shape_type_conversion
525  def compute_output_shape(self, input_shape):
526    if ((not isinstance(input_shape, (tuple, list))) or
527        (not isinstance(input_shape[0], (tuple, list)))):
528      # The tf_utils.shape_type_conversion decorator turns tensorshapes
529      # into tuples, so we need to verify that `input_shape` is a list/tuple,
530      # *and* that the individual elements are themselves shape tuples.
531      raise ValueError('A `Concatenate` layer should be called '
532                       'on a list of inputs.')
533    input_shapes = input_shape
534    output_shape = list(input_shapes[0])
535    for shape in input_shapes[1:]:
536      if output_shape[self.axis] is None or shape[self.axis] is None:
537        output_shape[self.axis] = None
538        break
539      output_shape[self.axis] += shape[self.axis]
540    return tuple(output_shape)
542  def compute_mask(self, inputs, mask=None):
543    if mask is None:
544      return None
545    if not isinstance(mask, (tuple, list)):
546      raise ValueError('`mask` should be a list.')
547    if not isinstance(inputs, (tuple, list)):
548      raise ValueError('`inputs` should be a list.')
549    if len(mask) != len(inputs):
550      raise ValueError('The lists `inputs` and `mask` '
551                       'should have the same length.')
552    if all(m is None for m in mask):
553      return None
554    # Make a list of masks while making sure
555    # the dimensionality of each mask
556    # is the same as the corresponding input.
557    masks = []
558    for input_i, mask_i in zip(inputs, mask):
559      if mask_i is None:
560        # Input is unmasked. Append all 1s to masks,
561        masks.append(array_ops.ones_like(input_i, dtype='bool'))
562      elif K.ndim(mask_i) < K.ndim(input_i):
563        # Mask is smaller than the input, expand it
564        masks.append(array_ops.expand_dims(mask_i, axis=-1))
565      else:
566        masks.append(mask_i)
567    concatenated = K.concatenate(masks, axis=self.axis)
568    return K.all(concatenated, axis=-1, keepdims=False)
570  def get_config(self):
571    config = {
572        'axis': self.axis,
573    }
574    base_config = super(Concatenate, self).get_config()
575    return dict(list(base_config.items()) + list(config.items()))
579class Dot(_Merge):
580  """Layer that computes a dot product between samples in two tensors.
582  E.g. if applied to a list of two tensors `a` and `b` of shape
583  `(batch_size, n)`, the output will be a tensor of shape `(batch_size, 1)`
584  where each entry `i` will be the dot product between
585  `a[i]` and `b[i]`.
587  >>> x = np.arange(10).reshape(1, 5, 2)
588  >>> print(x)
589  [[[0 1]
590    [2 3]
591    [4 5]
592    [6 7]
593    [8 9]]]
594  >>> y = np.arange(10, 20).reshape(1, 2, 5)
595  >>> print(y)
596  [[[10 11 12 13 14]
597    [15 16 17 18 19]]]
598  >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
599  <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
600  array([[[260, 360],
601          [320, 445]]])>
603  >>> x1 = tf.keras.layers.Dense(8)(np.arange(10).reshape(5, 2))
604  >>> x2 = tf.keras.layers.Dense(8)(np.arange(10, 20).reshape(5, 2))
605  >>> dotted = tf.keras.layers.Dot(axes=1)([x1, x2])
606  >>> dotted.shape
607  TensorShape([5, 1])
610  """
612  def __init__(self, axes, normalize=False, **kwargs):
613    """Initializes a layer that computes the element-wise dot product.
615      >>> x = np.arange(10).reshape(1, 5, 2)
616      >>> print(x)
617      [[[0 1]
618        [2 3]
619        [4 5]
620        [6 7]
621        [8 9]]]
622      >>> y = np.arange(10, 20).reshape(1, 2, 5)
623      >>> print(y)
624      [[[10 11 12 13 14]
625        [15 16 17 18 19]]]
626      >>> tf.keras.layers.Dot(axes=(1, 2))([x, y])
627      <tf.Tensor: shape=(1, 2, 2), dtype=int64, numpy=
628      array([[[260, 360],
629              [320, 445]]])>
631    Args:
632      axes: Integer or tuple of integers,
633        axis or axes along which to take the dot product. If a tuple, should
634        be two integers corresponding to the desired axis from the first input
635        and the desired axis from the second input, respectively. Note that the
636        size of the two selected axes must match.
637      normalize: Whether to L2-normalize samples along the
638        dot product axis before taking the dot product.
639        If set to True, then the output of the dot product
640        is the cosine proximity between the two samples.
641      **kwargs: Standard layer keyword arguments.
642    """
643    super(Dot, self).__init__(**kwargs)
644    if not isinstance(axes, int):
645      if not isinstance(axes, (list, tuple)):
646        raise TypeError('Invalid type for `axes` - '
647                        'should be a list or an int.')
648      if len(axes) != 2:
649        raise ValueError('Invalid format for `axes` - '
650                         'should contain two elements.')
651      if not isinstance(axes[0], int) or not isinstance(axes[1], int):
652        raise ValueError('Invalid format for `axes` - '
653                         'list elements should be "int".')
654    self.axes = axes
655    self.normalize = normalize
656    self.supports_masking = True
657    self._reshape_required = False
659  @tf_utils.shape_type_conversion
660  def build(self, input_shape):
661    # Used purely for shape validation.
662    if not isinstance(input_shape[0], tuple) or len(input_shape) != 2:
663      raise ValueError('A `Dot` layer should be called '
664                       'on a list of 2 inputs.')
665    shape1 = input_shape[0]
666    shape2 = input_shape[1]
667    if shape1 is None or shape2 is None:
668      return
669    if isinstance(self.axes, int):
670      if self.axes < 0:
671        axes = [self.axes % len(shape1), self.axes % len(shape2)]
672      else:
673        axes = [self.axes] * 2
674    else:
675      axes = self.axes
676    if shape1[axes[0]] != shape2[axes[1]]:
677      raise ValueError('Dimension incompatibility '
678                       '%s != %s. ' % (shape1[axes[0]], shape2[axes[1]]) +
679                       'Layer shapes: %s, %s. ' % (shape1, shape2) +
680                       'Chosen axes: %s, %s' % (axes[0], axes[1]))
682  def _merge_function(self, inputs):
683    base_layer_utils.no_ragged_support(inputs, self.name)
684    if len(inputs) != 2:
685      raise ValueError('A `Dot` layer should be called on exactly 2 inputs')
686    x1 = inputs[0]
687    x2 = inputs[1]
688    if isinstance(self.axes, int):
689      if self.axes < 0:
690        axes = [self.axes % K.ndim(x1), self.axes % K.ndim(x2)]
691      else:
692        axes = [self.axes] * 2
693    else:
694      axes = []
695      for i in range(len(self.axes)):
696        if self.axes[i] < 0:
697          axes.append(self.axes[i] % K.ndim(inputs[i]))
698        else:
699          axes.append(self.axes[i])
700    if self.normalize:
701      x1 = nn.l2_normalize(x1, axis=axes[0])
702      x2 = nn.l2_normalize(x2, axis=axes[1])
703    output = K.batch_dot(x1, x2, axes)
704    return output
706  @tf_utils.shape_type_conversion
707  def compute_output_shape(self, input_shape):
708    if not isinstance(input_shape, (tuple, list)) or len(input_shape) != 2:
709      raise ValueError('A `Dot` layer should be called '
710                       'on a list of 2 inputs.')
711    shape1 = list(input_shape[0])
712    shape2 = list(input_shape[1])
713    if isinstance(self.axes, int):
714      if self.axes < 0:
715        axes = [self.axes % len(shape1), self.axes % len(shape2)]
716      else:
717        axes = [self.axes] * 2
718    else:
719      axes = self.axes
720    shape1.pop(axes[0])
721    shape2.pop(axes[1])
722    shape2.pop(0)
723    output_shape = shape1 + shape2
724    if len(output_shape) == 1:
725      output_shape += [1]
726    return tuple(output_shape)
728  def compute_mask(self, inputs, mask=None):
729    return None
731  def get_config(self):
732    config = {
733        'axes': self.axes,
734        'normalize': self.normalize,
735    }
736    base_config = super(Dot, self).get_config()
737    return dict(list(base_config.items()) + list(config.items()))
741def add(inputs, **kwargs):
742  """Functional interface to the `tf.keras.layers.Add` layer.
744  Args:
745      inputs: A list of input tensors (at least 2) with the same shape.
746      **kwargs: Standard layer keyword arguments.
748  Returns:
749      A tensor as the sum of the inputs. It has the same shape as the inputs.
751  Examples:
753  >>> input_shape = (2, 3, 4)
754  >>> x1 = tf.random.normal(input_shape)
755  >>> x2 = tf.random.normal(input_shape)
756  >>> y = tf.keras.layers.add([x1, x2])
757  >>> print(y.shape)
758  (2, 3, 4)
760  Used in a functional model:
762  >>> input1 = tf.keras.layers.Input(shape=(16,))
763  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
764  >>> input2 = tf.keras.layers.Input(shape=(32,))
765  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
766  >>> added = tf.keras.layers.add([x1, x2])
767  >>> out = tf.keras.layers.Dense(4)(added)
768  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
770  """
771  return Add(**kwargs)(inputs)
775def subtract(inputs, **kwargs):
776  """Functional interface to the `Subtract` layer.
778  Args:
779      inputs: A list of input tensors (exactly 2).
780      **kwargs: Standard layer keyword arguments.
782  Returns:
783      A tensor, the difference of the inputs.
785  Examples:
787  ```python
788      import keras
790      input1 = keras.layers.Input(shape=(16,))
791      x1 = keras.layers.Dense(8, activation='relu')(input1)
792      input2 = keras.layers.Input(shape=(32,))
793      x2 = keras.layers.Dense(8, activation='relu')(input2)
794      subtracted = keras.layers.subtract([x1, x2])
796      out = keras.layers.Dense(4)(subtracted)
797      model = keras.models.Model(inputs=[input1, input2], outputs=out)
798  ```
799  """
800  return Subtract(**kwargs)(inputs)
804def multiply(inputs, **kwargs):
805  """Functional interface to the `Multiply` layer.
807  Args:
808      inputs: A list of input tensors (at least 2).
809      **kwargs: Standard layer keyword arguments.
811  Returns:
812      A tensor, the element-wise product of the inputs.
813  """
814  return Multiply(**kwargs)(inputs)
818def average(inputs, **kwargs):
819  """Functional interface to the `tf.keras.layers.Average` layer.
821  Example:
823  >>> x1 = np.ones((2, 2))
824  >>> x2 = np.zeros((2, 2))
825  >>> y = tf.keras.layers.Average()([x1, x2])
826  >>> y.numpy().tolist()
827  [[0.5, 0.5], [0.5, 0.5]]
829  Usage in a functional model:
831  >>> input1 = tf.keras.layers.Input(shape=(16,))
832  >>> x1 = tf.keras.layers.Dense(8, activation='relu')(input1)
833  >>> input2 = tf.keras.layers.Input(shape=(32,))
834  >>> x2 = tf.keras.layers.Dense(8, activation='relu')(input2)
835  >>> avg = tf.keras.layers.Average()([x1, x2])
836  >>> out = tf.keras.layers.Dense(4)(avg)
837  >>> model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
839  Args:
840      inputs: A list of input tensors (at least 2).
841      **kwargs: Standard layer keyword arguments.
843  Returns:
844      A tensor, the average of the inputs.
846  Raises:
847    ValueError: If there is a shape mismatch between the inputs and the shapes
848      cannot be broadcasted to match.
849  """
850  return Average(**kwargs)(inputs)
854def maximum(inputs, **kwargs):
855  """Functional interface to compute maximum (element-wise) list of `inputs`.
857  This is equivalent to the `tf.keras.layers.Maximum` layer.
859  For example:
861  ```python
862  input1 = tf.keras.layers.Input(shape=(16,))
863  x1 = tf.keras.layers.Dense(8, activation='relu')(input1) #shape=(None, 8)
864  input2 = tf.keras.layers.Input(shape=(32,))
865  x2 = tf.keras.layers.Dense(8, activation='relu')(input2) #shape=(None, 8)
866  max_inp=tf.keras.layers.maximum([x1,x2]) #shape=(None, 8)
867  out = tf.keras.layers.Dense(4)(max_inp)
868  model = tf.keras.models.Model(inputs=[input1, input2], outputs=out)
869  ```
871  Args:
872      inputs: A list of input tensors (at least 2) of same shape.
873      **kwargs: Standard layer keyword arguments.
875  Returns:
876      A tensor (of same shape as input tensor) with the element-wise
877      maximum of the inputs.
879  Raises:
880      ValueError: If input tensors are of different shape.
881  """
882  return Maximum(**kwargs)(inputs)
886def minimum(inputs, **kwargs):
887  """Functional interface to the `Minimum` layer.
889  Args:
890      inputs: A list of input tensors (at least 2).
891      **kwargs: Standard layer keyword arguments.
893  Returns:
894      A tensor, the element-wise minimum of the inputs.
895  """
896  return Minimum(**kwargs)(inputs)
900def concatenate(inputs, axis=-1, **kwargs):
901  """Functional interface to the `Concatenate` layer.
903  >>> x = np.arange(20).reshape(2, 2, 5)
904  >>> print(x)
905  [[[ 0  1  2  3  4]
906    [ 5  6  7  8  9]]
907   [[10 11 12 13 14]
908    [15 16 17 18 19]]]
909  >>> y = np.arange(20, 30).reshape(2, 1, 5)
910  >>> print(y)
911  [[[20 21 22 23 24]]
912   [[25 26 27 28 29]]]
913  >>> tf.keras.layers.concatenate([x, y],
914  ...                             axis=1)
915  <tf.Tensor: shape=(2, 3, 5), dtype=int64, numpy=
916  array([[[ 0,  1,  2,  3,  4],
917        [ 5,  6,  7,  8,  9],
918        [20, 21, 22, 23, 24]],
919       [[10, 11, 12, 13, 14],
920        [15, 16, 17, 18, 19],
921        [25, 26, 27, 28, 29]]])>
923  Args:
924      inputs: A list of input tensors (at least 2).
925      axis: Concatenation axis.
926      **kwargs: Standard layer keyword arguments.
928  Returns:
929      A tensor, the concatenation of the inputs alongside axis `axis`.
930  """
931  return Concatenate(axis=axis, **kwargs)(inputs)
935def dot(inputs, axes, normalize=False, **kwargs):
936  """Functional interface to the `Dot` layer.
938  Args:
939      inputs: A list of input tensors (at least 2).
940      axes: Integer or tuple of integers,
941          axis or axes along which to take the dot product.
942      normalize: Whether to L2-normalize samples along the
943          dot product axis before taking the dot product.
944          If set to True, then the output of the dot product
945          is the cosine proximity between the two samples.
946      **kwargs: Standard layer keyword arguments.
948  Returns:
949      A tensor, the dot product of the samples from the inputs.
950  """
951  return Dot(axes=axes, normalize=normalize, **kwargs)(inputs)