1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Bijector base."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23import contextlib
24import re
25
26import numpy as np
27import six
28
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.util.tf_export import tf_export
36
37
38__all__ = [
39    "Bijector",
40]
41
42
43class _Mapping(collections.namedtuple(
44    "_Mapping", ["x", "y", "ildj", "kwargs"])):
45  """Helper class to make it easier to manage caching in `Bijector`."""
46
47  def __new__(cls, x=None, y=None, ildj=None, kwargs=None):
48    """Custom __new__ so namedtuple items have defaults.
49
50    Args:
51      x: `Tensor`. Forward.
52      y: `Tensor`. Inverse.
53      ildj: `Tensor`. Inverse log det Jacobian.
54      kwargs: Python dictionary. Extra args supplied to
55        forward/inverse/etc functions.
56
57    Returns:
58      mapping: New instance of _Mapping.
59    """
60    return super(_Mapping, cls).__new__(cls, x, y, ildj, kwargs)
61
62  @property
63  def x_key(self):
64    """Returns key used for caching Y=g(X)."""
65    return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
66
67  @property
68  def y_key(self):
69    """Returns key used for caching X=g^{-1}(Y)."""
70    return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
71
72  def merge(self, x=None, y=None, ildj=None, kwargs=None, mapping=None):
73    """Returns new _Mapping with args merged with self.
74
75    Args:
76      x: `Tensor`. Forward.
77      y: `Tensor`. Inverse.
78      ildj: `Tensor`. Inverse log det Jacobian.
79      kwargs: Python dictionary. Extra args supplied to
80        forward/inverse/etc functions.
81      mapping: Instance of _Mapping to merge. Can only be specified if no other
82        arg is specified.
83
84    Returns:
85      mapping: New instance of `_Mapping` which has inputs merged with self.
86
87    Raises:
88      ValueError: if mapping and any other arg is not `None`.
89    """
90    if mapping is None:
91      mapping = _Mapping(x=x, y=y, ildj=ildj, kwargs=kwargs)
92    elif not all(arg is None for arg in [x, y, ildj, kwargs]):
93      raise ValueError("Cannot specify mapping and individual args.")
94    return _Mapping(
95        x=self._merge(self.x, mapping.x),
96        y=self._merge(self.y, mapping.y),
97        ildj=self._merge(self.ildj, mapping.ildj),
98        kwargs=self._merge(self.kwargs, mapping.kwargs))
99
100  def _merge(self, old, new):
101    """Helper to merge which handles merging one value."""
102    if old is None:
103      return new
104    elif new is not None and old != new:
105      raise ValueError("Incompatible values: %s != %s" % (old, new))
106    return old
107
108  def _deep_tuple(self, x):
109    """Converts lists of lists to tuples of tuples."""
110    return (tuple(map(self._deep_tuple, x))
111            if isinstance(x, (list, tuple)) else x)
112
113
114@six.add_metaclass(abc.ABCMeta)
115@tf_export("distributions.bijectors.Bijector")
116class Bijector(object):
117  r"""Interface for transformations of a `Distribution` sample.
118
119  Bijectors can be used to represent any differentiable and injective
120  (one to one) function defined on an open subset of `R^n`.  Some non-injective
121  transformations are also supported (see "Non Injective Transforms" below).
122
123  #### Mathematical Details
124
125  A `Bijector` implements a [smooth covering map](
126  https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local
127  diffeomorphism such that every point in the target has a neighborhood evenly
128  covered by a map ([see also](
129  https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)).
130  A `Bijector` is used by `TransformedDistribution` but can be generally used
131  for transforming a `Distribution` generated `Tensor`. A `Bijector` is
132  characterized by three operations:
133
134  1. Forward\
135     Useful for turning one random outcome into another random outcome from a
136     different distribution.
137  2. Inverse\
138     Useful for "reversing" a transformation to compute one probability in
139     terms of another.
140  3. `(log o det o Jacobian o inverse)(x)`\
141     "The log of the determinant of the matrix of all first-order partial
142     derivatives of the inverse function."\
143     Useful for inverting a transformation to compute one probability in terms
144     of another. Geometrically, the det(Jacobian) is the volume of the
145     transformation and is used to scale the probability.
146
147  By convention, transformations of random variables are named in terms of the
148  forward transformation. The forward transformation creates samples, the
149  inverse is useful for computing probabilities.
150
151  #### Example Uses
152
153  - Basic properties:
154
155  ```python
156  x = ...  # A tensor.
157  # Evaluate forward transformation.
158  fwd_x = my_bijector.forward(x)
159  x == my_bijector.inverse(fwd_x)
160  x != my_bijector.forward(fwd_x)  # Not equal because x != g(g(x)).
161  ```
162
163  - Computing a log-likelihood:
164
165  ```python
166  def transformed_log_prob(bijector, log_prob, x):
167    return (bijector.inverse_log_det_jacobian(x) +
168            log_prob(bijector.inverse(x)))
169  ```
170
171  - Transforming a random outcome:
172
173  ```python
174  def transformed_sample(bijector, x):
175    return bijector.forward(x)
176  ```
177
178  #### Example Bijectors
179
180  - "Exponential"
181
182    ```none
183    Y = g(X) = exp(X)
184    X ~ Normal(0, 1)  # Univariate.
185    ```
186
187    Implies:
188
189    ```none
190      g^{-1}(Y) = log(Y)
191      |Jacobian(g^{-1})(y)| = 1 / y
192      Y ~ LogNormal(0, 1), i.e.,
193      prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
194                = (1 / y) Normal(log(y); 0, 1)
195    ```
196
197    Here is an example of how one might implement the `Exp` bijector:
198
199    ```python
200      class Exp(Bijector):
201
202        def __init__(self, event_ndims=0, validate_args=False, name="exp"):
203          super(Exp, self).__init__(
204              event_ndims=event_ndims, validate_args=validate_args, name=name)
205
206        def _forward(self, x):
207          return math_ops.exp(x)
208
209        def _inverse(self, y):
210          return math_ops.log(y)
211
212        def _inverse_log_det_jacobian(self, y):
213          return -self._forward_log_det_jacobian(self._inverse(y))
214
215        def _forward_log_det_jacobian(self, x):
216          if self.event_ndims is None:
217            raise ValueError("Jacobian requires known event_ndims.")
218          event_dims = array_ops.shape(x)[-self.event_ndims:]
219          return math_ops.reduce_sum(x, axis=event_dims)
220      ```
221
222  - "Affine"
223
224    ```none
225    Y = g(X) = sqrtSigma * X + mu
226    X ~ MultivariateNormal(0, I_d)
227    ```
228
229    Implies:
230
231    ```none
232      g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
233      |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
234      Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
235      prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
236                = det(sqrtSigma)^(-d) *
237                  MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
238      ```
239
240  #### Jacobian
241
242  The Jacobian is a reduction over event dims. To see this, consider the `Exp`
243  `Bijector` applied to a `Tensor` which has sample, batch, and event (S, B, E)
244  shape semantics. Suppose the `Tensor`'s partitioned-shape is `(S=[4], B=[2],
245  E=[3, 3])`. The shape of the `Tensor` returned by `forward` and `inverse` is
246  unchanged, i.e., `[4, 2, 3, 3]`.  However the shape returned by
247  `inverse_log_det_jacobian` is `[4, 2]` because the Jacobian is a reduction
248  over the event dimensions.
249
250  It is sometimes useful to implement the inverse Jacobian as the negative
251  forward Jacobian. For example,
252
253  ```python
254  def _inverse_log_det_jacobian(self, y):
255     return -self._forward_log_det_jac(self._inverse(y))  # Note negation.
256  ```
257
258  The correctness of this approach can be seen from the following claim.
259
260  - Claim:
261
262      Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero
263      for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then:
264
265      ```none
266      (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X)
267      ```
268
269  - Proof:
270
271      From the bijective, nonzero differentiability of `g`, the
272      [inverse function theorem](
273          https://en.wikipedia.org/wiki/Inverse_function_theorem)
274      implies `g^{-1}` is differentiable in the image of `g`.
275      Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields
276      `I = g'(g^{-1}(y))*g^{-1}'(y)`.
277      The same theorem also implies `g^{-1}'` is non-singular therefore:
278      `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`.
279      The claim follows from [properties of determinant](
280  https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups).
281
282  Generally its preferable to directly implement the inverse Jacobian. This
283  should have superior numerical stability and will often share subgraphs with
284  the `_inverse` implementation.
285
286  #### Subclass Requirements
287
288  - Subclasses typically implement:
289
290      - `_forward`,
291      - `_inverse`,
292      - `_inverse_log_det_jacobian`,
293      - `_forward_log_det_jacobian` (optional).
294
295    The `_forward_log_det_jacobian` is called when the bijector is inverted via
296    the `Invert` bijector. If undefined, a slightly less efficiently
297    calculation, `-1 * _inverse_log_det_jacobian`, is used.
298
299    If the bijector changes the shape of the input, you must also implement:
300
301      - _forward_event_shape_tensor,
302      - _forward_event_shape (optional),
303      - _inverse_event_shape_tensor,
304      - _inverse_event_shape (optional).
305
306    By default the event-shape is assumed unchanged from input.
307
308  - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends
309    like `QuantizedDistribution`) then depending on your use, you may not need
310    to implement all of `_forward` and `_inverse` functions.
311
312    Examples:
313
314      1. Sampling (e.g., `sample`) only requires `_forward`.
315      2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require
316         `_inverse` (and related).
317      3. Only calling probability functions on the output of `sample` means
318        `_inverse` can be implemented as a cache lookup.
319
320    See "Example Uses" [above] which shows how these functions are used to
321    transform a distribution. (Note: `_forward` could theoretically be
322    implemented as a cache lookup but this would require controlling the
323    underlying sample generation mechanism.)
324
325  #### Non Injective Transforms
326
327  **WARNING** Handing of non-injective transforms is subject to change.
328
329  Non injective maps `g` are supported, provided their domain `D` can be
330  partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
331  ignoring sets of measure zero, the restriction of `g` to each subset is a
332  differentiable bijection onto `g(D)`.  In particular, this imples that for
333  `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always
334  contains exactly `k` distinct points.
335
336  The property, `_is_injective` is set to `False` to indicate that the bijector
337  is not injective, yet satisfies the above condition.
338
339  The usual bijector API is modified in the case `_is_injective is False` (see
340  method docstrings for specifics).  Here we show by example the `AbsoluteValue`
341  bijector.  In this case, the domain `D = (-inf, inf)`, can be partitioned
342  into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`.  Let `gi` be the
343  restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto
344  `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`.  We will use
345  `g1` and `g3` to define bijector methods over `D1` and `D3`.  `D2 = {0}` is
346  an oddball in that `g2` is one to one, and the derivative is not well defined.
347  Fortunately, when considering transformations of probability densities
348  (e.g. in `TransformedDistribution`), sets of measure zero have no effect in
349  theory, and only a small effect in 32 or 64 bit precision.  For that reason,
350  we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`,
351  which is convenient and results in a left-semicontinuous pdf.
352
353
354  ```python
355  abs = tf.contrib.distributions.bijectors.AbsoluteValue()
356
357  abs.forward(-1.)
358  ==> 1.
359
360  abs.forward(1.)
361  ==> 1.
362
363  abs.inverse(1.)
364  ==> (-1., 1.)
365
366  # The |dX/dY| is constant, == 1.  So Log|dX/dY| == 0.
367  abs.inverse_log_det_jacobian(1.)
368  ==> (0., 0.)
369
370  # Special case handling of 0.
371  abs.inverse(0.)
372  ==> (0., 0.)
373
374  abs.inverse_log_det_jacobian(0.)
375  ==> (0., 0.)
376  ```
377
378  """
379
380  @abc.abstractmethod
381  def __init__(self,
382               event_ndims=None,
383               graph_parents=None,
384               is_constant_jacobian=False,
385               validate_args=False,
386               dtype=None,
387               name=None):
388    """Constructs Bijector.
389
390    A `Bijector` transforms random variables into new random variables.
391
392    Examples:
393
394    ```python
395    # Create the Y = g(X) = X transform which operates on vector events.
396    identity = Identity(event_ndims=1)
397
398    # Create the Y = g(X) = exp(X) transform which operates on matrices.
399    exp = Exp(event_ndims=2)
400    ```
401
402    See `Bijector` subclass docstring for more details and specific examples.
403
404    Args:
405      event_ndims: number of dimensions associated with event coordinates.
406      graph_parents: Python list of graph prerequisites of this `Bijector`.
407      is_constant_jacobian: Python `bool` indicating that the Jacobian is not a
408        function of the input.
409      validate_args: Python `bool`, default `False`. Whether to validate input
410        with asserts. If `validate_args` is `False`, and the inputs are invalid,
411        correct behavior is not guaranteed.
412      dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
413        enforced.
414      name: The name to give Ops created by the initializer.
415
416    Raises:
417      ValueError:  If a member of `graph_parents` is not a `Tensor`.
418    """
419    self._event_ndims = (
420        ops.convert_to_tensor(event_ndims, dtype=dtypes.int32)
421        if event_ndims is not None else None)
422    self._graph_parents = graph_parents or []
423    self._is_constant_jacobian = is_constant_jacobian
424    self._validate_args = validate_args
425    self._dtype = dtype
426    self._from_y = {}
427    self._from_x = {}
428    # Using abbreviation ildj for "inverse log det Jacobian."
429    # This variable is not `None` iff is_constant_jacobian is `True`.
430    self._constant_ildj = None
431    if name:
432      self._name = name
433    else:
434      # We want the default convention to be snake_case rather than CamelCase
435      # since `Chain` uses bijector.name as the kwargs dictionary key.
436      def camel_to_snake(name):
437        s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
438        return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
439      self._name = camel_to_snake(type(self).__name__.lstrip("_"))
440
441    for i, t in enumerate(self._graph_parents):
442      if t is None or not tensor_util.is_tensor(t):
443        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
444
445  @property
446  def event_ndims(self):
447    """Returns then number of event dimensions this bijector operates on."""
448    return self._event_ndims
449
450  @property
451  def graph_parents(self):
452    """Returns this `Bijector`'s graph_parents as a Python list."""
453    return self._graph_parents
454
455  @property
456  def is_constant_jacobian(self):
457    """Returns true iff the Jacobian is not a function of x.
458
459    Note: Jacobian is either constant for both forward and inverse or neither.
460
461    Returns:
462      is_constant_jacobian: Python `bool`.
463    """
464    return self._is_constant_jacobian
465
466  @property
467  def _is_injective(self):
468    """Returns true iff the forward map `g` is injective (one-to-one function).
469
470    **WARNING** This hidden property and its behavior are subject to change.
471
472    Note:  Non-injective maps `g` are supported, provided their domain `D` can
473    be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
474    ignoring sets of measure zero, the restriction of `g` to each subset is a
475    differentiable bijection onto `g(D)`.
476
477    Returns:
478      is_injective: Python `bool`.
479    """
480    return True
481
482  @property
483  def validate_args(self):
484    """Returns True if Tensor arguments will be validated."""
485    return self._validate_args
486
487  @property
488  def dtype(self):
489    """dtype of `Tensor`s transformable by this distribution."""
490    return self._dtype
491
492  @property
493  def name(self):
494    """Returns the string name of this `Bijector`."""
495    return self._name
496
497  def _forward_event_shape_tensor(self, input_shape):
498    """Subclass implementation for `forward_event_shape_tensor` function."""
499    # By default, we assume event_shape is unchanged.
500    return input_shape
501
502  def forward_event_shape_tensor(self,
503                                 input_shape,
504                                 name="forward_event_shape_tensor"):
505    """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
506
507    Args:
508      input_shape: `Tensor`, `int32` vector indicating event-portion shape
509        passed into `forward` function.
510      name: name to give to the op
511
512    Returns:
513      forward_event_shape_tensor: `Tensor`, `int32` vector indicating
514        event-portion shape after applying `forward`.
515    """
516    with self._name_scope(name, [input_shape]):
517      input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32,
518                                          name="input_shape")
519      return self._forward_event_shape_tensor(input_shape)
520
521  def _forward_event_shape(self, input_shape):
522    """Subclass implementation for `forward_event_shape` public function."""
523    # By default, we assume event_shape is unchanged.
524    return input_shape
525
526  def forward_event_shape(self, input_shape):
527    """Shape of a single sample from a single batch as a `TensorShape`.
528
529    Same meaning as `forward_event_shape_tensor`. May be only partially defined.
530
531    Args:
532      input_shape: `TensorShape` indicating event-portion shape passed into
533        `forward` function.
534
535    Returns:
536      forward_event_shape_tensor: `TensorShape` indicating event-portion shape
537        after applying `forward`. Possibly unknown.
538    """
539    return self._forward_event_shape(tensor_shape.TensorShape(input_shape))
540
541  def _inverse_event_shape_tensor(self, output_shape):
542    """Subclass implementation for `inverse_event_shape_tensor` function."""
543    # By default, we assume event_shape is unchanged.
544    return output_shape
545
546  def inverse_event_shape_tensor(self,
547                                 output_shape,
548                                 name="inverse_event_shape_tensor"):
549    """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
550
551    Args:
552      output_shape: `Tensor`, `int32` vector indicating event-portion shape
553        passed into `inverse` function.
554      name: name to give to the op
555
556    Returns:
557      inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
558        event-portion shape after applying `inverse`.
559    """
560    with self._name_scope(name, [output_shape]):
561      output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32,
562                                           name="output_shape")
563      return self._inverse_event_shape_tensor(output_shape)
564
565  def _inverse_event_shape(self, output_shape):
566    """Subclass implementation for `inverse_event_shape` public function."""
567    # By default, we assume event_shape is unchanged.
568    return tensor_shape.TensorShape(output_shape)
569
570  def inverse_event_shape(self, output_shape):
571    """Shape of a single sample from a single batch as a `TensorShape`.
572
573    Same meaning as `inverse_event_shape_tensor`. May be only partially defined.
574
575    Args:
576      output_shape: `TensorShape` indicating event-portion shape passed into
577        `inverse` function.
578
579    Returns:
580      inverse_event_shape_tensor: `TensorShape` indicating event-portion shape
581        after applying `inverse`. Possibly unknown.
582    """
583    return self._inverse_event_shape(output_shape)
584
585  def _forward(self, x):
586    """Subclass implementation for `forward` public function."""
587    raise NotImplementedError("forward not implemented.")
588
589  def _call_forward(self, x, name, **kwargs):
590    with self._name_scope(name, [x]):
591      x = ops.convert_to_tensor(x, name="x")
592      self._maybe_assert_dtype(x)
593      if not self._is_injective:  # No caching for non-injective
594        return self._forward(x, **kwargs)
595      mapping = self._lookup(x=x, kwargs=kwargs)
596      if mapping.y is not None:
597        return mapping.y
598      mapping = mapping.merge(y=self._forward(x, **kwargs))
599      self._cache(mapping)
600      return mapping.y
601
602  def forward(self, x, name="forward"):
603    """Returns the forward `Bijector` evaluation, i.e., X = g(Y).
604
605    Args:
606      x: `Tensor`. The input to the "forward" evaluation.
607      name: The name to give this op.
608
609    Returns:
610      `Tensor`.
611
612    Raises:
613      TypeError: if `self.dtype` is specified and `x.dtype` is not
614        `self.dtype`.
615      NotImplementedError: if `_forward` is not implemented.
616    """
617    return self._call_forward(x, name)
618
619  def _inverse(self, y):
620    """Subclass implementation for `inverse` public function."""
621    raise NotImplementedError("inverse not implemented")
622
623  def _call_inverse(self, y, name, **kwargs):
624    with self._name_scope(name, [y]):
625      y = ops.convert_to_tensor(y, name="y")
626      self._maybe_assert_dtype(y)
627      if not self._is_injective:  # No caching for non-injective
628        return self._inverse(y, **kwargs)
629      mapping = self._lookup(y=y, kwargs=kwargs)
630      if mapping.x is not None:
631        return mapping.x
632      mapping = mapping.merge(x=self._inverse(y, **kwargs))
633      self._cache(mapping)
634      return mapping.x
635
636  def inverse(self, y, name="inverse"):
637    """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
638
639    Args:
640      y: `Tensor`. The input to the "inverse" evaluation.
641      name: The name to give this op.
642
643    Returns:
644      `Tensor`, if this bijector is injective.
645        If not injective, returns the k-tuple containing the unique
646        `k` points `(x1, ..., xk)` such that `g(xi) = y`.
647
648    Raises:
649      TypeError: if `self.dtype` is specified and `y.dtype` is not
650        `self.dtype`.
651      NotImplementedError: if `_inverse` is not implemented.
652    """
653    return self._call_inverse(y, name)
654
655  def _inverse_log_det_jacobian(self, y):
656    """Subclass implementation of `inverse_log_det_jacobian` public function."""
657    raise NotImplementedError("inverse_log_det_jacobian not implemented.")
658
659  def _call_inverse_log_det_jacobian(self, y, name, **kwargs):
660    with self._name_scope(name, [y]):
661      if self._constant_ildj is not None:
662        return self._constant_ildj
663      y = ops.convert_to_tensor(y, name="y")
664      self._maybe_assert_dtype(y)
665      if not self._is_injective:  # No caching for non-injective
666        return self._inverse_log_det_jacobian(y, **kwargs)
667      mapping = self._lookup(y=y, kwargs=kwargs)
668      if mapping.ildj is not None:
669        return mapping.ildj
670      try:
671        x = None  # Not needed; leave cache as is.
672        ildj = self._inverse_log_det_jacobian(y, **kwargs)
673      except NotImplementedError as original_exception:
674        try:
675          x = mapping.x if mapping.x is not None else self._inverse(y, **kwargs)
676          ildj = -self._forward_log_det_jacobian(x, **kwargs)
677        except NotImplementedError:
678          raise original_exception
679      mapping = mapping.merge(x=x, ildj=ildj)
680      self._cache(mapping)
681      if self.is_constant_jacobian:
682        self._constant_ildj = mapping.ildj
683      return mapping.ildj
684
685  def inverse_log_det_jacobian(self, y, name="inverse_log_det_jacobian"):
686    """Returns the (log o det o Jacobian o inverse)(y).
687
688    Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.)
689
690    Note that `forward_log_det_jacobian` is the negative of this function,
691    evaluated at `g^{-1}(y)`.
692
693    Args:
694      y: `Tensor`. The input to the "inverse" Jacobian evaluation.
695      name: The name to give this op.
696
697    Returns:
698      `Tensor`, if this bijector is injective.
699        If not injective, returns the tuple of local log det
700        Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction
701        of `g` to the `ith` partition `Di`.
702
703    Raises:
704      TypeError: if `self.dtype` is specified and `y.dtype` is not
705        `self.dtype`.
706      NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
707    """
708    return self._call_inverse_log_det_jacobian(y, name)
709
710  def _forward_log_det_jacobian(self, x):
711    """Subclass implementation of `forward_log_det_jacobian`."""
712    raise NotImplementedError(
713        "forward_log_det_jacobian not implemented.")
714
715  def _call_forward_log_det_jacobian(self, x, name, **kwargs):
716    with self._name_scope(name, [x]):
717      if self._constant_ildj is not None:
718        # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
719        return -1. * self._constant_ildj
720      x = ops.convert_to_tensor(x, name="x")
721      self._maybe_assert_dtype(x)
722      if not self._is_injective:
723        return self._forward_log_det_jacobian(x, **kwargs)  # No caching.
724      mapping = self._lookup(x=x, kwargs=kwargs)
725      if mapping.ildj is not None:
726        return -mapping.ildj
727      try:
728        y = None  # Not needed; leave cache as is.
729        ildj = -self._forward_log_det_jacobian(x, **kwargs)
730      except NotImplementedError as original_exception:
731        try:
732          y = mapping.y if mapping.y is not None else self._forward(x, **kwargs)
733          ildj = self._inverse_log_det_jacobian(y, **kwargs)
734        except NotImplementedError:
735          raise original_exception
736      mapping = mapping.merge(y=y, ildj=ildj)
737      self._cache(mapping)
738      if self.is_constant_jacobian:
739        self._constant_ildj = mapping.ildj
740      return -mapping.ildj
741
742  def forward_log_det_jacobian(self, x, name="forward_log_det_jacobian"):
743    """Returns both the forward_log_det_jacobian.
744
745    Args:
746      x: `Tensor`. The input to the "forward" Jacobian evaluation.
747      name: The name to give this op.
748
749    Returns:
750      `Tensor`, if this bijector is injective.
751        If not injective this is not implemented.
752
753    Raises:
754      TypeError: if `self.dtype` is specified and `y.dtype` is not
755        `self.dtype`.
756      NotImplementedError: if neither `_forward_log_det_jacobian`
757        nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or
758        this is a non-injective bijector.
759    """
760    if not self._is_injective:
761      raise NotImplementedError(
762          "forward_log_det_jacobian cannot be implemented for non-injective "
763          "transforms.")
764    return self._call_forward_log_det_jacobian(x, name)
765
766  @contextlib.contextmanager
767  def _name_scope(self, name=None, values=None):
768    """Helper function to standardize op scope."""
769    with ops.name_scope(self.name):
770      with ops.name_scope(
771          name, values=(values or []) + self.graph_parents) as scope:
772        yield scope
773
774  def _maybe_assert_dtype(self, x):
775    """Helper to check dtype when self.dtype is known."""
776    if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype:
777      raise TypeError("Input had dtype %s but expected %s." %
778                      (self.dtype, x.dtype))
779
780  def _cache(self, mapping):
781    """Helper which stores mapping info in forward/inverse dicts."""
782    if self._constant_ildj is not None:
783      # Fold in ildj if known constant Jacobian.
784      mapping = mapping.merge(ildj=self._constant_ildj)
785    # Merging from lookup is an added check that we're not overwriting anything
786    # which is not None.
787    mapping = mapping.merge(mapping=self._lookup(
788        mapping.x, mapping.y, mapping.kwargs))
789    if mapping.x is None and mapping.y is None:
790      raise ValueError("Caching expects at least one of (x,y) to be known, "
791                       "i.e., not None.")
792    self._from_x[mapping.x_key] = mapping
793    self._from_y[mapping.y_key] = mapping
794
795  def _lookup(self, x=None, y=None, kwargs=None):
796    """Helper which retrieves mapping info from forward/inverse dicts."""
797    mapping = _Mapping(x=x, y=y, kwargs=kwargs)
798    # Since _cache requires both x,y to be set, we only need to do one cache
799    # lookup since the mapping is always in both or neither.
800    if mapping.x is not None:
801      return self._from_x.get(mapping.x_key, mapping)
802    if mapping.y is not None:
803      return self._from_y.get(mapping.y_key, mapping)
804    return mapping
805
806  def _event_dims_tensor(self, sample):
807    """Return a 1D `int32` tensor: `range(rank(sample))[-event_ndims:]`."""
808    if self.event_ndims is None:
809      raise ValueError("Jacobian cannot be computed with unknown event_ndims")
810    static_event_ndims = tensor_util.constant_value(self.event_ndims)
811    static_rank = sample.get_shape().ndims
812    if static_event_ndims is not None and static_rank is not None:
813      return ops.convert_to_tensor(
814          static_rank + np.arange(-static_event_ndims, 0).astype(np.int32))
815
816    if static_event_ndims is not None:
817      event_range = np.arange(-static_event_ndims, 0).astype(np.int32)
818    else:
819      event_range = math_ops.range(-self.event_ndims, 0, dtype=dtypes.int32)
820
821    if static_rank is not None:
822      return event_range + static_rank
823    else:
824      return event_range + array_ops.rank(sample)
825