1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""MNIST model float training script with TensorFlow graph execution."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21from absl import flags
22
23import tensorflow as tf
24import tensorflow_datasets as tfds
25from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops
26from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs  # pylint: disable=unused-import
27from tensorflow.python.framework import load_library
28
29flags.DEFINE_integer('train_steps', 200, 'Number of steps in training.')
30
31_lib_dir = os.path.dirname(gen_mnist_ops.__file__)
32_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so')
33load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
34
35# MNIST dataset parameters.
36num_classes = 10  # total classes (0-9 digits).
37num_features = 784  # data features (img shape: 28*28).
38num_channels = 1
39
40# Training parameters.
41learning_rate = 0.01
42display_step = 10
43batch_size = 128
44
45# Network parameters.
46n_hidden_1 = 32  # 1st conv layer number of neurons.
47n_hidden_2 = 64  # 2nd conv layer number of neurons.
48n_hidden_3 = 1024  # 1st fully connected layer of neurons.
49flatten_size = num_features // 16 * n_hidden_2
50
51seed = 66478
52
53
54class FloatModel(tf.Module):
55  """Float inference for mnist model."""
56
57  def __init__(self):
58    self.weights = {
59        'f1':
60            tf.Variable(
61                tf.random.truncated_normal([5, 5, num_channels, n_hidden_1],
62                                           stddev=0.1,
63                                           seed=seed)),
64        'f2':
65            tf.Variable(
66                tf.random.truncated_normal([5, 5, n_hidden_1, n_hidden_2],
67                                           stddev=0.1,
68                                           seed=seed)),
69        'f3':
70            tf.Variable(
71                tf.random.truncated_normal([n_hidden_3, flatten_size],
72                                           stddev=0.1,
73                                           seed=seed)),
74        'f4':
75            tf.Variable(
76                tf.random.truncated_normal([num_classes, n_hidden_3],
77                                           stddev=0.1,
78                                           seed=seed)),
79    }
80
81    self.biases = {
82        'b1': tf.Variable(tf.zeros([n_hidden_1])),
83        'b2': tf.Variable(tf.zeros([n_hidden_2])),
84        'b3': tf.Variable(tf.zeros([n_hidden_3])),
85        'b4': tf.Variable(tf.zeros([num_classes])),
86    }
87
88  @tf.function
89  def __call__(self, data):
90    """The Model definition."""
91    x = tf.reshape(data, [-1, 28, 28, 1])
92
93    # 2D convolution, with 'SAME' padding (i.e. the output feature map has
94    # the same size as the input).
95
96    # NOTE: The data/x/input is always specified in floating point precision.
97    # output shape: [-1, 28, 28, 32]
98    conv1 = gen_mnist_ops.new_conv2d(x, self.weights['f1'], self.biases['b1'],
99                                     1, 1, 1, 1, 'SAME', 'RELU')
100
101    # Max pooling. The kernel size spec {ksize} also follows the layout of
102    # the data. Here we have a pooling window of 2, and a stride of 2.
103    # output shape: [-1, 14, 14, 32]
104    max_pool1 = gen_mnist_ops.new_max_pool(conv1, 2, 2, 2, 2, 'SAME')
105
106    # output shape: [-1, 14, 14, 64]
107    conv2 = gen_mnist_ops.new_conv2d(max_pool1, self.weights['f2'],
108                                     self.biases['b2'], 1, 1, 1, 1, 'SAME',
109                                     'RELU')
110
111    # output shape: [-1, 7, 7, 64]
112    max_pool2 = gen_mnist_ops.new_max_pool(conv2, 2, 2, 2, 2, 'SAME')
113
114    # Reshape the feature map cuboid into a 2D matrix to feed it to the
115    # fully connected layers.
116    # output shape: [-1, 7*7*64]
117    reshape = tf.reshape(max_pool2, [-1, flatten_size])
118
119    # output shape: [-1, 1024]
120    fc1 = gen_mnist_ops.new_fully_connected(reshape, self.weights['f3'],
121                                            self.biases['b3'], 'RELU')
122    # output shape: [-1, 10]
123    return gen_mnist_ops.new_fully_connected(fc1, self.weights['f4'],
124                                             self.biases['b4'])
125
126
127def main(strategy):
128  """Trains an MNIST model using the given tf.distribute.Strategy."""
129  # TODO(fengliuai): put this in some automatically generated code.
130  os.environ[
131      'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist'
132
133  ds_train = tfds.load('mnist', split='train', shuffle_files=True)
134  ds_train = ds_train.shuffle(1024).batch(batch_size).prefetch(64)
135  ds_train = strategy.experimental_distribute_dataset(ds_train)
136
137  with strategy.scope():
138    # Create an mnist float model with the specified float state.
139    model = FloatModel()
140    optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)
141
142  def train_step(features):
143    inputs = tf.image.convert_image_dtype(
144        features['image'], dtype=tf.float32, saturate=False)
145    labels = tf.one_hot(features['label'], num_classes)
146
147    with tf.GradientTape() as tape:
148      logits = model(inputs)
149      loss_value = tf.reduce_mean(
150          tf.nn.softmax_cross_entropy_with_logits(labels, logits))
151
152    grads = tape.gradient(loss_value, model.trainable_variables)
153    correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1))
154    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
155    optimizer.apply_gradients(zip(grads, model.trainable_variables))
156    return accuracy, loss_value
157
158  @tf.function
159  def distributed_train_step(dist_inputs):
160    per_replica_accuracy, per_replica_losses = strategy.run(
161        train_step, args=(dist_inputs,))
162    accuracy = strategy.reduce(
163        tf.distribute.ReduceOp.MEAN, per_replica_accuracy, axis=None)
164    loss_value = strategy.reduce(
165        tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None)
166    return accuracy, loss_value
167
168  iterator = iter(ds_train)
169  accuracy = 0.0
170  for step in range(flags.FLAGS.train_steps):
171    accuracy, loss_value = distributed_train_step(next(iterator))
172    if step % display_step == 0:
173      tf.print('Step %d:' % step)
174      tf.print('    Loss = %f' % loss_value)
175      tf.print('    Batch accuracy = %f' % accuracy)
176
177  return accuracy
178