1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras-based einsum dense layer."""
16# pylint: disable=g-classes-have-attributes
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.keras import activations
25from tensorflow.python.keras import constraints
26from tensorflow.python.keras import initializers
27from tensorflow.python.keras import regularizers
28from tensorflow.python.keras.engine.base_layer import Layer
29from tensorflow.python.ops import special_math_ops
30from tensorflow.python.util.tf_export import keras_export
31
32
33@keras_export("keras.layers.experimental.EinsumDense")
34class EinsumDense(Layer):
35  """A layer that uses tf.einsum as the backing computation.
36
37  This layer can perform einsum calculations of arbitrary dimensionality.
38
39  Args:
40    equation: An equation describing the einsum to perform. This equation must
41      be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or
42      `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis
43      expression sequence.
44    output_shape: The expected shape of the output tensor (excluding the batch
45      dimension and any dimensions represented by ellipses). You can specify
46      None for any dimension that is unknown or can be inferred from the input
47      shape.
48    activation: Activation function to use. If you don't specify anything, no
49      activation is applied (that is, a "linear" activation: `a(x) = x`).
50    bias_axes: A string containing the output dimension(s) to apply a bias to.
51      Each character in the `bias_axes` string should correspond to a character
52      in the output portion of the `equation` string.
53    kernel_initializer: Initializer for the `kernel` weights matrix.
54    bias_initializer: Initializer for the bias vector.
55    kernel_regularizer: Regularizer function applied to the `kernel` weights
56      matrix.
57    bias_regularizer: Regularizer function applied to the bias vector.
58    activity_regularizer: Regularizer function applied to the output of the
59      layer (its "activation")..
60    kernel_constraint: Constraint function applied to the `kernel` weights
61      matrix.
62    bias_constraint: Constraint function applied to the bias vector.
63
64  Examples:
65
66  **Biased dense layer with einsums**
67
68  This example shows how to instantiate a standard Keras dense layer using
69  einsum operations. This example is equivalent to
70  `tf.keras.layers.Dense(64, use_bias=True)`.
71
72  >>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c")
73  >>> input_tensor = tf.keras.Input(shape=[32])
74  >>> output_tensor = layer(input_tensor)
75  >>> output_tensor
76  <... shape=(None, 64) dtype=...>
77
78  **Applying a dense layer to a sequence**
79
80  This example shows how to instantiate a layer that applies the same dense
81  operation to every element in a sequence. Here, the 'output_shape' has two
82  values (since there are two non-batch dimensions in the output); the first
83  dimension in the output_shape is `None`, because the sequence dimension `b`
84  has an unknown shape.
85
86  >>> layer = EinsumDense("abc,cd->abd",
87  ...                     output_shape=(None, 64),
88  ...                     bias_axes="d")
89  >>> input_tensor = tf.keras.Input(shape=[32, 128])
90  >>> output_tensor = layer(input_tensor)
91  >>> output_tensor
92  <... shape=(None, 32, 64) dtype=...>
93
94  **Applying a dense layer to a sequence using ellipses**
95
96  This example shows how to instantiate a layer that applies the same dense
97  operation to every element in a sequence, but uses the ellipsis notation
98  instead of specifying the batch and sequence dimensions.
99
100  Because we are using ellipsis notation and have specified only one axis, the
101  output_shape arg is a single value. When instantiated in this way, the layer
102  can handle any number of sequence dimensions - including the case where no
103  sequence dimension exists.
104
105  >>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y")
106  >>> input_tensor = tf.keras.Input(shape=[32, 128])
107  >>> output_tensor = layer(input_tensor)
108  >>> output_tensor
109  <... shape=(None, 32, 64) dtype=...>
110  """
111
112  def __init__(self,
113               equation,
114               output_shape,
115               activation=None,
116               bias_axes=None,
117               kernel_initializer="glorot_uniform",
118               bias_initializer="zeros",
119               kernel_regularizer=None,
120               bias_regularizer=None,
121               activity_regularizer=None,
122               kernel_constraint=None,
123               bias_constraint=None,
124               **kwargs):
125    super(EinsumDense, self).__init__(**kwargs)
126    self.equation = equation
127    if isinstance(output_shape, int):
128      self.partial_output_shape = [output_shape]
129    else:
130      self.partial_output_shape = list(output_shape)
131    self.bias_axes = bias_axes
132    self.activation = activations.get(activation)
133    self.kernel_initializer = initializers.get(kernel_initializer)
134    self.bias_initializer = initializers.get(bias_initializer)
135    self.kernel_regularizer = regularizers.get(kernel_regularizer)
136    self.bias_regularizer = regularizers.get(bias_regularizer)
137    self.kernel_constraint = constraints.get(kernel_constraint)
138    self.bias_constraint = constraints.get(bias_constraint)
139
140  def build(self, input_shape):
141    input_shape = tensor_shape.TensorShape(input_shape)
142    shape_data = _analyze_einsum_string(self.equation,
143                                        self.bias_axes,
144                                        input_shape,
145                                        self.partial_output_shape)
146    kernel_shape, bias_shape, self.full_output_shape = shape_data
147    self.kernel = self.add_weight(
148        "kernel",
149        shape=kernel_shape,
150        initializer=self.kernel_initializer,
151        regularizer=self.kernel_regularizer,
152        constraint=self.kernel_constraint,
153        dtype=self.dtype,
154        trainable=True)
155
156    if bias_shape is not None:
157      self.bias = self.add_weight(
158          "bias",
159          shape=bias_shape,
160          initializer=self.bias_initializer,
161          regularizer=self.bias_regularizer,
162          constraint=self.bias_constraint,
163          dtype=self.dtype,
164          trainable=True)
165    else:
166      self.bias = None
167    super(EinsumDense, self).build(input_shape)
168
169  def compute_output_shape(self, _):
170    return tensor_shape.TensorShape(self.full_output_shape)
171
172  def get_config(self):
173    config = {
174        "output_shape":
175            self.partial_output_shape,
176        "equation":
177            self.equation,
178        "activation":
179            activations.serialize(self.activation),
180        "bias_axes":
181            self.bias_axes,
182        "kernel_initializer":
183            initializers.serialize(self.kernel_initializer),
184        "bias_initializer":
185            initializers.serialize(self.bias_initializer),
186        "kernel_regularizer":
187            regularizers.serialize(self.kernel_regularizer),
188        "bias_regularizer":
189            regularizers.serialize(self.bias_regularizer),
190        "activity_regularizer":
191            regularizers.serialize(self.activity_regularizer),
192        "kernel_constraint":
193            constraints.serialize(self.kernel_constraint),
194        "bias_constraint":
195            constraints.serialize(self.bias_constraint),
196    }
197    base_config = super(EinsumDense, self).get_config()
198    return dict(list(base_config.items()) + list(config.items()))
199
200  def call(self, inputs):
201    ret = special_math_ops.einsum(self.equation, inputs, self.kernel)
202    if self.bias is not None:
203      ret += self.bias
204    if self.activation is not None:
205      ret = self.activation(ret)
206    return ret
207
208
209def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape):
210  """Analyzes an einsum string to determine the required weight shape."""
211
212  dot_replaced_string = re.sub(r"\.\.\.", "0", equation)
213
214  # This is the case where no ellipses are present in the string.
215  split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)",
216                          dot_replaced_string)
217  if split_string:
218    return _analyze_split_string(split_string, bias_axes, input_shape,
219                                 output_shape)
220
221  # This is the case where ellipses are present on the left.
222  split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)",
223                          dot_replaced_string)
224  if split_string:
225    return _analyze_split_string(
226        split_string, bias_axes, input_shape, output_shape, left_elided=True)
227
228  # This is the case where ellipses are present on the right.
229  split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0",
230                          dot_replaced_string)
231  if split_string:
232    return _analyze_split_string(split_string, bias_axes, input_shape,
233                                 output_shape)
234
235  raise ValueError(
236      "Invalid einsum equation '%s'. Equations must be in the form "
237      "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation)
238
239
240def _analyze_split_string(split_string,
241                          bias_axes,
242                          input_shape,
243                          output_shape,
244                          left_elided=False):
245  """Analyze an pre-split einsum string to find the weight shape."""
246  input_spec = split_string.group(1)
247  weight_spec = split_string.group(2)
248  output_spec = split_string.group(3)
249  elided = len(input_shape) - len(input_spec)
250
251  if isinstance(output_shape, int):
252    output_shape = [output_shape]
253  else:
254    output_shape = list(output_shape)
255
256  output_shape.insert(0, input_shape[0])
257
258  if elided > 0 and left_elided:
259    for i in range(1, elided):
260      # We already inserted the 0th input dimension at dim 0, so we need to
261      # start at location 1 here.
262      output_shape.insert(1, input_shape[i])
263  elif elided > 0 and not left_elided:
264    for i in range(len(input_shape) - elided, len(input_shape)):
265      output_shape.append(input_shape[i])
266
267  if left_elided:
268    # If we have beginning dimensions elided, we need to use negative indexing
269    # to determine where in the input dimension our values are.
270    input_dim_map = {
271        dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec)
272    }
273    # Because we've constructed the full output shape already, we don't need
274    # to do negative indexing.
275    output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)}
276  else:
277    input_dim_map = {dim: i for i, dim in enumerate(input_spec)}
278    output_dim_map = {dim: i for i, dim in enumerate(output_spec)}
279
280  for i, dim in enumerate(input_spec):
281    input_shape_at_dim = input_shape[i]
282    if dim in output_dim_map:
283      output_shape_at_dim = output_shape[output_dim_map[dim]]
284      if (output_shape_at_dim is not None and
285          output_shape_at_dim != input_shape_at_dim):
286        raise ValueError(
287            "Input shape and output shape do not match at shared "
288            "dimension '%s'. Input shape is %s, and output shape "
289            "is %s." %
290            (dim, input_shape_at_dim, output_shape[output_dim_map[dim]]))
291
292  for dim in output_spec:
293    if dim not in input_spec and dim not in weight_spec:
294      raise ValueError("Dimension '%s' was specified in the output '%s' but "
295                       "has no corresponding dim in the input spec '%s' or "
296                       "weight spec '%s.'" % (dim, output_spec, input_spec,
297                                              output_spec))
298
299  weight_shape = []
300  for dim in weight_spec:
301    if dim in input_dim_map:
302      weight_shape.append(input_shape[input_dim_map[dim]])
303    elif dim in output_dim_map:
304      weight_shape.append(output_shape[output_dim_map[dim]])
305    else:
306      raise ValueError("Weight dimension '%s' did not have a match in either "
307                       "the input spec '%s' or the output spec '%s'. For this "
308                       "layer, the weight must be fully specified." %
309                       (dim, input_spec, output_spec))
310
311  if bias_axes is not None:
312    num_left_elided = elided if left_elided else 0
313    idx_map = {
314        char: output_shape[i + num_left_elided]
315        for i, char in enumerate(output_spec)
316    }
317
318    for char in bias_axes:
319      if char not in output_spec:
320        raise ValueError("Bias dimension '%s' was requested, but is not a part "
321                         "of the output specification '%s'" %
322                         (char, output_spec))
323
324    first_bias_location = min([output_spec.find(char) for char in bias_axes])
325    bias_output_spec = output_spec[first_bias_location:]
326
327    bias_shape = [
328        idx_map[char] if char in bias_axes else 1 for char in bias_output_spec
329    ]
330
331    if not left_elided:
332      for _ in range(elided):
333        bias_shape.append(1)
334  else:
335    bias_shape = None
336
337  return weight_shape, bias_shape, output_shape
338