1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Bijector base.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import abc 22import collections 23import contextlib 24import re 25 26import numpy as np 27import six 28 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.util.tf_export import tf_export 36 37 38__all__ = [ 39 "Bijector", 40] 41 42 43class _Mapping(collections.namedtuple( 44 "_Mapping", ["x", "y", "ildj", "kwargs"])): 45 """Helper class to make it easier to manage caching in `Bijector`.""" 46 47 def __new__(cls, x=None, y=None, ildj=None, kwargs=None): 48 """Custom __new__ so namedtuple items have defaults. 49 50 Args: 51 x: `Tensor`. Forward. 52 y: `Tensor`. Inverse. 53 ildj: `Tensor`. Inverse log det Jacobian. 54 kwargs: Python dictionary. Extra args supplied to 55 forward/inverse/etc functions. 56 57 Returns: 58 mapping: New instance of _Mapping. 59 """ 60 return super(_Mapping, cls).__new__(cls, x, y, ildj, kwargs) 61 62 @property 63 def x_key(self): 64 """Returns key used for caching Y=g(X).""" 65 return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) 66 67 @property 68 def y_key(self): 69 """Returns key used for caching X=g^{-1}(Y).""" 70 return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items()))) 71 72 def merge(self, x=None, y=None, ildj=None, kwargs=None, mapping=None): 73 """Returns new _Mapping with args merged with self. 74 75 Args: 76 x: `Tensor`. Forward. 77 y: `Tensor`. Inverse. 78 ildj: `Tensor`. Inverse log det Jacobian. 79 kwargs: Python dictionary. Extra args supplied to 80 forward/inverse/etc functions. 81 mapping: Instance of _Mapping to merge. Can only be specified if no other 82 arg is specified. 83 84 Returns: 85 mapping: New instance of `_Mapping` which has inputs merged with self. 86 87 Raises: 88 ValueError: if mapping and any other arg is not `None`. 89 """ 90 if mapping is None: 91 mapping = _Mapping(x=x, y=y, ildj=ildj, kwargs=kwargs) 92 elif not all(arg is None for arg in [x, y, ildj, kwargs]): 93 raise ValueError("Cannot specify mapping and individual args.") 94 return _Mapping( 95 x=self._merge(self.x, mapping.x), 96 y=self._merge(self.y, mapping.y), 97 ildj=self._merge(self.ildj, mapping.ildj), 98 kwargs=self._merge(self.kwargs, mapping.kwargs)) 99 100 def _merge(self, old, new): 101 """Helper to merge which handles merging one value.""" 102 if old is None: 103 return new 104 elif new is not None and old != new: 105 raise ValueError("Incompatible values: %s != %s" % (old, new)) 106 return old 107 108 def _deep_tuple(self, x): 109 """Converts lists of lists to tuples of tuples.""" 110 return (tuple(map(self._deep_tuple, x)) 111 if isinstance(x, (list, tuple)) else x) 112 113 114@six.add_metaclass(abc.ABCMeta) 115@tf_export("distributions.bijectors.Bijector") 116class Bijector(object): 117 r"""Interface for transformations of a `Distribution` sample. 118 119 Bijectors can be used to represent any differentiable and injective 120 (one to one) function defined on an open subset of `R^n`. Some non-injective 121 transformations are also supported (see "Non Injective Transforms" below). 122 123 #### Mathematical Details 124 125 A `Bijector` implements a [smooth covering map]( 126 https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local 127 diffeomorphism such that every point in the target has a neighborhood evenly 128 covered by a map ([see also]( 129 https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)). 130 A `Bijector` is used by `TransformedDistribution` but can be generally used 131 for transforming a `Distribution` generated `Tensor`. A `Bijector` is 132 characterized by three operations: 133 134 1. Forward\ 135 Useful for turning one random outcome into another random outcome from a 136 different distribution. 137 2. Inverse\ 138 Useful for "reversing" a transformation to compute one probability in 139 terms of another. 140 3. `(log o det o Jacobian o inverse)(x)`\ 141 "The log of the determinant of the matrix of all first-order partial 142 derivatives of the inverse function."\ 143 Useful for inverting a transformation to compute one probability in terms 144 of another. Geometrically, the det(Jacobian) is the volume of the 145 transformation and is used to scale the probability. 146 147 By convention, transformations of random variables are named in terms of the 148 forward transformation. The forward transformation creates samples, the 149 inverse is useful for computing probabilities. 150 151 #### Example Uses 152 153 - Basic properties: 154 155 ```python 156 x = ... # A tensor. 157 # Evaluate forward transformation. 158 fwd_x = my_bijector.forward(x) 159 x == my_bijector.inverse(fwd_x) 160 x != my_bijector.forward(fwd_x) # Not equal because x != g(g(x)). 161 ``` 162 163 - Computing a log-likelihood: 164 165 ```python 166 def transformed_log_prob(bijector, log_prob, x): 167 return (bijector.inverse_log_det_jacobian(x) + 168 log_prob(bijector.inverse(x))) 169 ``` 170 171 - Transforming a random outcome: 172 173 ```python 174 def transformed_sample(bijector, x): 175 return bijector.forward(x) 176 ``` 177 178 #### Example Bijectors 179 180 - "Exponential" 181 182 ```none 183 Y = g(X) = exp(X) 184 X ~ Normal(0, 1) # Univariate. 185 ``` 186 187 Implies: 188 189 ```none 190 g^{-1}(Y) = log(Y) 191 |Jacobian(g^{-1})(y)| = 1 / y 192 Y ~ LogNormal(0, 1), i.e., 193 prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) 194 = (1 / y) Normal(log(y); 0, 1) 195 ``` 196 197 Here is an example of how one might implement the `Exp` bijector: 198 199 ```python 200 class Exp(Bijector): 201 202 def __init__(self, event_ndims=0, validate_args=False, name="exp"): 203 super(Exp, self).__init__( 204 event_ndims=event_ndims, validate_args=validate_args, name=name) 205 206 def _forward(self, x): 207 return math_ops.exp(x) 208 209 def _inverse(self, y): 210 return math_ops.log(y) 211 212 def _inverse_log_det_jacobian(self, y): 213 return -self._forward_log_det_jacobian(self._inverse(y)) 214 215 def _forward_log_det_jacobian(self, x): 216 if self.event_ndims is None: 217 raise ValueError("Jacobian requires known event_ndims.") 218 event_dims = array_ops.shape(x)[-self.event_ndims:] 219 return math_ops.reduce_sum(x, axis=event_dims) 220 ``` 221 222 - "Affine" 223 224 ```none 225 Y = g(X) = sqrtSigma * X + mu 226 X ~ MultivariateNormal(0, I_d) 227 ``` 228 229 Implies: 230 231 ```none 232 g^{-1}(Y) = inv(sqrtSigma) * (Y - mu) 233 |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma)) 234 Y ~ MultivariateNormal(mu, sqrtSigma) , i.e., 235 prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y)) 236 = det(sqrtSigma)^(-d) * 237 MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d) 238 ``` 239 240 #### Jacobian 241 242 The Jacobian is a reduction over event dims. To see this, consider the `Exp` 243 `Bijector` applied to a `Tensor` which has sample, batch, and event (S, B, E) 244 shape semantics. Suppose the `Tensor`'s partitioned-shape is `(S=[4], B=[2], 245 E=[3, 3])`. The shape of the `Tensor` returned by `forward` and `inverse` is 246 unchanged, i.e., `[4, 2, 3, 3]`. However the shape returned by 247 `inverse_log_det_jacobian` is `[4, 2]` because the Jacobian is a reduction 248 over the event dimensions. 249 250 It is sometimes useful to implement the inverse Jacobian as the negative 251 forward Jacobian. For example, 252 253 ```python 254 def _inverse_log_det_jacobian(self, y): 255 return -self._forward_log_det_jac(self._inverse(y)) # Note negation. 256 ``` 257 258 The correctness of this approach can be seen from the following claim. 259 260 - Claim: 261 262 Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero 263 for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then: 264 265 ```none 266 (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X) 267 ``` 268 269 - Proof: 270 271 From the bijective, nonzero differentiability of `g`, the 272 [inverse function theorem]( 273 https://en.wikipedia.org/wiki/Inverse_function_theorem) 274 implies `g^{-1}` is differentiable in the image of `g`. 275 Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields 276 `I = g'(g^{-1}(y))*g^{-1}'(y)`. 277 The same theorem also implies `g^{-1}'` is non-singular therefore: 278 `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`. 279 The claim follows from [properties of determinant]( 280 https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups). 281 282 Generally its preferable to directly implement the inverse Jacobian. This 283 should have superior numerical stability and will often share subgraphs with 284 the `_inverse` implementation. 285 286 #### Subclass Requirements 287 288 - Subclasses typically implement: 289 290 - `_forward`, 291 - `_inverse`, 292 - `_inverse_log_det_jacobian`, 293 - `_forward_log_det_jacobian` (optional). 294 295 The `_forward_log_det_jacobian` is called when the bijector is inverted via 296 the `Invert` bijector. If undefined, a slightly less efficiently 297 calculation, `-1 * _inverse_log_det_jacobian`, is used. 298 299 If the bijector changes the shape of the input, you must also implement: 300 301 - _forward_event_shape_tensor, 302 - _forward_event_shape (optional), 303 - _inverse_event_shape_tensor, 304 - _inverse_event_shape (optional). 305 306 By default the event-shape is assumed unchanged from input. 307 308 - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends 309 like `QuantizedDistribution`) then depending on your use, you may not need 310 to implement all of `_forward` and `_inverse` functions. 311 312 Examples: 313 314 1. Sampling (e.g., `sample`) only requires `_forward`. 315 2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require 316 `_inverse` (and related). 317 3. Only calling probability functions on the output of `sample` means 318 `_inverse` can be implemented as a cache lookup. 319 320 See "Example Uses" [above] which shows how these functions are used to 321 transform a distribution. (Note: `_forward` could theoretically be 322 implemented as a cache lookup but this would require controlling the 323 underlying sample generation mechanism.) 324 325 #### Non Injective Transforms 326 327 **WARNING** Handing of non-injective transforms is subject to change. 328 329 Non injective maps `g` are supported, provided their domain `D` can be 330 partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that, 331 ignoring sets of measure zero, the restriction of `g` to each subset is a 332 differentiable bijection onto `g(D)`. In particular, this imples that for 333 `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always 334 contains exactly `k` distinct points. 335 336 The property, `_is_injective` is set to `False` to indicate that the bijector 337 is not injective, yet satisfies the above condition. 338 339 The usual bijector API is modified in the case `_is_injective is False` (see 340 method docstrings for specifics). Here we show by example the `AbsoluteValue` 341 bijector. In this case, the domain `D = (-inf, inf)`, can be partitioned 342 into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`. Let `gi` be the 343 restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto 344 `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`. We will use 345 `g1` and `g3` to define bijector methods over `D1` and `D3`. `D2 = {0}` is 346 an oddball in that `g2` is one to one, and the derivative is not well defined. 347 Fortunately, when considering transformations of probability densities 348 (e.g. in `TransformedDistribution`), sets of measure zero have no effect in 349 theory, and only a small effect in 32 or 64 bit precision. For that reason, 350 we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`, 351 which is convenient and results in a left-semicontinuous pdf. 352 353 354 ```python 355 abs = tf.contrib.distributions.bijectors.AbsoluteValue() 356 357 abs.forward(-1.) 358 ==> 1. 359 360 abs.forward(1.) 361 ==> 1. 362 363 abs.inverse(1.) 364 ==> (-1., 1.) 365 366 # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. 367 abs.inverse_log_det_jacobian(1.) 368 ==> (0., 0.) 369 370 # Special case handling of 0. 371 abs.inverse(0.) 372 ==> (0., 0.) 373 374 abs.inverse_log_det_jacobian(0.) 375 ==> (0., 0.) 376 ``` 377 378 """ 379 380 @abc.abstractmethod 381 def __init__(self, 382 event_ndims=None, 383 graph_parents=None, 384 is_constant_jacobian=False, 385 validate_args=False, 386 dtype=None, 387 name=None): 388 """Constructs Bijector. 389 390 A `Bijector` transforms random variables into new random variables. 391 392 Examples: 393 394 ```python 395 # Create the Y = g(X) = X transform which operates on vector events. 396 identity = Identity(event_ndims=1) 397 398 # Create the Y = g(X) = exp(X) transform which operates on matrices. 399 exp = Exp(event_ndims=2) 400 ``` 401 402 See `Bijector` subclass docstring for more details and specific examples. 403 404 Args: 405 event_ndims: number of dimensions associated with event coordinates. 406 graph_parents: Python list of graph prerequisites of this `Bijector`. 407 is_constant_jacobian: Python `bool` indicating that the Jacobian is not a 408 function of the input. 409 validate_args: Python `bool`, default `False`. Whether to validate input 410 with asserts. If `validate_args` is `False`, and the inputs are invalid, 411 correct behavior is not guaranteed. 412 dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not 413 enforced. 414 name: The name to give Ops created by the initializer. 415 416 Raises: 417 ValueError: If a member of `graph_parents` is not a `Tensor`. 418 """ 419 self._event_ndims = ( 420 ops.convert_to_tensor(event_ndims, dtype=dtypes.int32) 421 if event_ndims is not None else None) 422 self._graph_parents = graph_parents or [] 423 self._is_constant_jacobian = is_constant_jacobian 424 self._validate_args = validate_args 425 self._dtype = dtype 426 self._from_y = {} 427 self._from_x = {} 428 # Using abbreviation ildj for "inverse log det Jacobian." 429 # This variable is not `None` iff is_constant_jacobian is `True`. 430 self._constant_ildj = None 431 if name: 432 self._name = name 433 else: 434 # We want the default convention to be snake_case rather than CamelCase 435 # since `Chain` uses bijector.name as the kwargs dictionary key. 436 def camel_to_snake(name): 437 s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 438 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower() 439 self._name = camel_to_snake(type(self).__name__.lstrip("_")) 440 441 for i, t in enumerate(self._graph_parents): 442 if t is None or not tensor_util.is_tensor(t): 443 raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t)) 444 445 @property 446 def event_ndims(self): 447 """Returns then number of event dimensions this bijector operates on.""" 448 return self._event_ndims 449 450 @property 451 def graph_parents(self): 452 """Returns this `Bijector`'s graph_parents as a Python list.""" 453 return self._graph_parents 454 455 @property 456 def is_constant_jacobian(self): 457 """Returns true iff the Jacobian is not a function of x. 458 459 Note: Jacobian is either constant for both forward and inverse or neither. 460 461 Returns: 462 is_constant_jacobian: Python `bool`. 463 """ 464 return self._is_constant_jacobian 465 466 @property 467 def _is_injective(self): 468 """Returns true iff the forward map `g` is injective (one-to-one function). 469 470 **WARNING** This hidden property and its behavior are subject to change. 471 472 Note: Non-injective maps `g` are supported, provided their domain `D` can 473 be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that, 474 ignoring sets of measure zero, the restriction of `g` to each subset is a 475 differentiable bijection onto `g(D)`. 476 477 Returns: 478 is_injective: Python `bool`. 479 """ 480 return True 481 482 @property 483 def validate_args(self): 484 """Returns True if Tensor arguments will be validated.""" 485 return self._validate_args 486 487 @property 488 def dtype(self): 489 """dtype of `Tensor`s transformable by this distribution.""" 490 return self._dtype 491 492 @property 493 def name(self): 494 """Returns the string name of this `Bijector`.""" 495 return self._name 496 497 def _forward_event_shape_tensor(self, input_shape): 498 """Subclass implementation for `forward_event_shape_tensor` function.""" 499 # By default, we assume event_shape is unchanged. 500 return input_shape 501 502 def forward_event_shape_tensor(self, 503 input_shape, 504 name="forward_event_shape_tensor"): 505 """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. 506 507 Args: 508 input_shape: `Tensor`, `int32` vector indicating event-portion shape 509 passed into `forward` function. 510 name: name to give to the op 511 512 Returns: 513 forward_event_shape_tensor: `Tensor`, `int32` vector indicating 514 event-portion shape after applying `forward`. 515 """ 516 with self._name_scope(name, [input_shape]): 517 input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32, 518 name="input_shape") 519 return self._forward_event_shape_tensor(input_shape) 520 521 def _forward_event_shape(self, input_shape): 522 """Subclass implementation for `forward_event_shape` public function.""" 523 # By default, we assume event_shape is unchanged. 524 return input_shape 525 526 def forward_event_shape(self, input_shape): 527 """Shape of a single sample from a single batch as a `TensorShape`. 528 529 Same meaning as `forward_event_shape_tensor`. May be only partially defined. 530 531 Args: 532 input_shape: `TensorShape` indicating event-portion shape passed into 533 `forward` function. 534 535 Returns: 536 forward_event_shape_tensor: `TensorShape` indicating event-portion shape 537 after applying `forward`. Possibly unknown. 538 """ 539 return self._forward_event_shape(tensor_shape.TensorShape(input_shape)) 540 541 def _inverse_event_shape_tensor(self, output_shape): 542 """Subclass implementation for `inverse_event_shape_tensor` function.""" 543 # By default, we assume event_shape is unchanged. 544 return output_shape 545 546 def inverse_event_shape_tensor(self, 547 output_shape, 548 name="inverse_event_shape_tensor"): 549 """Shape of a single sample from a single batch as an `int32` 1D `Tensor`. 550 551 Args: 552 output_shape: `Tensor`, `int32` vector indicating event-portion shape 553 passed into `inverse` function. 554 name: name to give to the op 555 556 Returns: 557 inverse_event_shape_tensor: `Tensor`, `int32` vector indicating 558 event-portion shape after applying `inverse`. 559 """ 560 with self._name_scope(name, [output_shape]): 561 output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32, 562 name="output_shape") 563 return self._inverse_event_shape_tensor(output_shape) 564 565 def _inverse_event_shape(self, output_shape): 566 """Subclass implementation for `inverse_event_shape` public function.""" 567 # By default, we assume event_shape is unchanged. 568 return tensor_shape.TensorShape(output_shape) 569 570 def inverse_event_shape(self, output_shape): 571 """Shape of a single sample from a single batch as a `TensorShape`. 572 573 Same meaning as `inverse_event_shape_tensor`. May be only partially defined. 574 575 Args: 576 output_shape: `TensorShape` indicating event-portion shape passed into 577 `inverse` function. 578 579 Returns: 580 inverse_event_shape_tensor: `TensorShape` indicating event-portion shape 581 after applying `inverse`. Possibly unknown. 582 """ 583 return self._inverse_event_shape(output_shape) 584 585 def _forward(self, x): 586 """Subclass implementation for `forward` public function.""" 587 raise NotImplementedError("forward not implemented.") 588 589 def _call_forward(self, x, name, **kwargs): 590 with self._name_scope(name, [x]): 591 x = ops.convert_to_tensor(x, name="x") 592 self._maybe_assert_dtype(x) 593 if not self._is_injective: # No caching for non-injective 594 return self._forward(x, **kwargs) 595 mapping = self._lookup(x=x, kwargs=kwargs) 596 if mapping.y is not None: 597 return mapping.y 598 mapping = mapping.merge(y=self._forward(x, **kwargs)) 599 self._cache(mapping) 600 return mapping.y 601 602 def forward(self, x, name="forward"): 603 """Returns the forward `Bijector` evaluation, i.e., X = g(Y). 604 605 Args: 606 x: `Tensor`. The input to the "forward" evaluation. 607 name: The name to give this op. 608 609 Returns: 610 `Tensor`. 611 612 Raises: 613 TypeError: if `self.dtype` is specified and `x.dtype` is not 614 `self.dtype`. 615 NotImplementedError: if `_forward` is not implemented. 616 """ 617 return self._call_forward(x, name) 618 619 def _inverse(self, y): 620 """Subclass implementation for `inverse` public function.""" 621 raise NotImplementedError("inverse not implemented") 622 623 def _call_inverse(self, y, name, **kwargs): 624 with self._name_scope(name, [y]): 625 y = ops.convert_to_tensor(y, name="y") 626 self._maybe_assert_dtype(y) 627 if not self._is_injective: # No caching for non-injective 628 return self._inverse(y, **kwargs) 629 mapping = self._lookup(y=y, kwargs=kwargs) 630 if mapping.x is not None: 631 return mapping.x 632 mapping = mapping.merge(x=self._inverse(y, **kwargs)) 633 self._cache(mapping) 634 return mapping.x 635 636 def inverse(self, y, name="inverse"): 637 """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y). 638 639 Args: 640 y: `Tensor`. The input to the "inverse" evaluation. 641 name: The name to give this op. 642 643 Returns: 644 `Tensor`, if this bijector is injective. 645 If not injective, returns the k-tuple containing the unique 646 `k` points `(x1, ..., xk)` such that `g(xi) = y`. 647 648 Raises: 649 TypeError: if `self.dtype` is specified and `y.dtype` is not 650 `self.dtype`. 651 NotImplementedError: if `_inverse` is not implemented. 652 """ 653 return self._call_inverse(y, name) 654 655 def _inverse_log_det_jacobian(self, y): 656 """Subclass implementation of `inverse_log_det_jacobian` public function.""" 657 raise NotImplementedError("inverse_log_det_jacobian not implemented.") 658 659 def _call_inverse_log_det_jacobian(self, y, name, **kwargs): 660 with self._name_scope(name, [y]): 661 if self._constant_ildj is not None: 662 return self._constant_ildj 663 y = ops.convert_to_tensor(y, name="y") 664 self._maybe_assert_dtype(y) 665 if not self._is_injective: # No caching for non-injective 666 return self._inverse_log_det_jacobian(y, **kwargs) 667 mapping = self._lookup(y=y, kwargs=kwargs) 668 if mapping.ildj is not None: 669 return mapping.ildj 670 try: 671 x = None # Not needed; leave cache as is. 672 ildj = self._inverse_log_det_jacobian(y, **kwargs) 673 except NotImplementedError as original_exception: 674 try: 675 x = mapping.x if mapping.x is not None else self._inverse(y, **kwargs) 676 ildj = -self._forward_log_det_jacobian(x, **kwargs) 677 except NotImplementedError: 678 raise original_exception 679 mapping = mapping.merge(x=x, ildj=ildj) 680 self._cache(mapping) 681 if self.is_constant_jacobian: 682 self._constant_ildj = mapping.ildj 683 return mapping.ildj 684 685 def inverse_log_det_jacobian(self, y, name="inverse_log_det_jacobian"): 686 """Returns the (log o det o Jacobian o inverse)(y). 687 688 Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.) 689 690 Note that `forward_log_det_jacobian` is the negative of this function, 691 evaluated at `g^{-1}(y)`. 692 693 Args: 694 y: `Tensor`. The input to the "inverse" Jacobian evaluation. 695 name: The name to give this op. 696 697 Returns: 698 `Tensor`, if this bijector is injective. 699 If not injective, returns the tuple of local log det 700 Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction 701 of `g` to the `ith` partition `Di`. 702 703 Raises: 704 TypeError: if `self.dtype` is specified and `y.dtype` is not 705 `self.dtype`. 706 NotImplementedError: if `_inverse_log_det_jacobian` is not implemented. 707 """ 708 return self._call_inverse_log_det_jacobian(y, name) 709 710 def _forward_log_det_jacobian(self, x): 711 """Subclass implementation of `forward_log_det_jacobian`.""" 712 raise NotImplementedError( 713 "forward_log_det_jacobian not implemented.") 714 715 def _call_forward_log_det_jacobian(self, x, name, **kwargs): 716 with self._name_scope(name, [x]): 717 if self._constant_ildj is not None: 718 # Need "-1. *" to avoid invalid-unary-operand-type linter warning. 719 return -1. * self._constant_ildj 720 x = ops.convert_to_tensor(x, name="x") 721 self._maybe_assert_dtype(x) 722 if not self._is_injective: 723 return self._forward_log_det_jacobian(x, **kwargs) # No caching. 724 mapping = self._lookup(x=x, kwargs=kwargs) 725 if mapping.ildj is not None: 726 return -mapping.ildj 727 try: 728 y = None # Not needed; leave cache as is. 729 ildj = -self._forward_log_det_jacobian(x, **kwargs) 730 except NotImplementedError as original_exception: 731 try: 732 y = mapping.y if mapping.y is not None else self._forward(x, **kwargs) 733 ildj = self._inverse_log_det_jacobian(y, **kwargs) 734 except NotImplementedError: 735 raise original_exception 736 mapping = mapping.merge(y=y, ildj=ildj) 737 self._cache(mapping) 738 if self.is_constant_jacobian: 739 self._constant_ildj = mapping.ildj 740 return -mapping.ildj 741 742 def forward_log_det_jacobian(self, x, name="forward_log_det_jacobian"): 743 """Returns both the forward_log_det_jacobian. 744 745 Args: 746 x: `Tensor`. The input to the "forward" Jacobian evaluation. 747 name: The name to give this op. 748 749 Returns: 750 `Tensor`, if this bijector is injective. 751 If not injective this is not implemented. 752 753 Raises: 754 TypeError: if `self.dtype` is specified and `y.dtype` is not 755 `self.dtype`. 756 NotImplementedError: if neither `_forward_log_det_jacobian` 757 nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or 758 this is a non-injective bijector. 759 """ 760 if not self._is_injective: 761 raise NotImplementedError( 762 "forward_log_det_jacobian cannot be implemented for non-injective " 763 "transforms.") 764 return self._call_forward_log_det_jacobian(x, name) 765 766 @contextlib.contextmanager 767 def _name_scope(self, name=None, values=None): 768 """Helper function to standardize op scope.""" 769 with ops.name_scope(self.name): 770 with ops.name_scope( 771 name, values=(values or []) + self.graph_parents) as scope: 772 yield scope 773 774 def _maybe_assert_dtype(self, x): 775 """Helper to check dtype when self.dtype is known.""" 776 if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype: 777 raise TypeError("Input had dtype %s but expected %s." % 778 (self.dtype, x.dtype)) 779 780 def _cache(self, mapping): 781 """Helper which stores mapping info in forward/inverse dicts.""" 782 if self._constant_ildj is not None: 783 # Fold in ildj if known constant Jacobian. 784 mapping = mapping.merge(ildj=self._constant_ildj) 785 # Merging from lookup is an added check that we're not overwriting anything 786 # which is not None. 787 mapping = mapping.merge(mapping=self._lookup( 788 mapping.x, mapping.y, mapping.kwargs)) 789 if mapping.x is None and mapping.y is None: 790 raise ValueError("Caching expects at least one of (x,y) to be known, " 791 "i.e., not None.") 792 self._from_x[mapping.x_key] = mapping 793 self._from_y[mapping.y_key] = mapping 794 795 def _lookup(self, x=None, y=None, kwargs=None): 796 """Helper which retrieves mapping info from forward/inverse dicts.""" 797 mapping = _Mapping(x=x, y=y, kwargs=kwargs) 798 # Since _cache requires both x,y to be set, we only need to do one cache 799 # lookup since the mapping is always in both or neither. 800 if mapping.x is not None: 801 return self._from_x.get(mapping.x_key, mapping) 802 if mapping.y is not None: 803 return self._from_y.get(mapping.y_key, mapping) 804 return mapping 805 806 def _event_dims_tensor(self, sample): 807 """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`.""" 808 if self.event_ndims is None: 809 raise ValueError("Jacobian cannot be computed with unknown event_ndims") 810 static_event_ndims = tensor_util.constant_value(self.event_ndims) 811 static_rank = sample.get_shape().ndims 812 if static_event_ndims is not None and static_rank is not None: 813 return ops.convert_to_tensor( 814 static_rank + np.arange(-static_event_ndims, 0).astype(np.int32)) 815 816 if static_event_ndims is not None: 817 event_range = np.arange(-static_event_ndims, 0).astype(np.int32) 818 else: 819 event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32) 820 821 if static_rank is not None: 822 return event_range + static_rank 823 else: 824 return event_range + array_ops.rank(sample) 825