1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Arithmetic Operations that don't fit into math_ops due to dependencies.
16
17To avoid circular dependencies, some math_ops should go here.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import functools
26import re
27import string
28
29import numpy as np
30import opt_einsum
31import six
32
33from six.moves import xrange  # pylint: disable=redefined-builtin
34
35from tensorflow.compiler.tf2xla.ops import gen_xla_ops
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import tensor_shape
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import gen_linalg_ops
41from tensorflow.python.ops import gen_special_math_ops
42from tensorflow.python.ops import math_ops
43from tensorflow.python.platform import tf_logging as logging
44from tensorflow.python.util import deprecation
45from tensorflow.python.util import dispatch
46from tensorflow.python.util.tf_export import tf_export
47
48
49# TODO(b/27419586) Change docstring for required dtype of x once int allowed
50@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta'])
51@dispatch.add_dispatch_support
52@deprecation.deprecated_endpoints('lbeta')
53def lbeta(x, name=None):
54  r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
55
56  Given one-dimensional $z = [z_1,...,z_K]$, we define
57
58  $$Beta(z) = \frac{\prod_j \Gamma(z_j)}{\Gamma(\sum_j z_j)},$$
59
60  where $\Gamma$ is the gamma function.
61
62  And for $n + 1$ dimensional $x$ with shape $[N_1, ..., N_n, K]$, we define
63
64  $$lbeta(x)[i_1, ..., i_n] = \log{|Beta(x[i_1, ..., i_n, :])|}.$$
65
66  In other words, the last dimension is treated as the $z$ vector.
67
68  Note that if $z = [u, v]$, then
69
70  $$Beta(z) = \frac{\Gamma(u)\Gamma(v)}{\Gamma(u + v)}
71    = \int_0^1 t^{u-1} (1 - t)^{v-1} \mathrm{d}t,$$
72
73  which defines the traditional bivariate beta function.
74
75  If the last dimension is empty, we follow the convention that the sum over
76  the empty set is zero, and the product is one.
77
78  Args:
79    x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`.
80    name: A name for the operation (optional).
81
82  Returns:
83    The logarithm of \\(|Beta(x)|\\) reducing along the last dimension.
84  """
85  # In the event that the last dimension has zero entries, we return -inf.
86  # This is consistent with a convention that the sum over the empty set 0, and
87  # the product is 1.
88  # This is standard.  See https://en.wikipedia.org/wiki/Empty_set.
89  with ops.name_scope(name, 'lbeta', [x]):
90    x = ops.convert_to_tensor(x, name='x')
91
92    # Note reduce_sum([]) = 0.
93    log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1])
94
95    # Note lgamma(0) = infinity, so if x = []
96    # log_gamma_sum_x = lgamma(0) = infinity, and
97    # log_prod_gamma_x = lgamma(1) = 0,
98    # so result = -infinity
99    sum_x = math_ops.reduce_sum(x, axis=[-1])
100    log_gamma_sum_x = math_ops.lgamma(sum_x)
101    result = log_prod_gamma_x - log_gamma_sum_x
102
103    return result
104
105
106@tf_export('math.special.dawsn')
107@dispatch.add_dispatch_support
108def dawsn(x, name=None):
109  """Computes Dawson's integral of `x` element-wise.
110
111  Dawson's integral is defined as `exp(-x**2)` times the integral of
112  `exp(t**2)` from `0` to `x`, with the domain of definition all real numbers.
113
114  Dawson's function is odd.
115  >>> tf.math.special.dawsn([-1., -0.5, 0.5, 1.]).numpy()
116  array([-0.5380795, -0.4244364, 0.4244364,  0.5380795], dtype=float32)
117
118  This implementation is based off of the Cephes math library.
119
120  Args:
121    x: A `Tensor` or `SparseTensor`. Must be one of the following types:
122      `float32`, `float64`.
123    name: A name for the operation (optional).
124
125  Returns:
126    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
127
128  @compatibility(scipy)
129  Equivalent to scipy.special.dawsn
130  @end_compatibility
131  """
132  with ops.name_scope(name, 'dawsn', [x]):
133    return gen_special_math_ops.dawsn(x)
134
135
136@tf_export('math.special.expint')
137@dispatch.add_dispatch_support
138def expint(x, name=None):
139  """Computes the Exponential integral of `x` element-wise.
140
141  The Exponential integral is defined as the integral of `exp(t) / t` from
142  `-inf` to `x`, with the domain of definition all positive real numbers.
143
144  >>> tf.math.special.expint([1., 1.1, 2.1, 4.1]).numpy()
145  array([ 1.8951179,  2.1673784,  5.3332353, 21.048464], dtype=float32)
146
147  This implementation is based off of the Cephes math library.
148
149  Args:
150    x: A `Tensor` or `SparseTensor`. Must be one of the following types:
151      `float32`, `float64`.
152    name: A name for the operation (optional).
153
154  Returns:
155    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
156
157  @compatibility(scipy)
158  Equivalent to scipy.special.expi
159  @end_compatibility
160  """
161  with ops.name_scope(name, 'expint', [x]):
162    return gen_special_math_ops.expint(x)
163
164
165@tf_export('math.special.fresnel_cos')
166@dispatch.add_dispatch_support
167def fresnel_cos(x, name=None):
168  """Computes Fresnel's cosine integral of `x` element-wise.
169
170  The Fresnel cosine integral is defined as the integral of `cos(t^2)` from
171  `0` to `x`, with the domain of definition all real numbers.
172
173  The Fresnel cosine integral is odd.
174  >>> tf.math.special.fresnel_cos([-1., -0.1, 0.1, 1.]).numpy()
175  array([-0.7798934 , -0.09999753,  0.09999753,  0.7798934 ], dtype=float32)
176
177  This implementation is based off of the Cephes math library.
178
179  Args:
180    x: A `Tensor` or `SparseTensor`. Must be one of the following types:
181      `float32`, `float64`.
182    name: A name for the operation (optional).
183
184  Returns:
185    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
186
187  @compatibility(scipy)
188  Equivalent to scipy.special.fresnel second output.
189  @end_compatibility
190  """
191  with ops.name_scope(name, 'fresnel_cos', [x]):
192    return gen_special_math_ops.fresnel_cos(x)
193
194
195@tf_export('math.special.fresnel_sin')
196@dispatch.add_dispatch_support
197def fresnel_sin(x, name=None):
198  """Computes Fresnel's sine integral of `x` element-wise.
199
200  The Fresnel sine integral is defined as the integral of `sin(t^2)` from
201  `0` to `x`, with the domain of definition all real numbers.
202
203  >>> tf.math.special.fresnel_sin([-1., -0.1, 0.1, 1.]).numpy()
204  array([-0.43825912, -0.00052359,  0.00052359,  0.43825912], dtype=float32)
205
206  This implementation is based off of the Cephes math library.
207
208  Args:
209    x: A `Tensor` or `SparseTensor`. Must be one of the following types:
210      `float32`, `float64`.
211    name: A name for the operation (optional).
212
213  Returns:
214    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
215
216  @compatibility(scipy)
217  Equivalent to scipy.special.fresnel first output.
218  @end_compatibility
219  """
220  with ops.name_scope(name, 'fresnel_sin', [x]):
221    return gen_special_math_ops.fresnel_sin(x)
222
223
224@tf_export('math.special.spence')
225@dispatch.add_dispatch_support
226def spence(x, name=None):
227  """Computes Spence's integral of `x` element-wise.
228
229  Spence's integral is defined as the integral of `log(t) / (1 - t)` from
230  `1` to `x`, with the domain of definition all non-negative real numbers.
231
232  >>> tf.math.special.spence([0.5, 1., 2., 3.]).numpy()
233  array([ 0.58224034,  0.        , -0.82246685, -1.4367464], dtype=float32)
234
235  This implementation is based off of the Cephes math library.
236
237  Args:
238    x: A `Tensor` or `SparseTensor`. Must be one of the following types:
239      `float32`, `float64`.
240    name: A name for the operation (optional).
241
242  Returns:
243    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
244
245  @compatibility(scipy)
246  Equivalent to scipy.special.spence
247  @end_compatibility
248  """
249  with ops.name_scope(name, 'spence', [x]):
250    return gen_special_math_ops.spence(x)
251
252
253@tf_export('math.bessel_i0', 'math.special.bessel_i0')
254@dispatch.add_dispatch_support
255def bessel_i0(x, name=None):
256  """Computes the Bessel i0 function of `x` element-wise.
257
258  Modified Bessel function of order 0.
259
260  It is preferable to use the numerically stabler function `i0e(x)` instead.
261
262  >>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy()
263  array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32)
264
265  Args:
266    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
267      `float32`, `float64`.
268    name: A name for the operation (optional).
269
270  Returns:
271    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
272
273  @compatibility(scipy)
274  Equivalent to scipy.special.i0
275  @end_compatibility
276  """
277  with ops.name_scope(name, 'bessel_i0', [x]):
278    return gen_special_math_ops.bessel_i0(x)
279
280
281@tf_export('math.bessel_i0e', 'math.special.bessel_i0e')
282@dispatch.add_dispatch_support
283def bessel_i0e(x, name=None):
284  """Computes the Bessel i0e function of `x` element-wise.
285
286  Modified Bessel function of order 0.
287
288  >>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy()
289  array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32)
290
291  Args:
292    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
293      `float32`, `float64`.
294    name: A name for the operation (optional).
295
296  Returns:
297    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
298
299  @compatibility(scipy)
300  Equivalent to scipy.special.i0e
301  @end_compatibility
302  """
303  with ops.name_scope(name, 'bessel_i0e', [x]):
304    return gen_special_math_ops.bessel_i0e(x)
305
306
307@tf_export('math.bessel_i1', 'math.special.bessel_i1')
308@dispatch.add_dispatch_support
309def bessel_i1(x, name=None):
310  """Computes the Bessel i1 function of `x` element-wise.
311
312  Modified Bessel function of order 1.
313
314  It is preferable to use the numerically stabler function `i1e(x)` instead.
315
316  >>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy()
317  array([-0.5651591 , -0.25789431,  0.25789431,  0.5651591 ], dtype=float32)
318
319  Args:
320    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
321      `float32`, `float64`.
322    name: A name for the operation (optional).
323
324  Returns:
325    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
326
327  @compatibility(scipy)
328  Equivalent to scipy.special.i1
329  @end_compatibility
330  """
331  with ops.name_scope(name, 'bessel_i1', [x]):
332    return gen_special_math_ops.bessel_i1(x)
333
334
335@tf_export('math.bessel_i1e', 'math.special.bessel_i1e')
336@dispatch.add_dispatch_support
337def bessel_i1e(x, name=None):
338  """Computes the Bessel i1e function of `x` element-wise.
339
340  Modified Bessel function of order 1.
341
342  >>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy()
343  array([-0.20791042, -0.15642083,  0.15642083,  0.20791042], dtype=float32)
344
345  Args:
346    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
347      `float32`, `float64`.
348    name: A name for the operation (optional).
349
350  Returns:
351    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
352
353  @compatibility(scipy)
354  Equivalent to scipy.special.i1e
355  @end_compatibility
356  """
357  with ops.name_scope(name, 'bessel_i1e', [x]):
358    return gen_special_math_ops.bessel_i1e(x)
359
360
361@tf_export('math.special.bessel_k0')
362@dispatch.add_dispatch_support
363def bessel_k0(x, name=None):
364  """Computes the Bessel k0 function of `x` element-wise.
365
366  Modified Bessel function of order 0.
367
368  It is preferable to use the numerically stabler function `k0e(x)` instead.
369
370  >>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy()
371  array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32)
372
373  Args:
374    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
375      `float32`, `float64`.
376    name: A name for the operation (optional).
377
378  Returns:
379    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
380
381  @compatibility(scipy)
382  Equivalent to scipy.special.k0
383  @end_compatibility
384  """
385  with ops.name_scope(name, 'bessel_k0', [x]):
386    return gen_special_math_ops.bessel_k0(x)
387
388
389@tf_export('math.special.bessel_k0e')
390@dispatch.add_dispatch_support
391def bessel_k0e(x, name=None):
392  """Computes the Bessel k0e function of `x` element-wise.
393
394  Modified Bessel function of order 0.
395
396  >>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy()
397  array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32)
398
399  Args:
400    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
401      `float32`, `float64`.
402    name: A name for the operation (optional).
403
404  Returns:
405    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
406
407  @compatibility(scipy)
408  Equivalent to scipy.special.k0e
409  @end_compatibility
410  """
411  with ops.name_scope(name, 'bessel_k0e', [x]):
412    return gen_special_math_ops.bessel_k0e(x)
413
414
415@tf_export('math.special.bessel_k1')
416@dispatch.add_dispatch_support
417def bessel_k1(x, name=None):
418  """Computes the Bessel k1 function of `x` element-wise.
419
420  Modified Bessel function of order 1.
421
422  It is preferable to use the numerically stabler function `k1e(x)` instead.
423
424  >>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy()
425  array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32)
426
427  Args:
428    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
429      `float32`, `float64`.
430    name: A name for the operation (optional).
431
432  Returns:
433    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
434
435  @compatibility(scipy)
436  Equivalent to scipy.special.k1
437  @end_compatibility
438  """
439  with ops.name_scope(name, 'bessel_k1', [x]):
440    return gen_special_math_ops.bessel_k1(x)
441
442
443@tf_export('math.special.bessel_k1e')
444@dispatch.add_dispatch_support
445def bessel_k1e(x, name=None):
446  """Computes the Bessel k1e function of `x` element-wise.
447
448  Modified Bessel function of order 1.
449
450  >>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy()
451  array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32)
452
453  Args:
454    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
455      `float32`, `float64`.
456    name: A name for the operation (optional).
457
458  Returns:
459    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
460
461  @compatibility(scipy)
462  Equivalent to scipy.special.k1e
463  @end_compatibility
464  """
465  with ops.name_scope(name, 'bessel_k1e', [x]):
466    return gen_special_math_ops.bessel_k1e(x)
467
468
469@tf_export('math.special.bessel_j0')
470@dispatch.add_dispatch_support
471def bessel_j0(x, name=None):
472  """Computes the Bessel j0 function of `x` element-wise.
473
474  Modified Bessel function of order 0.
475
476  >>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy()
477  array([ 0.93846981,  0.76519769,  0.22389078, -0.39714981], dtype=float32)
478
479  Args:
480    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
481      `float32`, `float64`.
482    name: A name for the operation (optional).
483
484  Returns:
485    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
486
487  @compatibility(scipy)
488  Equivalent to scipy.special.j0
489  @end_compatibility
490  """
491  with ops.name_scope(name, 'bessel_j0', [x]):
492    return gen_special_math_ops.bessel_j0(x)
493
494
495@tf_export('math.special.bessel_j1')
496@dispatch.add_dispatch_support
497def bessel_j1(x, name=None):
498  """Computes the Bessel j1 function of `x` element-wise.
499
500  Modified Bessel function of order 1.
501
502  >>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy()
503  array([ 0.24226846,  0.44005059,  0.57672481, -0.06604333], dtype=float32)
504
505  Args:
506    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
507      `float32`, `float64`.
508    name: A name for the operation (optional).
509
510  Returns:
511    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
512
513  @compatibility(scipy)
514  Equivalent to scipy.special.j1
515  @end_compatibility
516  """
517  with ops.name_scope(name, 'bessel_j1', [x]):
518    return gen_special_math_ops.bessel_j1(x)
519
520
521@tf_export('math.special.bessel_y0')
522@dispatch.add_dispatch_support
523def bessel_y0(x, name=None):
524  """Computes the Bessel y0 function of `x` element-wise.
525
526  Modified Bessel function of order 0.
527
528  >>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy()
529  array([-0.44451873,  0.08825696,  0.51037567, -0.01694074], dtype=float32)
530
531  Args:
532    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
533      `float32`, `float64`.
534    name: A name for the operation (optional).
535
536  Returns:
537    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
538
539  @compatibility(scipy)
540  Equivalent to scipy.special.y0
541  @end_compatibility
542  """
543  with ops.name_scope(name, 'bessel_y0', [x]):
544    return gen_special_math_ops.bessel_y0(x)
545
546
547@tf_export('math.special.bessel_y1')
548@dispatch.add_dispatch_support
549def bessel_y1(x, name=None):
550  """Computes the Bessel y1 function of `x` element-wise.
551
552  Modified Bessel function of order 1.
553
554  >>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy()
555  array([-1.47147239, -0.78121282, -0.10703243,  0.39792571], dtype=float32)
556
557  Args:
558    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
559      `float32`, `float64`.
560    name: A name for the operation (optional).
561
562  Returns:
563    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
564
565  @compatibility(scipy)
566  Equivalent to scipy.special.y1
567  @end_compatibility
568  """
569  with ops.name_scope(name, 'bessel_y1', [x]):
570    return gen_special_math_ops.bessel_y1(x)
571
572
573@ops.RegisterGradient('XlaEinsum')
574def _einsum_grad(op, grad):
575  equation = op.get_attr('equation')
576  if isinstance(equation, bytes):
577    equation = equation.decode()
578
579  inputs, output = equation.split('->')
580  left, right = inputs.split(',')
581
582  return [
583      gen_xla_ops.xla_einsum(
584          grad,
585          op.inputs[1],
586          equation='{},{}->{}'.format(output, right, left),
587          name=None),
588      gen_xla_ops.xla_einsum(
589          grad,
590          op.inputs[0],
591          equation='{},{}->{}'.format(output, left, right),
592          name=None)
593  ]
594
595
596def _enclosing_tpu_context():
597  # pylint: disable=protected-access
598  context = ops.get_default_graph()._get_control_flow_context()
599  # pylint: enable=protected-access
600  while context is not None and not isinstance(
601      context, control_flow_ops.XLAControlFlowContext):
602    context = context.outer_context
603  return context
604
605
606@tf_export('einsum', 'linalg.einsum')
607@dispatch.add_dispatch_support
608def einsum(equation, *inputs, **kwargs):
609  r"""Tensor contraction over specified indices and outer product.
610
611  Einsum allows defining Tensors by defining their element-wise computation.
612  This computation is defined by `equation`, a shorthand form based on Einstein
613  summation. As an example, consider multiplying two matrices A and B to form a
614  matrix C.  The elements of C are given by:
615
616  $$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$
617
618  or
619
620  ```
621  C[i,k] = sum_j A[i,j] * B[j,k]
622  ```
623
624  The corresponding einsum `equation` is:
625
626  ```
627  ij,jk->ik
628  ```
629
630  In general, to convert the element-wise equation into the `equation` string,
631  use the following procedure (intermediate strings for matrix multiplication
632  example provided in parentheses):
633
634  1. remove variable names, brackets, and commas, (`ik = sum_j ij * jk`)
635  2. replace "*" with ",", (`ik = sum_j ij , jk`)
636  3. drop summation signs, and (`ik = ij, jk`)
637  4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`)
638
639  Note: If the output indices are not specified repeated indices are summed.
640  So `ij,jk->ik` can be simplified to `ij,jk`.
641
642  Many common operations can be expressed in this way.  For example:
643
644  **Matrix multiplication**
645
646  >>> m0 = tf.random.normal(shape=[2, 3])
647  >>> m1 = tf.random.normal(shape=[3, 5])
648  >>> e = tf.einsum('ij,jk->ik', m0, m1)
649  >>> # output[i,k] = sum_j m0[i,j] * m1[j, k]
650  >>> print(e.shape)
651  (2, 5)
652
653  Repeated indices are summed if the output indices are not specified.
654
655  >>> e = tf.einsum('ij,jk', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]
656  >>> print(e.shape)
657  (2, 5)
658
659
660  **Dot product**
661
662  >>> u = tf.random.normal(shape=[5])
663  >>> v = tf.random.normal(shape=[5])
664  >>> e = tf.einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]
665  >>> print(e.shape)
666  ()
667
668  **Outer product**
669
670  >>> u = tf.random.normal(shape=[3])
671  >>> v = tf.random.normal(shape=[5])
672  >>> e = tf.einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]
673  >>> print(e.shape)
674  (3, 5)
675
676  **Transpose**
677
678  >>> m = tf.ones(2,3)
679  >>> e = tf.einsum('ij->ji', m0)  # output[j,i] = m0[i,j]
680  >>> print(e.shape)
681  (3, 2)
682
683  **Diag**
684
685  >>> m = tf.reshape(tf.range(9), [3,3])
686  >>> diag = tf.einsum('ii->i', m)
687  >>> print(diag.shape)
688  (3,)
689
690  **Trace**
691
692  >>> # Repeated indices are summed.
693  >>> trace = tf.einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]
694  >>> assert trace == sum(diag)
695  >>> print(trace.shape)
696  ()
697
698  **Batch matrix multiplication**
699
700  >>> s = tf.random.normal(shape=[7,5,3])
701  >>> t = tf.random.normal(shape=[7,3,2])
702  >>> e = tf.einsum('bij,bjk->bik', s, t)
703  >>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
704  >>> print(e.shape)
705  (7, 5, 2)
706
707  This method does not support broadcasting on named-axes. All axes with
708  matching labels should have the same length. If you have length-1 axes,
709  use `tf.squeseze` or `tf.reshape` to eliminate them.
710
711  To write code that is agnostic to the number of indices in the input
712  use an ellipsis. The ellipsis is a placeholder for "whatever other indices
713  fit here".
714
715  For example, to perform a NumPy-style broadcasting-batch-matrix multiplication
716  where the matrix multiply acts on the last two axes of the input, use:
717
718  >>> s = tf.random.normal(shape=[11, 7, 5, 3])
719  >>> t = tf.random.normal(shape=[11, 7, 3, 2])
720  >>> e =  tf.einsum('...ij,...jk->...ik', s, t)
721  >>> print(e.shape)
722  (11, 7, 5, 2)
723
724  Einsum **will** broadcast over axes covered by the ellipsis.
725
726  >>> s = tf.random.normal(shape=[11, 1, 5, 3])
727  >>> t = tf.random.normal(shape=[1, 7, 3, 2])
728  >>> e =  tf.einsum('...ij,...jk->...ik', s, t)
729  >>> print(e.shape)
730  (11, 7, 5, 2)
731
732  Args:
733    equation: a `str` describing the contraction, in the same format as
734      `numpy.einsum`.
735    *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
736      be consistent with `equation`.
737    **kwargs:
738      - optimize: Optimization strategy to use to find contraction path using
739        opt_einsum. Must be 'greedy', 'optimal', 'branch-2', 'branch-all' or
740          'auto'. (optional, default: 'greedy').
741      - name: A name for the operation (optional).
742
743  Returns:
744    The contracted `Tensor`, with shape determined by `equation`.
745
746  Raises:
747    ValueError: If
748      - the format of `equation` is incorrect,
749      - number of inputs or their shapes are inconsistent with `equation`.
750  """
751  return _einsum_v2(equation, *inputs, **kwargs)
752
753
754def _einsum_v1(equation, *inputs, **kwargs):
755  """Legacy implementation of einsum without using EinsumOp."""
756  name = kwargs.pop('name', None)
757  if kwargs:
758    raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
759        [format(key) for key in sorted(list(kwargs.keys()))]))
760  with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
761    inputs = list(inputs)
762    input_shapes = [x.shape for x in inputs]
763    input_axis_labels, output_axis_labels = (
764        _einsum_v1_parse_and_resolve_equation(equation, input_shapes))
765
766    axis_labels = set(''.join(input_axis_labels) + output_axis_labels)
767
768    for a in axis_labels:
769      for input_labels in input_axis_labels:
770        if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
771            input_labels == input_labels[::-1] and '->' not in equation):
772          return math_ops.trace(inputs[0])
773        if input_labels.count(a) > 1:
774          raise ValueError(
775              'Subscript not supported: an axis appears more than once: %s' %
776              input_labels)
777    for a in axis_labels:
778      input_count = sum(1 for s in input_axis_labels if a in s)
779      if input_count > 2 and a not in output_axis_labels:
780        logging.warn(
781            'Falling back to exponential-space implementation of einsum()'
782            ' because index "%s" is summed over more than two inputs.', a)
783        return _exponential_space_einsum_v1(equation, *inputs)
784
785    # Use xla_einsum if executing on TPU and if the operation is a 2 input
786    # einsum supported by XlaEinsumOp.
787    if _enclosing_tpu_context() is not None and len(inputs) == 2:
788      return gen_xla_ops.xla_einsum(
789          inputs[0], inputs[1], input_axis_labels[0] + ',' +
790          input_axis_labels[1] + '->' + output_axis_labels)
791    temp = inputs[0]
792    temp_axis_labels = input_axis_labels[0]
793    for i in xrange(len(inputs) - 1):
794      axes_to_sum = (
795          set(temp_axis_labels) &
796          set(input_axis_labels[i + 1]) - set(output_axis_labels))
797      temp, temp_axis_labels = _einsum_v1_reduction(temp, temp_axis_labels,
798                                                    inputs[i + 1],
799                                                    input_axis_labels[i + 1],
800                                                    axes_to_sum)
801
802    missing_indices = set(temp_axis_labels) - set(output_axis_labels)
803    if missing_indices:
804      axis = [
805          i for i, a in enumerate(temp_axis_labels)
806          if a not in output_axis_labels
807      ]
808      temp = math_ops.reduce_sum(temp, axis=axis)
809      temp_axis_labels = ''.join(
810          a for a in temp_axis_labels if a in output_axis_labels)
811    if sorted(temp_axis_labels) != sorted(output_axis_labels):
812      raise ValueError('Invalid equation: %s' % equation)
813
814    perm = [temp_axis_labels.index(a) for a in output_axis_labels]
815    return _transpose_if_necessary(temp, perm)
816
817
818def _einsum_v1_parse_and_resolve_equation(equation, input_shapes):
819  """Helper for einsum() that splits/resolves inputs & outputs.
820
821  Args:
822    equation: Equation string given as argument to einsum().
823    input_shapes: List of the shapes of all inputs given to einsum()
824
825  Returns:
826    input_axis_labels, output_axis_labels where:
827      input_axis_labels: List of length len(input_shapes) of strings
828      representing the character label for each dimension of each given input,
829      resolving any broadcast (...) axes,
830    output_axis_labels: A string of character labels for each axes of output
831      tensor, filling in missing output subscripts and broadcast axes.
832
833  Raises:
834    ValueError: If equation is in the uncorrect format, incorrect number of
835      inputs given or broadcast axes "..." or output axes could not be resolved.
836  """
837  equation = equation.replace(' ', '')
838  match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation)
839  if not match:
840    raise ValueError('Indices have incorrect format: %s' % equation)
841
842  input_axis_labels = match.group(1).split(',')
843  output_axis_labels = match.group(2)[2:] if match.group(2) else None
844
845  if len(input_shapes) != len(input_axis_labels):
846    raise ValueError('Got %d arguments for equation "%s", expecting %d' %
847                     (len(input_shapes), equation, len(input_axis_labels)))
848
849  # Resolve Ellipsis
850  # Assign axes labels for unspecified dimensions in inputs. Labels taken
851  # from unused labels. Follow numpy einsum broadcasting conventions for
852  # tensors of different length and unlabeled output.
853  ellipsis_axes = ''
854  if '...' in equation:
855    unused = ''.join(
856        c for c in string.ascii_letters if c not in ''.join(input_axis_labels))
857    for i, ax in enumerate(input_axis_labels):
858      if '...' in ax:
859        parts = ax.split('...')
860        if len(parts) != 2:
861          raise ValueError('Unable to resolve ellipsis. Excess number found.')
862        if input_shapes[i].ndims is None:
863          raise ValueError('Unable to statically infer ellipsis axes.')
864        n = input_shapes[i].ndims - len(''.join(parts))
865        if n < 0:
866          raise ValueError('Ellipses lengths do not match.')
867        if len(unused) < n:
868          raise ValueError(
869              'Unable to resolve ellipsis, too many distinct labels.')
870        replace_axes = unused[-n:] if n > 0 else ''
871        input_axis_labels[i] = input_axis_labels[i].replace('...',
872                                                            replace_axes)
873        if len(replace_axes) > len(ellipsis_axes):
874          ellipsis_axes = replace_axes
875
876    if any('.' in ax for ax in input_axis_labels):
877      raise ValueError('period "." found outside of ellipsis')
878
879    if output_axis_labels is not None:
880      output_axis_labels = output_axis_labels.replace('...', ellipsis_axes)
881      if '.' in output_axis_labels:
882        raise ValueError('period "." found outside of ellipsis')
883
884  if output_axis_labels is None:
885    # infer the output subscripts if not given, assume alphabetical order,
886    # but always place ellipsis axes before given.
887    axis_labels = set(''.join(input_axis_labels)) - set(ellipsis_axes)
888    indices = ''.join(sorted(axis_labels))
889    counts = {ax: 0 for ax in indices}
890    for axes_ in input_axis_labels:
891      for ax in axes_:
892        if ax not in ellipsis_axes:
893          counts[ax] += 1
894
895    output_axis_labels = ellipsis_axes + ''.join(
896        sorted(ax for ax in axis_labels if counts[ax] == 1))
897
898  return input_axis_labels, output_axis_labels
899
900
901def _einsum_v1_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
902  """Helper for einsum() that computes the result of a two-argument einsum().
903
904  Args:
905    t0: a `Tensor`
906    t0_axis_labels: a string of axis labels.  This string's length must equal
907      the rank of t0.
908    t1: a `Tensor`
909    t1_axis_labels: a string to axis labels.  This string's length must equal
910      the rank of t1.
911    axes_to_sum: set of labels of axes to be summed over
912
913  Returns:
914    A `Tensor` whose elements are obtained by summing, over all axes in
915    `axes_to_sum`, the corresponding elements of `t0` and `t1`.
916
917    For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and
918    axes_to_sum == {j,k}, this will return a tensor x where
919
920      out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l]
921
922  Raises:
923    ValueError: if the rank of `t0` does not match the length of
924      `t0_axis_labels`, or that of `t1` does not match the length of
925      `t1_axis_labels`.
926  """
927  if len(t0_axis_labels) != len(t0.shape):
928    raise ValueError(
929        'Tensor t0 of rank %d does not match einsum reduction of length %d' %
930        (len(t0.shape), len(t0_axis_labels)))
931  if len(t1_axis_labels) != len(t1.shape):
932    raise ValueError(
933        'Tensor t1 of rank %d does not match einsum reduction of length %d' %
934        (len(t1.shape), len(t1_axis_labels)))
935
936  # This function computes the result of a two-argument einsum() using batch
937  # matrix multiplication.  This involves
938  # 1. transposing t0 and t1 so that axes are in the correct order for
939  #    batch matrix multiplication, and
940  # 2. reshaping t0 and t1 so that they are both of rank 3.
941
942  # First, we divide axes into three groups:
943  #  * "preserved" axes are present in both inputs and the output
944  #  * "summed" axes are present in both inputs but not the output
945  #  * "broadcast" axes are present in exactly one input and the output
946  #
947  # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a
948  # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are
949  # summed axes.
950  assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum)
951  preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum
952  broadcast_axes = {}
953  for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]):
954    broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum
955
956  # Reorder the axes so that:
957  # 1. preserved axes come first in both inputs
958  # 2. in input 0, broadcast axes come next, followed by summed axes
959  # 3. in input 1, summed axes come next, followed by broadcast axes
960  def sort_key(input_index, a):
961    if a in preserved_axes:
962      return (-1, a)
963    elif ((input_index == 0 and a in broadcast_axes[0]) or
964          (input_index == 1 and a in axes_to_sum)):
965      return (0, a)
966    else:
967      return (1, a)
968
969  axis_labels = [t0_axis_labels, t1_axis_labels]
970  sorted_axes = [
971      sorted(sym_list, key=lambda a: sort_key(i, a))
972      for i, sym_list in enumerate(axis_labels)
973  ]
974  inputs = [t0, t1]
975  for i, axes_str in enumerate(axis_labels):
976    perm = [axes_str.find(a) for a in sorted_axes[i]]
977    inputs[i] = _transpose_if_necessary(inputs[i], perm)
978  t0, t1 = inputs
979
980  if not axes_to_sum:
981    # In the special case where there are no axes to sum over, reduce to mul()
982    # rather than to batch matrix multiplication.
983    for _ in broadcast_axes[1]:
984      t0 = array_ops.expand_dims(t0, -1)
985    for _ in broadcast_axes[0]:
986      t1 = array_ops.expand_dims(t1, len(preserved_axes))
987    product = math_ops.multiply(t0, t1)
988    product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):]
989    return product, ''.join(product_axes)
990  else:
991    # Reduce to matmul().
992
993    # Reshape both inputs so as to combine multiple broadcast axes
994    # into a single axis, and combine multiple summed axes into a
995    # single axis.
996
997    t0_shape = _get_shape(t0)
998    num_broadcast_elements_t0 = _total_size(
999        t0_shape[len(preserved_axes):-len(axes_to_sum)])
1000    num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
1001    new_shape = (
1002        t0_shape[:len(preserved_axes)] +
1003        [num_broadcast_elements_t0, num_summed_elements])
1004    t0 = _reshape_if_necessary(t0, new_shape)
1005
1006    t1_shape = _get_shape(t1)
1007    num_broadcast_elements_t1 = _total_size(
1008        t1_shape[len(preserved_axes) + len(axes_to_sum):])
1009    new_shape = (
1010        t1_shape[:len(preserved_axes)] +
1011        [num_summed_elements, num_broadcast_elements_t1])
1012    t1 = _reshape_if_necessary(t1, new_shape)
1013
1014    product = math_ops.matmul(t0, t1)
1015
1016    # Undo compaction of broadcast axes
1017    uncompacted_shape = (
1018        t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] +
1019        t1_shape[len(t1_shape) - len(broadcast_axes[1]):])
1020    product = _reshape_if_necessary(product, uncompacted_shape)
1021
1022    product_axes = (
1023        sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] +
1024        sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):])
1025
1026    return product, ''.join(product_axes)
1027
1028
1029def _transpose_if_necessary(tensor, perm):
1030  """Like transpose(), but avoids creating a new tensor if possible."""
1031  if perm != list(range(len(perm))):
1032    return array_ops.transpose(tensor, perm=perm)
1033  else:
1034    return tensor
1035
1036
1037def _reshape_if_necessary(tensor, new_shape):
1038  """Like reshape(), but avoids creating a new tensor if possible."""
1039  # Accept None as an alias for -1 in new_shape.
1040  new_shape = tuple(-1 if x is None else x for x in new_shape)
1041  cur_shape = tuple(x.value for x in tensor.shape.dims)
1042  if (len(new_shape) == len(cur_shape) and
1043      all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1)
1044          for d0, d1 in zip(cur_shape, new_shape))):
1045    return tensor
1046  else:
1047    return array_ops.reshape(tensor, new_shape)
1048
1049
1050def _get_shape(tensor):
1051  """Like get_shape().as_list(), but explicitly queries the shape of a tensor
1052  if necessary to ensure that the returned value contains no unknown value."""
1053
1054  shape = tensor.shape.as_list()
1055  none_indices = [i for i, d in enumerate(shape) if d is None]
1056  if none_indices:
1057    # Query the shape if shape contains None values
1058    shape_tensor = array_ops.shape(tensor)
1059    for i in none_indices:
1060      shape[i] = shape_tensor[i]
1061  return shape
1062
1063
1064def _total_size(shape_values):
1065  """Given list of tensor shape values, returns total size.
1066  If shape_values contains tensor values (which are results of
1067  array_ops.shape), then it returns a scalar tensor.
1068  If not, it returns an integer."""
1069
1070  result = 1
1071  for val in shape_values:
1072    result *= val
1073  return result
1074
1075
1076def _exponential_space_einsum_v1(equation, *inputs):
1077  """Fallback implementation that supports summing an index over > 2 inputs."""
1078  inputs = list(inputs)
1079  input_shapes = [x.shape for x in inputs]
1080  idx_in, idx_out = _einsum_v1_parse_and_resolve_equation(
1081      equation, input_shapes)
1082
1083  idx_all = set(''.join(idx_in) + idx_out)
1084  indices = ''.join(sorted(idx_all))
1085
1086  missing_idx = set(idx_out).difference(idx_all)
1087  if missing_idx:
1088    raise ValueError('Unknown output axes: %s' % missing_idx)
1089
1090  axis_order = {}
1091  for ax in indices:
1092    if ax not in idx_out:
1093      axis_order[ax] = len(axis_order)
1094  for ax in idx_out:
1095    axis_order[ax] = len(axis_order)
1096
1097  # transpose inputs so axes are in order
1098  for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
1099    if input_.shape.ndims != len(axes_):
1100      raise ValueError(
1101          'Input %d with axes %s has incorrect' \
1102          ' number of dimensions (expected %d, got %d)' % (
1103              i, axes_, len(axes_), input_.shape.ndims
1104          )
1105      )
1106
1107    sorted_idx = sorted(axes_, key=axis_order.get)
1108
1109    if len(set(axes_)) != len(axes_):
1110      raise ValueError(
1111          'Subscript not supported: an axis appears more than once: %s' % axes_)
1112
1113    if list(axes_) != sorted_idx:
1114      permuted = [axes_.find(ax) for ax in sorted_idx]
1115      inputs[i] = array_ops.transpose(input_, permuted)
1116      idx_in[i] = sorted_idx
1117
1118  reduction_idx = []
1119  shapes = [[dim if dim else -1
1120             for dim in tensor.shape.as_list()]
1121            for tensor in inputs]
1122
1123  # validate shapes for broadcasting
1124  for j, ax in enumerate(sorted(idx_all, key=axis_order.get)):
1125    dims = []
1126    for i, idx in enumerate(idx_in):
1127      if ax not in idx:
1128        shapes[i].insert(j, 1)
1129      else:
1130        dim = shapes[i][j]
1131        if isinstance(dim, int) and dim > 1:
1132          dims.append(dim)
1133
1134    if len(set(dims)) > 1:
1135      raise ValueError('Dimension mismatch on axis: %s' % ax)
1136
1137    if ax not in idx_out:
1138      reduction_idx.append(j)
1139
1140  # reshape, multiply
1141  expanded_inputs = [
1142      array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes)
1143  ]
1144  expanded_output = 1
1145  for input_ in expanded_inputs:
1146    expanded_output *= input_
1147
1148  # contract
1149  return math_ops.reduce_sum(expanded_output, reduction_idx)
1150
1151
1152def _einsum_v2(equation, *inputs, **kwargs):
1153  """Implementation of einsum utilizing opt_einsum and EinsumOp."""
1154  name = kwargs.pop('name', None)
1155  optimize = kwargs.pop('optimize', 'greedy')
1156  if kwargs:
1157    msg = 'Invalid keyword arguments for einsum: {}'
1158    raise TypeError(msg.format(', '.join(kwargs)))
1159
1160  with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
1161    inputs = list(inputs)
1162    input_shapes = []
1163    for operand in inputs:
1164      if isinstance(operand.shape, tensor_shape.TensorShape):
1165        input_shapes.append(operand.shape.as_list() if operand.shape else None)
1166      else:
1167        input_shapes.append(list(operand.shape))
1168    # Validate and sanitize the equation and resolve static input shapes, as
1169    # opt_einsum requires that all shapes be a tuple of positive integers.
1170    # Also remove ellipsis from the equation as opt_einsum will replace them
1171    # with named labels. Then broadcasting between different shapes or ranks
1172    # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]).
1173    resolved_equation, resolved_input_shapes, ellipsis_label = (
1174        _einsum_v2_parse_and_resolve_equation(equation, input_shapes))
1175
1176    if len(inputs) <= 2:  # No need to call opt_einsum.
1177      # Replace back ellipses that were removed for opt_einsum.
1178      if ellipsis_label:
1179        resolved_equation = resolved_equation.replace(ellipsis_label, '...')
1180      return gen_linalg_ops.einsum(inputs, resolved_equation)
1181
1182    # Send fully specified shapes to opt_einsum, since it cannot handle unknown
1183    # dimensions. For unknown dimensions, we guess that the dimension equals 1.
1184    # Instead of creating Tensors or NumPy arrays with the specified shape,
1185    # create a dummy `shaped` object with a `shape` property.
1186    shaped = collections.namedtuple('shaped', ['shape'])
1187    shaped_inputs = tuple(
1188        [shaped(tuple(shape)) for shape in resolved_input_shapes])
1189    # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums.
1190    # Obtain the sequence of equations and the indices of operands involved in
1191    # each einsum operation.
1192    indices_and_equations = _get_opt_einsum_contract_path(
1193        resolved_equation, shaped_inputs, optimize)
1194    for operand_indices, binary_equation in indices_and_equations:
1195      if ellipsis_label:
1196        # Replace back ellipses that were removed for opt_einsum.
1197        binary_equation = binary_equation.replace(ellipsis_label, '...')
1198      operands = list(map(inputs.pop, operand_indices))
1199      inputs.append(gen_linalg_ops.einsum(operands, binary_equation))
1200    return inputs[0]
1201
1202
1203def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize):
1204  """Returns the (memoized) result of opt_einsum.contract_path."""
1205  # Note: We use einsum_call=True, which is an internal api for opt_einsum,
1206  # to get the contraction path without having opt_einsum perform the actual
1207  # contractions.
1208  _, contractions = opt_einsum.contract_path(
1209      equation,
1210      *shaped_inputs_tuple,
1211      optimize=optimize,
1212      einsum_call=True,
1213      use_blas=True)
1214  # Return a tuple so that the cached value is not mutable.
1215  indices_and_equations = tuple([(expr[0], expr[2]) for expr in contractions])
1216  return indices_and_equations
1217
1218
1219# Cache the possibly expensive opt_einsum.contract_path call using lru_cache
1220# from the Python3+ standard library.
1221if not six.PY2:
1222  _get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)(
1223      _get_opt_einsum_contract_path)
1224
1225
1226def _einsum_v2_parse_and_resolve_equation(equation, input_shapes):
1227  """Helper which validates einsum equation and resolves input shapes."""
1228  resolved_equation = equation.replace(' ', '')
1229  ellipsis_label = None
1230  if '...' in equation:
1231    # Replace ellipsis ('...') with '0' for (a) ease of parsing and (b) to
1232    # prevent opt_einsum from resolving them into named labels; as it doesn't
1233    # support broadcasting.
1234    ellipsis_label = '0'
1235    if ellipsis_label in resolved_equation:
1236      raise ValueError('Invalid character "0" in equation: {}'.format(equation))
1237    resolved_equation = resolved_equation.replace('...', ellipsis_label)
1238
1239  # Ensure there are no non-alphanumeric characters in the equation, including
1240  # periods (`.`) outside of ellipses, in the equation. This is not a hard
1241  # requirement; except we use a special character '0' for ellipsis.
1242  allowed_labels = 'a-zA-Z'
1243  if ellipsis_label:
1244    allowed_labels += ellipsis_label
1245  match = re.match('^([{0},]*)(->[{0}]*)?$'.format(allowed_labels),
1246                   resolved_equation)
1247  if not match:
1248    raise ValueError(
1249        'Subscripts have incorrect format: {}'.format(resolved_equation))
1250  input_labels = match.group(1).split(',')
1251  output_labels = match.group(2)[2:] if match.group(2) else None
1252
1253  if len(input_shapes) != len(input_labels):
1254    raise ValueError('Got {} inputs for equation "{}", expecting {}'.format(
1255        len(input_shapes), equation, len(input_labels)))
1256
1257  # Special case: if there are no '->', then we create output subscripts from
1258  # labels appearing only once.
1259  if '->' not in resolved_equation:
1260    label_counts = collections.Counter(match.group(1))
1261    output_labels = ''.join([
1262        x for x in sorted(list(label_counts))
1263        if x != ',' and label_counts[x] == 1
1264    ])
1265    resolved_equation += '->' + output_labels
1266  # Validate output_labels.
1267  if output_labels and len(set(output_labels)) != len(output_labels):
1268    raise ValueError(
1269        'Output subscripts contain a label appearing more than once: {}'.format(
1270            equation))
1271  input_label_set = set(match.group(1))
1272  for label in output_labels:
1273    if label != ellipsis_label and label not in input_label_set:
1274      raise ValueError('Output subscripts contain the label {} not present '
1275                       'in the input subscripts.'.format(label))
1276  if ellipsis_label and output_labels:
1277    num_output_ellipses = output_labels.count(ellipsis_label)
1278    if num_output_ellipses > 1:
1279      raise ValueError(
1280          'Output subscripts contain multiple ellipsis: {}'.format(equation))
1281
1282  # Early return if <= 2 inputs. Resolved shapes are not needed.
1283  if len(input_shapes) <= 2:
1284    return resolved_equation, None, ellipsis_label
1285
1286  # Create a map from axis labels to known dimensions. This is used to infer
1287  # unknown dimensions if a known dimension also has the same label.
1288  label_to_dim = collections.defaultdict(lambda: 1)
1289  for i, (labels, shape) in enumerate(zip(input_labels, input_shapes)):
1290    if shape is None:
1291      continue
1292    ellipsis_start = labels.find(ellipsis_label) if ellipsis_label else -1
1293    if ellipsis_start != -1:  # This input contains an ellipsis.
1294      if ellipsis_start != labels.rfind(ellipsis_label):
1295        raise ValueError('Too many ellipsis')
1296      if len(labels) > len(shape) + 1:
1297        raise ValueError('Too many named labels in {}th subscript string of'
1298                         ' equation {} for input shape {} '.format(
1299                             i, equation, shape))
1300      ellipsis_end = ellipsis_start + len(shape) + 1 - len(labels)
1301      shape[ellipsis_start:ellipsis_end] = ([
1302          np.prod(
1303              list(filter(None, shape[ellipsis_start:ellipsis_end])),
1304              dtype=np.int64)
1305      ])
1306    else:
1307      # This input does not contain an ellipsis.
1308      if len(labels) != len(shape):
1309        raise ValueError(
1310            'Number of named labels in input #{} of equation {} '
1311            'must be equal to the number of dimensions in shape {}'.format(
1312                i, equation, shape))
1313    for dim, label in zip(shape, labels):
1314      if dim is not None:
1315        label_to_dim[label] = max(label_to_dim[label], dim)
1316
1317  resolved_shapes = []
1318  for labels in input_labels:
1319    resolved_shapes.append([label_to_dim[label] for label in labels])
1320  return resolved_equation, resolved_shapes, ellipsis_label
1321