1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Arithmetic Operations that don't fit into math_ops due to dependencies. 16 17To avoid circular dependencies, some math_ops should go here. 18""" 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24import re 25 26from six.moves import xrange # pylint: disable=redefined-builtin 27 28from tensorflow.python.framework import ops 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.util import deprecation 33from tensorflow.python.util.tf_export import tf_export 34 35 36# TODO(b/27419586) Change docstring for required dtype of x once int allowed 37@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta']) 38@deprecation.deprecated_endpoints('lbeta') 39def lbeta(x, name=None): 40 r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. 41 42 Given one-dimensional `z = [z_0,...,z_{K-1}]`, we define 43 44 $$Beta(z) = \prod_j Gamma(z_j) / Gamma(\sum_j z_j)$$ 45 46 And for `n + 1` dimensional `x` with shape `[N1, ..., Nn, K]`, we define 47 $$lbeta(x)[i1, ..., in] = Log(|Beta(x[i1, ..., in, :])|)$$. 48 49 In other words, the last dimension is treated as the `z` vector. 50 51 Note that if `z = [u, v]`, then 52 \\(Beta(z) = int_0^1 t^{u-1} (1 - t)^{v-1} dt\\), which defines the 53 traditional bivariate beta function. 54 55 If the last dimension is empty, we follow the convention that the sum over 56 the empty set is zero, and the product is one. 57 58 Args: 59 x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`. 60 name: A name for the operation (optional). 61 62 Returns: 63 The logarithm of \\(|Beta(x)|\\) reducing along the last dimension. 64 """ 65 # In the event that the last dimension has zero entries, we return -inf. 66 # This is consistent with a convention that the sum over the empty set 0, and 67 # the product is 1. 68 # This is standard. See https://en.wikipedia.org/wiki/Empty_set. 69 with ops.name_scope(name, 'lbeta', [x]): 70 x = ops.convert_to_tensor(x, name='x') 71 72 # Note reduce_sum([]) = 0. 73 log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1]) 74 75 # Note lgamma(0) = infinity, so if x = [] 76 # log_gamma_sum_x = lgamma(0) = infinity, and 77 # log_prod_gamma_x = lgamma(1) = 0, 78 # so result = -infinity 79 sum_x = math_ops.reduce_sum(x, axis=[-1]) 80 log_gamma_sum_x = math_ops.lgamma(sum_x) 81 result = log_prod_gamma_x - log_gamma_sum_x 82 83 return result 84 85 86@tf_export('math.bessel_i0') 87def bessel_i0(x, name=None): 88 """Computes the Bessel i0 function of `x` element-wise. 89 90 Modified Bessel function of order 0. 91 92 It is preferable to use the numerically stabler function `i0e(x)` instead. 93 94 Args: 95 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 96 `float32`, `float64`. 97 name: A name for the operation (optional). 98 99 Returns: 100 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 101 102 @compatibility(scipy) 103 Equivalent to scipy.special.i0 104 @end_compatibility 105 """ 106 with ops.name_scope(name, 'bessel_i0', [x]): 107 return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i0e(x) 108 109 110@tf_export('math.bessel_i1') 111def bessel_i1(x, name=None): 112 """Computes the Bessel i1 function of `x` element-wise. 113 114 Modified Bessel function of order 1. 115 116 It is preferable to use the numerically stabler function `i1e(x)` instead. 117 118 Args: 119 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 120 `float32`, `float64`. 121 name: A name for the operation (optional). 122 123 Returns: 124 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 125 126 @compatibility(scipy) 127 Equivalent to scipy.special.i1 128 @end_compatibility 129 """ 130 with ops.name_scope(name, 'bessel_i1', [x]): 131 return math_ops.exp(math_ops.abs(x)) * math_ops.bessel_i1e(x) 132 133 134@tf_export('einsum', 'linalg.einsum') 135def einsum(equation, *inputs, **kwargs): 136 """A generalized contraction between tensors of arbitrary dimension. 137 138 This function returns a tensor whose elements are defined by `equation`, 139 which is written in a shorthand form inspired by the Einstein summation 140 convention. As an example, consider multiplying two matrices 141 A and B to form a matrix C. The elements of C are given by: 142 143 ``` 144 C[i,k] = sum_j A[i,j] * B[j,k] 145 ``` 146 147 The corresponding `equation` is: 148 149 ``` 150 ij,jk->ik 151 ``` 152 153 In general, the `equation` is obtained from the more familiar element-wise 154 equation by 155 1. removing variable names, brackets, and commas, 156 2. replacing "*" with ",", 157 3. dropping summation signs, and 158 4. moving the output to the right, and replacing "=" with "->". 159 160 Many common operations can be expressed in this way. For example: 161 162 ```python 163 # Matrix multiplication 164 >>> einsum('ij,jk->ik', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] 165 166 # Dot product 167 >>> einsum('i,i->', u, v) # output = sum_i u[i]*v[i] 168 169 # Outer product 170 >>> einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] 171 172 # Transpose 173 >>> einsum('ij->ji', m) # output[j,i] = m[i,j] 174 175 # Trace 176 >>> einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] 177 178 # Batch matrix multiplication 179 >>> einsum('aij,ajk->aik', s, t) # out[a,i,k] = sum_j s[a,i,j] * t[a, j, k] 180 ``` 181 182 This function behaves like `numpy.einsum`, but does not support: 183 184 * Ellipses (subscripts like `ij...,jk...->ik...`) 185 * Subscripts where an axis appears more than once for a single input 186 (e.g. `ijj,k->ik`) unless it is a trace (e.g. `ijji`). 187 188 Args: 189 equation: a `str` describing the contraction, in the same format as 190 `numpy.einsum`. 191 *inputs: the inputs to contract (each one a `Tensor`), whose shapes should 192 be consistent with `equation`. 193 name: A name for the operation (optional). 194 195 Returns: 196 The contracted `Tensor`, with shape determined by `equation`. 197 198 Raises: 199 ValueError: If 200 - the format of `equation` is incorrect, 201 - the number of inputs implied by `equation` does not match `len(inputs)`, 202 - an axis appears in the output subscripts but not in any of the inputs, 203 - the number of dimensions of an input differs from the number of 204 indices in its subscript, or 205 - the input shapes are inconsistent along a particular axis. 206 """ 207 equation = equation.replace(' ', '') 208 209 name = kwargs.pop('name', None) 210 if kwargs: 211 raise TypeError('invalid keyword arguments for this function: ' + ', '.join( 212 [format(key) for key in sorted(list(kwargs.keys()))])) 213 with ops.name_scope(name, 'einsum', [equation, inputs]) as name: 214 if '...' in equation: 215 raise ValueError('Subscripts with ellipses are not yet supported.') 216 217 match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation) 218 if not match: 219 raise ValueError('Indices have incorrect format: %s' % equation) 220 221 inputs = list(inputs) 222 input_axis_labels = match.group(1).split(',') 223 if len(inputs) != len(input_axis_labels): 224 raise ValueError('Got %d arguments for equation "%s", expecting %d' % 225 (len(inputs), equation, len(input_axis_labels))) 226 227 axis_labels = set(''.join(input_axis_labels)) 228 if match.group(2): 229 output_axis_labels = match.group(2)[2:] 230 else: 231 # infer the output subscripts if not given, assume alphabetical order 232 indices = ''.join(sorted(axis_labels)) 233 counts = {ax: 0 for ax in indices} 234 for axes_ in input_axis_labels: 235 for ax in axes_: 236 counts[ax] += 1 237 238 output_axis_labels = ''.join( 239 sorted(ax for ax in indices if counts[ax] == 1)) 240 for a in axis_labels: 241 for input_labels in input_axis_labels: 242 if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and 243 input_labels == input_labels[::-1] and '->' not in equation): 244 return math_ops.trace(inputs[0]) 245 if input_labels.count(a) > 1: 246 raise ValueError( 247 'Subscript not supported: an axis appears more than once: %s' % 248 input_labels) 249 for a in axis_labels: 250 input_count = sum(1 for s in input_axis_labels if a in s) 251 if input_count > 2 and a not in output_axis_labels: 252 logging.warn( 253 'Falling back to exponential-space implementation of einsum()' 254 ' because index "%s" is summed over more than two inputs.', a) 255 return _exponential_space_einsum(equation, *inputs) 256 257 temp = inputs[0] 258 temp_axis_labels = input_axis_labels[0] 259 for i in xrange(len(inputs) - 1): 260 axes_to_sum = ( 261 set(temp_axis_labels) & 262 set(input_axis_labels[i + 1]) - set(output_axis_labels)) 263 temp, temp_axis_labels = _einsum_reduction( 264 temp, temp_axis_labels, inputs[i + 1], input_axis_labels[i + 1], 265 axes_to_sum) 266 267 268 missing_indices = set(temp_axis_labels) - set(output_axis_labels) 269 if missing_indices: 270 axis = [ 271 i for i, a in enumerate(temp_axis_labels) 272 if a not in output_axis_labels 273 ] 274 temp = math_ops.reduce_sum(temp, axis=axis) 275 temp_axis_labels = ''.join( 276 a for a in temp_axis_labels if a in output_axis_labels) 277 if sorted(temp_axis_labels) != sorted(output_axis_labels): 278 raise ValueError('Invalid equation: %s' % equation) 279 280 perm = [temp_axis_labels.index(a) for a in output_axis_labels] 281 return _transpose_if_necessary(temp, perm) 282 283 284def _einsum_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): 285 """Helper for einsum() that computes the result of a two-argument einsum(). 286 287 Args: 288 t0: a `Tensor` 289 t0_axis_labels: a string of axis labels. This string's length must equal 290 the rank of t0. 291 t1: a `Tensor` 292 t1_axis_labels: a string to axis labels. This string's length must equal 293 the rank of t1. 294 axes_to_sum: set of labels of axes to be summed over 295 296 Returns: 297 A `Tensor` whose elements are obtained by summing, over all axes in 298 `axes_to_sum`, the corresponding elements of `t0` and `t1`. 299 300 For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and 301 axes_to_sum == {j,k}, this will return a tensor x where 302 303 out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l] 304 305 Raises: 306 ValueError: if the rank of `t0` does not match the length of 307 `t0_axis_labels`, or that of `t1` does not match the length of 308 `t1_axis_labels`. 309 """ 310 if len(t0_axis_labels) != len(t0.get_shape()): 311 raise ValueError( 312 'Tensor t0 of rank %d does not match einsum reduction of length %d' % 313 (len(t0.get_shape()), len(t0_axis_labels))) 314 if len(t1_axis_labels) != len(t1.get_shape()): 315 raise ValueError( 316 'Tensor t1 of rank %d does not match einsum reduction of length %d' % 317 (len(t1.get_shape()), len(t1_axis_labels))) 318 319 # This function computes the result of a two-argument einsum() using batch 320 # matrix multiplication. This involves 321 # 1. transposing t0 and t1 so that axes are in the correct order for 322 # batch matrix multiplication, and 323 # 2. reshaping t0 and t1 so that they are both of rank 3. 324 325 # First, we divide axes into three groups: 326 # * "preserved" axes are present in both inputs and the output 327 # * "summed" axes are present in both inputs but not the output 328 # * "broadcast" axes are present in exactly one input and the output 329 # 330 # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a 331 # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are 332 # summed axes. 333 assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum) 334 preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum 335 broadcast_axes = {} 336 for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]): 337 broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum 338 339 # Reorder the axes so that: 340 # 1. preserved axes come first in both inputs 341 # 2. in input 0, broadcast axes come next, followed by summed axes 342 # 3. in input 1, summed axes come next, followed by broadcast axes 343 def sort_key(input_index, a): 344 if a in preserved_axes: 345 return (-1, a) 346 elif ((input_index == 0 and a in broadcast_axes[0]) or 347 (input_index == 1 and a in axes_to_sum)): 348 return (0, a) 349 else: 350 return (1, a) 351 352 axis_labels = [t0_axis_labels, t1_axis_labels] 353 sorted_axes = [ 354 sorted(sym_list, key=lambda a: sort_key(i, a)) 355 for i, sym_list in enumerate(axis_labels) 356 ] 357 inputs = [t0, t1] 358 for i, axes_str in enumerate(axis_labels): 359 perm = [axes_str.find(a) for a in sorted_axes[i]] 360 inputs[i] = _transpose_if_necessary(inputs[i], perm) 361 t0, t1 = inputs 362 363 if not axes_to_sum: 364 # In the special case where there are no axes to sum over, reduce to mul() 365 # rather than to batch matrix multiplication. 366 for _ in broadcast_axes[1]: 367 t0 = array_ops.expand_dims(t0, -1) 368 for _ in broadcast_axes[0]: 369 t1 = array_ops.expand_dims(t1, len(preserved_axes)) 370 product = math_ops.multiply(t0, t1) 371 product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):] 372 return product, ''.join(product_axes) 373 else: 374 # Reduce to matmul(). 375 376 # Reshape both inputs so as to combine multiple broadcast axes 377 # into a single axis, and combine multiple summed axes into a 378 # single axis. 379 380 t0_shape = _get_shape(t0) 381 num_broadcast_elements_t0 = _total_size( 382 t0_shape[len(preserved_axes):-len(axes_to_sum)]) 383 num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):]) 384 new_shape = ( 385 t0_shape[:len(preserved_axes)] + 386 [num_broadcast_elements_t0, num_summed_elements]) 387 t0 = _reshape_if_necessary(t0, new_shape) 388 389 t1_shape = _get_shape(t1) 390 num_broadcast_elements_t1 = _total_size( 391 t1_shape[len(preserved_axes) + len(axes_to_sum):]) 392 new_shape = ( 393 t1_shape[:len(preserved_axes)] + 394 [num_summed_elements, num_broadcast_elements_t1]) 395 t1 = _reshape_if_necessary(t1, new_shape) 396 397 product = math_ops.matmul(t0, t1) 398 399 # Undo compaction of broadcast axes 400 uncompacted_shape = ( 401 t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] + 402 t1_shape[len(t1_shape) - len(broadcast_axes[1]):]) 403 product = _reshape_if_necessary(product, uncompacted_shape) 404 405 product_axes = ( 406 sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] + 407 sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):]) 408 409 return product, ''.join(product_axes) 410 411 412def _transpose_if_necessary(tensor, perm): 413 """Like transpose(), but avoids creating a new tensor if possible.""" 414 if perm != range(len(perm)): 415 return array_ops.transpose(tensor, perm=perm) 416 else: 417 return tensor 418 419 420def _reshape_if_necessary(tensor, new_shape): 421 """Like reshape(), but avoids creating a new tensor if possible.""" 422 # Accept None as an alias for -1 in new_shape. 423 new_shape = tuple(-1 if x is None else x for x in new_shape) 424 cur_shape = tuple(x.value for x in tensor.get_shape().dims) 425 if (len(new_shape) == len(cur_shape) and 426 all(d0 == d1 or d1 == -1 for d0, d1 in zip(cur_shape, new_shape))): 427 return tensor 428 else: 429 return array_ops.reshape(tensor, new_shape) 430 431 432def _get_shape(tensor): 433 """Like get_shape().as_list(), but explicitly queries the shape of a tensor 434 if necessary to ensure that the returned value contains no unknown value.""" 435 436 shape = tensor.get_shape().as_list() 437 none_indices = [i for i, d in enumerate(shape) if d is None] 438 if none_indices: 439 # Query the shape if shape contains None values 440 shape_tensor = array_ops.shape(tensor) 441 for i in none_indices: 442 shape[i] = shape_tensor[i] 443 return shape 444 445 446def _total_size(shape_values): 447 """Given list of tensor shape values, returns total size. 448 If shape_values contains tensor values (which are results of 449 array_ops.shape), then it returns a scalar tensor. 450 If not, it returns an integer.""" 451 452 result = 1 453 for val in shape_values: 454 result *= val 455 return result 456 457 458def _exponential_space_einsum(equation, *inputs): 459 """Fallback implementation that supports summing an index over > 2 inputs.""" 460 if '...' in equation: 461 raise ValueError('Subscripts with ellipses are not yet supported.') 462 463 match = re.match('^([a-zA-Z,]+)(->[a-zA-Z]*)?$', equation) 464 if not match: 465 raise ValueError('Indices have incorrect format: %s' % equation) 466 467 inputs = list(inputs) 468 idx_in = match.group(1).split(',') 469 idx_all = set(''.join(idx_in)) 470 indices = ''.join(sorted(idx_all)) 471 472 if match.group(2): 473 idx_out = match.group(2)[2:] 474 475 else: 476 # infer the output subscripts if not given, assume alphabetical order 477 counts = {ax: 0 for ax in indices} 478 for axes_ in idx_in: 479 for ax in axes_: 480 counts[ax] += 1 481 482 idx_out = ''.join(sorted(ax for ax in indices if counts[ax] == 1)) 483 484 if len(idx_in) != len(inputs): 485 raise ValueError('Expected %d inputs but got %d' % (len(idx_in), 486 len(inputs))) 487 488 missing_idx = set(idx_out).difference(idx_all) 489 if missing_idx: 490 raise ValueError('Unknown output axes: %s' % missing_idx) 491 492 axis_order = {} 493 for ax in indices: 494 if ax not in idx_out: 495 axis_order[ax] = len(axis_order) 496 for ax in idx_out: 497 axis_order[ax] = len(axis_order) 498 499 # transpose inputs so axes are in order 500 for i, (input_, axes_) in enumerate(zip(inputs, idx_in)): 501 if input_.get_shape().ndims != len(axes_): 502 raise ValueError( 503 'Input %d with axes %s has incorrect' \ 504 ' number of dimensions (expected %d, got %d)' % ( 505 i, axes_, len(axes_), input_.get_shape().ndims 506 ) 507 ) 508 509 sorted_idx = sorted(axes_, key=axis_order.get) 510 511 if len(set(axes_)) != len(axes_): 512 raise ValueError( 513 'Subscript not supported: an axis appears more than once: %s' % axes_) 514 515 if list(axes_) != sorted_idx: 516 permuted = [axes_.find(ax) for ax in sorted_idx] 517 inputs[i] = array_ops.transpose(input_, permuted) 518 idx_in[i] = sorted_idx 519 520 reduction_idx = [] 521 shapes = [[dim if dim else -1 522 for dim in tensor.get_shape().as_list()] 523 for tensor in inputs] 524 525 # validate shapes for broadcasting 526 for j, ax in enumerate(sorted(idx_all, key=axis_order.get)): 527 dims = [] 528 for i, idx in enumerate(idx_in): 529 if ax not in idx: 530 shapes[i].insert(j, 1) 531 else: 532 dim = shapes[i][j] 533 if isinstance(dim, int) and dim > 1: 534 dims.append(dim) 535 536 if len(set(dims)) > 1: 537 raise ValueError('Dimension mismatch on axis: %s' % ax) 538 539 if ax not in idx_out: 540 reduction_idx.append(j) 541 542 # reshape, multiply 543 expanded_inputs = [ 544 array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes) 545 ] 546 expanded_output = 1 547 for input_ in expanded_inputs: 548 expanded_output *= input_ 549 550 # contract 551 return math_ops.reduce_sum(expanded_output, reduction_idx) 552