1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""L2HMC compatible with TensorFlow's eager execution.
16
17Reference [Generalizing Hamiltonian Monte Carlo with Neural
18Networks](https://arxiv.org/pdf/1711.09268.pdf)
19
20Code adapted from the released TensorFlow graph implementation by original
21authors https://github.com/brain-research/l2hmc.
22"""
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27import numpy as np
28import numpy.random as npr
29import tensorflow as tf
30import tensorflow.contrib.eager as tfe
31from tensorflow.contrib.eager.python.examples.l2hmc import neural_nets
32
33
34class Dynamics(tf.keras.Model):
35  """Dynamics engine of naive L2HMC sampler."""
36
37  def __init__(self,
38               x_dim,
39               minus_loglikelihood_fn,
40               n_steps=25,
41               eps=.1,
42               np_seed=1):
43    """Initialization.
44
45    Args:
46      x_dim: dimensionality of observed data
47      minus_loglikelihood_fn: log-likelihood function of conditional probability
48      n_steps: number of leapfrog steps within each transition
49      eps: initial value learnable scale of step size
50      np_seed: Random seed for numpy; used to control sampled masks.
51    """
52    super(Dynamics, self).__init__()
53
54    npr.seed(np_seed)
55    self.x_dim = x_dim
56    self.potential = minus_loglikelihood_fn
57    self.n_steps = n_steps
58
59    self._construct_time()
60    self._construct_masks()
61
62    self.position_fn = neural_nets.GenericNet(x_dim, factor=2.)
63    self.momentum_fn = neural_nets.GenericNet(x_dim, factor=1.)
64
65    self.eps = tf.Variable(
66        initial_value=eps, name="eps", dtype=tf.float32, trainable=True)
67
68  def apply_transition(self, position):
69    """Propose a new state and perform the accept or reject step."""
70
71    # Simulate dynamics both forward and backward;
72    # Use sampled Bernoulli masks to compute the actual solutions
73    position_f, momentum_f, accept_prob_f = self.transition_kernel(
74        position, forward=True)
75    position_b, momentum_b, accept_prob_b = self.transition_kernel(
76        position, forward=False)
77
78    # Decide direction uniformly
79    batch_size = tf.shape(position)[0]
80    forward_mask = tf.cast(tf.random_uniform((batch_size,)) > .5, tf.float32)
81    backward_mask = 1. - forward_mask
82
83    # Obtain proposed states
84    position_post = (
85        forward_mask[:, None] * position_f +
86        backward_mask[:, None] * position_b)
87    momentum_post = (
88        forward_mask[:, None] * momentum_f +
89        backward_mask[:, None] * momentum_b)
90
91    # Probability of accepting the proposed states
92    accept_prob = forward_mask * accept_prob_f + backward_mask * accept_prob_b
93
94    # Accept or reject step
95    accept_mask = tf.cast(
96        accept_prob > tf.random_uniform(tf.shape(accept_prob)), tf.float32)
97    reject_mask = 1. - accept_mask
98
99    # Samples after accept/reject step
100    position_out = (
101        accept_mask[:, None] * position_post + reject_mask[:, None] * position)
102
103    return position_post, momentum_post, accept_prob, position_out
104
105  def transition_kernel(self, position, forward=True):
106    """Transition kernel of augmented leapfrog integrator."""
107
108    lf_fn = self._forward_lf if forward else self._backward_lf
109
110    # Resample momentum
111    momentum = tf.random_normal(tf.shape(position))
112    position_post, momentum_post = position, momentum
113    sumlogdet = 0.
114    # Apply augmented leapfrog steps
115    for i in range(self.n_steps):
116      position_post, momentum_post, logdet = lf_fn(position_post, momentum_post,
117                                                   i)
118      sumlogdet += logdet
119    accept_prob = self._compute_accept_prob(position, momentum, position_post,
120                                            momentum_post, sumlogdet)
121
122    return position_post, momentum_post, accept_prob
123
124  def _forward_lf(self, position, momentum, i):
125    """One forward augmented leapfrog step. See eq (5-6) in paper."""
126
127    t = self._get_time(i)
128    mask, mask_inv = self._get_mask(i)
129    sumlogdet = 0.
130
131    momentum, logdet = self._update_momentum_forward(position, momentum, t)
132    sumlogdet += logdet
133
134    position, logdet = self._update_position_forward(position, momentum, t,
135                                                     mask, mask_inv)
136    sumlogdet += logdet
137
138    position, logdet = self._update_position_forward(position, momentum, t,
139                                                     mask_inv, mask)
140    sumlogdet += logdet
141
142    momentum, logdet = self._update_momentum_forward(position, momentum, t)
143    sumlogdet += logdet
144
145    return position, momentum, sumlogdet
146
147  def _backward_lf(self, position, momentum, i):
148    """One backward augmented leapfrog step. See Appendix A in paper."""
149
150    # Reversed index/sinusoidal time
151    t = self._get_time(self.n_steps - i - 1)
152    mask, mask_inv = self._get_mask(self.n_steps - i - 1)
153    sumlogdet = 0.
154
155    momentum, logdet = self._update_momentum_backward(position, momentum, t)
156    sumlogdet += logdet
157
158    position, logdet = self._update_position_backward(position, momentum, t,
159                                                      mask_inv, mask)
160    sumlogdet += logdet
161
162    position, logdet = self._update_position_backward(position, momentum, t,
163                                                      mask, mask_inv)
164    sumlogdet += logdet
165
166    momentum, logdet = self._update_momentum_backward(position, momentum, t)
167    sumlogdet += logdet
168
169    return position, momentum, sumlogdet
170
171  def _update_momentum_forward(self, position, momentum, t):
172    """Update v in the forward leapfrog step."""
173
174    grad = self.grad_potential(position)
175    scale, translation, transformed = self.momentum_fn([position, grad, t])
176    scale *= .5 * self.eps
177    transformed *= self.eps
178    momentum = (
179        momentum * tf.exp(scale) -
180        .5 * self.eps * (tf.exp(transformed) * grad - translation))
181
182    return momentum, tf.reduce_sum(scale, axis=1)
183
184  def _update_position_forward(self, position, momentum, t, mask, mask_inv):
185    """Update x in the forward leapfrog step."""
186
187    scale, translation, transformed = self.position_fn(
188        [momentum, mask * position, t])
189    scale *= self.eps
190    transformed *= self.eps
191    position = (
192        mask * position +
193        mask_inv * (position * tf.exp(scale) + self.eps *
194                    (tf.exp(transformed) * momentum + translation)))
195    return position, tf.reduce_sum(mask_inv * scale, axis=1)
196
197  def _update_momentum_backward(self, position, momentum, t):
198    """Update v in the backward leapfrog step. Inverting the forward update."""
199
200    grad = self.grad_potential(position)
201    scale, translation, transformed = self.momentum_fn([position, grad, t])
202    scale *= -.5 * self.eps
203    transformed *= self.eps
204    momentum = (
205        tf.exp(scale) * (momentum + .5 * self.eps *
206                         (tf.exp(transformed) * grad - translation)))
207
208    return momentum, tf.reduce_sum(scale, axis=1)
209
210  def _update_position_backward(self, position, momentum, t, mask, mask_inv):
211    """Update x in the backward leapfrog step. Inverting the forward update."""
212
213    scale, translation, transformed = self.position_fn(
214        [momentum, mask * position, t])
215    scale *= -self.eps
216    transformed *= self.eps
217    position = (
218        mask * position + mask_inv * tf.exp(scale) *
219        (position - self.eps * (tf.exp(transformed) * momentum + translation)))
220
221    return position, tf.reduce_sum(mask_inv * scale, axis=1)
222
223  def _compute_accept_prob(self, position, momentum, position_post,
224                           momentum_post, sumlogdet):
225    """Compute the prob of accepting the proposed state given old state."""
226
227    old_hamil = self.hamiltonian(position, momentum)
228    new_hamil = self.hamiltonian(position_post, momentum_post)
229    prob = tf.exp(tf.minimum(old_hamil - new_hamil + sumlogdet, 0.))
230
231    # Ensure numerical stability as well as correct gradients
232    return tf.where(tf.is_finite(prob), prob, tf.zeros_like(prob))
233
234  def _construct_time(self):
235    """Convert leapfrog step index into sinusoidal time."""
236
237    self.ts = []
238    for i in range(self.n_steps):
239      t = tf.constant(
240          [
241              np.cos(2 * np.pi * i / self.n_steps),
242              np.sin(2 * np.pi * i / self.n_steps)
243          ],
244          dtype=tf.float32)
245      self.ts.append(t[None, :])
246
247  def _get_time(self, i):
248    """Get sinusoidal time for i-th augmented leapfrog step."""
249
250    return self.ts[i]
251
252  def _construct_masks(self):
253    """Construct different binary masks for different time steps."""
254
255    self.masks = []
256    for _ in range(self.n_steps):
257      # Need to use npr here because tf would generated different random
258      # values across different `sess.run`
259      idx = npr.permutation(np.arange(self.x_dim))[:self.x_dim // 2]
260      mask = np.zeros((self.x_dim,))
261      mask[idx] = 1.
262      mask = tf.constant(mask, dtype=tf.float32)
263      self.masks.append(mask[None, :])
264
265  def _get_mask(self, i):
266    """Get binary masks for i-th augmented leapfrog step."""
267
268    m = self.masks[i]
269    return m, 1. - m
270
271  def kinetic(self, v):
272    """Compute the kinetic energy."""
273
274    return .5 * tf.reduce_sum(v**2, axis=1)
275
276  def hamiltonian(self, position, momentum):
277    """Compute the overall Hamiltonian."""
278
279    return self.potential(position) + self.kinetic(momentum)
280
281  def grad_potential(self, position, check_numerics=True):
282    """Get gradient of potential function at current location."""
283
284    if tf.executing_eagerly():
285      grad = tfe.gradients_function(self.potential)(position)[0]
286    else:
287      grad = tf.gradients(self.potential(position), position)[0]
288
289    return grad
290
291
292# Examples of unnormalized log densities
293def get_scg_energy_fn():
294  """Get energy function for 2d strongly correlated Gaussian."""
295
296  # Avoid recreating tf constants on each invocation of gradients
297  mu = tf.constant([0., 0.])
298  sigma = tf.constant([[50.05, -49.95], [-49.95, 50.05]])
299  sigma_inv = tf.matrix_inverse(sigma)
300
301  def energy(x):
302    """Unnormalized minus log density of 2d strongly correlated Gaussian."""
303
304    xmmu = x - mu
305    return .5 * tf.diag_part(
306        tf.matmul(tf.matmul(xmmu, sigma_inv), tf.transpose(xmmu)))
307
308  return energy, mu, sigma
309
310
311def get_rw_energy_fn():
312  """Get energy function for rough well distribution."""
313  # For small eta, the density underlying the rough-well energy is very close to
314  # a unit Gaussian; however, the gradient is greatly affected by the small
315  # cosine perturbations
316  eta = 1e-2
317  mu = tf.constant([0., 0.])
318  sigma = tf.constant([[1., 0.], [0., 1.]])
319
320  def energy(x):
321    ip = tf.reduce_sum(x**2., axis=1)
322    return .5 * ip + eta * tf.reduce_sum(tf.cos(x / eta), axis=1)
323
324  return energy, mu, sigma
325
326
327# Loss function
328def compute_loss(dynamics, x, scale=.1, eps=1e-4):
329  """Compute loss defined in equation (8)."""
330
331  z = tf.random_normal(tf.shape(x))  # Auxiliary variable
332  x_, _, x_accept_prob, x_out = dynamics.apply_transition(x)
333  z_, _, z_accept_prob, _ = dynamics.apply_transition(z)
334
335  # Add eps for numerical stability; following released impl
336  x_loss = tf.reduce_sum((x - x_)**2, axis=1) * x_accept_prob + eps
337  z_loss = tf.reduce_sum((z - z_)**2, axis=1) * z_accept_prob + eps
338
339  loss = tf.reduce_mean(
340      (1. / x_loss + 1. / z_loss) * scale - (x_loss + z_loss) / scale, axis=0)
341
342  return loss, x_out, x_accept_prob
343
344
345def loss_and_grads(dynamics, x, loss_fn=compute_loss):
346  """Obtain loss value and gradients."""
347  with tf.GradientTape() as tape:
348    loss_val, out, accept_prob = loss_fn(dynamics, x)
349  grads = tape.gradient(loss_val, dynamics.trainable_variables)
350
351  return loss_val, grads, out, accept_prob
352