1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""`LinearOperator` acting like the identity matrix.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25from tensorflow.python.framework import tensor_shape 26from tensorflow.python.framework import tensor_util 27from tensorflow.python.ops import array_ops 28from tensorflow.python.ops import check_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.ops.linalg import linalg_impl as linalg 32from tensorflow.python.ops.linalg import linear_operator 33from tensorflow.python.ops.linalg import linear_operator_util 34from tensorflow.python.util.tf_export import tf_export 35 36__all__ = [ 37 "LinearOperatorIdentity", 38 "LinearOperatorScaledIdentity", 39] 40 41 42class BaseLinearOperatorIdentity(linear_operator.LinearOperator): 43 """Base class for Identity operators.""" 44 45 def _check_num_rows_possibly_add_asserts(self): 46 """Static check of init arg `num_rows`, possibly add asserts.""" 47 # Possibly add asserts. 48 if self._assert_proper_shapes: 49 self._num_rows = control_flow_ops.with_dependencies([ 50 check_ops.assert_rank( 51 self._num_rows, 52 0, 53 message="Argument num_rows must be a 0-D Tensor."), 54 check_ops.assert_non_negative( 55 self._num_rows, 56 message="Argument num_rows must be non-negative."), 57 ], self._num_rows) 58 59 # Static checks. 60 if not self._num_rows.dtype.is_integer: 61 raise TypeError("Argument num_rows must be integer type. Found:" 62 " %s" % self._num_rows) 63 64 num_rows_static = self._num_rows_static 65 66 if num_rows_static is None: 67 return # Cannot do any other static checks. 68 69 if num_rows_static.ndim != 0: 70 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 71 " %s" % num_rows_static) 72 73 if num_rows_static < 0: 74 raise ValueError("Argument num_rows must be non-negative. Found:" 75 " %s" % num_rows_static) 76 77 def _min_matrix_dim(self): 78 """Minimum of domain/range dimension, if statically available, else None.""" 79 domain_dim = tensor_shape.dimension_value(self.domain_dimension) 80 range_dim = tensor_shape.dimension_value(self.range_dimension) 81 if domain_dim is None or range_dim is None: 82 return None 83 return min(domain_dim, range_dim) 84 85 def _min_matrix_dim_tensor(self): 86 """Minimum of domain/range dimension, as a tensor.""" 87 return math_ops.reduce_min(self.shape_tensor()[-2:]) 88 89 def _ones_diag(self): 90 """Returns the diagonal of this operator as all ones.""" 91 if self.shape.is_fully_defined(): 92 d_shape = self.batch_shape.concatenate([self._min_matrix_dim()]) 93 else: 94 d_shape = array_ops.concat( 95 [self.batch_shape_tensor(), 96 [self._min_matrix_dim_tensor()]], axis=0) 97 98 return array_ops.ones(shape=d_shape, dtype=self.dtype) 99 100 101@tf_export("linalg.LinearOperatorIdentity") 102class LinearOperatorIdentity(BaseLinearOperatorIdentity): 103 """`LinearOperator` acting like a [batch] square identity matrix. 104 105 This operator acts like a [batch] identity matrix `A` with shape 106 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 107 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 108 an `N x N` matrix. This matrix `A` is not materialized, but for 109 purposes of broadcasting this shape will be relevant. 110 111 `LinearOperatorIdentity` is initialized with `num_rows`, and optionally 112 `batch_shape`, and `dtype` arguments. If `batch_shape` is `None`, this 113 operator efficiently passes through all arguments. If `batch_shape` is 114 provided, broadcasting may occur, which will require making copies. 115 116 ```python 117 # Create a 2 x 2 identity matrix. 118 operator = LinearOperatorIdentity(num_rows=2, dtype=tf.float32) 119 120 operator.to_dense() 121 ==> [[1., 0.] 122 [0., 1.]] 123 124 operator.shape 125 ==> [2, 2] 126 127 operator.log_abs_determinant() 128 ==> 0. 129 130 x = ... Shape [2, 4] Tensor 131 operator.matmul(x) 132 ==> Shape [2, 4] Tensor, same as x. 133 134 y = tf.random.normal(shape=[3, 2, 4]) 135 # Note that y.shape is compatible with operator.shape because operator.shape 136 # is broadcast to [3, 2, 2]. 137 # This broadcast does NOT require copying data, since we can infer that y 138 # will be passed through without changing shape. We are always able to infer 139 # this if the operator has no batch_shape. 140 x = operator.solve(y) 141 ==> Shape [3, 2, 4] Tensor, same as y. 142 143 # Create a 2-batch of 2x2 identity matrices 144 operator = LinearOperatorIdentity(num_rows=2, batch_shape=[2]) 145 operator.to_dense() 146 ==> [[[1., 0.] 147 [0., 1.]], 148 [[1., 0.] 149 [0., 1.]]] 150 151 # Here, even though the operator has a batch shape, the input is the same as 152 # the output, so x can be passed through without a copy. The operator is able 153 # to detect that no broadcast is necessary because both x and the operator 154 # have statically defined shape. 155 x = ... Shape [2, 2, 3] 156 operator.matmul(x) 157 ==> Shape [2, 2, 3] Tensor, same as x 158 159 # Here the operator and x have different batch_shape, and are broadcast. 160 # This requires a copy, since the output is different size than the input. 161 x = ... Shape [1, 2, 3] 162 operator.matmul(x) 163 ==> Shape [2, 2, 3] Tensor, equal to [x, x] 164 ``` 165 166 ### Shape compatibility 167 168 This operator acts on [batch] matrix with compatible shape. 169 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 170 171 ``` 172 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 173 x.shape = [C1,...,Cc] + [N, R], 174 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 175 ``` 176 177 ### Performance 178 179 If `batch_shape` initialization arg is `None`: 180 181 * `operator.matmul(x)` is `O(1)` 182 * `operator.solve(x)` is `O(1)` 183 * `operator.determinant()` is `O(1)` 184 185 If `batch_shape` initialization arg is provided, and static checks cannot 186 rule out the need to broadcast: 187 188 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 189 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 190 * `operator.determinant()` is `O(B1*...*Bb)` 191 192 #### Matrix property hints 193 194 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 195 for `X = non_singular, self_adjoint, positive_definite, square`. 196 These have the following meaning: 197 198 * If `is_X == True`, callers should expect the operator to have the 199 property `X`. This is a promise that should be fulfilled, but is *not* a 200 runtime assert. For example, finite floating point precision may result 201 in these promises being violated. 202 * If `is_X == False`, callers should expect the operator to not have `X`. 203 * If `is_X == None` (the default), callers should have no expectation either 204 way. 205 """ 206 207 def __init__(self, 208 num_rows, 209 batch_shape=None, 210 dtype=None, 211 is_non_singular=True, 212 is_self_adjoint=True, 213 is_positive_definite=True, 214 is_square=True, 215 assert_proper_shapes=False, 216 name="LinearOperatorIdentity"): 217 r"""Initialize a `LinearOperatorIdentity`. 218 219 The `LinearOperatorIdentity` is initialized with arguments defining `dtype` 220 and shape. 221 222 This operator is able to broadcast the leading (batch) dimensions, which 223 sometimes requires copying data. If `batch_shape` is `None`, the operator 224 can take arguments of any batch shape without copying. See examples. 225 226 Args: 227 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 228 corresponding identity matrix. 229 batch_shape: Optional `1-D` integer `Tensor`. The shape of the leading 230 dimensions. If `None`, this operator has no leading dimensions. 231 dtype: Data type of the matrix that this operator represents. 232 is_non_singular: Expect that this operator is non-singular. 233 is_self_adjoint: Expect that this operator is equal to its hermitian 234 transpose. 235 is_positive_definite: Expect that this operator is positive definite, 236 meaning the quadratic form `x^H A x` has positive real part for all 237 nonzero `x`. Note that we do not require the operator to be 238 self-adjoint to be positive-definite. See: 239 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 240 is_square: Expect that this operator acts like square [batch] matrices. 241 assert_proper_shapes: Python `bool`. If `False`, only perform static 242 checks that initialization and method arguments have proper shape. 243 If `True`, and static checks are inconclusive, add asserts to the graph. 244 name: A name for this `LinearOperator` 245 246 Raises: 247 ValueError: If `num_rows` is determined statically to be non-scalar, or 248 negative. 249 ValueError: If `batch_shape` is determined statically to not be 1-D, or 250 negative. 251 ValueError: If any of the following is not `True`: 252 `{is_self_adjoint, is_non_singular, is_positive_definite}`. 253 TypeError: If `num_rows` or `batch_shape` is ref-type (e.g. Variable). 254 """ 255 parameters = dict( 256 num_rows=num_rows, 257 batch_shape=batch_shape, 258 dtype=dtype, 259 is_non_singular=is_non_singular, 260 is_self_adjoint=is_self_adjoint, 261 is_positive_definite=is_positive_definite, 262 is_square=is_square, 263 assert_proper_shapes=assert_proper_shapes, 264 name=name) 265 266 dtype = dtype or dtypes.float32 267 self._assert_proper_shapes = assert_proper_shapes 268 269 with ops.name_scope(name): 270 dtype = dtypes.as_dtype(dtype) 271 if not is_self_adjoint: 272 raise ValueError("An identity operator is always self adjoint.") 273 if not is_non_singular: 274 raise ValueError("An identity operator is always non-singular.") 275 if not is_positive_definite: 276 raise ValueError("An identity operator is always positive-definite.") 277 if not is_square: 278 raise ValueError("An identity operator is always square.") 279 280 super(LinearOperatorIdentity, self).__init__( 281 dtype=dtype, 282 is_non_singular=is_non_singular, 283 is_self_adjoint=is_self_adjoint, 284 is_positive_definite=is_positive_definite, 285 is_square=is_square, 286 parameters=parameters, 287 name=name) 288 289 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 290 linear_operator_util.assert_not_ref_type(batch_shape, "batch_shape") 291 292 self._num_rows = linear_operator_util.shape_tensor( 293 num_rows, name="num_rows") 294 self._num_rows_static = tensor_util.constant_value(self._num_rows) 295 self._check_num_rows_possibly_add_asserts() 296 297 if batch_shape is None: 298 self._batch_shape_arg = None 299 else: 300 self._batch_shape_arg = linear_operator_util.shape_tensor( 301 batch_shape, name="batch_shape_arg") 302 self._batch_shape_static = tensor_util.constant_value( 303 self._batch_shape_arg) 304 self._check_batch_shape_possibly_add_asserts() 305 306 def _shape(self): 307 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 308 self._num_rows_static)) 309 if self._batch_shape_arg is None: 310 return matrix_shape 311 312 batch_shape = tensor_shape.TensorShape(self._batch_shape_static) 313 return batch_shape.concatenate(matrix_shape) 314 315 def _shape_tensor(self): 316 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 317 if self._batch_shape_arg is None: 318 return matrix_shape 319 320 return array_ops.concat((self._batch_shape_arg, matrix_shape), 0) 321 322 def _assert_non_singular(self): 323 return control_flow_ops.no_op("assert_non_singular") 324 325 def _assert_positive_definite(self): 326 return control_flow_ops.no_op("assert_positive_definite") 327 328 def _assert_self_adjoint(self): 329 return control_flow_ops.no_op("assert_self_adjoint") 330 331 def _possibly_broadcast_batch_shape(self, x): 332 """Return 'x', possibly after broadcasting the leading dimensions.""" 333 # If we have no batch shape, our batch shape broadcasts with everything! 334 if self._batch_shape_arg is None: 335 return x 336 337 # Static attempt: 338 # If we determine that no broadcast is necessary, pass x through 339 # If we need a broadcast, add to an array of zeros. 340 # 341 # special_shape is the shape that, when broadcast with x's shape, will give 342 # the correct broadcast_shape. Note that 343 # We have already verified the second to last dimension of self.shape 344 # matches x's shape in assert_compatible_matrix_dimensions. 345 # Also, the final dimension of 'x' can have any shape. 346 # Therefore, the final two dimensions of special_shape are 1's. 347 special_shape = self.batch_shape.concatenate([1, 1]) 348 bshape = array_ops.broadcast_static_shape(x.shape, special_shape) 349 if special_shape.is_fully_defined(): 350 # bshape.is_fully_defined iff special_shape.is_fully_defined. 351 if bshape == x.shape: 352 return x 353 # Use the built in broadcasting of addition. 354 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 355 return x + zeros 356 357 # Dynamic broadcast: 358 # Always add to an array of zeros, rather than using a "cond", since a 359 # cond would require copying data from GPU --> CPU. 360 special_shape = array_ops.concat((self.batch_shape_tensor(), [1, 1]), 0) 361 zeros = array_ops.zeros(shape=special_shape, dtype=self.dtype) 362 return x + zeros 363 364 def _matmul(self, x, adjoint=False, adjoint_arg=False): 365 # Note that adjoint has no effect since this matrix is self-adjoint. 366 x = linalg.adjoint(x) if adjoint_arg else x 367 if self._assert_proper_shapes: 368 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 369 x = control_flow_ops.with_dependencies([aps], x) 370 return self._possibly_broadcast_batch_shape(x) 371 372 def _determinant(self): 373 return array_ops.ones(shape=self.batch_shape_tensor(), dtype=self.dtype) 374 375 def _log_abs_determinant(self): 376 return array_ops.zeros(shape=self.batch_shape_tensor(), dtype=self.dtype) 377 378 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 379 return self._matmul(rhs, adjoint_arg=adjoint_arg) 380 381 def _trace(self): 382 # Get Tensor of all ones of same shape as self.batch_shape. 383 if self.batch_shape.is_fully_defined(): 384 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 385 else: 386 batch_of_ones = array_ops.ones( 387 shape=self.batch_shape_tensor(), dtype=self.dtype) 388 389 if self._min_matrix_dim() is not None: 390 return self._min_matrix_dim() * batch_of_ones 391 else: 392 return (math_ops.cast(self._min_matrix_dim_tensor(), self.dtype) * 393 batch_of_ones) 394 395 def _diag_part(self): 396 return self._ones_diag() 397 398 def add_to_tensor(self, mat, name="add_to_tensor"): 399 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 400 401 Args: 402 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 403 name: A name to give this `Op`. 404 405 Returns: 406 A `Tensor` with broadcast shape and same `dtype` as `self`. 407 """ 408 with self._name_scope(name): 409 mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat") 410 mat_diag = array_ops.matrix_diag_part(mat) 411 new_diag = 1 + mat_diag 412 return array_ops.matrix_set_diag(mat, new_diag) 413 414 def _eigvals(self): 415 return self._ones_diag() 416 417 def _cond(self): 418 return array_ops.ones(self.batch_shape_tensor(), dtype=self.dtype) 419 420 def _check_num_rows_possibly_add_asserts(self): 421 """Static check of init arg `num_rows`, possibly add asserts.""" 422 # Possibly add asserts. 423 if self._assert_proper_shapes: 424 self._num_rows = control_flow_ops.with_dependencies([ 425 check_ops.assert_rank( 426 self._num_rows, 427 0, 428 message="Argument num_rows must be a 0-D Tensor."), 429 check_ops.assert_non_negative( 430 self._num_rows, 431 message="Argument num_rows must be non-negative."), 432 ], self._num_rows) 433 434 # Static checks. 435 if not self._num_rows.dtype.is_integer: 436 raise TypeError("Argument num_rows must be integer type. Found:" 437 " %s" % self._num_rows) 438 439 num_rows_static = self._num_rows_static 440 441 if num_rows_static is None: 442 return # Cannot do any other static checks. 443 444 if num_rows_static.ndim != 0: 445 raise ValueError("Argument num_rows must be a 0-D Tensor. Found:" 446 " %s" % num_rows_static) 447 448 if num_rows_static < 0: 449 raise ValueError("Argument num_rows must be non-negative. Found:" 450 " %s" % num_rows_static) 451 452 def _check_batch_shape_possibly_add_asserts(self): 453 """Static check of init arg `batch_shape`, possibly add asserts.""" 454 if self._batch_shape_arg is None: 455 return 456 457 # Possibly add asserts 458 if self._assert_proper_shapes: 459 self._batch_shape_arg = control_flow_ops.with_dependencies([ 460 check_ops.assert_rank( 461 self._batch_shape_arg, 462 1, 463 message="Argument batch_shape must be a 1-D Tensor."), 464 check_ops.assert_non_negative( 465 self._batch_shape_arg, 466 message="Argument batch_shape must be non-negative."), 467 ], self._batch_shape_arg) 468 469 # Static checks 470 if not self._batch_shape_arg.dtype.is_integer: 471 raise TypeError("Argument batch_shape must be integer type. Found:" 472 " %s" % self._batch_shape_arg) 473 474 if self._batch_shape_static is None: 475 return # Cannot do any other static checks. 476 477 if self._batch_shape_static.ndim != 1: 478 raise ValueError("Argument batch_shape must be a 1-D Tensor. Found:" 479 " %s" % self._batch_shape_static) 480 481 if np.any(self._batch_shape_static < 0): 482 raise ValueError("Argument batch_shape must be non-negative. Found:" 483 "%s" % self._batch_shape_static) 484 485 486@tf_export("linalg.LinearOperatorScaledIdentity") 487class LinearOperatorScaledIdentity(BaseLinearOperatorIdentity): 488 """`LinearOperator` acting like a scaled [batch] identity matrix `A = c I`. 489 490 This operator acts like a scaled [batch] identity matrix `A` with shape 491 `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a 492 batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is 493 a scaled version of the `N x N` identity matrix. 494 495 `LinearOperatorIdentity` is initialized with `num_rows`, and a `multiplier` 496 (a `Tensor`) of shape `[B1,...,Bb]`. `N` is set to `num_rows`, and the 497 `multiplier` determines the scale for each batch member. 498 499 ```python 500 # Create a 2 x 2 scaled identity matrix. 501 operator = LinearOperatorIdentity(num_rows=2, multiplier=3.) 502 503 operator.to_dense() 504 ==> [[3., 0.] 505 [0., 3.]] 506 507 operator.shape 508 ==> [2, 2] 509 510 operator.log_abs_determinant() 511 ==> 2 * Log[3] 512 513 x = ... Shape [2, 4] Tensor 514 operator.matmul(x) 515 ==> 3 * x 516 517 y = tf.random.normal(shape=[3, 2, 4]) 518 # Note that y.shape is compatible with operator.shape because operator.shape 519 # is broadcast to [3, 2, 2]. 520 x = operator.solve(y) 521 ==> 3 * x 522 523 # Create a 2-batch of 2x2 identity matrices 524 operator = LinearOperatorIdentity(num_rows=2, multiplier=5.) 525 operator.to_dense() 526 ==> [[[5., 0.] 527 [0., 5.]], 528 [[5., 0.] 529 [0., 5.]]] 530 531 x = ... Shape [2, 2, 3] 532 operator.matmul(x) 533 ==> 5 * x 534 535 # Here the operator and x have different batch_shape, and are broadcast. 536 x = ... Shape [1, 2, 3] 537 operator.matmul(x) 538 ==> 5 * x 539 ``` 540 541 ### Shape compatibility 542 543 This operator acts on [batch] matrix with compatible shape. 544 `x` is a batch matrix with compatible shape for `matmul` and `solve` if 545 546 ``` 547 operator.shape = [B1,...,Bb] + [N, N], with b >= 0 548 x.shape = [C1,...,Cc] + [N, R], 549 and [C1,...,Cc] broadcasts with [B1,...,Bb] to [D1,...,Dd] 550 ``` 551 552 ### Performance 553 554 * `operator.matmul(x)` is `O(D1*...*Dd*N*R)` 555 * `operator.solve(x)` is `O(D1*...*Dd*N*R)` 556 * `operator.determinant()` is `O(D1*...*Dd)` 557 558 #### Matrix property hints 559 560 This `LinearOperator` is initialized with boolean flags of the form `is_X`, 561 for `X = non_singular, self_adjoint, positive_definite, square`. 562 These have the following meaning 563 * If `is_X == True`, callers should expect the operator to have the 564 property `X`. This is a promise that should be fulfilled, but is *not* a 565 runtime assert. For example, finite floating point precision may result 566 in these promises being violated. 567 * If `is_X == False`, callers should expect the operator to not have `X`. 568 * If `is_X == None` (the default), callers should have no expectation either 569 way. 570 """ 571 572 def __init__(self, 573 num_rows, 574 multiplier, 575 is_non_singular=None, 576 is_self_adjoint=None, 577 is_positive_definite=None, 578 is_square=True, 579 assert_proper_shapes=False, 580 name="LinearOperatorScaledIdentity"): 581 r"""Initialize a `LinearOperatorScaledIdentity`. 582 583 The `LinearOperatorScaledIdentity` is initialized with `num_rows`, which 584 determines the size of each identity matrix, and a `multiplier`, 585 which defines `dtype`, batch shape, and scale of each matrix. 586 587 This operator is able to broadcast the leading (batch) dimensions. 588 589 Args: 590 num_rows: Scalar non-negative integer `Tensor`. Number of rows in the 591 corresponding identity matrix. 592 multiplier: `Tensor` of shape `[B1,...,Bb]`, or `[]` (a scalar). 593 is_non_singular: Expect that this operator is non-singular. 594 is_self_adjoint: Expect that this operator is equal to its hermitian 595 transpose. 596 is_positive_definite: Expect that this operator is positive definite, 597 meaning the quadratic form `x^H A x` has positive real part for all 598 nonzero `x`. Note that we do not require the operator to be 599 self-adjoint to be positive-definite. See: 600 https://en.wikipedia.org/wiki/Positive-definite_matrix#Extension_for_non-symmetric_matrices 601 is_square: Expect that this operator acts like square [batch] matrices. 602 assert_proper_shapes: Python `bool`. If `False`, only perform static 603 checks that initialization and method arguments have proper shape. 604 If `True`, and static checks are inconclusive, add asserts to the graph. 605 name: A name for this `LinearOperator` 606 607 Raises: 608 ValueError: If `num_rows` is determined statically to be non-scalar, or 609 negative. 610 """ 611 parameters = dict( 612 num_rows=num_rows, 613 multiplier=multiplier, 614 is_non_singular=is_non_singular, 615 is_self_adjoint=is_self_adjoint, 616 is_positive_definite=is_positive_definite, 617 is_square=is_square, 618 assert_proper_shapes=assert_proper_shapes, 619 name=name) 620 621 self._assert_proper_shapes = assert_proper_shapes 622 623 with ops.name_scope(name, values=[multiplier, num_rows]): 624 self._multiplier = linear_operator_util.convert_nonref_to_tensor( 625 multiplier, name="multiplier") 626 627 # Check and auto-set hints. 628 if not self._multiplier.dtype.is_complex: 629 if is_self_adjoint is False: # pylint: disable=g-bool-id-comparison 630 raise ValueError("A real diagonal operator is always self adjoint.") 631 else: 632 is_self_adjoint = True 633 634 if not is_square: 635 raise ValueError("A ScaledIdentity operator is always square.") 636 637 linear_operator_util.assert_not_ref_type(num_rows, "num_rows") 638 639 super(LinearOperatorScaledIdentity, self).__init__( 640 dtype=self._multiplier.dtype.base_dtype, 641 is_non_singular=is_non_singular, 642 is_self_adjoint=is_self_adjoint, 643 is_positive_definite=is_positive_definite, 644 is_square=is_square, 645 parameters=parameters, 646 name=name) 647 648 self._num_rows = linear_operator_util.shape_tensor( 649 num_rows, name="num_rows") 650 self._num_rows_static = tensor_util.constant_value(self._num_rows) 651 self._check_num_rows_possibly_add_asserts() 652 self._num_rows_cast_to_dtype = math_ops.cast(self._num_rows, self.dtype) 653 self._num_rows_cast_to_real_dtype = math_ops.cast(self._num_rows, 654 self.dtype.real_dtype) 655 656 def _shape(self): 657 matrix_shape = tensor_shape.TensorShape((self._num_rows_static, 658 self._num_rows_static)) 659 660 batch_shape = self.multiplier.shape 661 return batch_shape.concatenate(matrix_shape) 662 663 def _shape_tensor(self): 664 matrix_shape = array_ops.stack((self._num_rows, self._num_rows), axis=0) 665 666 batch_shape = array_ops.shape(self.multiplier) 667 return array_ops.concat((batch_shape, matrix_shape), 0) 668 669 def _assert_non_singular(self): 670 return check_ops.assert_positive( 671 math_ops.abs(self.multiplier), message="LinearOperator was singular") 672 673 def _assert_positive_definite(self): 674 return check_ops.assert_positive( 675 math_ops.real(self.multiplier), 676 message="LinearOperator was not positive definite.") 677 678 def _assert_self_adjoint(self): 679 imag_multiplier = math_ops.imag(self.multiplier) 680 return check_ops.assert_equal( 681 array_ops.zeros_like(imag_multiplier), 682 imag_multiplier, 683 message="LinearOperator was not self-adjoint") 684 685 def _make_multiplier_matrix(self, conjugate=False): 686 # Shape [B1,...Bb, 1, 1] 687 multiplier_matrix = array_ops.expand_dims( 688 array_ops.expand_dims(self.multiplier, -1), -1) 689 if conjugate: 690 multiplier_matrix = math_ops.conj(multiplier_matrix) 691 return multiplier_matrix 692 693 def _matmul(self, x, adjoint=False, adjoint_arg=False): 694 x = linalg.adjoint(x) if adjoint_arg else x 695 if self._assert_proper_shapes: 696 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, x) 697 x = control_flow_ops.with_dependencies([aps], x) 698 return x * self._make_multiplier_matrix(conjugate=adjoint) 699 700 def _determinant(self): 701 return self.multiplier**self._num_rows_cast_to_dtype 702 703 def _log_abs_determinant(self): 704 return self._num_rows_cast_to_real_dtype * math_ops.log( 705 math_ops.abs(self.multiplier)) 706 707 def _solve(self, rhs, adjoint=False, adjoint_arg=False): 708 rhs = linalg.adjoint(rhs) if adjoint_arg else rhs 709 if self._assert_proper_shapes: 710 aps = linear_operator_util.assert_compatible_matrix_dimensions(self, rhs) 711 rhs = control_flow_ops.with_dependencies([aps], rhs) 712 return rhs / self._make_multiplier_matrix(conjugate=adjoint) 713 714 def _trace(self): 715 # Get Tensor of all ones of same shape as self.batch_shape. 716 if self.batch_shape.is_fully_defined(): 717 batch_of_ones = array_ops.ones(shape=self.batch_shape, dtype=self.dtype) 718 else: 719 batch_of_ones = array_ops.ones( 720 shape=self.batch_shape_tensor(), dtype=self.dtype) 721 722 if self._min_matrix_dim() is not None: 723 return self.multiplier * self._min_matrix_dim() * batch_of_ones 724 else: 725 return (self.multiplier * math_ops.cast(self._min_matrix_dim_tensor(), 726 self.dtype) * batch_of_ones) 727 728 def _diag_part(self): 729 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 730 731 def add_to_tensor(self, mat, name="add_to_tensor"): 732 """Add matrix represented by this operator to `mat`. Equiv to `I + mat`. 733 734 Args: 735 mat: `Tensor` with same `dtype` and shape broadcastable to `self`. 736 name: A name to give this `Op`. 737 738 Returns: 739 A `Tensor` with broadcast shape and same `dtype` as `self`. 740 """ 741 with self._name_scope(name): 742 # Shape [B1,...,Bb, 1] 743 multiplier_vector = array_ops.expand_dims(self.multiplier, -1) 744 745 # Shape [C1,...,Cc, M, M] 746 mat = ops.convert_to_tensor_v2_with_dispatch(mat, name="mat") 747 748 # Shape [C1,...,Cc, M] 749 mat_diag = array_ops.matrix_diag_part(mat) 750 751 # multiplier_vector broadcasts here. 752 new_diag = multiplier_vector + mat_diag 753 754 return array_ops.matrix_set_diag(mat, new_diag) 755 756 def _eigvals(self): 757 return self._ones_diag() * self.multiplier[..., array_ops.newaxis] 758 759 def _cond(self): 760 # Condition number for a scalar time identity matrix is one, except when the 761 # scalar is zero. 762 return array_ops.where_v2( 763 math_ops.equal(self._multiplier, 0.), 764 math_ops.cast(np.nan, dtype=self.dtype), 765 math_ops.cast(1., dtype=self.dtype)) 766 767 @property 768 def multiplier(self): 769 """The [batch] scalar `Tensor`, `c` in `cI`.""" 770 return self._multiplier 771