1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""A deep MNIST classifier using convolutional layers. 16 17Sample usage: 18 python mnist.py --help 19""" 20 21from __future__ import absolute_import 22from __future__ import division 23from __future__ import print_function 24 25import argparse 26import os 27import sys 28import time 29 30import tensorflow as tf 31 32import tensorflow.contrib.eager as tfe 33from tensorflow.examples.tutorials.mnist import input_data 34 35FLAGS = None 36 37 38class MNISTModel(tf.keras.Model): 39 """MNIST Network. 40 41 Network structure is equivalent to: 42 https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py 43 and 44 https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py 45 46 But written using the tf.layers API. 47 """ 48 49 def __init__(self, data_format): 50 """Creates a model for classifying a hand-written digit. 51 52 Args: 53 data_format: Either 'channels_first' or 'channels_last'. 54 'channels_first' is typically faster on GPUs while 'channels_last' is 55 typically faster on CPUs. See 56 https://www.tensorflow.org/performance/performance_guide#data_formats 57 """ 58 super(MNISTModel, self).__init__(name='') 59 if data_format == 'channels_first': 60 self._input_shape = [-1, 1, 28, 28] 61 else: 62 assert data_format == 'channels_last' 63 self._input_shape = [-1, 28, 28, 1] 64 self.conv1 = tf.layers.Conv2D( 65 32, 5, data_format=data_format, activation=tf.nn.relu) 66 self.conv2 = tf.layers.Conv2D( 67 64, 5, data_format=data_format, activation=tf.nn.relu) 68 self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu) 69 self.fc2 = tf.layers.Dense(10) 70 self.dropout = tf.layers.Dropout(0.5) 71 self.max_pool2d = tf.layers.MaxPooling2D( 72 (2, 2), (2, 2), padding='SAME', data_format=data_format) 73 74 def call(self, inputs, training=False): 75 """Computes labels from inputs. 76 77 Users should invoke __call__ to run the network, which delegates to this 78 method (and not call this method directly). 79 80 Args: 81 inputs: A batch of images as a Tensor with shape [batch_size, 784]. 82 training: True if invoked in the context of training (causing dropout to 83 be applied). False otherwise. 84 85 Returns: 86 A Tensor with shape [batch_size, 10] containing the predicted logits 87 for each image in the batch, for each of the 10 classes. 88 """ 89 90 x = tf.reshape(inputs, self._input_shape) 91 x = self.conv1(x) 92 x = self.max_pool2d(x) 93 x = self.conv2(x) 94 x = self.max_pool2d(x) 95 x = tf.layers.flatten(x) 96 x = self.fc1(x) 97 x = self.dropout(x, training=training) 98 x = self.fc2(x) 99 return x 100 101 102def loss(predictions, labels): 103 return tf.reduce_mean( 104 tf.nn.softmax_cross_entropy_with_logits( 105 logits=predictions, labels=labels)) 106 107 108def compute_accuracy(predictions, labels): 109 return tf.reduce_sum( 110 tf.cast( 111 tf.equal( 112 tf.argmax(predictions, axis=1, 113 output_type=tf.int64), 114 tf.argmax(labels, axis=1, 115 output_type=tf.int64)), 116 dtype=tf.float32)) / float(predictions.shape[0].value) 117 118 119def train_one_epoch(model, optimizer, dataset, log_interval=None): 120 """Trains model on `dataset` using `optimizer`.""" 121 122 tf.train.get_or_create_global_step() 123 124 for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): 125 with tf.contrib.summary.record_summaries_every_n_global_steps(10): 126 with tfe.GradientTape() as tape: 127 prediction = model(images, training=True) 128 loss_value = loss(prediction, labels) 129 tf.contrib.summary.scalar('loss', loss_value) 130 tf.contrib.summary.scalar('accuracy', 131 compute_accuracy(prediction, labels)) 132 grads = tape.gradient(loss_value, model.variables) 133 optimizer.apply_gradients(zip(grads, model.variables)) 134 if log_interval and batch % log_interval == 0: 135 print('Batch #%d\tLoss: %.6f' % (batch, loss_value)) 136 137 138def test(model, dataset): 139 """Perform an evaluation of `model` on the examples from `dataset`.""" 140 avg_loss = tfe.metrics.Mean('loss') 141 accuracy = tfe.metrics.Accuracy('accuracy') 142 143 for (images, labels) in tfe.Iterator(dataset): 144 predictions = model(images, training=False) 145 avg_loss(loss(predictions, labels)) 146 accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64), 147 tf.argmax(labels, axis=1, output_type=tf.int64)) 148 print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' % 149 (avg_loss.result(), 100 * accuracy.result())) 150 with tf.contrib.summary.always_record_summaries(): 151 tf.contrib.summary.scalar('loss', avg_loss.result()) 152 tf.contrib.summary.scalar('accuracy', accuracy.result()) 153 154 155def load_data(data_dir): 156 """Returns training and test tf.data.Dataset objects.""" 157 data = input_data.read_data_sets(data_dir, one_hot=True) 158 train_ds = tf.data.Dataset.from_tensor_slices((data.train.images, 159 data.train.labels)) 160 test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels)) 161 return (train_ds, test_ds) 162 163 164def main(_): 165 tfe.enable_eager_execution() 166 167 (device, data_format) = ('/gpu:0', 'channels_first') 168 if FLAGS.no_gpu or tfe.num_gpus() <= 0: 169 (device, data_format) = ('/cpu:0', 'channels_last') 170 print('Using device %s, and data format %s.' % (device, data_format)) 171 172 # Load the datasets 173 (train_ds, test_ds) = load_data(FLAGS.data_dir) 174 train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size) 175 176 # Create the model and optimizer 177 model = MNISTModel(data_format) 178 optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) 179 180 if FLAGS.output_dir: 181 train_dir = os.path.join(FLAGS.output_dir, 'train') 182 test_dir = os.path.join(FLAGS.output_dir, 'eval') 183 tf.gfile.MakeDirs(FLAGS.output_dir) 184 else: 185 train_dir = None 186 test_dir = None 187 summary_writer = tf.contrib.summary.create_file_writer( 188 train_dir, flush_millis=10000) 189 test_summary_writer = tf.contrib.summary.create_file_writer( 190 test_dir, flush_millis=10000, name='test') 191 checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') 192 193 with tf.device(device): 194 for epoch in range(1, 11): 195 with tfe.restore_variables_on_create( 196 tf.train.latest_checkpoint(FLAGS.checkpoint_dir)): 197 global_step = tf.train.get_or_create_global_step() 198 start = time.time() 199 with summary_writer.as_default(): 200 train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval) 201 end = time.time() 202 print('\nTrain time for epoch #%d (global step %d): %f' % ( 203 epoch, global_step.numpy(), end - start)) 204 with test_summary_writer.as_default(): 205 test(model, test_ds) 206 all_variables = ( 207 model.variables 208 + optimizer.variables() 209 + [global_step]) 210 tfe.Saver(all_variables).save( 211 checkpoint_prefix, global_step=global_step) 212 213 214if __name__ == '__main__': 215 parser = argparse.ArgumentParser() 216 parser.add_argument( 217 '--data-dir', 218 type=str, 219 default='/tmp/tensorflow/mnist/input_data', 220 help='Directory for storing input data') 221 parser.add_argument( 222 '--batch-size', 223 type=int, 224 default=64, 225 metavar='N', 226 help='input batch size for training (default: 64)') 227 parser.add_argument( 228 '--log-interval', 229 type=int, 230 default=10, 231 metavar='N', 232 help='how many batches to wait before logging training status') 233 parser.add_argument( 234 '--output_dir', 235 type=str, 236 default=None, 237 metavar='N', 238 help='Directory to write TensorBoard summaries') 239 parser.add_argument( 240 '--checkpoint_dir', 241 type=str, 242 default='/tmp/tensorflow/mnist/checkpoints/', 243 metavar='N', 244 help='Directory to save checkpoints in (once per epoch)') 245 parser.add_argument( 246 '--lr', 247 type=float, 248 default=0.01, 249 metavar='LR', 250 help='learning rate (default: 0.01)') 251 parser.add_argument( 252 '--momentum', 253 type=float, 254 default=0.5, 255 metavar='M', 256 help='SGD momentum (default: 0.5)') 257 parser.add_argument( 258 '--no-gpu', 259 action='store_true', 260 default=False, 261 help='disables GPU usage even if a GPU is available') 262 263 FLAGS, unparsed = parser.parse_known_args() 264 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 265