1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines all the new composite ops used in the mnist example."""
15
16# pylint: disable=g-direct-tensorflow-import
17# pylint: disable=missing-function-docstring
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import os
24import sys
25
26import tensorflow as tf
27
28from tensorflow.compiler.mlir.tfr.python import composite
29from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
30from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
31from tensorflow.python.ops import gen_math_ops
32from tensorflow.python.ops import gen_nn_ops
33from tensorflow.python.ops import math_ops
34from tensorflow.python.platform import app
35from tensorflow.python.platform import flags
36
37Composite = composite.Composite
38FLAGS = flags.FLAGS
39
40flags.DEFINE_string(
41    'output', None,
42    'Path to write the genereated register op file and MLIR file.')
43
44flags.DEFINE_bool('gen_register_op', True,
45                  'Generate register op cc file or tfr mlir file.')
46
47
48@Composite(
49    'NewConv2D',
50    inputs=['input_: T', 'filter_: T', 'bias: T'],
51    attrs=[
52        'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int',
53        'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""'
54    ],
55    derived_attrs=['T: {float, int8}'],
56    outputs=['o: T'])
57def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h,
58                             dilation_w, dilation_h, padding, act):
59  res = tf.raw_ops.Conv2D(
60      input=input_,
61      filter=filter_,
62      strides=[1, stride_w, stride_h, 1],
63      dilations=[1, dilation_w, dilation_h, 1],
64      padding=padding)
65  res = tf.raw_ops.Add(x=res, y=bias)
66  if act == 'RELU':
67    return tf.raw_ops.Relu(features=res)
68  elif act == 'RELU6':
69    return tf.raw_ops.Relu6(features=res)
70  elif act == 'TANH':
71    return tf.raw_ops.Tanh(x=res)
72  else:
73    return res
74
75
76@tf.RegisterGradient('NewConv2D')
77def _conv_add_relu_grad(op, grad):
78  act = op.get_attr('act')
79  y = op.outputs[0]
80  if act == 'RELU':
81    grad = gen_nn_ops.relu_grad(grad, y)
82  elif act == 'RELU6':
83    grad = gen_nn_ops.relu6_grad(grad, y)
84  elif act == 'TANH':
85    y = math_ops.conj(y)
86    grad = gen_math_ops.tanh_grad(y, grad)
87
88  broadcast_shape = tf.shape(y)
89  input_value_shape = tf.shape(op.inputs[2])
90  _, reduction_axes = tf.raw_ops.BroadcastGradientArgs(
91      s0=broadcast_shape, s1=input_value_shape)
92  updates_grad_reshaped = tf.reduce_sum(
93      grad, axis=reduction_axes, keepdims=True)
94  bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape)
95
96  dilations = [1, op.get_attr('dilation_w'), op.get_attr('dilation_h'), 1]
97  strides = [1, op.get_attr('stride_w'), op.get_attr('stride_h'), 1]
98  padding = op.get_attr('padding')
99  shape_0, shape_1 = tf.shape_n([op.inputs[0], op.inputs[1]])
100  return [
101      tf.compat.v1.nn.conv2d_backprop_input(
102          shape_0,
103          op.inputs[1],
104          grad,
105          strides=strides,
106          padding=padding,
107          dilations=dilations,
108          data_format='NHWC'),
109      tf.compat.v1.nn.conv2d_backprop_filter(
110          op.inputs[0],
111          shape_1,
112          grad,
113          strides=strides,
114          padding=padding,
115          dilations=dilations,
116          data_format='NHWC'), bias_grad
117  ]
118
119
120@Composite(
121    'NewFullyConnected',
122    inputs=['input_: T', 'filter_: T', 'bias: T'],
123    attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'],
124    derived_attrs=['T: {float, int8}'],
125    outputs=['o: T'])
126def _composite_fully_connected(input_, filter_, bias, act):
127  res = tf.raw_ops.MatMul(
128      a=input_, b=filter_, transpose_a=False, transpose_b=True)
129  res = tf.raw_ops.Add(x=res, y=bias)
130  if act == 'RELU':
131    return tf.raw_ops.Relu(features=res)
132  elif act == 'RELU6':
133    return tf.raw_ops.Relu6(features=res)
134  elif act == 'TANH':
135    return tf.raw_ops.Tanh(x=res)
136  else:
137    return res
138
139
140@tf.RegisterGradient('NewFullyConnected')
141def _fully_connected_grad(op, grad):
142  act = op.get_attr('act')
143  y = op.outputs[0]
144  if act == 'RELU':
145    grad = gen_nn_ops.relu_grad(grad, y)
146  elif act == 'RELU6':
147    grad = gen_nn_ops.relu6_grad(grad, y)
148  elif act == 'TANH':
149    y = math_ops.conj(y)
150    grad = gen_math_ops.tanh_grad(y, grad)
151
152  broadcast_shape = tf.shape(y)
153  input_value_shape = tf.shape(op.inputs[2])
154  _, reduction_axes = tf.raw_ops.BroadcastGradientArgs(
155      s0=broadcast_shape, s1=input_value_shape)
156  updates_grad_reshaped = tf.reduce_sum(
157      grad, axis=reduction_axes, keepdims=True)
158  bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape)
159
160  a = math_ops.conj(op.inputs[0])
161  b = math_ops.conj(op.inputs[1])
162  grad_a = gen_math_ops.mat_mul(grad, b)
163  grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
164  return [grad_a, grad_b, bias_grad]
165
166
167@Composite(
168    'NewMaxPool',
169    inputs=['input_: T'],
170    attrs=[
171        'stride_w: int', 'stride_h: int', 'filter_width: int',
172        'filter_height: int', 'padding: {"SAME", "VALID"}'
173    ],
174    derived_attrs=['T: {float, int8}'],
175    outputs=['o: T'])
176def _composite_max_pool(input_, stride_w, stride_h, filter_width, filter_height,
177                        padding):
178  ksize = [1, filter_width, filter_height, 1]
179  strides = [1, stride_w, stride_h, 1]
180  return tf.raw_ops.MaxPool(
181      input=input_, ksize=ksize, strides=strides, padding=padding)
182
183
184@tf.RegisterGradient('NewMaxPool')
185def _max_pool_grad(op, grad):
186  filter_width = op.get_attr('filter_width')
187  filter_height = op.get_attr('filter_height')
188  stride_w = op.get_attr('stride_w')
189  stride_h = op.get_attr('stride_h')
190  padding = op.get_attr('padding')
191  return tf.raw_ops.MaxPoolGrad(
192      orig_input=op.inputs[0],
193      orig_output=op.outputs[0],
194      grad=grad,
195      ksize=[1, filter_width, filter_height, 1],
196      strides=[1, stride_w, stride_h, 1],
197      padding=padding,
198      data_format='NHWC')
199
200
201def main(_):
202  if FLAGS.gen_register_op:
203    assert FLAGS.output.endswith('.cc')
204    generated_code = gen_register_op(sys.modules[__name__], '_composite_')
205  else:
206    assert FLAGS.output.endswith('.mlir')
207    generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_',)
208
209  dirname = os.path.dirname(FLAGS.output)
210  if not os.path.exists(dirname):
211    os.makedirs(dirname)
212  with open(FLAGS.output, 'w') as f:
213    f.write(generated_code)
214
215
216if __name__ == '__main__':
217  app.run(main=main)
218