1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Multivariate Normal distribution classes.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.distributions.python.ops import distribution_util 22from tensorflow.contrib.distributions.python.ops import mvn_linear_operator as mvn_linop 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import nn 25from tensorflow.python.util import deprecation 26 27 28__all__ = [ 29 "MultivariateNormalDiag", 30 "MultivariateNormalDiagWithSoftplusScale", 31] 32 33 34class MultivariateNormalDiag( 35 mvn_linop.MultivariateNormalLinearOperator): 36 """The multivariate normal distribution on `R^k`. 37 38 The Multivariate Normal distribution is defined over `R^k` and parameterized 39 by a (batch of) length-`k` `loc` vector (aka "mu") and a (batch of) `k x k` 40 `scale` matrix; `covariance = scale @ scale.T` where `@` denotes 41 matrix-multiplication. 42 43 #### Mathematical Details 44 45 The probability density function (pdf) is, 46 47 ```none 48 pdf(x; loc, scale) = exp(-0.5 ||y||**2) / Z, 49 y = inv(scale) @ (x - loc), 50 Z = (2 pi)**(0.5 k) |det(scale)|, 51 ``` 52 53 where: 54 55 * `loc` is a vector in `R^k`, 56 * `scale` is a linear operator in `R^{k x k}`, `cov = scale @ scale.T`, 57 * `Z` denotes the normalization constant, and, 58 * `||y||**2` denotes the squared Euclidean norm of `y`. 59 60 A (non-batch) `scale` matrix is: 61 62 ```none 63 scale = diag(scale_diag + scale_identity_multiplier * ones(k)) 64 ``` 65 66 where: 67 68 * `scale_diag.shape = [k]`, and, 69 * `scale_identity_multiplier.shape = []`. 70 71 Additional leading dimensions (if any) will index batches. 72 73 If both `scale_diag` and `scale_identity_multiplier` are `None`, then 74 `scale` is the Identity matrix. 75 76 The MultivariateNormal distribution is a member of the [location-scale 77 family](https://en.wikipedia.org/wiki/Location-scale_family), i.e., it can be 78 constructed as, 79 80 ```none 81 X ~ MultivariateNormal(loc=0, scale=1) # Identity scale, zero shift. 82 Y = scale @ X + loc 83 ``` 84 85 #### Examples 86 87 ```python 88 import tensorflow_probability as tfp 89 tfd = tfp.distributions 90 91 # Initialize a single 2-variate Gaussian. 92 mvn = tfd.MultivariateNormalDiag( 93 loc=[1., -1], 94 scale_diag=[1, 2.]) 95 96 mvn.mean().eval() 97 # ==> [1., -1] 98 99 mvn.stddev().eval() 100 # ==> [1., 2] 101 102 # Evaluate this on an observation in `R^2`, returning a scalar. 103 mvn.prob([-1., 0]).eval() # shape: [] 104 105 # Initialize a 3-batch, 2-variate scaled-identity Gaussian. 106 mvn = tfd.MultivariateNormalDiag( 107 loc=[1., -1], 108 scale_identity_multiplier=[1, 2., 3]) 109 110 mvn.mean().eval() # shape: [3, 2] 111 # ==> [[1., -1] 112 # [1, -1], 113 # [1, -1]] 114 115 mvn.stddev().eval() # shape: [3, 2] 116 # ==> [[1., 1], 117 # [2, 2], 118 # [3, 3]] 119 120 # Evaluate this on an observation in `R^2`, returning a length-3 vector. 121 mvn.prob([-1., 0]).eval() # shape: [3] 122 123 # Initialize a 2-batch of 3-variate Gaussians. 124 mvn = tfd.MultivariateNormalDiag( 125 loc=[[1., 2, 3], 126 [11, 22, 33]] # shape: [2, 3] 127 scale_diag=[[1., 2, 3], 128 [0.5, 1, 1.5]]) # shape: [2, 3] 129 130 # Evaluate this on a two observations, each in `R^3`, returning a length-2 131 # vector. 132 x = [[-1., 0, 1], 133 [-11, 0, 11.]] # shape: [2, 3]. 134 mvn.prob(x).eval() # shape: [2] 135 ``` 136 137 """ 138 139 @deprecation.deprecated( 140 "2018-10-01", 141 "The TensorFlow Distributions library has moved to " 142 "TensorFlow Probability " 143 "(https://github.com/tensorflow/probability). You " 144 "should update all references to use `tfp.distributions` " 145 "instead of `tf.contrib.distributions`.", 146 warn_once=True) 147 def __init__(self, 148 loc=None, 149 scale_diag=None, 150 scale_identity_multiplier=None, 151 validate_args=False, 152 allow_nan_stats=True, 153 name="MultivariateNormalDiag"): 154 """Construct Multivariate Normal distribution on `R^k`. 155 156 The `batch_shape` is the broadcast shape between `loc` and `scale` 157 arguments. 158 159 The `event_shape` is given by last dimension of the matrix implied by 160 `scale`. The last dimension of `loc` (if provided) must broadcast with this. 161 162 Recall that `covariance = scale @ scale.T`. A (non-batch) `scale` matrix is: 163 164 ```none 165 scale = diag(scale_diag + scale_identity_multiplier * ones(k)) 166 ``` 167 168 where: 169 170 * `scale_diag.shape = [k]`, and, 171 * `scale_identity_multiplier.shape = []`. 172 173 Additional leading dimensions (if any) will index batches. 174 175 If both `scale_diag` and `scale_identity_multiplier` are `None`, then 176 `scale` is the Identity matrix. 177 178 Args: 179 loc: Floating-point `Tensor`. If this is set to `None`, `loc` is 180 implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where 181 `b >= 0` and `k` is the event size. 182 scale_diag: Non-zero, floating-point `Tensor` representing a diagonal 183 matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, 184 and characterizes `b`-batches of `k x k` diagonal matrices added to 185 `scale`. When both `scale_identity_multiplier` and `scale_diag` are 186 `None` then `scale` is the `Identity`. 187 scale_identity_multiplier: Non-zero, floating-point `Tensor` representing 188 a scaled-identity-matrix added to `scale`. May have shape 189 `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scaled 190 `k x k` identity matrices added to `scale`. When both 191 `scale_identity_multiplier` and `scale_diag` are `None` then `scale` is 192 the `Identity`. 193 validate_args: Python `bool`, default `False`. When `True` distribution 194 parameters are checked for validity despite possibly degrading runtime 195 performance. When `False` invalid inputs may silently render incorrect 196 outputs. 197 allow_nan_stats: Python `bool`, default `True`. When `True`, 198 statistics (e.g., mean, mode, variance) use the value "`NaN`" to 199 indicate the result is undefined. When `False`, an exception is raised 200 if one or more of the statistic's batch members are undefined. 201 name: Python `str` name prefixed to Ops created by this class. 202 203 Raises: 204 ValueError: if at most `scale_identity_multiplier` is specified. 205 """ 206 parameters = dict(locals()) 207 with ops.name_scope(name) as name: 208 with ops.name_scope("init", values=[ 209 loc, scale_diag, scale_identity_multiplier]): 210 # No need to validate_args while making diag_scale. The returned 211 # LinearOperatorDiag has an assert_non_singular method that is called by 212 # the Bijector. 213 scale = distribution_util.make_diag_scale( 214 loc=loc, 215 scale_diag=scale_diag, 216 scale_identity_multiplier=scale_identity_multiplier, 217 validate_args=False, 218 assert_positive=False) 219 super(MultivariateNormalDiag, self).__init__( 220 loc=loc, 221 scale=scale, 222 validate_args=validate_args, 223 allow_nan_stats=allow_nan_stats, 224 name=name) 225 self._parameters = parameters 226 227 228class MultivariateNormalDiagWithSoftplusScale(MultivariateNormalDiag): 229 """MultivariateNormalDiag with `diag_stddev = softplus(diag_stddev)`.""" 230 231 @deprecation.deprecated( 232 "2018-10-01", 233 "The TensorFlow Distributions library has moved to " 234 "TensorFlow Probability " 235 "(https://github.com/tensorflow/probability). You " 236 "should update all references to use `tfp.distributions` " 237 "instead of `tf.contrib.distributions`.", 238 warn_once=True) 239 def __init__(self, 240 loc, 241 scale_diag, 242 validate_args=False, 243 allow_nan_stats=True, 244 name="MultivariateNormalDiagWithSoftplusScale"): 245 parameters = dict(locals()) 246 with ops.name_scope(name, values=[scale_diag]) as name: 247 super(MultivariateNormalDiagWithSoftplusScale, self).__init__( 248 loc=loc, 249 scale_diag=nn.softplus(scale_diag), 250 validate_args=validate_args, 251 allow_nan_stats=allow_nan_stats, 252 name=name) 253 self._parameters = parameters 254