1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Bijector base."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import collections
23import contextlib
24import re
25
26import numpy as np
27import six
28
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import check_ops
35from tensorflow.python.ops import math_ops
36from tensorflow.python.ops.distributions import util as distribution_util
37
38
39__all__ = [
40    "Bijector",
41]
42
43
44class _Mapping(collections.namedtuple(
45    "_Mapping", ["x", "y", "ildj_map", "kwargs"])):
46  """Helper class to make it easier to manage caching in `Bijector`."""
47
48  def __new__(cls, x=None, y=None, ildj_map=None, kwargs=None):
49    """Custom __new__ so namedtuple items have defaults.
50
51    Args:
52      x: `Tensor`. Forward.
53      y: `Tensor`. Inverse.
54      ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
55        representing the inverse log det jacobian.
56      kwargs: Python dictionary. Extra args supplied to
57        forward/inverse/etc functions.
58
59    Returns:
60      mapping: New instance of _Mapping.
61    """
62    return super(_Mapping, cls).__new__(cls, x, y, ildj_map, kwargs)
63
64  @property
65  def x_key(self):
66    """Returns key used for caching Y=g(X)."""
67    return (self.x,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
68
69  @property
70  def y_key(self):
71    """Returns key used for caching X=g^{-1}(Y)."""
72    return (self.y,) + self._deep_tuple(tuple(sorted(self.kwargs.items())))
73
74  def merge(self, x=None, y=None, ildj_map=None, kwargs=None, mapping=None):
75    """Returns new _Mapping with args merged with self.
76
77    Args:
78      x: `Tensor`. Forward.
79      y: `Tensor`. Inverse.
80      ildj_map: `Dictionary`. This is a mapping from event_ndims to a `Tensor`
81        representing the inverse log det jacobian.
82      kwargs: Python dictionary. Extra args supplied to
83        forward/inverse/etc functions.
84      mapping: Instance of _Mapping to merge. Can only be specified if no other
85        arg is specified.
86
87    Returns:
88      mapping: New instance of `_Mapping` which has inputs merged with self.
89
90    Raises:
91      ValueError: if mapping and any other arg is not `None`.
92    """
93    if mapping is None:
94      mapping = _Mapping(x=x, y=y, ildj_map=ildj_map, kwargs=kwargs)
95    elif any(arg is not None for arg in [x, y, ildj_map, kwargs]):
96      raise ValueError("Cannot simultaneously specify mapping and individual "
97                       "arguments.")
98
99    return _Mapping(
100        x=self._merge(self.x, mapping.x),
101        y=self._merge(self.y, mapping.y),
102        ildj_map=self._merge_dicts(self.ildj_map, mapping.ildj_map),
103        kwargs=self._merge(self.kwargs, mapping.kwargs))
104
105  def _merge_dicts(self, old=None, new=None):
106    """Helper to merge two dictionaries."""
107    old = dict() if old is None else old
108    new = dict() if new is None else new
109    for k, v in six.iteritems(new):
110      val = old.get(k, None)
111      if val is not None and val != v:
112        raise ValueError("Found different value for existing key "
113                         "(key:{} old_value:{} new_value:{}".format(
114                             k, old[k], v))
115      old[k] = v
116    return old
117
118  def _merge(self, old, new):
119    """Helper to merge which handles merging one value."""
120    if old is None:
121      return new
122    elif new is not None and old != new:
123      raise ValueError("Incompatible values: %s != %s" % (old, new))
124    return old
125
126  def _deep_tuple(self, x):
127    """Converts lists of lists to tuples of tuples."""
128    return (tuple(map(self._deep_tuple, x))
129            if isinstance(x, (list, tuple)) else x)
130
131
132@six.add_metaclass(abc.ABCMeta)
133class Bijector(object):
134  r"""Interface for transformations of a `Distribution` sample.
135
136  Bijectors can be used to represent any differentiable and injective
137  (one to one) function defined on an open subset of `R^n`.  Some non-injective
138  transformations are also supported (see "Non Injective Transforms" below).
139
140  #### Mathematical Details
141
142  A `Bijector` implements a [smooth covering map](
143  https://en.wikipedia.org/wiki/Local_diffeomorphism), i.e., a local
144  diffeomorphism such that every point in the target has a neighborhood evenly
145  covered by a map ([see also](
146  https://en.wikipedia.org/wiki/Covering_space#Covering_of_a_manifold)).
147  A `Bijector` is used by `TransformedDistribution` but can be generally used
148  for transforming a `Distribution` generated `Tensor`. A `Bijector` is
149  characterized by three operations:
150
151  1. Forward
152
153     Useful for turning one random outcome into another random outcome from a
154     different distribution.
155
156  2. Inverse
157
158     Useful for "reversing" a transformation to compute one probability in
159     terms of another.
160
161  3. `log_det_jacobian(x)`
162
163     "The log of the absolute value of the determinant of the matrix of all
164     first-order partial derivatives of the inverse function."
165
166     Useful for inverting a transformation to compute one probability in terms
167     of another. Geometrically, the Jacobian determinant is the volume of the
168     transformation and is used to scale the probability.
169
170     We take the absolute value of the determinant before log to avoid NaN
171     values.  Geometrically, a negative determinant corresponds to an
172     orientation-reversing transformation.  It is ok for us to discard the sign
173     of the determinant because we only integrate everywhere-nonnegative
174     functions (probability densities) and the correct orientation is always the
175     one that produces a nonnegative integrand.
176
177  By convention, transformations of random variables are named in terms of the
178  forward transformation. The forward transformation creates samples, the
179  inverse is useful for computing probabilities.
180
181  #### Example Uses
182
183  - Basic properties:
184
185  ```python
186  x = ...  # A tensor.
187  # Evaluate forward transformation.
188  fwd_x = my_bijector.forward(x)
189  x == my_bijector.inverse(fwd_x)
190  x != my_bijector.forward(fwd_x)  # Not equal because x != g(g(x)).
191  ```
192
193  - Computing a log-likelihood:
194
195  ```python
196  def transformed_log_prob(bijector, log_prob, x):
197    return (bijector.inverse_log_det_jacobian(x, event_ndims=0) +
198            log_prob(bijector.inverse(x)))
199  ```
200
201  - Transforming a random outcome:
202
203  ```python
204  def transformed_sample(bijector, x):
205    return bijector.forward(x)
206  ```
207
208  #### Example Bijectors
209
210  - "Exponential"
211
212    ```none
213    Y = g(X) = exp(X)
214    X ~ Normal(0, 1)  # Univariate.
215    ```
216
217    Implies:
218
219    ```none
220      g^{-1}(Y) = log(Y)
221      |Jacobian(g^{-1})(y)| = 1 / y
222      Y ~ LogNormal(0, 1), i.e.,
223      prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
224                = (1 / y) Normal(log(y); 0, 1)
225    ```
226
227    Here is an example of how one might implement the `Exp` bijector:
228
229    ```python
230      class Exp(Bijector):
231
232        def __init__(self, validate_args=False, name="exp"):
233          super(Exp, self).__init__(
234              validate_args=validate_args,
235              forward_min_event_ndims=0,
236              name=name)
237
238        def _forward(self, x):
239          return math_ops.exp(x)
240
241        def _inverse(self, y):
242          return math_ops.log(y)
243
244        def _inverse_log_det_jacobian(self, y):
245          return -self._forward_log_det_jacobian(self._inverse(y))
246
247        def _forward_log_det_jacobian(self, x):
248          # Notice that we needn't do any reducing, even when`event_ndims > 0`.
249          # The base Bijector class will handle reducing for us; it knows how
250          # to do so because we called `super` `__init__` with
251          # `forward_min_event_ndims = 0`.
252          return x
253      ```
254
255  - "Affine"
256
257    ```none
258    Y = g(X) = sqrtSigma * X + mu
259    X ~ MultivariateNormal(0, I_d)
260    ```
261
262    Implies:
263
264    ```none
265      g^{-1}(Y) = inv(sqrtSigma) * (Y - mu)
266      |Jacobian(g^{-1})(y)| = det(inv(sqrtSigma))
267      Y ~ MultivariateNormal(mu, sqrtSigma) , i.e.,
268      prob(Y=y) = |Jacobian(g^{-1})(y)| * prob(X=g^{-1}(y))
269                = det(sqrtSigma)^(-d) *
270                  MultivariateNormal(inv(sqrtSigma) * (y - mu); 0, I_d)
271      ```
272
273  #### Min_event_ndims and Naming
274
275  Bijectors are named for the dimensionality of data they act on (i.e. without
276  broadcasting). We can think of bijectors having an intrinsic `min_event_ndims`
277  , which is the minimum number of dimensions for the bijector act on. For
278  instance, a Cholesky decomposition requires a matrix, and hence
279  `min_event_ndims=2`.
280
281  Some examples:
282
283  `AffineScalar:  min_event_ndims=0`
284  `Affine:  min_event_ndims=1`
285  `Cholesky:  min_event_ndims=2`
286  `Exp:  min_event_ndims=0`
287  `Sigmoid:  min_event_ndims=0`
288  `SoftmaxCentered:  min_event_ndims=1`
289
290  Note the difference between `Affine` and `AffineScalar`. `AffineScalar`
291  operates on scalar events, whereas `Affine` operates on vector-valued events.
292
293  More generally, there is a `forward_min_event_ndims` and an
294  `inverse_min_event_ndims`. In most cases, these will be the same.
295  However, for some shape changing bijectors, these will be different
296  (e.g. a bijector which pads an extra dimension at the end, might have
297  `forward_min_event_ndims=0` and `inverse_min_event_ndims=1`.
298
299
300  #### Jacobian Determinant
301
302  The Jacobian determinant is a reduction over `event_ndims - min_event_ndims`
303  (`forward_min_event_ndims` for `forward_log_det_jacobian` and
304  `inverse_min_event_ndims` for `inverse_log_det_jacobian`).
305  To see this, consider the `Exp` `Bijector` applied to a `Tensor` which has
306  sample, batch, and event (S, B, E) shape semantics. Suppose the `Tensor`'s
307  partitioned-shape is `(S=[4], B=[2], E=[3, 3])`. The shape of the `Tensor`
308  returned by `forward` and `inverse` is unchanged, i.e., `[4, 2, 3, 3]`.
309  However the shape returned by `inverse_log_det_jacobian` is `[4, 2]` because
310  the Jacobian determinant is a reduction over the event dimensions.
311
312  Another example is the `Affine` `Bijector`. Because `min_event_ndims = 1`, the
313  Jacobian determinant reduction is over `event_ndims - 1`.
314
315  It is sometimes useful to implement the inverse Jacobian determinant as the
316  negative forward Jacobian determinant. For example,
317
318  ```python
319  def _inverse_log_det_jacobian(self, y):
320     return -self._forward_log_det_jac(self._inverse(y))  # Note negation.
321  ```
322
323  The correctness of this approach can be seen from the following claim.
324
325  - Claim:
326
327      Assume `Y = g(X)` is a bijection whose derivative exists and is nonzero
328      for its domain, i.e., `dY/dX = d/dX g(X) != 0`. Then:
329
330      ```none
331      (log o det o jacobian o g^{-1})(Y) = -(log o det o jacobian o g)(X)
332      ```
333
334  - Proof:
335
336      From the bijective, nonzero differentiability of `g`, the
337      [inverse function theorem](
338          https://en.wikipedia.org/wiki/Inverse_function_theorem)
339      implies `g^{-1}` is differentiable in the image of `g`.
340      Applying the chain rule to `y = g(x) = g(g^{-1}(y))` yields
341      `I = g'(g^{-1}(y))*g^{-1}'(y)`.
342      The same theorem also implies `g^{-1}'` is non-singular therefore:
343      `inv[ g'(g^{-1}(y)) ] = g^{-1}'(y)`.
344      The claim follows from [properties of determinant](
345  https://en.wikipedia.org/wiki/Determinant#Multiplicativity_and_matrix_groups).
346
347  Generally its preferable to directly implement the inverse Jacobian
348  determinant.  This should have superior numerical stability and will often
349  share subgraphs with the `_inverse` implementation.
350
351  #### Is_constant_jacobian
352
353  Certain bijectors will have constant jacobian matrices. For instance, the
354  `Affine` bijector encodes multiplication by a matrix plus a shift, with
355  jacobian matrix, the same aforementioned matrix.
356
357  `is_constant_jacobian` encodes the fact that the jacobian matrix is constant.
358  The semantics of this argument are the following:
359
360    * Repeated calls to "log_det_jacobian" functions with the same
361      `event_ndims` (but not necessarily same input), will return the first
362      computed jacobian (because the matrix is constant, and hence is input
363      independent).
364    * `log_det_jacobian` implementations are merely broadcastable to the true
365      `log_det_jacobian` (because, again, the jacobian matrix is input
366      independent). Specifically, `log_det_jacobian` is implemented as the
367      log jacobian determinant for a single input.
368
369      ```python
370      class Identity(Bijector):
371
372        def __init__(self, validate_args=False, name="identity"):
373          super(Identity, self).__init__(
374              is_constant_jacobian=True,
375              validate_args=validate_args,
376              forward_min_event_ndims=0,
377              name=name)
378
379        def _forward(self, x):
380          return x
381
382        def _inverse(self, y):
383          return y
384
385        def _inverse_log_det_jacobian(self, y):
386          return -self._forward_log_det_jacobian(self._inverse(y))
387
388        def _forward_log_det_jacobian(self, x):
389          # The full log jacobian determinant would be array_ops.zero_like(x).
390          # However, we circumvent materializing that, since the jacobian
391          # calculation is input independent, and we specify it for one input.
392          return constant_op.constant(0., x.dtype.base_dtype)
393
394      ```
395
396  #### Subclass Requirements
397
398  - Subclasses typically implement:
399
400      - `_forward`,
401      - `_inverse`,
402      - `_inverse_log_det_jacobian`,
403      - `_forward_log_det_jacobian` (optional).
404
405    The `_forward_log_det_jacobian` is called when the bijector is inverted via
406    the `Invert` bijector. If undefined, a slightly less efficiently
407    calculation, `-1 * _inverse_log_det_jacobian`, is used.
408
409    If the bijector changes the shape of the input, you must also implement:
410
411      - _forward_event_shape_tensor,
412      - _forward_event_shape (optional),
413      - _inverse_event_shape_tensor,
414      - _inverse_event_shape (optional).
415
416    By default the event-shape is assumed unchanged from input.
417
418  - If the `Bijector`'s use is limited to `TransformedDistribution` (or friends
419    like `QuantizedDistribution`) then depending on your use, you may not need
420    to implement all of `_forward` and `_inverse` functions.
421
422    Examples:
423
424      1. Sampling (e.g., `sample`) only requires `_forward`.
425      2. Probability functions (e.g., `prob`, `cdf`, `survival`) only require
426         `_inverse` (and related).
427      3. Only calling probability functions on the output of `sample` means
428        `_inverse` can be implemented as a cache lookup.
429
430    See "Example Uses" [above] which shows how these functions are used to
431    transform a distribution. (Note: `_forward` could theoretically be
432    implemented as a cache lookup but this would require controlling the
433    underlying sample generation mechanism.)
434
435  #### Non Injective Transforms
436
437  **WARNING** Handing of non-injective transforms is subject to change.
438
439  Non injective maps `g` are supported, provided their domain `D` can be
440  partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
441  ignoring sets of measure zero, the restriction of `g` to each subset is a
442  differentiable bijection onto `g(D)`.  In particular, this imples that for
443  `y in g(D)`, the set inverse, i.e. `g^{-1}(y) = {x in D : g(x) = y}`, always
444  contains exactly `k` distinct points.
445
446  The property, `_is_injective` is set to `False` to indicate that the bijector
447  is not injective, yet satisfies the above condition.
448
449  The usual bijector API is modified in the case `_is_injective is False` (see
450  method docstrings for specifics).  Here we show by example the `AbsoluteValue`
451  bijector.  In this case, the domain `D = (-inf, inf)`, can be partitioned
452  into `D1 = (-inf, 0)`, `D2 = {0}`, and `D3 = (0, inf)`.  Let `gi` be the
453  restriction of `g` to `Di`, then both `g1` and `g3` are bijections onto
454  `(0, inf)`, with `g1^{-1}(y) = -y`, and `g3^{-1}(y) = y`.  We will use
455  `g1` and `g3` to define bijector methods over `D1` and `D3`.  `D2 = {0}` is
456  an oddball in that `g2` is one to one, and the derivative is not well defined.
457  Fortunately, when considering transformations of probability densities
458  (e.g. in `TransformedDistribution`), sets of measure zero have no effect in
459  theory, and only a small effect in 32 or 64 bit precision.  For that reason,
460  we define `inverse(0)` and `inverse_log_det_jacobian(0)` both as `[0, 0]`,
461  which is convenient and results in a left-semicontinuous pdf.
462
463
464  ```python
465  abs = tfp.distributions.bijectors.AbsoluteValue()
466
467  abs.forward(-1.)
468  ==> 1.
469
470  abs.forward(1.)
471  ==> 1.
472
473  abs.inverse(1.)
474  ==> (-1., 1.)
475
476  # The |dX/dY| is constant, == 1.  So Log|dX/dY| == 0.
477  abs.inverse_log_det_jacobian(1., event_ndims=0)
478  ==> (0., 0.)
479
480  # Special case handling of 0.
481  abs.inverse(0.)
482  ==> (0., 0.)
483
484  abs.inverse_log_det_jacobian(0., event_ndims=0)
485  ==> (0., 0.)
486  ```
487
488  """
489
490  @abc.abstractmethod
491  def __init__(self,
492               graph_parents=None,
493               is_constant_jacobian=False,
494               validate_args=False,
495               dtype=None,
496               forward_min_event_ndims=None,
497               inverse_min_event_ndims=None,
498               name=None):
499    """Constructs Bijector.
500
501    A `Bijector` transforms random variables into new random variables.
502
503    Examples:
504
505    ```python
506    # Create the Y = g(X) = X transform.
507    identity = Identity()
508
509    # Create the Y = g(X) = exp(X) transform.
510    exp = Exp()
511    ```
512
513    See `Bijector` subclass docstring for more details and specific examples.
514
515    Args:
516      graph_parents: Python list of graph prerequisites of this `Bijector`.
517      is_constant_jacobian: Python `bool` indicating that the Jacobian matrix is
518        not a function of the input.
519      validate_args: Python `bool`, default `False`. Whether to validate input
520        with asserts. If `validate_args` is `False`, and the inputs are invalid,
521        correct behavior is not guaranteed.
522      dtype: `tf.dtype` supported by this `Bijector`. `None` means dtype is not
523        enforced.
524      forward_min_event_ndims: Python `integer` indicating the minimum number of
525        dimensions `forward` operates on.
526      inverse_min_event_ndims: Python `integer` indicating the minimum number of
527        dimensions `inverse` operates on. Will be set to
528        `forward_min_event_ndims` by default, if no value is provided.
529      name: The name to give Ops created by the initializer.
530
531    Raises:
532      ValueError:  If neither `forward_min_event_ndims` and
533        `inverse_min_event_ndims` are specified, or if either of them is
534        negative.
535      ValueError:  If a member of `graph_parents` is not a `Tensor`.
536    """
537    self._graph_parents = graph_parents or []
538
539    if forward_min_event_ndims is None and inverse_min_event_ndims is None:
540      raise ValueError("Must specify at least one of `forward_min_event_ndims` "
541                       "and `inverse_min_event_ndims`.")
542    elif inverse_min_event_ndims is None:
543      inverse_min_event_ndims = forward_min_event_ndims
544    elif forward_min_event_ndims is None:
545      forward_min_event_ndims = inverse_min_event_ndims
546
547    if not isinstance(forward_min_event_ndims, int):
548      raise TypeError("Expected forward_min_event_ndims to be of "
549                      "type int, got {}".format(
550                          type(forward_min_event_ndims).__name__))
551
552    if not isinstance(inverse_min_event_ndims, int):
553      raise TypeError("Expected inverse_min_event_ndims to be of "
554                      "type int, got {}".format(
555                          type(inverse_min_event_ndims).__name__))
556
557    if forward_min_event_ndims < 0:
558      raise ValueError("forward_min_event_ndims must be a non-negative "
559                       "integer.")
560    if inverse_min_event_ndims < 0:
561      raise ValueError("inverse_min_event_ndims must be a non-negative "
562                       "integer.")
563
564    self._forward_min_event_ndims = forward_min_event_ndims
565    self._inverse_min_event_ndims = inverse_min_event_ndims
566    self._is_constant_jacobian = is_constant_jacobian
567    self._constant_ildj_map = {}
568    self._validate_args = validate_args
569    self._dtype = dtype
570    self._from_y = {}
571    self._from_x = {}
572    if name:
573      self._name = name
574    else:
575      # We want the default convention to be snake_case rather than CamelCase
576      # since `Chain` uses bijector.name as the kwargs dictionary key.
577      def camel_to_snake(name):
578        s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
579        return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
580      self._name = camel_to_snake(type(self).__name__.lstrip("_"))
581
582    for i, t in enumerate(self._graph_parents):
583      if t is None or not tensor_util.is_tensor(t):
584        raise ValueError("Graph parent item %d is not a Tensor; %s." % (i, t))
585
586  @property
587  def graph_parents(self):
588    """Returns this `Bijector`'s graph_parents as a Python list."""
589    return self._graph_parents
590
591  @property
592  def forward_min_event_ndims(self):
593    """Returns the minimal number of dimensions bijector.forward operates on."""
594    return self._forward_min_event_ndims
595
596  @property
597  def inverse_min_event_ndims(self):
598    """Returns the minimal number of dimensions bijector.inverse operates on."""
599    return self._inverse_min_event_ndims
600
601  @property
602  def is_constant_jacobian(self):
603    """Returns true iff the Jacobian matrix is not a function of x.
604
605    Note: Jacobian matrix is either constant for both forward and inverse or
606    neither.
607
608    Returns:
609      is_constant_jacobian: Python `bool`.
610    """
611    return self._is_constant_jacobian
612
613  @property
614  def _is_injective(self):
615    """Returns true iff the forward map `g` is injective (one-to-one function).
616
617    **WARNING** This hidden property and its behavior are subject to change.
618
619    Note:  Non-injective maps `g` are supported, provided their domain `D` can
620    be partitioned into `k` disjoint subsets, `Union{D1, ..., Dk}`, such that,
621    ignoring sets of measure zero, the restriction of `g` to each subset is a
622    differentiable bijection onto `g(D)`.
623
624    Returns:
625      is_injective: Python `bool`.
626    """
627    return True
628
629  @property
630  def validate_args(self):
631    """Returns True if Tensor arguments will be validated."""
632    return self._validate_args
633
634  @property
635  def dtype(self):
636    """dtype of `Tensor`s transformable by this distribution."""
637    return self._dtype
638
639  @property
640  def name(self):
641    """Returns the string name of this `Bijector`."""
642    return self._name
643
644  def _forward_event_shape_tensor(self, input_shape):
645    """Subclass implementation for `forward_event_shape_tensor` function."""
646    # By default, we assume event_shape is unchanged.
647    return input_shape
648
649  def forward_event_shape_tensor(self,
650                                 input_shape,
651                                 name="forward_event_shape_tensor"):
652    """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
653
654    Args:
655      input_shape: `Tensor`, `int32` vector indicating event-portion shape
656        passed into `forward` function.
657      name: name to give to the op
658
659    Returns:
660      forward_event_shape_tensor: `Tensor`, `int32` vector indicating
661        event-portion shape after applying `forward`.
662    """
663    with self._name_scope(name, [input_shape]):
664      input_shape = ops.convert_to_tensor(input_shape, dtype=dtypes.int32,
665                                          name="input_shape")
666      return self._forward_event_shape_tensor(input_shape)
667
668  def _forward_event_shape(self, input_shape):
669    """Subclass implementation for `forward_event_shape` public function."""
670    # By default, we assume event_shape is unchanged.
671    return input_shape
672
673  def forward_event_shape(self, input_shape):
674    """Shape of a single sample from a single batch as a `TensorShape`.
675
676    Same meaning as `forward_event_shape_tensor`. May be only partially defined.
677
678    Args:
679      input_shape: `TensorShape` indicating event-portion shape passed into
680        `forward` function.
681
682    Returns:
683      forward_event_shape_tensor: `TensorShape` indicating event-portion shape
684        after applying `forward`. Possibly unknown.
685    """
686    return self._forward_event_shape(tensor_shape.TensorShape(input_shape))
687
688  def _inverse_event_shape_tensor(self, output_shape):
689    """Subclass implementation for `inverse_event_shape_tensor` function."""
690    # By default, we assume event_shape is unchanged.
691    return output_shape
692
693  def inverse_event_shape_tensor(self,
694                                 output_shape,
695                                 name="inverse_event_shape_tensor"):
696    """Shape of a single sample from a single batch as an `int32` 1D `Tensor`.
697
698    Args:
699      output_shape: `Tensor`, `int32` vector indicating event-portion shape
700        passed into `inverse` function.
701      name: name to give to the op
702
703    Returns:
704      inverse_event_shape_tensor: `Tensor`, `int32` vector indicating
705        event-portion shape after applying `inverse`.
706    """
707    with self._name_scope(name, [output_shape]):
708      output_shape = ops.convert_to_tensor(output_shape, dtype=dtypes.int32,
709                                           name="output_shape")
710      return self._inverse_event_shape_tensor(output_shape)
711
712  def _inverse_event_shape(self, output_shape):
713    """Subclass implementation for `inverse_event_shape` public function."""
714    # By default, we assume event_shape is unchanged.
715    return tensor_shape.TensorShape(output_shape)
716
717  def inverse_event_shape(self, output_shape):
718    """Shape of a single sample from a single batch as a `TensorShape`.
719
720    Same meaning as `inverse_event_shape_tensor`. May be only partially defined.
721
722    Args:
723      output_shape: `TensorShape` indicating event-portion shape passed into
724        `inverse` function.
725
726    Returns:
727      inverse_event_shape_tensor: `TensorShape` indicating event-portion shape
728        after applying `inverse`. Possibly unknown.
729    """
730    return self._inverse_event_shape(output_shape)
731
732  def _forward(self, x):
733    """Subclass implementation for `forward` public function."""
734    raise NotImplementedError("forward not implemented.")
735
736  def _call_forward(self, x, name, **kwargs):
737    with self._name_scope(name, [x]):
738      x = ops.convert_to_tensor(x, name="x")
739      self._maybe_assert_dtype(x)
740      if not self._is_injective:  # No caching for non-injective
741        return self._forward(x, **kwargs)
742      mapping = self._lookup(x=x, kwargs=kwargs)
743      if mapping.y is not None:
744        return mapping.y
745      mapping = mapping.merge(y=self._forward(x, **kwargs))
746      self._cache(mapping)
747      return mapping.y
748
749  def forward(self, x, name="forward"):
750    """Returns the forward `Bijector` evaluation, i.e., X = g(Y).
751
752    Args:
753      x: `Tensor`. The input to the "forward" evaluation.
754      name: The name to give this op.
755
756    Returns:
757      `Tensor`.
758
759    Raises:
760      TypeError: if `self.dtype` is specified and `x.dtype` is not
761        `self.dtype`.
762      NotImplementedError: if `_forward` is not implemented.
763    """
764    return self._call_forward(x, name)
765
766  def _inverse(self, y):
767    """Subclass implementation for `inverse` public function."""
768    raise NotImplementedError("inverse not implemented")
769
770  def _call_inverse(self, y, name, **kwargs):
771    with self._name_scope(name, [y]):
772      y = ops.convert_to_tensor(y, name="y")
773      self._maybe_assert_dtype(y)
774      if not self._is_injective:  # No caching for non-injective
775        return self._inverse(y, **kwargs)
776      mapping = self._lookup(y=y, kwargs=kwargs)
777      if mapping.x is not None:
778        return mapping.x
779      mapping = mapping.merge(x=self._inverse(y, **kwargs))
780      self._cache(mapping)
781      return mapping.x
782
783  def inverse(self, y, name="inverse"):
784    """Returns the inverse `Bijector` evaluation, i.e., X = g^{-1}(Y).
785
786    Args:
787      y: `Tensor`. The input to the "inverse" evaluation.
788      name: The name to give this op.
789
790    Returns:
791      `Tensor`, if this bijector is injective.
792        If not injective, returns the k-tuple containing the unique
793        `k` points `(x1, ..., xk)` such that `g(xi) = y`.
794
795    Raises:
796      TypeError: if `self.dtype` is specified and `y.dtype` is not
797        `self.dtype`.
798      NotImplementedError: if `_inverse` is not implemented.
799    """
800    return self._call_inverse(y, name)
801
802  def _inverse_log_det_jacobian(self, y):
803    """Subclass implementation of `inverse_log_det_jacobian` public function.
804
805    In particular, this method differs from the public function, in that it
806    does not take `event_ndims`. Thus, this implements the minimal Jacobian
807    determinant calculation (i.e. over `inverse_min_event_ndims`).
808
809    Args:
810      y: `Tensor`. The input to the "inverse_log_det_jacobian" evaluation.
811    Returns:
812      inverse_log_det_jacobian: `Tensor`, if this bijector is injective.
813        If not injective, returns the k-tuple containing jacobians for the
814        unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
815    """
816    raise NotImplementedError("inverse_log_det_jacobian not implemented.")
817
818  def _call_inverse_log_det_jacobian(self, y, event_ndims, name, **kwargs):
819    with self._name_scope(name, [y]):
820      if event_ndims in self._constant_ildj_map:
821        return self._constant_ildj_map[event_ndims]
822      y = ops.convert_to_tensor(y, name="y")
823      self._maybe_assert_dtype(y)
824      with ops.control_dependencies(self._check_valid_event_ndims(
825          min_event_ndims=self.inverse_min_event_ndims,
826          event_ndims=event_ndims)):
827        if not self._is_injective:  # No caching for non-injective
828          try:
829            ildjs = self._inverse_log_det_jacobian(y, **kwargs)
830            return tuple(self._reduce_jacobian_det_over_event(
831                y, ildj, self.inverse_min_event_ndims, event_ndims)
832                         for ildj in ildjs)
833          except NotImplementedError as original_exception:
834            try:
835              x = self._inverse(y, **kwargs)
836              fldjs = self._forward_log_det_jacobian(x, **kwargs)
837              return tuple(self._reduce_jacobian_det_over_event(
838                  x, -fldj, self.forward_min_event_ndims, event_ndims)
839                           for fldj in fldjs)
840            except NotImplementedError:
841              raise original_exception
842
843        mapping = self._lookup(y=y, kwargs=kwargs)
844        if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
845          return mapping.ildj_map[event_ndims]
846        try:
847          x = None  # Not needed; leave cache as is.
848          ildj = self._inverse_log_det_jacobian(y, **kwargs)
849          ildj = self._reduce_jacobian_det_over_event(
850              y, ildj, self.inverse_min_event_ndims, event_ndims)
851        except NotImplementedError as original_exception:
852          try:
853            x = (mapping.x if mapping.x is not None
854                 else self._inverse(y, **kwargs))
855            ildj = -self._forward_log_det_jacobian(x, **kwargs)
856            ildj = self._reduce_jacobian_det_over_event(
857                x, ildj, self.forward_min_event_ndims, event_ndims)
858          except NotImplementedError:
859            raise original_exception
860
861        mapping = mapping.merge(x=x, ildj_map={event_ndims: ildj})
862        self._cache(mapping)
863        if self.is_constant_jacobian:
864          self._constant_ildj_map[event_ndims] = ildj
865        return ildj
866
867  def inverse_log_det_jacobian(
868      self, y, event_ndims, name="inverse_log_det_jacobian"):
869    """Returns the (log o det o Jacobian o inverse)(y).
870
871    Mathematically, returns: `log(det(dX/dY))(Y)`. (Recall that: `X=g^{-1}(Y)`.)
872
873    Note that `forward_log_det_jacobian` is the negative of this function,
874    evaluated at `g^{-1}(y)`.
875
876    Args:
877      y: `Tensor`. The input to the "inverse" Jacobian determinant evaluation.
878      event_ndims: Number of dimensions in the probabilistic events being
879        transformed. Must be greater than or equal to
880        `self.inverse_min_event_ndims`. The result is summed over the final
881        dimensions to produce a scalar Jacobian determinant for each event,
882        i.e. it has shape `y.shape.ndims - event_ndims` dimensions.
883      name: The name to give this op.
884
885    Returns:
886      `Tensor`, if this bijector is injective.
887        If not injective, returns the tuple of local log det
888        Jacobians, `log(det(Dg_i^{-1}(y)))`, where `g_i` is the restriction
889        of `g` to the `ith` partition `Di`.
890
891    Raises:
892      TypeError: if `self.dtype` is specified and `y.dtype` is not
893        `self.dtype`.
894      NotImplementedError: if `_inverse_log_det_jacobian` is not implemented.
895    """
896    return self._call_inverse_log_det_jacobian(y, event_ndims, name)
897
898  def _forward_log_det_jacobian(self, x):
899    """Subclass implementation of `forward_log_det_jacobian` public function.
900
901    In particular, this method differs from the public function, in that it
902    does not take `event_ndims`. Thus, this implements the minimal Jacobian
903    determinant calculation (i.e. over `forward_min_event_ndims`).
904
905    Args:
906      x: `Tensor`. The input to the "forward_log_det_jacobian" evaluation.
907
908    Returns:
909      forward_log_det_jacobian: `Tensor`, if this bijector is injective.
910        If not injective, returns the k-tuple containing jacobians for the
911        unique `k` points `(x1, ..., xk)` such that `g(xi) = y`.
912    """
913
914    raise NotImplementedError(
915        "forward_log_det_jacobian not implemented.")
916
917  def _call_forward_log_det_jacobian(self, x, event_ndims, name, **kwargs):
918    if not self._is_injective:
919      raise NotImplementedError(
920          "forward_log_det_jacobian cannot be implemented for non-injective "
921          "transforms.")
922    with self._name_scope(name, [x]):
923      with ops.control_dependencies(self._check_valid_event_ndims(
924          min_event_ndims=self.forward_min_event_ndims,
925          event_ndims=event_ndims)):
926        if event_ndims in self._constant_ildj_map:
927          # Need "-1. *" to avoid invalid-unary-operand-type linter warning.
928          return -1. * self._constant_ildj_map[event_ndims]
929        x = ops.convert_to_tensor(x, name="x")
930        self._maybe_assert_dtype(x)
931        if not self._is_injective:  # No caching for non-injective
932          try:
933            fldjs = self._forward_log_det_jacobian(x, **kwargs)  # No caching.
934            return tuple(self._reduce_jacobian_det_over_event(
935                x, fldj, self.forward_min_event_ndims, event_ndims)
936                         for fldj in fldjs)
937          except NotImplementedError as original_exception:
938            try:
939              y = self._forward(x, **kwargs)
940              ildjs = self._inverse_log_det_jacobian(y, **kwargs)
941              return tuple(self._reduce_jacobian_det_over_event(
942                  y, -ildj, self.inverse_min_event_ndims, event_ndims)
943                           for ildj in ildjs)
944            except NotImplementedError:
945              raise original_exception
946        mapping = self._lookup(x=x, kwargs=kwargs)
947        if mapping.ildj_map is not None and event_ndims in mapping.ildj_map:
948          return -mapping.ildj_map[event_ndims]
949        try:
950          y = None  # Not needed; leave cache as is.
951          ildj = -self._forward_log_det_jacobian(x, **kwargs)
952          ildj = self._reduce_jacobian_det_over_event(
953              x, ildj, self.forward_min_event_ndims, event_ndims)
954        except NotImplementedError as original_exception:
955          try:
956            y = (mapping.y if mapping.y is not None
957                 else self._forward(x, **kwargs))
958            ildj = self._inverse_log_det_jacobian(y, **kwargs)
959            ildj = self._reduce_jacobian_det_over_event(
960                y, ildj, self.inverse_min_event_ndims, event_ndims)
961          except NotImplementedError:
962            raise original_exception
963        mapping = mapping.merge(y=y, ildj_map={event_ndims: ildj})
964        self._cache(mapping)
965        if self.is_constant_jacobian:
966          self._constant_ildj_map[event_ndims] = ildj
967        return -ildj
968
969  def forward_log_det_jacobian(
970      self, x, event_ndims, name="forward_log_det_jacobian"):
971    """Returns both the forward_log_det_jacobian.
972
973    Args:
974      x: `Tensor`. The input to the "forward" Jacobian determinant evaluation.
975      event_ndims: Number of dimensions in the probabilistic events being
976        transformed. Must be greater than or equal to
977        `self.forward_min_event_ndims`. The result is summed over the final
978        dimensions to produce a scalar Jacobian determinant for each event,
979        i.e. it has shape `x.shape.ndims - event_ndims` dimensions.
980      name: The name to give this op.
981
982    Returns:
983      `Tensor`, if this bijector is injective.
984        If not injective this is not implemented.
985
986    Raises:
987      TypeError: if `self.dtype` is specified and `y.dtype` is not
988        `self.dtype`.
989      NotImplementedError: if neither `_forward_log_det_jacobian`
990        nor {`_inverse`, `_inverse_log_det_jacobian`} are implemented, or
991        this is a non-injective bijector.
992    """
993    return self._call_forward_log_det_jacobian(x, event_ndims, name)
994
995  @contextlib.contextmanager
996  def _name_scope(self, name=None, values=None):
997    """Helper function to standardize op scope."""
998    with ops.name_scope(self.name):
999      with ops.name_scope(
1000          name, values=(values or []) + self.graph_parents) as scope:
1001        yield scope
1002
1003  def _maybe_assert_dtype(self, x):
1004    """Helper to check dtype when self.dtype is known."""
1005    if self.dtype is not None and self.dtype.base_dtype != x.dtype.base_dtype:
1006      raise TypeError("Input had dtype %s but expected %s." %
1007                      (self.dtype, x.dtype))
1008
1009  def _cache(self, mapping):
1010    """Helper which stores mapping info in forward/inverse dicts."""
1011    # Merging from lookup is an added check that we're not overwriting anything
1012    # which is not None.
1013    mapping = mapping.merge(mapping=self._lookup(
1014        mapping.x, mapping.y, mapping.kwargs))
1015    if mapping.x is None and mapping.y is None:
1016      raise ValueError("Caching expects at least one of (x,y) to be known, "
1017                       "i.e., not None.")
1018    self._from_x[mapping.x_key] = mapping
1019    self._from_y[mapping.y_key] = mapping
1020
1021  def _lookup(self, x=None, y=None, kwargs=None):
1022    """Helper which retrieves mapping info from forward/inverse dicts."""
1023    mapping = _Mapping(x=x, y=y, kwargs=kwargs)
1024    # Since _cache requires both x,y to be set, we only need to do one cache
1025    # lookup since the mapping is always in both or neither.
1026    if mapping.x is not None:
1027      return self._from_x.get(mapping.x_key, mapping)
1028    if mapping.y is not None:
1029      return self._from_y.get(mapping.y_key, mapping)
1030    return mapping
1031
1032  def _reduce_jacobian_det_over_event(
1033      self, y, ildj, min_event_ndims, event_ndims):
1034    """Reduce jacobian over event_ndims - min_event_ndims."""
1035    # In this case, we need to tile the Jacobian over the event and reduce.
1036    y_rank = array_ops.rank(y)
1037    y_shape = array_ops.shape(y)[
1038        y_rank - event_ndims : y_rank - min_event_ndims]
1039
1040    ones = array_ops.ones(y_shape, ildj.dtype)
1041    reduced_ildj = math_ops.reduce_sum(
1042        ones * ildj,
1043        axis=self._get_event_reduce_dims(min_event_ndims, event_ndims))
1044    # The multiplication by ones can change the inferred static shape so we try
1045    # to recover as much as possible.
1046    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
1047    if (event_ndims_ is not None and
1048        y.shape.ndims is not None and
1049        ildj.shape.ndims is not None):
1050      y_shape = y.shape[y.shape.ndims - event_ndims_ :
1051                        y.shape.ndims - min_event_ndims]
1052      broadcast_shape = array_ops.broadcast_static_shape(ildj.shape, y_shape)
1053      reduced_ildj.set_shape(
1054          broadcast_shape[: broadcast_shape.ndims - (
1055              event_ndims_ - min_event_ndims)])
1056
1057    return reduced_ildj
1058
1059  def _get_event_reduce_dims(self, min_event_ndims, event_ndims):
1060    """Compute the reduction dimensions given event_ndims."""
1061    event_ndims_ = self._maybe_get_static_event_ndims(event_ndims)
1062
1063    if event_ndims_ is not None:
1064      return [-index for index in range(1, event_ndims_ - min_event_ndims + 1)]
1065    else:
1066      reduce_ndims = event_ndims - min_event_ndims
1067      return math_ops.range(-reduce_ndims, 0)
1068
1069  def _check_valid_event_ndims(self, min_event_ndims, event_ndims):
1070    """Check whether event_ndims is atleast min_event_ndims."""
1071    event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims")
1072    event_ndims_ = tensor_util.constant_value(event_ndims)
1073    assertions = []
1074
1075    if not event_ndims.dtype.is_integer:
1076      raise ValueError("Expected integer dtype, got dtype {}".format(
1077          event_ndims.dtype))
1078
1079    if event_ndims_ is not None:
1080      if event_ndims.shape.ndims != 0:
1081        raise ValueError("Expected scalar event_ndims, got shape {}".format(
1082            event_ndims.shape))
1083      if min_event_ndims > event_ndims_:
1084        raise ValueError("event_ndims ({}) must be larger than "
1085                         "min_event_ndims ({})".format(
1086                             event_ndims_, min_event_ndims))
1087    elif self.validate_args:
1088      assertions += [
1089          check_ops.assert_greater_equal(event_ndims, min_event_ndims)]
1090
1091    if event_ndims.shape.is_fully_defined():
1092      if event_ndims.shape.ndims != 0:
1093        raise ValueError("Expected scalar shape, got ndims {}".format(
1094            event_ndims.shape.ndims))
1095
1096    elif self.validate_args:
1097      assertions += [
1098          check_ops.assert_rank(event_ndims, 0, message="Expected scalar.")]
1099    return assertions
1100
1101  def _maybe_get_static_event_ndims(self, event_ndims):
1102    """Helper which returns tries to return an integer static value."""
1103    event_ndims_ = distribution_util.maybe_get_static_value(event_ndims)
1104
1105    if isinstance(event_ndims_, (np.generic, np.ndarray)):
1106      if event_ndims_.dtype not in (np.int32, np.int64):
1107        raise ValueError("Expected integer dtype, got dtype {}".format(
1108            event_ndims_.dtype))
1109
1110      if isinstance(event_ndims_, np.ndarray) and len(event_ndims_.shape):
1111        raise ValueError("Expected a scalar integer, got {}".format(
1112            event_ndims_))
1113      event_ndims_ = int(event_ndims_)
1114
1115    return event_ndims_
1116