1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Internal utilities for `LinearOperator` classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.module import module 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import control_flow_ops 29from tensorflow.python.ops import linalg_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops import variables as variables_module 32from tensorflow.python.util import nest 33 34 35################################################################################ 36# To make more friendly for TF2. 37################################################################################ 38 39 40def convert_nonref_to_tensor(value, dtype=None, dtype_hint=None, name=None): 41 """Converts the given `value` to a `Tensor` if input is nonreference type. 42 43 This function converts Python objects of various types to `Tensor` objects 44 except if the input has nonreference semantics. Reference semantics are 45 characterized by `is_ref` and is any object which is a 46 `tf.Variable` or instance of `tf.Module`. This function accepts any input 47 which `tf.convert_to_tensor` would also. 48 49 Note: This function diverges from default Numpy behavior for `float` and 50 `string` types when `None` is present in a Python list or scalar. Rather 51 than silently converting `None` values, an error will be thrown. 52 53 Args: 54 value: An object whose type has a registered `Tensor` conversion function. 55 dtype: Optional element type for the returned tensor. If missing, the 56 type is inferred from the type of `value`. 57 dtype_hint: Optional element type for the returned tensor, 58 used when dtype is None. In some cases, a caller may not have a 59 dtype in mind when converting to a tensor, so dtype_hint 60 can be used as a soft preference. If the conversion to 61 `dtype_hint` is not possible, this argument has no effect. 62 name: Optional name to use if a new `Tensor` is created. 63 64 Returns: 65 tensor: A `Tensor` based on `value`. 66 67 Raises: 68 TypeError: If no conversion function is registered for `value` to `dtype`. 69 RuntimeError: If a registered conversion function returns an invalid value. 70 ValueError: If the `value` is a tensor not of given `dtype` in graph mode. 71 72 73 #### Examples: 74 75 ```python 76 77 x = tf.Variable(0.) 78 y = convert_nonref_to_tensor(x) 79 x is y 80 # ==> True 81 82 x = tf.constant(0.) 83 y = convert_nonref_to_tensor(x) 84 x is y 85 # ==> True 86 87 x = np.array(0.) 88 y = convert_nonref_to_tensor(x) 89 x is y 90 # ==> False 91 tf.is_tensor(y) 92 # ==> True 93 94 x = tfp.util.DeferredTensor(13.37, lambda x: x) 95 y = convert_nonref_to_tensor(x) 96 x is y 97 # ==> True 98 tf.is_tensor(y) 99 # ==> False 100 tf.equal(y, 13.37) 101 # ==> True 102 ``` 103 104 """ 105 # We explicitly do not use a tf.name_scope to avoid graph clutter. 106 if value is None: 107 return None 108 if is_ref(value): 109 if dtype is None: 110 return value 111 dtype_base = base_dtype(dtype) 112 value_dtype_base = base_dtype(value.dtype) 113 if dtype_base != value_dtype_base: 114 raise TypeError('Mutable type must be of dtype "{}" but is "{}".'.format( 115 dtype_name(dtype_base), dtype_name(value_dtype_base))) 116 return value 117 return ops.convert_to_tensor_v2_with_dispatch( 118 value, dtype=dtype, dtype_hint=dtype_hint, name=name) 119 120 121def base_dtype(dtype): 122 """Returns a non-reference `dtype` based on this `dtype`.""" 123 dtype = dtypes.as_dtype(dtype) 124 if hasattr(dtype, "base_dtype"): 125 return dtype.base_dtype 126 return dtype 127 128 129def dtype_name(dtype): 130 """Returns the string name for this `dtype`.""" 131 dtype = dtypes.as_dtype(dtype) 132 if hasattr(dtype, "name"): 133 return dtype.name 134 if hasattr(dtype, "__name__"): 135 return dtype.__name__ 136 return str(dtype) 137 138 139def check_dtype(arg, dtype): 140 """Check that arg.dtype == self.dtype.""" 141 if arg.dtype.base_dtype != dtype: 142 raise TypeError( 143 "Expected argument to have dtype %s. Found: %s in tensor %s" % 144 (dtype, arg.dtype, arg)) 145 146 147def is_ref(x): 148 """Evaluates if the object has reference semantics. 149 150 An object is deemed "reference" if it is a `tf.Variable` instance or is 151 derived from a `tf.Module` with `dtype` and `shape` properties. 152 153 Args: 154 x: Any object. 155 156 Returns: 157 is_ref: Python `bool` indicating input is has nonreference semantics, i.e., 158 is a `tf.Variable` or a `tf.Module` with `dtype` and `shape` properties. 159 """ 160 return ( 161 # Note: we check that tf.Variable is a class because we might be using a 162 # different backend other than TF. 163 isinstance(x, variables_module.Variable) or 164 (isinstance(x, module.Module) and hasattr(x, "dtype") and 165 hasattr(x, "shape"))) 166 167 168def assert_not_ref_type(x, arg_name): 169 if is_ref(x): 170 raise TypeError( 171 "Argument %s cannot be reference type. Found: %s" % (arg_name, type(x))) 172 173 174################################################################################ 175# Asserts. 176################################################################################ 177 178 179def assert_no_entries_with_modulus_zero( 180 x, message=None, name="assert_no_entries_with_modulus_zero"): 181 """Returns `Op` that asserts Tensor `x` has no entries with modulus zero. 182 183 Args: 184 x: Numeric `Tensor`, real, integer, or complex. 185 message: A string message to prepend to failure message. 186 name: A name to give this `Op`. 187 188 Returns: 189 An `Op` that asserts `x` has no entries with modulus zero. 190 """ 191 with ops.name_scope(name, values=[x]): 192 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 193 dtype = x.dtype.base_dtype 194 should_be_nonzero = math_ops.abs(x) 195 zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype) 196 return check_ops.assert_less(zero, should_be_nonzero, message=message) 197 198 199def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): 200 """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts. 201 202 Args: 203 x: Numeric `Tensor`, real, integer, or complex. 204 message: A string message to prepend to failure message. 205 name: A name to give this `Op`. 206 207 Returns: 208 An `Op` that asserts `x` has no entries with modulus zero. 209 """ 210 with ops.name_scope(name, values=[x]): 211 x = ops.convert_to_tensor_v2_with_dispatch(x, name="x") 212 dtype = x.dtype.base_dtype 213 214 if dtype.is_floating: 215 return control_flow_ops.no_op() 216 217 zero = ops.convert_to_tensor_v2_with_dispatch(0, dtype=dtype.real_dtype) 218 return check_ops.assert_equal(zero, math_ops.imag(x), message=message) 219 220 221def assert_compatible_matrix_dimensions(operator, x): 222 """Assert that an argument to solve/matmul has proper domain dimension. 223 224 If `operator.shape[-2:] = [M, N]`, and `x.shape[-2:] = [Q, R]`, then 225 `operator.matmul(x)` is defined only if `N = Q`. This `Op` returns an 226 `Assert` that "fires" if this is not the case. Static checks are already 227 done by the base class `LinearOperator`. 228 229 Args: 230 operator: `LinearOperator`. 231 x: `Tensor`. 232 233 Returns: 234 `Assert` `Op`. 235 """ 236 # Static checks are done in the base class. Only tensor asserts here. 237 assert_same_dd = check_ops.assert_equal( 238 array_ops.shape(x)[-2], 239 operator.domain_dimension_tensor(), 240 # This error message made to look similar to error raised by static check 241 # in the base class. 242 message=("Dimensions are not compatible. " 243 "shape[-2] of argument to be the same as this operator")) 244 245 return assert_same_dd 246 247 248def assert_is_batch_matrix(tensor): 249 """Static assert that `tensor` has rank `2` or higher.""" 250 sh = tensor.shape 251 if sh.ndims is not None and sh.ndims < 2: 252 raise ValueError( 253 "Expected [batch] matrix to have at least two dimensions. Found: " 254 "%s" % tensor) 255 256 257def shape_tensor(shape, name=None): 258 """Convert Tensor using default type, unless empty list or tuple.""" 259 # Works just like random_ops._ShapeTensor. 260 if isinstance(shape, (tuple, list)) and not shape: 261 dtype = dtypes.int32 262 else: 263 dtype = None 264 return ops.convert_to_tensor_v2_with_dispatch(shape, dtype=dtype, name=name) 265 266 267################################################################################ 268# Broadcasting versions of common linear algebra functions. 269# TODO(b/77519145) Do this more efficiently in some special cases. 270################################################################################ 271 272 273def broadcast_matrix_batch_dims(batch_matrices, name=None): 274 """Broadcast leading dimensions of zero or more [batch] matrices. 275 276 Example broadcasting one batch dim of two simple matrices. 277 278 ```python 279 x = [[1, 2], 280 [3, 4]] # Shape [2, 2], no batch dims 281 282 y = [[[1]]] # Shape [1, 1, 1], 1 batch dim of shape [1] 283 284 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 285 286 x_bc 287 ==> [[[1, 2], 288 [3, 4]]] # Shape [1, 2, 2], 1 batch dim of shape [1]. 289 290 y_bc 291 ==> same as y 292 ``` 293 294 Example broadcasting many batch dims 295 296 ```python 297 x = tf.random.normal(shape=(2, 3, 1, 4, 4)) 298 y = tf.random.normal(shape=(1, 3, 2, 5, 5)) 299 x_bc, y_bc = broadcast_matrix_batch_dims([x, y]) 300 301 x_bc.shape 302 ==> (2, 3, 2, 4, 4) 303 304 y_bc.shape 305 ==> (2, 3, 2, 5, 5) 306 ``` 307 308 Args: 309 batch_matrices: Iterable of `Tensor`s, each having two or more dimensions. 310 name: A string name to prepend to created ops. 311 312 Returns: 313 bcast_matrices: List of `Tensor`s, with `bcast_matrices[i]` containing 314 the values from `batch_matrices[i]`, with possibly broadcast batch dims. 315 316 Raises: 317 ValueError: If any input `Tensor` is statically determined to have less 318 than two dimensions. 319 """ 320 with ops.name_scope( 321 name or "broadcast_matrix_batch_dims", values=batch_matrices): 322 check_ops.assert_proper_iterable(batch_matrices) 323 batch_matrices = list(batch_matrices) 324 325 for i, mat in enumerate(batch_matrices): 326 batch_matrices[i] = ops.convert_to_tensor_v2_with_dispatch(mat) 327 assert_is_batch_matrix(batch_matrices[i]) 328 329 if len(batch_matrices) < 2: 330 return batch_matrices 331 332 # Try static broadcasting. 333 # bcast_batch_shape is the broadcast batch shape of ALL matrices. 334 # E.g. if batch_matrices = [x, y], with 335 # x.shape = [2, j, k] (batch shape = [2]) 336 # y.shape = [3, 1, l, m] (batch shape = [3, 1]) 337 # ==> bcast_batch_shape = [3, 2] 338 bcast_batch_shape = batch_matrices[0].shape[:-2] 339 for mat in batch_matrices[1:]: 340 bcast_batch_shape = array_ops.broadcast_static_shape( 341 bcast_batch_shape, 342 mat.shape[:-2]) 343 if bcast_batch_shape.is_fully_defined(): 344 for i, mat in enumerate(batch_matrices): 345 if mat.shape[:-2] != bcast_batch_shape: 346 bcast_shape = array_ops.concat( 347 [bcast_batch_shape.as_list(), array_ops.shape(mat)[-2:]], axis=0) 348 batch_matrices[i] = array_ops.broadcast_to(mat, bcast_shape) 349 return batch_matrices 350 351 # Since static didn't work, do dynamic, which always copies data. 352 bcast_batch_shape = array_ops.shape(batch_matrices[0])[:-2] 353 for mat in batch_matrices[1:]: 354 bcast_batch_shape = array_ops.broadcast_dynamic_shape( 355 bcast_batch_shape, 356 array_ops.shape(mat)[:-2]) 357 for i, mat in enumerate(batch_matrices): 358 batch_matrices[i] = array_ops.broadcast_to( 359 mat, 360 array_ops.concat( 361 [bcast_batch_shape, array_ops.shape(mat)[-2:]], axis=0)) 362 363 return batch_matrices 364 365 366def matrix_solve_with_broadcast(matrix, rhs, adjoint=False, name=None): 367 """Solve systems of linear equations.""" 368 with ops.name_scope(name, "MatrixSolveWithBroadcast", [matrix, rhs]): 369 matrix = ops.convert_to_tensor_v2_with_dispatch(matrix, name="matrix") 370 rhs = ops.convert_to_tensor_v2_with_dispatch( 371 rhs, name="rhs", dtype=matrix.dtype) 372 373 # If either matrix/rhs has extra dims, we can reshape to get rid of them. 374 matrix, rhs, reshape_inv, still_need_to_transpose = _reshape_for_efficiency( 375 matrix, rhs, adjoint_a=adjoint) 376 377 # This will broadcast by brute force if we still need to. 378 matrix, rhs = broadcast_matrix_batch_dims([matrix, rhs]) 379 380 solution = linalg_ops.matrix_solve( 381 matrix, rhs, adjoint=adjoint and still_need_to_transpose) 382 383 return reshape_inv(solution) 384 385 386def _reshape_for_efficiency(a, 387 b, 388 transpose_a=False, 389 transpose_b=False, 390 adjoint_a=False, 391 adjoint_b=False): 392 """Maybe reshape a, b, and return an inverse map. For matmul/solve.""" 393 def identity(x): 394 return x 395 396 # At this point, we have not taken transpose/adjoint of a/b. 397 still_need_to_transpose = True 398 399 if a.shape.ndims is None or b.shape.ndims is None: 400 return a, b, identity, still_need_to_transpose 401 402 # This could be handled in the future, but seems less common. 403 if a.shape.ndims >= b.shape.ndims: 404 return a, b, identity, still_need_to_transpose 405 406 # From now on, we might modify b, but will not modify a. 407 408 # Suppose: 409 # a.shape = C + [m, n], b.shape = 410 # b.shape = S + C + [n, r] 411 b_extra_ndims = b.shape.ndims - a.shape.ndims 412 413 # b_extra_sh = S, b_main_sh = C + [n, r] 414 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 415 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 416 417 # No reason to flip unless the extra dims of b are big enough. Why? 418 # Assume adjoint/transpose = False. Then... 419 # By not flipping, we have to replicate a to shape 420 # b_extra_sh + a.shape, 421 # which could use extra memory. But in all cases, the final output has shape 422 # b_extra_sh + a.shape[:-1] + [b.shape[-1]] 423 # So we only end up creating a larger object if the end dim of b is smaller 424 # than the end dim of a. This often happens, e.g. if b was a vector that was 425 # expanded to a matrix (by appending a singleton). 426 427 # Since adjoint/transpose may not be False, we must make adjustments here. 428 # The dim of b that holds the multiple equations. 429 a_domain_sz_ = a.shape[-2 if adjoint_a or transpose_a else -1] 430 b_eq_sz_ = b.shape[-2 if adjoint_b or transpose_b else -1] 431 b_extra_sz_ = ( 432 np.prod(b.shape[:b_extra_ndims].as_list()) 433 if b.shape[:b_extra_ndims].is_fully_defined() else None) 434 if (a_domain_sz_ is not None and b_eq_sz_ is not None and 435 b_extra_sz_ is not None): 436 if b_extra_sz_ < 2 or a_domain_sz_ <= b_eq_sz_: 437 return a, b, identity, still_need_to_transpose 438 439 # At this point, we're flipping for sure! 440 # Any transposes/adjoints will happen here explicitly, rather than in calling 441 # code. Why? To avoid having to write separate complex code for each case. 442 if adjoint_a: 443 a = array_ops.matrix_transpose(a, conjugate=True) 444 elif transpose_a: 445 a = array_ops.matrix_transpose(a, conjugate=False) 446 if adjoint_b: 447 b = array_ops.matrix_transpose(b, conjugate=True) 448 elif transpose_a: 449 b = array_ops.matrix_transpose(b, conjugate=False) 450 still_need_to_transpose = False 451 452 # Recompute shapes, since the transpose/adjoint may have changed them. 453 b_extra_sh = array_ops.shape(b)[:b_extra_ndims] 454 b_main_sh = array_ops.shape(b)[b_extra_ndims:] 455 456 # Permutation to put the extra dims at the end. 457 perm = ( 458 np.concatenate( 459 (np.arange(b_extra_ndims, b.shape.ndims), 460 np.arange(0, b_extra_ndims)), 0)) 461 b_extra_on_end = array_ops.transpose(b, perm=perm) 462 463 # Now squash this end into one long dim. 464 b_squashed_end = array_ops.reshape( 465 b_extra_on_end, array_ops.concat((b_main_sh[:-1], [-1]), 0)) 466 467 def reshape_inv(y): 468 # Expand the extra dims hanging off the end, "b_extra_sh". 469 # Note we use y_sh[:-1] + [b_main_sh[-1]] rather than b_main_sh, because y 470 # Could have different batch dims than a and b, because of broadcasting. 471 y_extra_shape = array_ops.concat( 472 (array_ops.shape(y)[:-1], [b_main_sh[-1]], b_extra_sh), 0) 473 y_extra_on_end = array_ops.reshape(y, y_extra_shape) 474 inverse_perm = np.argsort(perm) 475 return array_ops.transpose(y_extra_on_end, perm=inverse_perm) 476 477 return a, b_squashed_end, reshape_inv, still_need_to_transpose 478 479 480################################################################################ 481# Helpers for hints. 482################################################################################ 483 484 485def use_operator_or_provided_hint_unless_contradicting( 486 operator, hint_attr_name, provided_hint_value, message): 487 """Get combined hint in the case where operator.hint should equal hint. 488 489 Args: 490 operator: LinearOperator that a meta-operator was initialized with. 491 hint_attr_name: String name for the attribute. 492 provided_hint_value: Bool or None. Value passed by user in initialization. 493 message: Error message to print if hints contradict. 494 495 Returns: 496 True, False, or None. 497 498 Raises: 499 ValueError: If hints contradict. 500 """ 501 op_hint = getattr(operator, hint_attr_name) 502 # pylint: disable=g-bool-id-comparison 503 if op_hint is False and provided_hint_value: 504 raise ValueError(message) 505 if op_hint and provided_hint_value is False: 506 raise ValueError(message) 507 if op_hint or provided_hint_value: 508 return True 509 if op_hint is False or provided_hint_value is False: 510 return False 511 # pylint: enable=g-bool-id-comparison 512 return None 513 514 515################################################################################ 516# Utilities for blockwise operators. 517################################################################################ 518 519 520def arg_is_blockwise(block_dimensions, arg, arg_split_dim): 521 """Detect if input should be interpreted as a list of blocks.""" 522 # Tuples and lists of length equal to the number of operators may be 523 # blockwise. 524 if (isinstance(arg, (tuple, list)) and len(arg) == len(block_dimensions)): 525 # If the elements of the iterable are not nested, interpret the input as 526 # blockwise. 527 if not any(nest.is_nested(x) for x in arg): 528 return True 529 else: 530 arg_dims = [ops.convert_to_tensor_v2_with_dispatch( 531 x).shape[arg_split_dim] for x in arg] 532 self_dims = [dim.value for dim in block_dimensions] 533 534 # If none of the operator dimensions are known, interpret the input as 535 # blockwise if its matching dimensions are unequal. 536 if all(self_d is None for self_d in self_dims): 537 538 # A nested tuple/list with a single outermost element is not blockwise 539 if len(arg_dims) == 1: 540 return False 541 elif any(dim != arg_dims[0] for dim in arg_dims): 542 return True 543 else: 544 raise ValueError( 545 "Parsing of the input structure is ambiguous. Please input " 546 "a blockwise iterable of `Tensor`s or a single `Tensor`.") 547 548 # If input dimensions equal the respective (known) blockwise operator 549 # dimensions, then the input is blockwise. 550 if all(self_d == arg_d or self_d is None 551 for self_d, arg_d in zip(self_dims, arg_dims)): 552 return True 553 554 # If input dimensions equals are all equal, and are greater than or equal 555 # to the sum of the known operator dimensions, interpret the input as 556 # blockwise. 557 # input is not blockwise. 558 self_dim = sum(self_d for self_d in self_dims if self_d is not None) 559 if all(s == arg_dims[0] for s in arg_dims) and arg_dims[0] >= self_dim: 560 return False 561 562 # If none of these conditions is met, the input shape is mismatched. 563 raise ValueError("Input dimension does not match operator dimension.") 564 else: 565 return False 566 567 568def split_arg_into_blocks(block_dims, block_dims_fn, arg, axis=-1): 569 """Split `x` into blocks matching `operators`'s `domain_dimension`. 570 571 Specifically, if we have a blockwise lower-triangular matrix, with block 572 sizes along the diagonal `[M_j, M_j] j = 0,1,2..J`, this method splits `arg` 573 on `axis` into `J` tensors, whose shape at `axis` is `M_j`. 574 575 Args: 576 block_dims: Iterable of `TensorShapes`. 577 block_dims_fn: Callable returning an iterable of `Tensor`s. 578 arg: `Tensor`. `arg` is split into `J` tensors. 579 axis: Python `Integer` representing the axis to split `arg` on. 580 581 Returns: 582 A list of `Tensor`s. 583 """ 584 block_sizes = [dim.value for dim in block_dims] 585 if any(d is None for d in block_sizes): 586 block_sizes = block_dims_fn() 587 return array_ops.split(arg, block_sizes, axis=axis) 588