1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for `mlir_gen` module""" 16 17# pylint: disable=missing-function-docstring 18# pylint: disable=invalid-name 19 20from __future__ import absolute_import 21from __future__ import division 22from __future__ import print_function 23 24from tensorflow.python.platform import test 25from tensorflow.python.types import core 26from tensorflow.python.tf_program.mlir_gen import mlir_gen 27 28import tensorflow.compiler.mlir.python.mlir_wrapper.filecheck_wrapper as fw 29 30 31class MLIRGenTestBase(test.TestCase): 32 33 def _check_code(self, mlir_code, exp_mlir_code): 34 return self.assertTrue(fw.check(str(mlir_code), exp_mlir_code)) 35 36 37class MLIRGenTest(MLIRGenTestBase): 38 """MLIR Generation Tests for Tensorflow Program""" 39 40 def test_simple(self): 41 42 def test_fn(): 43 pass 44 45 mlir_code = mlir_gen(test_fn) 46 mlir_code_exp = r""" 47 CHECK-LABEL: @test_fn 48 """ 49 self._check_code(mlir_code, mlir_code_exp) 50 51 def test_argument(self): 52 53 def test_fn(x: core.Tensor) -> core.Tensor: 54 return x 55 56 mlir_code = mlir_gen(test_fn) 57 mlir_code_exp = r""" 58 CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> { 59 CHECK-NEXT: return %arg0 : tensor<*xi32> 60 """ 61 self._check_code(mlir_code, mlir_code_exp) 62 63 def test_constant(self): 64 65 def test_fn() -> int: 66 return 23 67 68 mlir_code = mlir_gen(test_fn) 69 exp_mlir_code = r""" 70 CHECK-LABEL: func @test_fn() -> i32 71 CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32> 72 CHECK: return %[[r0]] : tensor<i32> 73 """ 74 self._check_code(mlir_code, exp_mlir_code) 75 76 def test_BoolOp(self): 77 78 def test_fn(x: bool, y: bool) -> bool: 79 return x or y or x and x and y 80 81 mlir_code = mlir_gen(test_fn) 82 exp_mlir_code = r""" 83 CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1 84 CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1> 85 CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1> 86 CHECK: return %[[r1]] : tensor<*xi1> 87 """ 88 self._check_code(mlir_code, exp_mlir_code) 89 90 def test_Call(self): 91 92 def test_fn(): 93 94 def f1(): 95 return 23 96 97 def f2(): 98 return f1() 99 100 f2() 101 102 mlir_code = mlir_gen(test_fn) 103 exp_mlir_code = r""" 104 CHECK-LABEL: func @test_fn() 105 CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f2} : () -> () 106 CHECK: } 107 CHECK-LABEL: func @f1() { 108 CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32> 109 CHECK: return %[[r0]] : tensor<i32> 110 CHECK: } 111 CHECK-LABEL: func @f2() { 112 CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f1} : () -> () 113 } 114 """ 115 self._check_code(mlir_code, exp_mlir_code) 116 117 def test_Compare(self): 118 119 def test_fn(x: core.Tensor, y: core.Tensor, z: core.Tensor): 120 return x > y < z 121 122 mlir_code = mlir_gen(test_fn) 123 exp_mlir_code = r""" 124 CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>) 125 CHECK: %[[r0:[0-9]+]] = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1> 126 CHECK: %[[r1:[0-9]+]] = "tf.Less"(%[[r0]], %arg2) : (tensor<*xi1>, tensor<*xi32>) -> tensor<*xi1> 127 CHECK: return %[[r1]] : tensor<*xi1> 128 """ 129 self._check_code(mlir_code, exp_mlir_code) 130 131 def test_Assign_BinOp(self): 132 133 def test_fn() -> int: 134 y = 12 + 23 - 24 135 return y 136 137 mlir_code = mlir_gen(test_fn) 138 exp_mlir_code = r""" 139 CHECK-LABEL: func @test_fn() -> i32 140 CHECK: %[[r0:[0-9]+]] = "tf.AddV2"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32> 141 CHECK: %[[r1:[0-9]+]] = "tf.Sub"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32> 142 CHECK: return %[[r1]] : tensor<i32> 143 """ 144 self._check_code(mlir_code, exp_mlir_code) 145 146 def test_if(self): 147 148 def test_fn(x: core.Tensor) -> int: 149 res = 0 150 if x > 0: 151 res = 1 152 elif x < 0: 153 res = -1 154 else: 155 res = 0 156 return res 157 158 mlir_code = mlir_gen(test_fn) 159 exp_mlir_code = r""" 160 CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32 161 162 CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 163 CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( { 164 CHECK: return %{{[0-9]+}} : tensor<i32> 165 CHECK-NEXT: }, { 166 CHECK: %[[r3:[0-9]+]] = "tf.Less"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 167 CHECK: %[[r4:[0-9]+]] = "tfp.If"(%[[r3]]) ( { 168 CHECK: %[[r5:[0-9]+]] = "tf.Neg"(%{{[0-9]+}}) : (tensor<i32>) -> tensor<i32> 169 CHECK: return %[[r5]] : tensor<i32> 170 CHECK-NEXT: }, { 171 CHECK: return %{{[0-9]+}} : tensor<i32> 172 CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32> 173 CHECK: return %[[r4]] : tensor<i32> 174 CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32> 175 CHECK-NEXT: return %[[r2]] : tensor<i32> 176 """ 177 self._check_code(mlir_code, exp_mlir_code) 178 179 def test_while(self): 180 181 def test_fn(x: core.Tensor) -> core.Tensor: 182 s = 0 183 while x > 0: 184 s = s + x 185 return s 186 187 mlir_code = mlir_gen(test_fn) 188 exp_mlir_code = r""" 189 CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> 190 191 CHECK: %[[r1:[0-9]+]] = "tfp.While"(%0) ( { 192 CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>): 193 CHECK: %[[r2:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 194 CHECK-NEXT: return %[[r2]] : tensor<*xi1> 195 CHECK-NEXT: }, { 196 CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>): 197 CHECK: %[[r3:[0-9]+]] = "tf.AddV2"(%arg1, %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32> 198 CHECK-NEXT: return %[[r3]] : tensor<*xi32> 199 CHECK-NEXT: }) : (tensor<i32>) -> tensor<i32> 200 CHECK-NEXT: return %[[r1]] : tensor<i32> 201 """ 202 self._check_code(mlir_code, exp_mlir_code) 203 204 def test_fibonacci(self): 205 206 def test_fn(x: core.Tensor) -> core.Tensor: 207 res, idx = 0, 2 208 a, b = 0, 1 209 if x == 0 or x == 1: 210 res = x 211 else: 212 while idx <= x: 213 res = a + b 214 a = b 215 b = res 216 idx = idx + 1 217 return res 218 219 mlir_code = mlir_gen(test_fn) 220 exp_mlir_code = r""" 221 CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> 222 CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 223 CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1> 224 CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1> 225 226 CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( { 227 CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32> 228 CHECK-NEXT: }, { 229 CHECK-NEXT: %[[r10:[0-9]+]]:4 = "tfp.While"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( { 230 CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): 231 CHECK-NEXT: %[[r11:[0-9]+]] = "tf.LessEqual"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>) -> tensor<*xi1> 232 CHECK-NEXT: return %[[r11]] : tensor<*xi1> 233 CHECK-NEXT: }, { 234 CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>): 235 CHECK-NEXT: %[[r12:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32> 236 CHECK: %[[r13:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32> 237 CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32> 238 CHECK-NEXT: }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) 239 CHECK-NEXT: return %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32> 240 CHECK-NEXT: }) : (tensor<*xi1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) 241 CHECK-NEXT: return %[[r9]]#{{[0-9]+}} : tensor<i32> 242 """ 243 self._check_code(mlir_code, exp_mlir_code) 244 245 246if __name__ == '__main__': 247 test.main() 248