1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Arithmetic Operations that don't fit into math_ops due to dependencies.
16
17To avoid circular dependencies, some math_ops should go here.
18"""
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import re
25
26from six.moves import xrange  # pylint: disable=redefined-builtin
27
28from tensorflow.python.framework import ops
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import deprecation
33from tensorflow.python.util.tf_export import tf_export
34
35
36# TODO(b/27419586) Change docstring for required dtype of x once int allowed
37@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta'])
38@deprecation.deprecated_endpoints('lbeta')
39def lbeta(x, name=None):
40  r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
41
42  Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define
43
44  $$Beta(z) = \prod_j Gamma(z_j) / Gamma(\sum_j z_j)$$
45
46  And for `n + 1` dimensional `x` with shape `[N1, ..., Nn, K]`, we define
47  $$lbeta(x)[i1, ..., in] = Log(|Beta(x[i1, ..., in, :])|)$$.
48
49  In other words, the last dimension is treated as the `z` vector.
50
51  Note that if `z = [u, v]`, then
52  \\(Beta(z) = int_0^1 t^{u-1} (1 - t)^{v-1} dt\\), which defines the
53  traditional bivariate beta function.
54
55  If the last dimension is empty, we follow the convention that the sum over
56  the empty set is zero, and the product is one.
57
58  Args:
59    x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`.
60    name: A name for the operation (optional).
61
62  Returns:
63    The logarithm of \\(|Beta(x)|\\) reducing along the last dimension.
64  """
65  # In the event that the last dimension has zero entries, we return -inf.
66  # This is consistent with a convention that the sum over the empty set 0, and
67  # the product is 1.
68  # This is standard.  See https://en.wikipedia.org/wiki/Empty_set.
69  with ops.name_scope(name, 'lbeta', [x]):
70    x = ops.convert_to_tensor(x, name='x')
71
72    # Note reduce_sum([]) = 0.
73    log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1])
74
75    # Note lgamma(0) = infinity, so if x = []
76    # log_gamma_sum_x = lgamma(0) = infinity, and
77    # log_prod_gamma_x = lgamma(1) = 0,
78    # so result = -infinity
79    sum_x = math_ops.reduce_sum(x, axis=[-1])
80    log_gamma_sum_x = math_ops.lgamma(sum_x)
81    result = log_prod_gamma_x - log_gamma_sum_x
82
83    return result
84
85
86@tf_export('math.bessel_i0')
87def bessel_i0(x, name=None):
88  """Computes the Bessel i0 function of `x` element-wise.
89
90  Modified Bessel function of order 0.
91
92  It is preferable to use the numerically stabler function `i0e(x)` instead.
93
94  Args:
95    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
96      `float32`, `float64`.
97    name: A name for the operation (optional).
98
99  Returns:
100    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
101
102  @compatibility(scipy)
103  Equivalent to scipy.special.i0
104  @end_compatibility
105  """
106  with ops.name_scope(name, 'bessel_i0', [x]):
107    return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x)
108
109
110@tf_export('math.bessel_i1')
111def bessel_i1(x, name=None):
112  """Computes the Bessel i1 function of `x` element-wise.
113
114  Modified Bessel function of order 1.
115
116  It is preferable to use the numerically stabler function `i1e(x)` instead.
117
118  Args:
119    x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
120      `float32`, `float64`.
121    name: A name for the operation (optional).
122
123  Returns:
124    A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
125
126  @compatibility(scipy)
127  Equivalent to scipy.special.i1
128  @end_compatibility
129  """
130  with ops.name_scope(name, 'bessel_i1', [x]):
131    return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x)
132
133
134@tf_export('einsum', 'linalg.einsum')
135def einsum(equation, *inputs, **kwargs):
136  """A generalized contraction between tensors of arbitrary dimension.
137
138  This function returns a tensor whose elements are defined by `equation`,
139  which is written in a shorthand form inspired by the Einstein summation
140  convention.  As an example, consider multiplying two matrices
141  A and B to form a matrix C.  The elements of C are given by:
142
143  ```
144    C[i,k] = sum_j A[i,j] * B[j,k]
145  ```
146
147  The corresponding `equation` is:
148
149  ```
150    ij,jk->ik
151  ```
152
153  In general, the `equation` is obtained from the more familiar element-wise
154  equation by
155    1. removing variable names, brackets, and commas,
156    2. replacing "*" with ",",
157    3. dropping summation signs, and
158    4. moving the output to the right, and replacing "=" with "->".
159
160  Many common operations can be expressed in this way.  For example:
161
162  ```python
163  # Matrix multiplication
164  >>> einsum('ij,jk->ik', m0, m1)  # output[i,k] = sum_j m0[i,j] * m1[j, k]
165
166  # Dot product
167  >>> einsum('i,i->', u, v)  # output = sum_i u[i]*v[i]
168
169  # Outer product
170  >>> einsum('i,j->ij', u, v)  # output[i,j] = u[i]*v[j]
171
172  # Transpose
173  >>> einsum('ij->ji', m)  # output[j,i] = m[i,j]
174
175  # Trace
176  >>> einsum('ii', m)  # output[j,i] = trace(m) = sum_i m[i, i]
177
178  # Batch matrix multiplication
179  >>> einsum('aij,ajk->aik', s, t)  # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
180  ```
181
182  This function behaves like `numpy.einsum`, but does not support:
183
184  * Ellipses (subscripts like `ij...,jk...->ik...`)
185  * Subscripts where an axis appears more than once for a single input
186    (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`).
187
188  Args:
189    equation: a `str` describing the contraction, in the same format as
190      `numpy.einsum`.
191    *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
192      be consistent with `equation`.
193    name: A name for the operation (optional).
194
195  Returns:
196    The contracted `Tensor`, with shape determined by `equation`.
197
198  Raises:
199    ValueError: If
200      - the format of `equation` is incorrect,
201      - the number of inputs implied by `equation` does not match `len(inputs)`,
202      - an axis appears in the output subscripts but not in any of the inputs,
203      - the number of dimensions of an input differs from the number of
204        indices in its subscript, or
205      - the input shapes are inconsistent along a particular axis.
206  """
207  equation = equation.replace(' ', '')
208
209  name = kwargs.pop('name', None)
210  if kwargs:
211    raise TypeError('invalid keyword arguments for this function: ' + ', '.join(
212        [format(key) for key in sorted(list(kwargs.keys()))]))
213  with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
214    if '...' in equation:
215      raise ValueError('Subscripts with ellipses are not yet supported.')
216
217    match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation)
218    if not match:
219      raise ValueError('Indices have incorrect format: %s' % equation)
220
221    inputs = list(inputs)
222    input_axis_labels = match.group(1).split(',')
223    if len(inputs) != len(input_axis_labels):
224      raise ValueError('Got %d arguments for equation "%s", expecting %d' %
225                       (len(inputs), equation, len(input_axis_labels)))
226
227    axis_labels = set(''.join(input_axis_labels))
228    if match.group(2):
229      output_axis_labels = match.group(2)[2:]
230    else:
231      # infer the output subscripts if not given, assume alphabetical order
232      indices = ''.join(sorted(axis_labels))
233      counts = {ax: 0 for ax in indices}
234      for axes_ in input_axis_labels:
235        for ax in axes_:
236          counts[ax] += 1
237
238      output_axis_labels = ''.join(
239          sorted(ax for ax in indices if counts[ax] == 1))
240    for a in axis_labels:
241      for input_labels in input_axis_labels:
242        if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
243            input_labels == input_labels[::-1] and '->' not in equation):
244          return math_ops.trace(inputs[0])
245        if input_labels.count(a) > 1:
246          raise ValueError(
247              'Subscript not supported: an axis appears more than once: %s' %
248              input_labels)
249    for a in axis_labels:
250      input_count = sum(1 for s in input_axis_labels if a in s)
251      if input_count > 2 and a not in output_axis_labels:
252        logging.warn(
253            'Falling back to exponential-space implementation of einsum()'
254            ' because index "%s" is summed over more than two inputs.', a)
255        return _exponential_space_einsum(equation, *inputs)
256
257    temp = inputs[0]
258    temp_axis_labels = input_axis_labels[0]
259    for i in xrange(len(inputs) - 1):
260      axes_to_sum = (
261          set(temp_axis_labels) &
262          set(input_axis_labels[i + 1]) - set(output_axis_labels))
263      temp, temp_axis_labels = _einsum_reduction(
264          temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1],
265          axes_to_sum)
266
267
268    missing_indices = set(temp_axis_labels) - set(output_axis_labels)
269    if missing_indices:
270      axis = [
271          i for i, a in enumerate(temp_axis_labels)
272          if a not in output_axis_labels
273      ]
274      temp = math_ops.reduce_sum(temp, axis=axis)
275      temp_axis_labels = ''.join(
276          a for a in temp_axis_labels if a in output_axis_labels)
277    if sorted(temp_axis_labels) != sorted(output_axis_labels):
278      raise ValueError('Invalid equation: %s' % equation)
279
280    perm = [temp_axis_labels.index(a) for a in output_axis_labels]
281    return _transpose_if_necessary(temp, perm)
282
283
284def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
285  """Helper for einsum() that computes the result of a two-argument einsum().
286
287  Args:
288    t0: a `Tensor`
289    t0_axis_labels: a string of axis labels.  This string's length must equal
290      the rank of t0.
291    t1: a `Tensor`
292    t1_axis_labels: a string to axis labels.  This string's length must equal
293      the rank of t1.
294    axes_to_sum: set of labels of axes to be summed over
295
296  Returns:
297    A `Tensor` whose elements are obtained by summing, over all axes in
298    `axes_to_sum`, the corresponding elements of `t0` and `t1`.
299
300    For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and
301    axes_to_sum == {j,k}, this will return a tensor x where
302
303      out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l]
304
305  Raises:
306    ValueError: if the rank of `t0` does not match the length of
307      `t0_axis_labels`, or that of `t1` does not match the length of
308      `t1_axis_labels`.
309  """
310  if len(t0_axis_labels) != len(t0.get_shape()):
311    raise ValueError(
312        'Tensor t0 of rank %d does not match einsum reduction of length %d' %
313        (len(t0.get_shape()), len(t0_axis_labels)))
314  if len(t1_axis_labels) != len(t1.get_shape()):
315    raise ValueError(
316        'Tensor t1 of rank %d does not match einsum reduction of length %d' %
317        (len(t1.get_shape()), len(t1_axis_labels)))
318
319  # This function computes the result of a two-argument einsum() using batch
320  # matrix multiplication.  This involves
321  # 1. transposing t0 and t1 so that axes are in the correct order for
322  #    batch matrix multiplication, and
323  # 2. reshaping t0 and t1 so that they are both of rank 3.
324
325  # First, we divide axes into three groups:
326  #  * "preserved" axes are present in both inputs and the output
327  #  * "summed" axes are present in both inputs but not the output
328  #  * "broadcast" axes are present in exactly one input and the output
329  #
330  # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a
331  # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are
332  # summed axes.
333  assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum)
334  preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum
335  broadcast_axes = {}
336  for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]):
337    broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum
338
339  # Reorder the axes so that:
340  # 1. preserved axes come first in both inputs
341  # 2. in input 0, broadcast axes come next, followed by summed axes
342  # 3. in input 1, summed axes come next, followed by broadcast axes
343  def sort_key(input_index, a):
344    if a in preserved_axes:
345      return (-1, a)
346    elif ((input_index == 0 and a in broadcast_axes[0]) or
347          (input_index == 1 and a in axes_to_sum)):
348      return (0, a)
349    else:
350      return (1, a)
351
352  axis_labels = [t0_axis_labels, t1_axis_labels]
353  sorted_axes = [
354      sorted(sym_list, key=lambda a: sort_key(i, a))
355      for i, sym_list in enumerate(axis_labels)
356  ]
357  inputs = [t0, t1]
358  for i, axes_str in enumerate(axis_labels):
359    perm = [axes_str.find(a) for a in sorted_axes[i]]
360    inputs[i] = _transpose_if_necessary(inputs[i], perm)
361  t0, t1 = inputs
362
363  if not axes_to_sum:
364    # In the special case where there are no axes to sum over, reduce to mul()
365    # rather than to batch matrix multiplication.
366    for _ in broadcast_axes[1]:
367      t0 = array_ops.expand_dims(t0, -1)
368    for _ in broadcast_axes[0]:
369      t1 = array_ops.expand_dims(t1, len(preserved_axes))
370    product = math_ops.multiply(t0, t1)
371    product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):]
372    return product, ''.join(product_axes)
373  else:
374    # Reduce to matmul().
375
376    # Reshape both inputs so as to combine multiple broadcast axes
377    # into a single axis, and combine multiple summed axes into a
378    # single axis.
379
380    t0_shape = _get_shape(t0)
381    num_broadcast_elements_t0 = _total_size(
382        t0_shape[len(preserved_axes):-len(axes_to_sum)])
383    num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
384    new_shape = (
385        t0_shape[:len(preserved_axes)] +
386        [num_broadcast_elements_t0, num_summed_elements])
387    t0 = _reshape_if_necessary(t0, new_shape)
388
389    t1_shape = _get_shape(t1)
390    num_broadcast_elements_t1 = _total_size(
391        t1_shape[len(preserved_axes) + len(axes_to_sum):])
392    new_shape = (
393        t1_shape[:len(preserved_axes)] +
394        [num_summed_elements, num_broadcast_elements_t1])
395    t1 = _reshape_if_necessary(t1, new_shape)
396
397    product = math_ops.matmul(t0, t1)
398
399    # Undo compaction of broadcast axes
400    uncompacted_shape = (
401        t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] +
402        t1_shape[len(t1_shape) - len(broadcast_axes[1]):])
403    product = _reshape_if_necessary(product, uncompacted_shape)
404
405    product_axes = (
406        sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] +
407        sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):])
408
409    return product, ''.join(product_axes)
410
411
412def _transpose_if_necessary(tensor, perm):
413  """Like transpose(), but avoids creating a new tensor if possible."""
414  if perm != range(len(perm)):
415    return array_ops.transpose(tensor, perm=perm)
416  else:
417    return tensor
418
419
420def _reshape_if_necessary(tensor, new_shape):
421  """Like reshape(), but avoids creating a new tensor if possible."""
422  # Accept None as an alias for -1 in new_shape.
423  new_shape = tuple(-1 if x is None else x for x in new_shape)
424  cur_shape = tuple(x.value for x in tensor.get_shape().dims)
425  if (len(new_shape) == len(cur_shape) and
426      all(d0 == d1 or d1 == -1 for d0, d1 in zip(cur_shape, new_shape))):
427    return tensor
428  else:
429    return array_ops.reshape(tensor, new_shape)
430
431
432def _get_shape(tensor):
433  """Like get_shape().as_list(), but explicitly queries the shape of a tensor
434  if necessary to ensure that the returned value contains no unknown value."""
435
436  shape = tensor.get_shape().as_list()
437  none_indices = [i for i, d in enumerate(shape) if d is None]
438  if none_indices:
439    # Query the shape if shape contains None values
440    shape_tensor = array_ops.shape(tensor)
441    for i in none_indices:
442      shape[i] = shape_tensor[i]
443  return shape
444
445
446def _total_size(shape_values):
447  """Given list of tensor shape values, returns total size.
448  If shape_values contains tensor values (which are results of
449  array_ops.shape), then it returns a scalar tensor.
450  If not, it returns an integer."""
451
452  result = 1
453  for val in shape_values:
454    result *= val
455  return result
456
457
458def _exponential_space_einsum(equation, *inputs):
459  """Fallback implementation that supports summing an index over > 2 inputs."""
460  if '...' in equation:
461    raise ValueError('Subscripts with ellipses are not yet supported.')
462
463  match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation)
464  if not match:
465    raise ValueError('Indices have incorrect format: %s' % equation)
466
467  inputs = list(inputs)
468  idx_in = match.group(1).split(',')
469  idx_all = set(''.join(idx_in))
470  indices = ''.join(sorted(idx_all))
471
472  if match.group(2):
473    idx_out = match.group(2)[2:]
474
475  else:
476    # infer the output subscripts if not given, assume alphabetical order
477    counts = {ax: 0 for ax in indices}
478    for axes_ in idx_in:
479      for ax in axes_:
480        counts[ax] += 1
481
482    idx_out = ''.join(sorted(ax for ax in indices if counts[ax] == 1))
483
484  if len(idx_in) != len(inputs):
485    raise ValueError('Expected %d inputs but got %d' % (len(idx_in),
486                                                        len(inputs)))
487
488  missing_idx = set(idx_out).difference(idx_all)
489  if missing_idx:
490    raise ValueError('Unknown output axes: %s' % missing_idx)
491
492  axis_order = {}
493  for ax in indices:
494    if ax not in idx_out:
495      axis_order[ax] = len(axis_order)
496  for ax in idx_out:
497    axis_order[ax] = len(axis_order)
498
499  # transpose inputs so axes are in order
500  for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
501    if input_.get_shape().ndims != len(axes_):
502      raise ValueError(
503          'Input %d with axes %s has incorrect' \
504          ' number of dimensions (expected %d, got %d)' % (
505              i, axes_, len(axes_), input_.get_shape().ndims
506          )
507      )
508
509    sorted_idx = sorted(axes_, key=axis_order.get)
510
511    if len(set(axes_)) != len(axes_):
512      raise ValueError(
513          'Subscript not supported: an axis appears more than once: %s' % axes_)
514
515    if list(axes_) != sorted_idx:
516      permuted = [axes_.find(ax) for ax in sorted_idx]
517      inputs[i] = array_ops.transpose(input_, permuted)
518      idx_in[i] = sorted_idx
519
520  reduction_idx = []
521  shapes = [[dim if dim else -1
522             for dim in tensor.get_shape().as_list()]
523            for tensor in inputs]
524
525  # validate shapes for broadcasting
526  for j, ax in enumerate(sorted(idx_all, key=axis_order.get)):
527    dims = []
528    for i, idx in enumerate(idx_in):
529      if ax not in idx:
530        shapes[i].insert(j, 1)
531      else:
532        dim = shapes[i][j]
533        if isinstance(dim, int) and dim > 1:
534          dims.append(dim)
535
536    if len(set(dims)) > 1:
537      raise ValueError('Dimension mismatch on axis: %s' % ax)
538
539    if ax not in idx_out:
540      reduction_idx.append(j)
541
542  # reshape, multiply
543  expanded_inputs = [
544      array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes)
545  ]
546  expanded_output = 1
547  for input_ in expanded_inputs:
548    expanded_output *= input_
549
550  # contract
551  return math_ops.reduce_sum(expanded_output, reduction_idx)
552