1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A deep MNIST classifier using convolutional layers.
16
17Sample usage:
18  python mnist.py --help
19"""
20
21from __future__ import absolute_import
22from __future__ import division
23from __future__ import print_function
24
25import argparse
26import os
27import sys
28import time
29
30import tensorflow as tf
31
32import tensorflow.contrib.eager as tfe
33from tensorflow.examples.tutorials.mnist import input_data
34
35FLAGS = None
36
37
38class MNISTModel(tf.keras.Model):
39  """MNIST Network.
40
41  Network structure is equivalent to:
42  https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py
43  and
44  https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py
45
46  But written using the tf.layers API.
47  """
48
49  def __init__(self, data_format):
50    """Creates a model for classifying a hand-written digit.
51
52    Args:
53      data_format: Either 'channels_first' or 'channels_last'.
54        'channels_first' is typically faster on GPUs while 'channels_last' is
55        typically faster on CPUs. See
56        https://www.tensorflow.org/performance/performance_guide#data_formats
57    """
58    super(MNISTModel, self).__init__(name='')
59    if data_format == 'channels_first':
60      self._input_shape = [-1, 1, 28, 28]
61    else:
62      assert data_format == 'channels_last'
63      self._input_shape = [-1, 28, 28, 1]
64    self.conv1 = tf.layers.Conv2D(
65        32, 5, data_format=data_format, activation=tf.nn.relu)
66    self.conv2 = tf.layers.Conv2D(
67        64, 5, data_format=data_format, activation=tf.nn.relu)
68    self.fc1 = tf.layers.Dense(1024, activation=tf.nn.relu)
69    self.fc2 = tf.layers.Dense(10)
70    self.dropout = tf.layers.Dropout(0.5)
71    self.max_pool2d = tf.layers.MaxPooling2D(
72        (2, 2), (2, 2), padding='SAME', data_format=data_format)
73
74  def call(self, inputs, training=False):
75    """Computes labels from inputs.
76
77    Users should invoke __call__ to run the network, which delegates to this
78    method (and not call this method directly).
79
80    Args:
81      inputs: A batch of images as a Tensor with shape [batch_size, 784].
82      training: True if invoked in the context of training (causing dropout to
83        be applied).  False otherwise.
84
85    Returns:
86      A Tensor with shape [batch_size, 10] containing the predicted logits
87      for each image in the batch, for each of the 10 classes.
88    """
89
90    x = tf.reshape(inputs, self._input_shape)
91    x = self.conv1(x)
92    x = self.max_pool2d(x)
93    x = self.conv2(x)
94    x = self.max_pool2d(x)
95    x = tf.layers.flatten(x)
96    x = self.fc1(x)
97    x = self.dropout(x, training=training)
98    x = self.fc2(x)
99    return x
100
101
102def loss(predictions, labels):
103  return tf.reduce_mean(
104      tf.nn.softmax_cross_entropy_with_logits(
105          logits=predictions, labels=labels))
106
107
108def compute_accuracy(predictions, labels):
109  return tf.reduce_sum(
110      tf.cast(
111          tf.equal(
112              tf.argmax(predictions, axis=1,
113                        output_type=tf.int64),
114              tf.argmax(labels, axis=1,
115                        output_type=tf.int64)),
116          dtype=tf.float32)) / float(predictions.shape[0].value)
117
118
119def train_one_epoch(model, optimizer, dataset, log_interval=None):
120  """Trains model on `dataset` using `optimizer`."""
121
122  tf.train.get_or_create_global_step()
123
124  for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)):
125    with tf.contrib.summary.record_summaries_every_n_global_steps(10):
126      with tfe.GradientTape() as tape:
127        prediction = model(images, training=True)
128        loss_value = loss(prediction, labels)
129        tf.contrib.summary.scalar('loss', loss_value)
130        tf.contrib.summary.scalar('accuracy',
131                                  compute_accuracy(prediction, labels))
132      grads = tape.gradient(loss_value, model.variables)
133      optimizer.apply_gradients(zip(grads, model.variables))
134      if log_interval and batch % log_interval == 0:
135        print('Batch #%d\tLoss: %.6f' % (batch, loss_value))
136
137
138def test(model, dataset):
139  """Perform an evaluation of `model` on the examples from `dataset`."""
140  avg_loss = tfe.metrics.Mean('loss')
141  accuracy = tfe.metrics.Accuracy('accuracy')
142
143  for (images, labels) in tfe.Iterator(dataset):
144    predictions = model(images, training=False)
145    avg_loss(loss(predictions, labels))
146    accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64),
147             tf.argmax(labels, axis=1, output_type=tf.int64))
148  print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' %
149        (avg_loss.result(), 100 * accuracy.result()))
150  with tf.contrib.summary.always_record_summaries():
151    tf.contrib.summary.scalar('loss', avg_loss.result())
152    tf.contrib.summary.scalar('accuracy', accuracy.result())
153
154
155def load_data(data_dir):
156  """Returns training and test tf.data.Dataset objects."""
157  data = input_data.read_data_sets(data_dir, one_hot=True)
158  train_ds = tf.data.Dataset.from_tensor_slices((data.train.images,
159                                                 data.train.labels))
160  test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels))
161  return (train_ds, test_ds)
162
163
164def main(_):
165  tfe.enable_eager_execution()
166
167  (device, data_format) = ('/gpu:0', 'channels_first')
168  if FLAGS.no_gpu or tfe.num_gpus() <= 0:
169    (device, data_format) = ('/cpu:0', 'channels_last')
170  print('Using device %s, and data format %s.' % (device, data_format))
171
172  # Load the datasets
173  (train_ds, test_ds) = load_data(FLAGS.data_dir)
174  train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size)
175
176  # Create the model and optimizer
177  model = MNISTModel(data_format)
178  optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum)
179
180  if FLAGS.output_dir:
181    train_dir = os.path.join(FLAGS.output_dir, 'train')
182    test_dir = os.path.join(FLAGS.output_dir, 'eval')
183    tf.gfile.MakeDirs(FLAGS.output_dir)
184  else:
185    train_dir = None
186    test_dir = None
187  summary_writer = tf.contrib.summary.create_file_writer(
188      train_dir, flush_millis=10000)
189  test_summary_writer = tf.contrib.summary.create_file_writer(
190      test_dir, flush_millis=10000, name='test')
191  checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
192
193  with tf.device(device):
194    for epoch in range(1, 11):
195      with tfe.restore_variables_on_create(
196          tf.train.latest_checkpoint(FLAGS.checkpoint_dir)):
197        global_step = tf.train.get_or_create_global_step()
198        start = time.time()
199        with summary_writer.as_default():
200          train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval)
201        end = time.time()
202        print('\nTrain time for epoch #%d (global step %d): %f' % (
203            epoch, global_step.numpy(), end - start))
204      with test_summary_writer.as_default():
205        test(model, test_ds)
206      all_variables = (
207          model.variables
208          + optimizer.variables()
209          + [global_step])
210      tfe.Saver(all_variables).save(
211          checkpoint_prefix, global_step=global_step)
212
213
214if __name__ == '__main__':
215  parser = argparse.ArgumentParser()
216  parser.add_argument(
217      '--data-dir',
218      type=str,
219      default='/tmp/tensorflow/mnist/input_data',
220      help='Directory for storing input data')
221  parser.add_argument(
222      '--batch-size',
223      type=int,
224      default=64,
225      metavar='N',
226      help='input batch size for training (default: 64)')
227  parser.add_argument(
228      '--log-interval',
229      type=int,
230      default=10,
231      metavar='N',
232      help='how many batches to wait before logging training status')
233  parser.add_argument(
234      '--output_dir',
235      type=str,
236      default=None,
237      metavar='N',
238      help='Directory to write TensorBoard summaries')
239  parser.add_argument(
240      '--checkpoint_dir',
241      type=str,
242      default='/tmp/tensorflow/mnist/checkpoints/',
243      metavar='N',
244      help='Directory to save checkpoints in (once per epoch)')
245  parser.add_argument(
246      '--lr',
247      type=float,
248      default=0.01,
249      metavar='LR',
250      help='learning rate (default: 0.01)')
251  parser.add_argument(
252      '--momentum',
253      type=float,
254      default=0.5,
255      metavar='M',
256      help='SGD momentum (default: 0.5)')
257  parser.add_argument(
258      '--no-gpu',
259      action='store_true',
260      default=False,
261      help='disables GPU usage even if a GPU is available')
262
263  FLAGS, unparsed = parser.parse_known_args()
264  tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
265