1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Tests for tensorflow.compiler.mlir.tfr.examples.mnist.ops_defs."""
15
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import os
21import tensorflow as tf
22
23from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops
24from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs
25from tensorflow.compiler.mlir.tfr.python import test_utils
26from tensorflow.python.framework import load_library
27from tensorflow.python.platform import test
28
29_lib_dir = os.path.dirname(gen_mnist_ops.__file__)
30_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so')
31load_library.load_op_library(os.path.join(_lib_dir, _lib_name))
32
33
34class MnistOpsDefsTest(test_utils.OpsDefsTest):
35
36  def test_new_conv2d_relu(self):
37    input_ = tf.random.uniform([1, 4, 4, 1])
38    filter_ = tf.random.uniform([2, 2, 1, 8])
39    bias = tf.zeros([8])
40    kwargs = {
41        'input_': input_,
42        'filter_': filter_,
43        'bias': bias,
44        'stride_w': 2,
45        'stride_h': 2,
46        'dilation_w': 1,
47        'dilation_h': 1,
48        'padding': 'SAME',
49        'act': 'RELU'
50    }
51
52    self._assertOpAndComposite([input_, filter_, bias],
53                               tf.function(gen_mnist_ops.new_conv2d),
54                               ops_defs._composite_conv_add_relu, kwargs)
55
56  def test_new_conv2d_relu6(self):
57    input_ = tf.random.uniform([1, 4, 4, 1])
58    filter_ = tf.random.uniform([2, 2, 1, 8])
59    bias = tf.zeros([8])
60    kwargs = {
61        'input_': input_,
62        'filter_': filter_,
63        'bias': bias,
64        'stride_w': 2,
65        'stride_h': 2,
66        'dilation_w': 1,
67        'dilation_h': 1,
68        'padding': 'SAME',
69        'act': 'RELU6'
70    }
71
72    self._assertOpAndComposite([input_, filter_, bias],
73                               tf.function(gen_mnist_ops.new_conv2d),
74                               ops_defs._composite_conv_add_relu, kwargs)
75
76  def test_new_conv2d_tanh(self):
77    self.skipTest('Fix tanh gradients')
78    input_ = tf.random.uniform([1, 4, 4, 1])
79    filter_ = tf.random.uniform([2, 2, 1, 8])
80    bias = tf.zeros([8])
81    kwargs = {
82        'input_': input_,
83        'filter_': filter_,
84        'bias': bias,
85        'stride_w': 2,
86        'stride_h': 2,
87        'dilation_w': 1,
88        'dilation_h': 1,
89        'padding': 'SAME',
90        'act': 'TANH'
91    }
92
93    self._assertOpAndComposite([input_, filter_, bias],
94                               tf.function(gen_mnist_ops.new_conv2d),
95                               ops_defs._composite_conv_add_relu, kwargs)
96
97  def test_new_fully_connected(self):
98    input_ = tf.random.uniform([2, 4])
99    filter_ = tf.random.uniform([3, 4])
100    bias = tf.zeros([3])
101    kwargs = {'input_': input_, 'filter_': filter_, 'bias': bias, 'act': 'RELU'}
102
103    self._assertOpAndComposite([input_, filter_, bias],
104                               tf.function(gen_mnist_ops.new_fully_connected),
105                               ops_defs._composite_fully_connected, kwargs)
106
107  def test_new_max_pool(self):
108    input_ = tf.random.uniform([8, 4, 4, 1])
109    kwargs = {
110        'input_': input_,
111        'stride_w': 2,
112        'stride_h': 2,
113        'filter_width': 1,
114        'filter_height': 1,
115        'padding': 'SAME',
116    }
117
118    self._assertOpAndComposite([input_],
119                               tf.function(gen_mnist_ops.new_max_pool),
120                               ops_defs._composite_max_pool, kwargs)
121
122
123if __name__ == '__main__':
124  os.environ[
125      'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist'
126  test.main()
127