1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Tests for tensorflow.compiler.mlir.tfr.examples.mnist.ops_defs.""" 15 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import os 21import tensorflow as tf 22 23from tensorflow.compiler.mlir.tfr.examples.mnist import gen_mnist_ops 24from tensorflow.compiler.mlir.tfr.examples.mnist import ops_defs 25from tensorflow.compiler.mlir.tfr.python import test_utils 26from tensorflow.python.framework import load_library 27from tensorflow.python.platform import test 28 29_lib_dir = os.path.dirname(gen_mnist_ops.__file__) 30_lib_name = os.path.basename(gen_mnist_ops.__file__)[4:].replace('.py', '.so') 31load_library.load_op_library(os.path.join(_lib_dir, _lib_name)) 32 33 34class MnistOpsDefsTest(test_utils.OpsDefsTest): 35 36 def test_new_conv2d_relu(self): 37 input_ = tf.random.uniform([1, 4, 4, 1]) 38 filter_ = tf.random.uniform([2, 2, 1, 8]) 39 bias = tf.zeros([8]) 40 kwargs = { 41 'input_': input_, 42 'filter_': filter_, 43 'bias': bias, 44 'stride_w': 2, 45 'stride_h': 2, 46 'dilation_w': 1, 47 'dilation_h': 1, 48 'padding': 'SAME', 49 'act': 'RELU' 50 } 51 52 self._assertOpAndComposite([input_, filter_, bias], 53 tf.function(gen_mnist_ops.new_conv2d), 54 ops_defs._composite_conv_add_relu, kwargs) 55 56 def test_new_conv2d_relu6(self): 57 input_ = tf.random.uniform([1, 4, 4, 1]) 58 filter_ = tf.random.uniform([2, 2, 1, 8]) 59 bias = tf.zeros([8]) 60 kwargs = { 61 'input_': input_, 62 'filter_': filter_, 63 'bias': bias, 64 'stride_w': 2, 65 'stride_h': 2, 66 'dilation_w': 1, 67 'dilation_h': 1, 68 'padding': 'SAME', 69 'act': 'RELU6' 70 } 71 72 self._assertOpAndComposite([input_, filter_, bias], 73 tf.function(gen_mnist_ops.new_conv2d), 74 ops_defs._composite_conv_add_relu, kwargs) 75 76 def test_new_conv2d_tanh(self): 77 self.skipTest('Fix tanh gradients') 78 input_ = tf.random.uniform([1, 4, 4, 1]) 79 filter_ = tf.random.uniform([2, 2, 1, 8]) 80 bias = tf.zeros([8]) 81 kwargs = { 82 'input_': input_, 83 'filter_': filter_, 84 'bias': bias, 85 'stride_w': 2, 86 'stride_h': 2, 87 'dilation_w': 1, 88 'dilation_h': 1, 89 'padding': 'SAME', 90 'act': 'TANH' 91 } 92 93 self._assertOpAndComposite([input_, filter_, bias], 94 tf.function(gen_mnist_ops.new_conv2d), 95 ops_defs._composite_conv_add_relu, kwargs) 96 97 def test_new_fully_connected(self): 98 input_ = tf.random.uniform([2, 4]) 99 filter_ = tf.random.uniform([3, 4]) 100 bias = tf.zeros([3]) 101 kwargs = {'input_': input_, 'filter_': filter_, 'bias': bias, 'act': 'RELU'} 102 103 self._assertOpAndComposite([input_, filter_, bias], 104 tf.function(gen_mnist_ops.new_fully_connected), 105 ops_defs._composite_fully_connected, kwargs) 106 107 def test_new_max_pool(self): 108 input_ = tf.random.uniform([8, 4, 4, 1]) 109 kwargs = { 110 'input_': input_, 111 'stride_w': 2, 112 'stride_h': 2, 113 'filter_width': 1, 114 'filter_height': 1, 115 'padding': 'SAME', 116 } 117 118 self._assertOpAndComposite([input_], 119 tf.function(gen_mnist_ops.new_max_pool), 120 ops_defs._composite_max_pool, kwargs) 121 122 123if __name__ == '__main__': 124 os.environ[ 125 'TF_MLIR_TFR_LIB_DIR'] = 'tensorflow/compiler/mlir/tfr/examples/mnist' 126 test.main() 127