1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""L2HMC on simple Gaussian mixture model with TensorFlow eager."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import sys
23
24from absl import flags
25import numpy as np
26import tensorflow as tf
27from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc
28try:
29  import matplotlib.pyplot as plt  # pylint: disable=g-import-not-at-top
30  HAS_MATPLOTLIB = True
31except ImportError:
32  HAS_MATPLOTLIB = False
33tfe = tf.contrib.eager
34
35
36def main(_):
37  tf.enable_eager_execution()
38  global_step = tf.train.get_or_create_global_step()
39  global_step.assign(1)
40
41  energy_fn, mean, covar = {
42      "scg": l2hmc.get_scg_energy_fn(),
43      "rw": l2hmc.get_rw_energy_fn()
44  }[FLAGS.energy_fn]
45
46  x_dim = 2
47  train_iters = 5000
48  eval_iters = 2000
49  eps = 0.1
50  n_steps = 10  # Chain length
51  n_samples = 200
52  record_loss_every = 100
53
54  dynamics = l2hmc.Dynamics(
55      x_dim=x_dim, minus_loglikelihood_fn=energy_fn, n_steps=n_steps, eps=eps)
56  learning_rate = tf.train.exponential_decay(
57      1e-3, global_step, 1000, 0.96, staircase=True)
58  optimizer = tf.train.AdamOptimizer(learning_rate)
59  checkpointer = tf.train.Checkpoint(
60      optimizer=optimizer, dynamics=dynamics, global_step=global_step)
61
62  if FLAGS.train_dir:
63    summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir)
64    if FLAGS.restore:
65      latest_path = tf.train.latest_checkpoint(FLAGS.train_dir)
66      checkpointer.restore(latest_path)
67      print("Restored latest checkpoint at path:\"{}\" ".format(latest_path))
68      sys.stdout.flush()
69
70  if not FLAGS.restore:
71    # Training
72    if FLAGS.use_defun:
73      # Use `tfe.deun` to boost performance when there are lots of small ops
74      loss_fn = tfe.function(l2hmc.compute_loss)
75    else:
76      loss_fn = l2hmc.compute_loss
77
78    samples = tf.random_normal(shape=[n_samples, x_dim])
79    for i in range(1, train_iters + 1):
80      loss, samples, accept_prob = train_one_iter(
81          dynamics,
82          samples,
83          optimizer,
84          loss_fn=loss_fn,
85          global_step=global_step)
86
87      if i % record_loss_every == 0:
88        print("Iteration {}, loss {:.4f}, x_accept_prob {:.4f}".format(
89            i, loss.numpy(),
90            accept_prob.numpy().mean()))
91        if FLAGS.train_dir:
92          with summary_writer.as_default():
93            with tf.contrib.summary.always_record_summaries():
94              tf.contrib.summary.scalar("Training loss", loss, step=global_step)
95    print("Training complete.")
96    sys.stdout.flush()
97
98    if FLAGS.train_dir:
99      saved_path = checkpointer.save(
100          file_prefix=os.path.join(FLAGS.train_dir, "ckpt"))
101      print("Saved checkpoint at path: \"{}\" ".format(saved_path))
102      sys.stdout.flush()
103
104  # Evaluation
105  if FLAGS.use_defun:
106    # Use tfe.deun to boost performance when there are lots of small ops
107    apply_transition = tfe.function(dynamics.apply_transition)
108  else:
109    apply_transition = dynamics.apply_transition
110
111  samples = tf.random_normal(shape=[n_samples, x_dim])
112  samples_history = []
113  for i in range(eval_iters):
114    samples_history.append(samples.numpy())
115    _, _, _, samples = apply_transition(samples)
116  samples_history = np.array(samples_history)
117  print("Sampling complete.")
118  sys.stdout.flush()
119
120  # Mean and covariance of target distribution
121  mean = mean.numpy()
122  covar = covar.numpy()
123  ac_spectrum = compute_ac_spectrum(samples_history, mean, covar)
124  print("First 25 entries of the auto-correlation spectrum: {}".format(
125      ac_spectrum[:25]))
126  ess = compute_ess(ac_spectrum)
127  print("Effective sample size per Metropolis-Hastings step: {}".format(ess))
128  sys.stdout.flush()
129
130  if FLAGS.train_dir:
131    # Plot autocorrelation spectrum in tensorboard
132    plot_step = tfe.Variable(1, trainable=False, dtype=tf.int64)
133
134    for ac in ac_spectrum:
135      with summary_writer.as_default():
136        with tf.contrib.summary.always_record_summaries():
137          tf.contrib.summary.scalar("Autocorrelation", ac, step=plot_step)
138      plot_step.assign(plot_step + n_steps)
139
140    if HAS_MATPLOTLIB:
141      # Choose a single chain and plot the trajectory
142      single_chain = samples_history[:, 0, :]
143      xs = single_chain[:100, 0]
144      ys = single_chain[:100, 1]
145      plt.figure()
146      plt.plot(xs, ys, color="orange", marker="o", alpha=0.6)  # Trained chain
147      plt.savefig(os.path.join(FLAGS.train_dir, "single_chain.png"))
148
149
150def train_one_iter(dynamics,
151                   x,
152                   optimizer,
153                   loss_fn=l2hmc.compute_loss,
154                   global_step=None):
155  """Train the sampler for one iteration."""
156  loss, grads, out, accept_prob = l2hmc.loss_and_grads(
157      dynamics, x, loss_fn=loss_fn)
158  optimizer.apply_gradients(
159      zip(grads, dynamics.trainable_variables), global_step=global_step)
160
161  return loss, out, accept_prob
162
163
164def compute_ac_spectrum(samples_history, target_mean, target_covar):
165  """Compute autocorrelation spectrum.
166
167  Follows equation 15 from the L2HMC paper.
168
169  Args:
170    samples_history: Numpy array of shape [T, B, D], where T is the total
171        number of time steps, B is the batch size, and D is the dimensionality
172        of sample space.
173    target_mean: 1D Numpy array of the mean of target(true) distribution.
174    target_covar: 2D Numpy array representing a symmetric matrix for variance.
175  Returns:
176    Autocorrelation spectrum, Numpy array of shape [T-1].
177  """
178
179  # Using numpy here since eager is a bit slow due to the loop
180  time_steps = samples_history.shape[0]
181  trace = np.trace(target_covar)
182
183  rhos = []
184  for t in range(time_steps - 1):
185    rho_t = 0.
186    for tau in range(time_steps - t):
187      v_tau = samples_history[tau, :, :] - target_mean
188      v_tau_plus_t = samples_history[tau + t, :, :] - target_mean
189      # Take dot product over observation dims and take mean over batch dims
190      rho_t += np.mean(np.sum(v_tau * v_tau_plus_t, axis=1))
191
192    rho_t /= trace * (time_steps - t)
193    rhos.append(rho_t)
194
195  return np.array(rhos)
196
197
198def compute_ess(ac_spectrum):
199  """Compute the effective sample size based on autocorrelation spectrum.
200
201  This follows equation 16 from the L2HMC paper.
202
203  Args:
204    ac_spectrum: Autocorrelation spectrum
205  Returns:
206    The effective sample size
207  """
208  # Cutoff from the first value less than 0.05
209  cutoff = np.argmax(ac_spectrum[1:] < .05)
210  if cutoff == 0:
211    cutoff = len(ac_spectrum)
212  ess = 1. / (1. + 2. * np.sum(ac_spectrum[1:cutoff]))
213  return ess
214
215
216if __name__ == "__main__":
217  flags.DEFINE_string(
218      "train_dir",
219      default=None,
220      help="[Optional] Directory to store the training information")
221  flags.DEFINE_boolean(
222      "restore",
223      default=False,
224      help="[Optional] Restore the latest checkpoint from `train_dir` if True")
225  flags.DEFINE_boolean(
226      "use_defun",
227      default=False,
228      help="[Optional] Use `tfe.defun` to boost performance")
229  flags.DEFINE_string(
230      "energy_fn",
231      default="scg",
232      help="[Optional] The energy function used for experimentation"
233      "Other options include `rw`")
234  FLAGS = flags.FLAGS
235  tf.app.run(main)
236