1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""MNIST model float training script with TensorFlow graph execution.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21from absl import flags 22 23import tensorflow as tf 24import tensorflow_datasets as tfds 25from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops 26from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs # pylint: disable=unused-import 27from tensorflow.python.framework import load_library 28 29flags.DEFINE_integer('train_steps', 200, 'Number of steps in training.') 30 31_lib_dir = os.path.dirname(gen_mnist_ops.__file__) 32_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') 33load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) 34 35# MNIST dataset parameters. 36num_classes = 10 # total classes (0-9 digits). 37num_features = 784 # data features (img shape: 28*28). 38num_channels = 1 39 40# Training parameters. 41learning_rate = 0.01 42display_step = 10 43batch_size = 128 44 45# Network parameters. 46n_hidden_1 = 32 # 1st conv layer number of neurons. 47n_hidden_2 = 64 # 2nd conv layer number of neurons. 48n_hidden_3 = 1024 # 1st fully connected layer of neurons. 49flatten_size = num_features // 16 * n_hidden_2 50 51seed = 66478 52 53 54class FloatModel(tf.Module): 55 """Float inference for mnist model.""" 56 57 def __init__(self): 58 self.weights = { 59 'f1': 60 tf.Variable( 61 tf.random.truncated_normal([5, 5, num_channels, n_hidden_1], 62 stddev=0.1, 63 seed=seed)), 64 'f2': 65 tf.Variable( 66 tf.random.truncated_normal([5, 5, n_hidden_1, n_hidden_2], 67 stddev=0.1, 68 seed=seed)), 69 'f3': 70 tf.Variable( 71 tf.random.truncated_normal([n_hidden_3, flatten_size], 72 stddev=0.1, 73 seed=seed)), 74 'f4': 75 tf.Variable( 76 tf.random.truncated_normal([num_classes, n_hidden_3], 77 stddev=0.1, 78 seed=seed)), 79 } 80 81 self.biases = { 82 'b1': tf.Variable(tf.zeros([n_hidden_1])), 83 'b2': tf.Variable(tf.zeros([n_hidden_2])), 84 'b3': tf.Variable(tf.zeros([n_hidden_3])), 85 'b4': tf.Variable(tf.zeros([num_classes])), 86 } 87 88 @tf.function 89 def __call__(self, data): 90 """The Model definition.""" 91 x = tf.reshape(data, [-1, 28, 28, 1]) 92 93 # 2D convolution, with 'SAME' padding (i.e. the output feature map has 94 # the same size as the input). 95 96 # NOTE: The data/x/input is always specified in floating point precision. 97 # output shape: [-1, 28, 28, 32] 98 conv1 = gen_mnist_ops.new_conv2d(x, self.weights['f1'], self.biases['b1'], 99 1, 1, 1, 1, 'SAME', 'RELU') 100 101 # Max pooling. The kernel size spec {ksize} also follows the layout of 102 # the data. Here we have a pooling window of 2, and a stride of 2. 103 # output shape: [-1, 14, 14, 32] 104 max_pool1 = gen_mnist_ops.new_max_pool(conv1, 2, 2, 2, 2, 'SAME') 105 106 # output shape: [-1, 14, 14, 64] 107 conv2 = gen_mnist_ops.new_conv2d(max_pool1, self.weights['f2'], 108 self.biases['b2'], 1, 1, 1, 1, 'SAME', 109 'RELU') 110 111 # output shape: [-1, 7, 7, 64] 112 max_pool2 = gen_mnist_ops.new_max_pool(conv2, 2, 2, 2, 2, 'SAME') 113 114 # Reshape the feature map cuboid into a 2D matrix to feed it to the 115 # fully connected layers. 116 # output shape: [-1, 7*7*64] 117 reshape = tf.reshape(max_pool2, [-1, flatten_size]) 118 119 # output shape: [-1, 1024] 120 fc1 = gen_mnist_ops.new_fully_connected(reshape, self.weights['f3'], 121 self.biases['b3'], 'RELU') 122 # output shape: [-1, 10] 123 return gen_mnist_ops.new_fully_connected(fc1, self.weights['f4'], 124 self.biases['b4']) 125 126 127def main(strategy): 128 """Trains an MNIST model using the given tf.distribute.Strategy.""" 129 # TODO(fengliuai): put this in some automatically generated code. 130 os.environ[ 131 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' 132 133 ds_train = tfds.load('mnist', split='train', shuffle_files=True) 134 ds_train = ds_train.shuffle(1024).batch(batch_size).prefetch(64) 135 ds_train = strategy.experimental_distribute_dataset(ds_train) 136 137 with strategy.scope(): 138 # Create an mnist float model with the specified float state. 139 model = FloatModel() 140 optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate) 141 142 def train_step(features): 143 inputs = tf.image.convert_image_dtype( 144 features['image'], dtype=tf.float32, saturate=False) 145 labels = tf.one_hot(features['label'], num_classes) 146 147 with tf.GradientTape() as tape: 148 logits = model(inputs) 149 loss_value = tf.reduce_mean( 150 tf.nn.softmax_cross_entropy_with_logits(labels, logits)) 151 152 grads = tape.gradient(loss_value, model.trainable_variables) 153 correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(labels, 1)) 154 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) 155 optimizer.apply_gradients(zip(grads, model.trainable_variables)) 156 return accuracy, loss_value 157 158 @tf.function 159 def distributed_train_step(dist_inputs): 160 per_replica_accuracy, per_replica_losses = strategy.run( 161 train_step, args=(dist_inputs,)) 162 accuracy = strategy.reduce( 163 tf.distribute.ReduceOp.MEAN, per_replica_accuracy, axis=None) 164 loss_value = strategy.reduce( 165 tf.distribute.ReduceOp.MEAN, per_replica_losses, axis=None) 166 return accuracy, loss_value 167 168 iterator = iter(ds_train) 169 accuracy = 0.0 170 for step in range(flags.FLAGS.train_steps): 171 accuracy, loss_value = distributed_train_step(next(iterator)) 172 if step % display_step == 0: 173 tf.print('Step %d:' % step) 174 tf.print(' Loss = %f' % loss_value) 175 tf.print(' Batch accuracy = %f' % accuracy) 176 177 return accuracy 178