• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Multivariate Normal distribution classes."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.distributions.python.ops import distribution_util
22from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import nn
25from tensorflow.python.util import deprecation
26
27
28__all__ = [
29    "MultivariateNormalDiag",
30    "MultivariateNormalDiagWithSoftplusScale",
31]
32
33
34class MultivariateNormalDiag(
35    mvn_linop.MultivariateNormalLinearOperator):
36  """The multivariate normal distribution on `R^k`.
37
38  The Multivariate Normal distribution is defined over `R^k` and parameterized
39  by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k`
40  `scale` matrix; `covariance = scale @ scale.T` where `@` denotes
41  matrix-multiplication.
42
43  #### Mathematical Details
44
45  The probability density function (pdf) is,
46
47  ```none
48  pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z,
49  y = inv(scale) @ (x - loc),
50  Z = (2 pi)**(0.5 k) |det(scale)|,
51  ```
52
53  where:
54
55  * `loc` is a vector in `R^k`,
56  * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`,
57  * `Z` denotes the normalization constant, and,
58  * `||y||**2` denotes the squared Euclidean norm of `y`.
59
60  A (non-batch) `scale` matrix is:
61
62  ```none
63  scale = diag(scale_diag + scale_identity_multiplier * ones(k))
64  ```
65
66  where:
67
68  * `scale_diag.shape = [k]`, and,
69  * `scale_identity_multiplier.shape = []`.
70
71  Additional leading dimensions (if any) will index batches.
72
73  If both `scale_diag` and `scale_identity_multiplier` are `None`, then
74  `scale` is the Identity matrix.
75
76  The MultivariateNormal distribution is a member of the [location-scale
77  family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be
78  constructed as,
79
80  ```none
81  X ~ MultivariateNormal(loc=0, scale=1)   # Identity scale, zero shift.
82  Y = scale @ X + loc
83  ```
84
85  #### Examples
86
87  ```python
88  import tensorflow_probability as tfp
89  tfd = tfp.distributions
90
91  # Initialize a single 2-variate Gaussian.
92  mvn = tfd.MultivariateNormalDiag(
93      loc=[1., -1],
94      scale_diag=[1, 2.])
95
96  mvn.mean().eval()
97  # ==> [1., -1]
98
99  mvn.stddev().eval()
100  # ==> [1., 2]
101
102  # Evaluate this on an observation in `R^2`, returning a scalar.
103  mvn.prob([-1., 0]).eval()  # shape: []
104
105  # Initialize a 3-batch, 2-variate scaled-identity Gaussian.
106  mvn = tfd.MultivariateNormalDiag(
107      loc=[1., -1],
108      scale_identity_multiplier=[1, 2., 3])
109
110  mvn.mean().eval()  # shape: [3, 2]
111  # ==> [[1., -1]
112  #      [1, -1],
113  #      [1, -1]]
114
115  mvn.stddev().eval()  # shape: [3, 2]
116  # ==> [[1., 1],
117  #      [2, 2],
118  #      [3, 3]]
119
120  # Evaluate this on an observation in `R^2`, returning a length-3 vector.
121  mvn.prob([-1., 0]).eval()  # shape: [3]
122
123  # Initialize a 2-batch of 3-variate Gaussians.
124  mvn = tfd.MultivariateNormalDiag(
125      loc=[[1., 2, 3],
126           [11, 22, 33]]           # shape: [2, 3]
127      scale_diag=[[1., 2, 3],
128                  [0.5, 1, 1.5]])  # shape: [2, 3]
129
130  # Evaluate this on a two observations, each in `R^3`, returning a length-2
131  # vector.
132  x = [[-1., 0, 1],
133       [-11, 0, 11.]]   # shape: [2, 3].
134  mvn.prob(x).eval()    # shape: [2]
135  ```
136
137  """
138
139  @deprecation.deprecated(
140      "2018-10-01",
141      "The TensorFlow Distributions library has moved to "
142      "TensorFlow Probability "
143      "(https://github.com/tensorflow/probability). You "
144      "should update all references to use `tfp.distributions` "
145      "instead of `tf.contrib.distributions`.",
146      warn_once=True)
147  def __init__(self,
148               loc=None,
149               scale_diag=None,
150               scale_identity_multiplier=None,
151               validate_args=False,
152               allow_nan_stats=True,
153               name="MultivariateNormalDiag"):
154    """Construct Multivariate Normal distribution on `R^k`.
155
156    The `batch_shape` is the broadcast shape between `loc` and `scale`
157    arguments.
158
159    The `event_shape` is given by last dimension of the matrix implied by
160    `scale`. The last dimension of `loc` (if provided) must broadcast with this.
161
162    Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is:
163
164    ```none
165    scale = diag(scale_diag + scale_identity_multiplier * ones(k))
166    ```
167
168    where:
169
170    * `scale_diag.shape = [k]`, and,
171    * `scale_identity_multiplier.shape = []`.
172
173    Additional leading dimensions (if any) will index batches.
174
175    If both `scale_diag` and `scale_identity_multiplier` are `None`, then
176    `scale` is the Identity matrix.
177
178    Args:
179      loc: Floating-point `Tensor`. If this is set to `None`, `loc` is
180        implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where
181        `b >= 0` and `k` is the event size.
182      scale_diag: Non-zero, floating-point `Tensor` representing a diagonal
183        matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`,
184        and characterizes `b`-batches of `k x k` diagonal matrices added to
185        `scale`. When both `scale_identity_multiplier` and `scale_diag` are
186        `None` then `scale` is the `Identity`.
187      scale_identity_multiplier: Non-zero, floating-point `Tensor` representing
188        a scaled-identity-matrix added to `scale`. May have shape
189        `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled
190        `k x k` identity matrices added to `scale`. When both
191        `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is
192        the `Identity`.
193      validate_args: Python `bool`, default `False`. When `True` distribution
194        parameters are checked for validity despite possibly degrading runtime
195        performance. When `False` invalid inputs may silently render incorrect
196        outputs.
197      allow_nan_stats: Python `bool`, default `True`. When `True`,
198        statistics (e.g., mean, mode, variance) use the value "`NaN`" to
199        indicate the result is undefined. When `False`, an exception is raised
200        if one or more of the statistic's batch members are undefined.
201      name: Python `str` name prefixed to Ops created by this class.
202
203    Raises:
204      ValueError: if at most `scale_identity_multiplier` is specified.
205    """
206    parameters = dict(locals())
207    with ops.name_scope(name) as name:
208      with ops.name_scope("init", values=[
209          loc, scale_diag, scale_identity_multiplier]):
210        # No need to validate_args while making diag_scale.  The returned
211        # LinearOperatorDiag has an assert_non_singular method that is called by
212        # the Bijector.
213        scale = distribution_util.make_diag_scale(
214            loc=loc,
215            scale_diag=scale_diag,
216            scale_identity_multiplier=scale_identity_multiplier,
217            validate_args=False,
218            assert_positive=False)
219    super(MultivariateNormalDiag, self).__init__(
220        loc=loc,
221        scale=scale,
222        validate_args=validate_args,
223        allow_nan_stats=allow_nan_stats,
224        name=name)
225    self._parameters = parameters
226
227
228class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag):
229  """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`."""
230
231  @deprecation.deprecated(
232      "2018-10-01",
233      "The TensorFlow Distributions library has moved to "
234      "TensorFlow Probability "
235      "(https://github.com/tensorflow/probability). You "
236      "should update all references to use `tfp.distributions` "
237      "instead of `tf.contrib.distributions`.",
238      warn_once=True)
239  def __init__(self,
240               loc,
241               scale_diag,
242               validate_args=False,
243               allow_nan_stats=True,
244               name="MultivariateNormalDiagWithSoftplusScale"):
245    parameters = dict(locals())
246    with ops.name_scope(name, values=[scale_diag]) as name:
247      super(MultivariateNormalDiagWithSoftplusScale, self).__init__(
248          loc=loc,
249          scale_diag=nn.softplus(scale_diag),
250          validate_args=validate_args,
251          allow_nan_stats=allow_nan_stats,
252          name=name)
253    self._parameters = parameters
254