1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""L2HMC on simple Gaussian mixture model with TensorFlow eager.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import sys 23 24from absl import flags 25import numpy as np 26import tensorflow as tf 27from tensorflow.contrib.eager.python.examples.l2hmc import l2hmc 28try: 29 import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top 30 HAS_MATPLOTLIB = True 31except ImportError: 32 HAS_MATPLOTLIB = False 33tfe = tf.contrib.eager 34 35 36def main(_): 37 tf.enable_eager_execution() 38 global_step = tf.train.get_or_create_global_step() 39 global_step.assign(1) 40 41 energy_fn, mean, covar = { 42 "scg": l2hmc.get_scg_energy_fn(), 43 "rw": l2hmc.get_rw_energy_fn() 44 }[FLAGS.energy_fn] 45 46 x_dim = 2 47 train_iters = 5000 48 eval_iters = 2000 49 eps = 0.1 50 n_steps = 10 # Chain length 51 n_samples = 200 52 record_loss_every = 100 53 54 dynamics = l2hmc.Dynamics( 55 x_dim=x_dim, minus_loglikelihood_fn=energy_fn, n_steps=n_steps, eps=eps) 56 learning_rate = tf.train.exponential_decay( 57 1e-3, global_step, 1000, 0.96, staircase=True) 58 optimizer = tf.train.AdamOptimizer(learning_rate) 59 checkpointer = tf.train.Checkpoint( 60 optimizer=optimizer, dynamics=dynamics, global_step=global_step) 61 62 if FLAGS.train_dir: 63 summary_writer = tf.contrib.summary.create_file_writer(FLAGS.train_dir) 64 if FLAGS.restore: 65 latest_path = tf.train.latest_checkpoint(FLAGS.train_dir) 66 checkpointer.restore(latest_path) 67 print("Restored latest checkpoint at path:\"{}\" ".format(latest_path)) 68 sys.stdout.flush() 69 70 if not FLAGS.restore: 71 # Training 72 if FLAGS.use_defun: 73 # Use `tfe.deun` to boost performance when there are lots of small ops 74 loss_fn = tfe.function(l2hmc.compute_loss) 75 else: 76 loss_fn = l2hmc.compute_loss 77 78 samples = tf.random_normal(shape=[n_samples, x_dim]) 79 for i in range(1, train_iters + 1): 80 loss, samples, accept_prob = train_one_iter( 81 dynamics, 82 samples, 83 optimizer, 84 loss_fn=loss_fn, 85 global_step=global_step) 86 87 if i % record_loss_every == 0: 88 print("Iteration {}, loss {:.4f}, x_accept_prob {:.4f}".format( 89 i, loss.numpy(), 90 accept_prob.numpy().mean())) 91 if FLAGS.train_dir: 92 with summary_writer.as_default(): 93 with tf.contrib.summary.always_record_summaries(): 94 tf.contrib.summary.scalar("Training loss", loss, step=global_step) 95 print("Training complete.") 96 sys.stdout.flush() 97 98 if FLAGS.train_dir: 99 saved_path = checkpointer.save( 100 file_prefix=os.path.join(FLAGS.train_dir, "ckpt")) 101 print("Saved checkpoint at path: \"{}\" ".format(saved_path)) 102 sys.stdout.flush() 103 104 # Evaluation 105 if FLAGS.use_defun: 106 # Use tfe.deun to boost performance when there are lots of small ops 107 apply_transition = tfe.function(dynamics.apply_transition) 108 else: 109 apply_transition = dynamics.apply_transition 110 111 samples = tf.random_normal(shape=[n_samples, x_dim]) 112 samples_history = [] 113 for i in range(eval_iters): 114 samples_history.append(samples.numpy()) 115 _, _, _, samples = apply_transition(samples) 116 samples_history = np.array(samples_history) 117 print("Sampling complete.") 118 sys.stdout.flush() 119 120 # Mean and covariance of target distribution 121 mean = mean.numpy() 122 covar = covar.numpy() 123 ac_spectrum = compute_ac_spectrum(samples_history, mean, covar) 124 print("First 25 entries of the auto-correlation spectrum: {}".format( 125 ac_spectrum[:25])) 126 ess = compute_ess(ac_spectrum) 127 print("Effective sample size per Metropolis-Hastings step: {}".format(ess)) 128 sys.stdout.flush() 129 130 if FLAGS.train_dir: 131 # Plot autocorrelation spectrum in tensorboard 132 plot_step = tfe.Variable(1, trainable=False, dtype=tf.int64) 133 134 for ac in ac_spectrum: 135 with summary_writer.as_default(): 136 with tf.contrib.summary.always_record_summaries(): 137 tf.contrib.summary.scalar("Autocorrelation", ac, step=plot_step) 138 plot_step.assign(plot_step + n_steps) 139 140 if HAS_MATPLOTLIB: 141 # Choose a single chain and plot the trajectory 142 single_chain = samples_history[:, 0, :] 143 xs = single_chain[:100, 0] 144 ys = single_chain[:100, 1] 145 plt.figure() 146 plt.plot(xs, ys, color="orange", marker="o", alpha=0.6) # Trained chain 147 plt.savefig(os.path.join(FLAGS.train_dir, "single_chain.png")) 148 149 150def train_one_iter(dynamics, 151 x, 152 optimizer, 153 loss_fn=l2hmc.compute_loss, 154 global_step=None): 155 """Train the sampler for one iteration.""" 156 loss, grads, out, accept_prob = l2hmc.loss_and_grads( 157 dynamics, x, loss_fn=loss_fn) 158 optimizer.apply_gradients( 159 zip(grads, dynamics.trainable_variables), global_step=global_step) 160 161 return loss, out, accept_prob 162 163 164def compute_ac_spectrum(samples_history, target_mean, target_covar): 165 """Compute autocorrelation spectrum. 166 167 Follows equation 15 from the L2HMC paper. 168 169 Args: 170 samples_history: Numpy array of shape [T, B, D], where T is the total 171 number of time steps, B is the batch size, and D is the dimensionality 172 of sample space. 173 target_mean: 1D Numpy array of the mean of target(true) distribution. 174 target_covar: 2D Numpy array representing a symmetric matrix for variance. 175 Returns: 176 Autocorrelation spectrum, Numpy array of shape [T-1]. 177 """ 178 179 # Using numpy here since eager is a bit slow due to the loop 180 time_steps = samples_history.shape[0] 181 trace = np.trace(target_covar) 182 183 rhos = [] 184 for t in range(time_steps - 1): 185 rho_t = 0. 186 for tau in range(time_steps - t): 187 v_tau = samples_history[tau, :, :] - target_mean 188 v_tau_plus_t = samples_history[tau + t, :, :] - target_mean 189 # Take dot product over observation dims and take mean over batch dims 190 rho_t += np.mean(np.sum(v_tau * v_tau_plus_t, axis=1)) 191 192 rho_t /= trace * (time_steps - t) 193 rhos.append(rho_t) 194 195 return np.array(rhos) 196 197 198def compute_ess(ac_spectrum): 199 """Compute the effective sample size based on autocorrelation spectrum. 200 201 This follows equation 16 from the L2HMC paper. 202 203 Args: 204 ac_spectrum: Autocorrelation spectrum 205 Returns: 206 The effective sample size 207 """ 208 # Cutoff from the first value less than 0.05 209 cutoff = np.argmax(ac_spectrum[1:] < .05) 210 if cutoff == 0: 211 cutoff = len(ac_spectrum) 212 ess = 1. / (1. + 2. * np.sum(ac_spectrum[1:cutoff])) 213 return ess 214 215 216if __name__ == "__main__": 217 flags.DEFINE_string( 218 "train_dir", 219 default=None, 220 help="[Optional] Directory to store the training information") 221 flags.DEFINE_boolean( 222 "restore", 223 default=False, 224 help="[Optional] Restore the latest checkpoint from `train_dir` if True") 225 flags.DEFINE_boolean( 226 "use_defun", 227 default=False, 228 help="[Optional] Use `tfe.defun` to boost performance") 229 flags.DEFINE_string( 230 "energy_fn", 231 default="scg", 232 help="[Optional] The energy function used for experimentation" 233 "Other options include `rw`") 234 FLAGS = flags.FLAGS 235 tf.app.run(main) 236