1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""L2HMC compatible with TensorFlow's eager execution. 16 17Reference [Generalizing Hamiltonian Monte Carlo with Neural 18Networks](https://arxiv.org/pdf/1711.09268.pdf) 19 20Code adapted from the released TensorFlow graph implementation by original 21authors https://github.com/brain-research/l2hmc. 22""" 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27import numpy as np 28import numpy.random as npr 29import tensorflow as tf 30import tensorflow.contrib.eager as tfe 31from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets 32 33 34class Dynamics(tf.keras.Model): 35 """Dynamics engine of naive L2HMC sampler.""" 36 37 def __init__(self, 38 x_dim, 39 minus_loglikelihood_fn, 40 n_steps=25, 41 eps=.1, 42 np_seed=1): 43 """Initialization. 44 45 Args: 46 x_dim: dimensionality of observed data 47 minus_loglikelihood_fn: log-likelihood function of conditional probability 48 n_steps: number of leapfrog steps within each transition 49 eps: initial value learnable scale of step size 50 np_seed: Random seed for numpy; used to control sampled masks. 51 """ 52 super(Dynamics, self).__init__() 53 54 npr.seed(np_seed) 55 self.x_dim = x_dim 56 self.potential = minus_loglikelihood_fn 57 self.n_steps = n_steps 58 59 self._construct_time() 60 self._construct_masks() 61 62 self.position_fn = neural_nets.GenericNet(x_dim, factor=2.) 63 self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.) 64 65 self.eps = tf.Variable( 66 initial_value=eps, name="eps", dtype=tf.float32, trainable=True) 67 68 def apply_transition(self, position): 69 """Propose a new state and perform the accept or reject step.""" 70 71 # Simulate dynamics both forward and backward; 72 # Use sampled Bernoulli masks to compute the actual solutions 73 position_f, momentum_f, accept_prob_f = self.transition_kernel( 74 position, forward=True) 75 position_b, momentum_b, accept_prob_b = self.transition_kernel( 76 position, forward=False) 77 78 # Decide direction uniformly 79 batch_size = tf.shape(position)[0] 80 forward_mask = tf.cast(tf.random_uniform((batch_size,)) > .5, tf.float32) 81 backward_mask = 1. - forward_mask 82 83 # Obtain proposed states 84 position_post = ( 85 forward_mask[:, None] * position_f + 86 backward_mask[:, None] * position_b) 87 momentum_post = ( 88 forward_mask[:, None] * momentum_f + 89 backward_mask[:, None] * momentum_b) 90 91 # Probability of accepting the proposed states 92 accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b 93 94 # Accept or reject step 95 accept_mask = tf.cast( 96 accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32) 97 reject_mask = 1. - accept_mask 98 99 # Samples after accept/reject step 100 position_out = ( 101 accept_mask[:, None] * position_post + reject_mask[:, None] * position) 102 103 return position_post, momentum_post, accept_prob, position_out 104 105 def transition_kernel(self, position, forward=True): 106 """Transition kernel of augmented leapfrog integrator.""" 107 108 lf_fn = self._forward_lf if forward else self._backward_lf 109 110 # Resample momentum 111 momentum = tf.random_normal(tf.shape(position)) 112 position_post, momentum_post = position, momentum 113 sumlogdet = 0. 114 # Apply augmented leapfrog steps 115 for i in range(self.n_steps): 116 position_post, momentum_post, logdet = lf_fn(position_post, momentum_post, 117 i) 118 sumlogdet += logdet 119 accept_prob = self._compute_accept_prob(position, momentum, position_post, 120 momentum_post, sumlogdet) 121 122 return position_post, momentum_post, accept_prob 123 124 def _forward_lf(self, position, momentum, i): 125 """One forward augmented leapfrog step. See eq (5-6) in paper.""" 126 127 t = self._get_time(i) 128 mask, mask_inv = self._get_mask(i) 129 sumlogdet = 0. 130 131 momentum, logdet = self._update_momentum_forward(position, momentum, t) 132 sumlogdet += logdet 133 134 position, logdet = self._update_position_forward(position, momentum, t, 135 mask, mask_inv) 136 sumlogdet += logdet 137 138 position, logdet = self._update_position_forward(position, momentum, t, 139 mask_inv, mask) 140 sumlogdet += logdet 141 142 momentum, logdet = self._update_momentum_forward(position, momentum, t) 143 sumlogdet += logdet 144 145 return position, momentum, sumlogdet 146 147 def _backward_lf(self, position, momentum, i): 148 """One backward augmented leapfrog step. See Appendix A in paper.""" 149 150 # Reversed index/sinusoidal time 151 t = self._get_time(self.n_steps - i - 1) 152 mask, mask_inv = self._get_mask(self.n_steps - i - 1) 153 sumlogdet = 0. 154 155 momentum, logdet = self._update_momentum_backward(position, momentum, t) 156 sumlogdet += logdet 157 158 position, logdet = self._update_position_backward(position, momentum, t, 159 mask_inv, mask) 160 sumlogdet += logdet 161 162 position, logdet = self._update_position_backward(position, momentum, t, 163 mask, mask_inv) 164 sumlogdet += logdet 165 166 momentum, logdet = self._update_momentum_backward(position, momentum, t) 167 sumlogdet += logdet 168 169 return position, momentum, sumlogdet 170 171 def _update_momentum_forward(self, position, momentum, t): 172 """Update v in the forward leapfrog step.""" 173 174 grad = self.grad_potential(position) 175 scale, translation, transformed = self.momentum_fn([position, grad, t]) 176 scale *= .5 * self.eps 177 transformed *= self.eps 178 momentum = ( 179 momentum * tf.exp(scale) - 180 .5 * self.eps * (tf.exp(transformed) * grad - translation)) 181 182 return momentum, tf.reduce_sum(scale, axis=1) 183 184 def _update_position_forward(self, position, momentum, t, mask, mask_inv): 185 """Update x in the forward leapfrog step.""" 186 187 scale, translation, transformed = self.position_fn( 188 [momentum, mask * position, t]) 189 scale *= self.eps 190 transformed *= self.eps 191 position = ( 192 mask * position + 193 mask_inv * (position * tf.exp(scale) + self.eps * 194 (tf.exp(transformed) * momentum + translation))) 195 return position, tf.reduce_sum(mask_inv * scale, axis=1) 196 197 def _update_momentum_backward(self, position, momentum, t): 198 """Update v in the backward leapfrog step. Inverting the forward update.""" 199 200 grad = self.grad_potential(position) 201 scale, translation, transformed = self.momentum_fn([position, grad, t]) 202 scale *= -.5 * self.eps 203 transformed *= self.eps 204 momentum = ( 205 tf.exp(scale) * (momentum + .5 * self.eps * 206 (tf.exp(transformed) * grad - translation))) 207 208 return momentum, tf.reduce_sum(scale, axis=1) 209 210 def _update_position_backward(self, position, momentum, t, mask, mask_inv): 211 """Update x in the backward leapfrog step. Inverting the forward update.""" 212 213 scale, translation, transformed = self.position_fn( 214 [momentum, mask * position, t]) 215 scale *= -self.eps 216 transformed *= self.eps 217 position = ( 218 mask * position + mask_inv * tf.exp(scale) * 219 (position - self.eps * (tf.exp(transformed) * momentum + translation))) 220 221 return position, tf.reduce_sum(mask_inv * scale, axis=1) 222 223 def _compute_accept_prob(self, position, momentum, position_post, 224 momentum_post, sumlogdet): 225 """Compute the prob of accepting the proposed state given old state.""" 226 227 old_hamil = self.hamiltonian(position, momentum) 228 new_hamil = self.hamiltonian(position_post, momentum_post) 229 prob = tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.)) 230 231 # Ensure numerical stability as well as correct gradients 232 return tf.where(tf.is_finite(prob), prob, tf.zeros_like(prob)) 233 234 def _construct_time(self): 235 """Convert leapfrog step index into sinusoidal time.""" 236 237 self.ts = [] 238 for i in range(self.n_steps): 239 t = tf.constant( 240 [ 241 np.cos(2 * np.pi * i / self.n_steps), 242 np.sin(2 * np.pi * i / self.n_steps) 243 ], 244 dtype=tf.float32) 245 self.ts.append(t[None, :]) 246 247 def _get_time(self, i): 248 """Get sinusoidal time for i-th augmented leapfrog step.""" 249 250 return self.ts[i] 251 252 def _construct_masks(self): 253 """Construct different binary masks for different time steps.""" 254 255 self.masks = [] 256 for _ in range(self.n_steps): 257 # Need to use npr here because tf would generated different random 258 # values across different `sess.run` 259 idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2] 260 mask = np.zeros((self.x_dim,)) 261 mask[idx] = 1. 262 mask = tf.constant(mask, dtype=tf.float32) 263 self.masks.append(mask[None, :]) 264 265 def _get_mask(self, i): 266 """Get binary masks for i-th augmented leapfrog step.""" 267 268 m = self.masks[i] 269 return m, 1. - m 270 271 def kinetic(self, v): 272 """Compute the kinetic energy.""" 273 274 return .5 * tf.reduce_sum(v**2, axis=1) 275 276 def hamiltonian(self, position, momentum): 277 """Compute the overall Hamiltonian.""" 278 279 return self.potential(position) + self.kinetic(momentum) 280 281 def grad_potential(self, position, check_numerics=True): 282 """Get gradient of potential function at current location.""" 283 284 if tf.executing_eagerly(): 285 grad = tfe.gradients_function(self.potential)(position)[0] 286 else: 287 grad = tf.gradients(self.potential(position), position)[0] 288 289 return grad 290 291 292# Examples of unnormalized log densities 293def get_scg_energy_fn(): 294 """Get energy function for 2d strongly correlated Gaussian.""" 295 296 # Avoid recreating tf constants on each invocation of gradients 297 mu = tf.constant([0., 0.]) 298 sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]]) 299 sigma_inv = tf.matrix_inverse(sigma) 300 301 def energy(x): 302 """Unnormalized minus log density of 2d strongly correlated Gaussian.""" 303 304 xmmu = x - mu 305 return .5 * tf.diag_part( 306 tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu))) 307 308 return energy, mu, sigma 309 310 311def get_rw_energy_fn(): 312 """Get energy function for rough well distribution.""" 313 # For small eta, the density underlying the rough-well energy is very close to 314 # a unit Gaussian; however, the gradient is greatly affected by the small 315 # cosine perturbations 316 eta = 1e-2 317 mu = tf.constant([0., 0.]) 318 sigma = tf.constant([[1., 0.], [0., 1.]]) 319 320 def energy(x): 321 ip = tf.reduce_sum(x**2., axis=1) 322 return .5 * ip + eta * tf.reduce_sum(tf.cos(x / eta), axis=1) 323 324 return energy, mu, sigma 325 326 327# Loss function 328def compute_loss(dynamics, x, scale=.1, eps=1e-4): 329 """Compute loss defined in equation (8).""" 330 331 z = tf.random_normal(tf.shape(x)) # Auxiliary variable 332 x_, _, x_accept_prob, x_out = dynamics.apply_transition(x) 333 z_, _, z_accept_prob, _ = dynamics.apply_transition(z) 334 335 # Add eps for numerical stability; following released impl 336 x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps 337 z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps 338 339 loss = tf.reduce_mean( 340 (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0) 341 342 return loss, x_out, x_accept_prob 343 344 345def loss_and_grads(dynamics, x, loss_fn=compute_loss): 346 """Obtain loss value and gradients.""" 347 with tf.GradientTape() as tape: 348 loss_val, out, accept_prob = loss_fn(dynamics, x) 349 grads = tape.gradient(loss_val, dynamics.trainable_variables) 350 351 return loss_val, grads, out, accept_prob 352