1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Defines all the new composite ops used in the mnist example.""" 15 16# pylint: disable=g-direct-tensorflow-import 17# pylint: disable=missing-function-docstring 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import os 24import sys 25 26import tensorflow as tf 27 28from tensorflow.compiler.mlir.tfr.python import composite 29from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op 30from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module 31from tensorflow.python.ops import gen_math_ops 32from tensorflow.python.ops import gen_nn_ops 33from tensorflow.python.ops import math_ops 34from tensorflow.python.platform import app 35from tensorflow.python.platform import flags 36 37Composite = composite.Composite 38FLAGS = flags.FLAGS 39 40flags.DEFINE_string( 41 'output', None, 42 'Path to write the genereated register op file and MLIR file.') 43 44flags.DEFINE_bool('gen_register_op', True, 45 'Generate register op cc file or tfr mlir file.') 46 47 48@Composite( 49 'NewConv2D', 50 inputs=['input_: T', 'filter_: T', 'bias: T'], 51 attrs=[ 52 'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int', 53 'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""' 54 ], 55 derived_attrs=['T: {float, int8}'], 56 outputs=['o: T']) 57def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h, 58 dilation_w, dilation_h, padding, act): 59 res = tf.raw_ops.Conv2D( 60 input=input_, 61 filter=filter_, 62 strides=[1, stride_w, stride_h, 1], 63 dilations=[1, dilation_w, dilation_h, 1], 64 padding=padding) 65 res = tf.raw_ops.Add(x=res, y=bias) 66 if act == 'RELU': 67 return tf.raw_ops.Relu(features=res) 68 elif act == 'RELU6': 69 return tf.raw_ops.Relu6(features=res) 70 elif act == 'TANH': 71 return tf.raw_ops.Tanh(x=res) 72 else: 73 return res 74 75 76@tf.RegisterGradient('NewConv2D') 77def _conv_add_relu_grad(op, grad): 78 act = op.get_attr('act') 79 y = op.outputs[0] 80 if act == 'RELU': 81 grad = gen_nn_ops.relu_grad(grad, y) 82 elif act == 'RELU6': 83 grad = gen_nn_ops.relu6_grad(grad, y) 84 elif act == 'TANH': 85 y = math_ops.conj(y) 86 grad = gen_math_ops.tanh_grad(y, grad) 87 88 broadcast_shape = tf.shape(y) 89 input_value_shape = tf.shape(op.inputs[2]) 90 _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( 91 s0=broadcast_shape, s1=input_value_shape) 92 updates_grad_reshaped = tf.reduce_sum( 93 grad, axis=reduction_axes, keepdims=True) 94 bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) 95 96 dilations = [1, op.get_attr('dilation_w'), op.get_attr('dilation_h'), 1] 97 strides = [1, op.get_attr('stride_w'), op.get_attr('stride_h'), 1] 98 padding = op.get_attr('padding') 99 shape_0, shape_1 = tf.shape_n([op.inputs[0], op.inputs[1]]) 100 return [ 101 tf.compat.v1.nn.conv2d_backprop_input( 102 shape_0, 103 op.inputs[1], 104 grad, 105 strides=strides, 106 padding=padding, 107 dilations=dilations, 108 data_format='NHWC'), 109 tf.compat.v1.nn.conv2d_backprop_filter( 110 op.inputs[0], 111 shape_1, 112 grad, 113 strides=strides, 114 padding=padding, 115 dilations=dilations, 116 data_format='NHWC'), bias_grad 117 ] 118 119 120@Composite( 121 'NewFullyConnected', 122 inputs=['input_: T', 'filter_: T', 'bias: T'], 123 attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'], 124 derived_attrs=['T: {float, int8}'], 125 outputs=['o: T']) 126def _composite_fully_connected(input_, filter_, bias, act): 127 res = tf.raw_ops.MatMul( 128 a=input_, b=filter_, transpose_a=False, transpose_b=True) 129 res = tf.raw_ops.Add(x=res, y=bias) 130 if act == 'RELU': 131 return tf.raw_ops.Relu(features=res) 132 elif act == 'RELU6': 133 return tf.raw_ops.Relu6(features=res) 134 elif act == 'TANH': 135 return tf.raw_ops.Tanh(x=res) 136 else: 137 return res 138 139 140@tf.RegisterGradient('NewFullyConnected') 141def _fully_connected_grad(op, grad): 142 act = op.get_attr('act') 143 y = op.outputs[0] 144 if act == 'RELU': 145 grad = gen_nn_ops.relu_grad(grad, y) 146 elif act == 'RELU6': 147 grad = gen_nn_ops.relu6_grad(grad, y) 148 elif act == 'TANH': 149 y = math_ops.conj(y) 150 grad = gen_math_ops.tanh_grad(y, grad) 151 152 broadcast_shape = tf.shape(y) 153 input_value_shape = tf.shape(op.inputs[2]) 154 _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( 155 s0=broadcast_shape, s1=input_value_shape) 156 updates_grad_reshaped = tf.reduce_sum( 157 grad, axis=reduction_axes, keepdims=True) 158 bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) 159 160 a = math_ops.conj(op.inputs[0]) 161 b = math_ops.conj(op.inputs[1]) 162 grad_a = gen_math_ops.mat_mul(grad, b) 163 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) 164 return [grad_a, grad_b, bias_grad] 165 166 167@Composite( 168 'NewMaxPool', 169 inputs=['input_: T'], 170 attrs=[ 171 'stride_w: int', 'stride_h: int', 'filter_width: int', 172 'filter_height: int', 'padding: {"SAME", "VALID"}' 173 ], 174 derived_attrs=['T: {float, int8}'], 175 outputs=['o: T']) 176def _composite_max_pool(input_, stride_w, stride_h, filter_width, filter_height, 177 padding): 178 ksize = [1, filter_width, filter_height, 1] 179 strides = [1, stride_w, stride_h, 1] 180 return tf.raw_ops.MaxPool( 181 input=input_, ksize=ksize, strides=strides, padding=padding) 182 183 184@tf.RegisterGradient('NewMaxPool') 185def _max_pool_grad(op, grad): 186 filter_width = op.get_attr('filter_width') 187 filter_height = op.get_attr('filter_height') 188 stride_w = op.get_attr('stride_w') 189 stride_h = op.get_attr('stride_h') 190 padding = op.get_attr('padding') 191 return tf.raw_ops.MaxPoolGrad( 192 orig_input=op.inputs[0], 193 orig_output=op.outputs[0], 194 grad=grad, 195 ksize=[1, filter_width, filter_height, 1], 196 strides=[1, stride_w, stride_h, 1], 197 padding=padding, 198 data_format='NHWC') 199 200 201def main(_): 202 if FLAGS.gen_register_op: 203 assert FLAGS.output.endswith('.cc') 204 generated_code = gen_register_op(sys.modules[__name__], '_composite_') 205 else: 206 assert FLAGS.output.endswith('.mlir') 207 generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_',) 208 209 dirname = os.path.dirname(FLAGS.output) 210 if not os.path.exists(dirname): 211 os.makedirs(dirname) 212 with open(FLAGS.output, 'w') as f: 213 f.write(generated_code) 214 215 216if __name__ == '__main__': 217 app.run(main=main) 218