1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Keras-based einsum dense layer.""" 16# pylint: disable=g-classes-have-attributes 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22 23from tensorflow.python.framework import tensor_shape 24from tensorflow.python.keras import activations 25from tensorflow.python.keras import constraints 26from tensorflow.python.keras import initializers 27from tensorflow.python.keras import regularizers 28from tensorflow.python.keras.engine.base_layer import Layer 29from tensorflow.python.ops import special_math_ops 30from tensorflow.python.util.tf_export import keras_export 31 32 33@keras_export("keras.layers.experimental.EinsumDense") 34class EinsumDense(Layer): 35 """A layer that uses tf.einsum as the backing computation. 36 37 This layer can perform einsum calculations of arbitrary dimensionality. 38 39 Args: 40 equation: An equation describing the einsum to perform. This equation must 41 be a valid einsum string of the form `ab,bc->ac`, `...ab,bc->...ac`, or 42 `ab...,bc->ac...` where 'ab', 'bc', and 'ac' can be any valid einsum axis 43 expression sequence. 44 output_shape: The expected shape of the output tensor (excluding the batch 45 dimension and any dimensions represented by ellipses). You can specify 46 None for any dimension that is unknown or can be inferred from the input 47 shape. 48 activation: Activation function to use. If you don't specify anything, no 49 activation is applied (that is, a "linear" activation: `a(x) = x`). 50 bias_axes: A string containing the output dimension(s) to apply a bias to. 51 Each character in the `bias_axes` string should correspond to a character 52 in the output portion of the `equation` string. 53 kernel_initializer: Initializer for the `kernel` weights matrix. 54 bias_initializer: Initializer for the bias vector. 55 kernel_regularizer: Regularizer function applied to the `kernel` weights 56 matrix. 57 bias_regularizer: Regularizer function applied to the bias vector. 58 activity_regularizer: Regularizer function applied to the output of the 59 layer (its "activation").. 60 kernel_constraint: Constraint function applied to the `kernel` weights 61 matrix. 62 bias_constraint: Constraint function applied to the bias vector. 63 64 Examples: 65 66 **Biased dense layer with einsums** 67 68 This example shows how to instantiate a standard Keras dense layer using 69 einsum operations. This example is equivalent to 70 `tf.keras.layers.Dense(64, use_bias=True)`. 71 72 >>> layer = EinsumDense("ab,bc->ac", output_shape=64, bias_axes="c") 73 >>> input_tensor = tf.keras.Input(shape=[32]) 74 >>> output_tensor = layer(input_tensor) 75 >>> output_tensor 76 <... shape=(None, 64) dtype=...> 77 78 **Applying a dense layer to a sequence** 79 80 This example shows how to instantiate a layer that applies the same dense 81 operation to every element in a sequence. Here, the 'output_shape' has two 82 values (since there are two non-batch dimensions in the output); the first 83 dimension in the output_shape is `None`, because the sequence dimension `b` 84 has an unknown shape. 85 86 >>> layer = EinsumDense("abc,cd->abd", 87 ... output_shape=(None, 64), 88 ... bias_axes="d") 89 >>> input_tensor = tf.keras.Input(shape=[32, 128]) 90 >>> output_tensor = layer(input_tensor) 91 >>> output_tensor 92 <... shape=(None, 32, 64) dtype=...> 93 94 **Applying a dense layer to a sequence using ellipses** 95 96 This example shows how to instantiate a layer that applies the same dense 97 operation to every element in a sequence, but uses the ellipsis notation 98 instead of specifying the batch and sequence dimensions. 99 100 Because we are using ellipsis notation and have specified only one axis, the 101 output_shape arg is a single value. When instantiated in this way, the layer 102 can handle any number of sequence dimensions - including the case where no 103 sequence dimension exists. 104 105 >>> layer = EinsumDense("...x,xy->...y", output_shape=64, bias_axes="y") 106 >>> input_tensor = tf.keras.Input(shape=[32, 128]) 107 >>> output_tensor = layer(input_tensor) 108 >>> output_tensor 109 <... shape=(None, 32, 64) dtype=...> 110 """ 111 112 def __init__(self, 113 equation, 114 output_shape, 115 activation=None, 116 bias_axes=None, 117 kernel_initializer="glorot_uniform", 118 bias_initializer="zeros", 119 kernel_regularizer=None, 120 bias_regularizer=None, 121 activity_regularizer=None, 122 kernel_constraint=None, 123 bias_constraint=None, 124 **kwargs): 125 super(EinsumDense, self).__init__(**kwargs) 126 self.equation = equation 127 if isinstance(output_shape, int): 128 self.partial_output_shape = [output_shape] 129 else: 130 self.partial_output_shape = list(output_shape) 131 self.bias_axes = bias_axes 132 self.activation = activations.get(activation) 133 self.kernel_initializer = initializers.get(kernel_initializer) 134 self.bias_initializer = initializers.get(bias_initializer) 135 self.kernel_regularizer = regularizers.get(kernel_regularizer) 136 self.bias_regularizer = regularizers.get(bias_regularizer) 137 self.kernel_constraint = constraints.get(kernel_constraint) 138 self.bias_constraint = constraints.get(bias_constraint) 139 140 def build(self, input_shape): 141 input_shape = tensor_shape.TensorShape(input_shape) 142 shape_data = _analyze_einsum_string(self.equation, 143 self.bias_axes, 144 input_shape, 145 self.partial_output_shape) 146 kernel_shape, bias_shape, self.full_output_shape = shape_data 147 self.kernel = self.add_weight( 148 "kernel", 149 shape=kernel_shape, 150 initializer=self.kernel_initializer, 151 regularizer=self.kernel_regularizer, 152 constraint=self.kernel_constraint, 153 dtype=self.dtype, 154 trainable=True) 155 156 if bias_shape is not None: 157 self.bias = self.add_weight( 158 "bias", 159 shape=bias_shape, 160 initializer=self.bias_initializer, 161 regularizer=self.bias_regularizer, 162 constraint=self.bias_constraint, 163 dtype=self.dtype, 164 trainable=True) 165 else: 166 self.bias = None 167 super(EinsumDense, self).build(input_shape) 168 169 def compute_output_shape(self, _): 170 return tensor_shape.TensorShape(self.full_output_shape) 171 172 def get_config(self): 173 config = { 174 "output_shape": 175 self.partial_output_shape, 176 "equation": 177 self.equation, 178 "activation": 179 activations.serialize(self.activation), 180 "bias_axes": 181 self.bias_axes, 182 "kernel_initializer": 183 initializers.serialize(self.kernel_initializer), 184 "bias_initializer": 185 initializers.serialize(self.bias_initializer), 186 "kernel_regularizer": 187 regularizers.serialize(self.kernel_regularizer), 188 "bias_regularizer": 189 regularizers.serialize(self.bias_regularizer), 190 "activity_regularizer": 191 regularizers.serialize(self.activity_regularizer), 192 "kernel_constraint": 193 constraints.serialize(self.kernel_constraint), 194 "bias_constraint": 195 constraints.serialize(self.bias_constraint), 196 } 197 base_config = super(EinsumDense, self).get_config() 198 return dict(list(base_config.items()) + list(config.items())) 199 200 def call(self, inputs): 201 ret = special_math_ops.einsum(self.equation, inputs, self.kernel) 202 if self.bias is not None: 203 ret += self.bias 204 if self.activation is not None: 205 ret = self.activation(ret) 206 return ret 207 208 209def _analyze_einsum_string(equation, bias_axes, input_shape, output_shape): 210 """Analyzes an einsum string to determine the required weight shape.""" 211 212 dot_replaced_string = re.sub(r"\.\.\.", "0", equation) 213 214 # This is the case where no ellipses are present in the string. 215 split_string = re.match("([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)", 216 dot_replaced_string) 217 if split_string: 218 return _analyze_split_string(split_string, bias_axes, input_shape, 219 output_shape) 220 221 # This is the case where ellipses are present on the left. 222 split_string = re.match("0([a-zA-Z]+),([a-zA-Z]+)->0([a-zA-Z]+)", 223 dot_replaced_string) 224 if split_string: 225 return _analyze_split_string( 226 split_string, bias_axes, input_shape, output_shape, left_elided=True) 227 228 # This is the case where ellipses are present on the right. 229 split_string = re.match("([a-zA-Z]{2,})0,([a-zA-Z]+)->([a-zA-Z]+)0", 230 dot_replaced_string) 231 if split_string: 232 return _analyze_split_string(split_string, bias_axes, input_shape, 233 output_shape) 234 235 raise ValueError( 236 "Invalid einsum equation '%s'. Equations must be in the form " 237 "[X],[Y]->[Z], ...[X],[Y]->...[Z], or [X]...,[Y]->[Z]...." % equation) 238 239 240def _analyze_split_string(split_string, 241 bias_axes, 242 input_shape, 243 output_shape, 244 left_elided=False): 245 """Analyze an pre-split einsum string to find the weight shape.""" 246 input_spec = split_string.group(1) 247 weight_spec = split_string.group(2) 248 output_spec = split_string.group(3) 249 elided = len(input_shape) - len(input_spec) 250 251 if isinstance(output_shape, int): 252 output_shape = [output_shape] 253 else: 254 output_shape = list(output_shape) 255 256 output_shape.insert(0, input_shape[0]) 257 258 if elided > 0 and left_elided: 259 for i in range(1, elided): 260 # We already inserted the 0th input dimension at dim 0, so we need to 261 # start at location 1 here. 262 output_shape.insert(1, input_shape[i]) 263 elif elided > 0 and not left_elided: 264 for i in range(len(input_shape) - elided, len(input_shape)): 265 output_shape.append(input_shape[i]) 266 267 if left_elided: 268 # If we have beginning dimensions elided, we need to use negative indexing 269 # to determine where in the input dimension our values are. 270 input_dim_map = { 271 dim: (i + elided) - len(input_shape) for i, dim in enumerate(input_spec) 272 } 273 # Because we've constructed the full output shape already, we don't need 274 # to do negative indexing. 275 output_dim_map = {dim: (i + elided) for i, dim in enumerate(output_spec)} 276 else: 277 input_dim_map = {dim: i for i, dim in enumerate(input_spec)} 278 output_dim_map = {dim: i for i, dim in enumerate(output_spec)} 279 280 for i, dim in enumerate(input_spec): 281 input_shape_at_dim = input_shape[i] 282 if dim in output_dim_map: 283 output_shape_at_dim = output_shape[output_dim_map[dim]] 284 if (output_shape_at_dim is not None and 285 output_shape_at_dim != input_shape_at_dim): 286 raise ValueError( 287 "Input shape and output shape do not match at shared " 288 "dimension '%s'. Input shape is %s, and output shape " 289 "is %s." % 290 (dim, input_shape_at_dim, output_shape[output_dim_map[dim]])) 291 292 for dim in output_spec: 293 if dim not in input_spec and dim not in weight_spec: 294 raise ValueError("Dimension '%s' was specified in the output '%s' but " 295 "has no corresponding dim in the input spec '%s' or " 296 "weight spec '%s.'" % (dim, output_spec, input_spec, 297 output_spec)) 298 299 weight_shape = [] 300 for dim in weight_spec: 301 if dim in input_dim_map: 302 weight_shape.append(input_shape[input_dim_map[dim]]) 303 elif dim in output_dim_map: 304 weight_shape.append(output_shape[output_dim_map[dim]]) 305 else: 306 raise ValueError("Weight dimension '%s' did not have a match in either " 307 "the input spec '%s' or the output spec '%s'. For this " 308 "layer, the weight must be fully specified." % 309 (dim, input_spec, output_spec)) 310 311 if bias_axes is not None: 312 num_left_elided = elided if left_elided else 0 313 idx_map = { 314 char: output_shape[i + num_left_elided] 315 for i, char in enumerate(output_spec) 316 } 317 318 for char in bias_axes: 319 if char not in output_spec: 320 raise ValueError("Bias dimension '%s' was requested, but is not a part " 321 "of the output specification '%s'" % 322 (char, output_spec)) 323 324 first_bias_location = min([output_spec.find(char) for char in bias_axes]) 325 bias_output_spec = output_spec[first_bias_location:] 326 327 bias_shape = [ 328 idx_map[char] if char in bias_axes else 1 for char in bias_output_spec 329 ] 330 331 if not left_elided: 332 for _ in range(elided): 333 bias_shape.append(1) 334 else: 335 bias_shape = None 336 337 return weight_shape, bias_shape, output_shape 338