• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for `mlir_gen` module"""
16
17# pylint: disable=missing-function-docstring
18# pylint: disable=invalid-name
19
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24from tensorflow.python.platform import test
25from tensorflow.python.types import core
26from tensorflow.python.tf_program.mlir_gen import mlir_gen
27
28import tensorflow.compiler.mlir.python.mlir_wrapper.filecheck_wrapper as fw
29
30
31class MLIRGenTestBase(test.TestCase):
32
33  def _check_code(self, mlir_code, exp_mlir_code):
34    return self.assertTrue(fw.check(str(mlir_code), exp_mlir_code))
35
36
37class MLIRGenTest(MLIRGenTestBase):
38  """MLIR Generation Tests for Tensorflow Program"""
39
40  def test_simple(self):
41
42    def test_fn():
43      pass
44
45    mlir_code = mlir_gen(test_fn)
46    mlir_code_exp = r"""
47      CHECK-LABEL: @test_fn
48    """
49    self._check_code(mlir_code, mlir_code_exp)
50
51  def test_argument(self):
52
53    def test_fn(x: core.Tensor) -> core.Tensor:
54      return x
55
56    mlir_code = mlir_gen(test_fn)
57    mlir_code_exp = r"""
58      CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32> {
59        CHECK-NEXT: return %arg0 : tensor<*xi32>
60    """
61    self._check_code(mlir_code, mlir_code_exp)
62
63  def test_constant(self):
64
65    def test_fn() -> int:
66      return 23
67
68    mlir_code = mlir_gen(test_fn)
69    exp_mlir_code = r"""
70      CHECK-LABEL: func @test_fn() -> i32
71      CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32>
72      CHECK: return %[[r0]] : tensor<i32>
73    """
74    self._check_code(mlir_code, exp_mlir_code)
75
76  def test_BoolOp(self):
77
78    def test_fn(x: bool, y: bool) -> bool:
79      return x or y or x and x and y
80
81    mlir_code = mlir_gen(test_fn)
82    exp_mlir_code = r"""
83      CHECK-LABEL: func @test_fn(%arg0: i1, %arg1: i1) -> i1
84      CHECK: %[[r0:[0-9]+]] = "tfp.And"(%arg0, %arg0, %arg1) : (i1, i1, i1) -> tensor<*xi1>
85      CHECK: %[[r1:[0-9]+]] = "tfp.Or"(%arg0, %arg1, %[[r0]]) : (i1, i1, tensor<*xi1>) -> tensor<*xi1>
86      CHECK: return %[[r1]] : tensor<*xi1>
87    """
88    self._check_code(mlir_code, exp_mlir_code)
89
90  def test_Call(self):
91
92    def test_fn():
93
94      def f1():
95        return 23
96
97      def f2():
98        return f1()
99
100      f2()
101
102    mlir_code = mlir_gen(test_fn)
103    exp_mlir_code = r"""
104      CHECK-LABEL: func @test_fn()
105        CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f2} : () -> ()
106      CHECK: }
107      CHECK-LABEL: func @f1() {
108        CHECK: %[[r0:[0-9]+]] = "tf.Const"() {value = dense<23> : tensor<i32>} : () -> tensor<i32>
109        CHECK: return %[[r0]] : tensor<i32>
110      CHECK: }
111      CHECK-LABEL: func @f2() {
112        CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, f = @f1} : () -> ()
113      }
114    """
115    self._check_code(mlir_code, exp_mlir_code)
116
117  def test_Compare(self):
118
119    def test_fn(x: core.Tensor, y: core.Tensor, z: core.Tensor):
120      return x > y < z
121
122    mlir_code = mlir_gen(test_fn)
123    exp_mlir_code = r"""
124      CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>, %arg2: tensor<*xi32>)
125      CHECK: %[[r0:[0-9]+]] = "tf.Greater"(%arg0, %arg1) : (tensor<*xi32>, tensor<*xi32>) -> tensor<*xi1>
126      CHECK: %[[r1:[0-9]+]] = "tf.Less"(%[[r0]], %arg2) : (tensor<*xi1>, tensor<*xi32>) -> tensor<*xi1>
127      CHECK: return %[[r1]] : tensor<*xi1>
128    """
129    self._check_code(mlir_code, exp_mlir_code)
130
131  def test_Assign_BinOp(self):
132
133    def test_fn() -> int:
134      y = 12 + 23 - 24
135      return y
136
137    mlir_code = mlir_gen(test_fn)
138    exp_mlir_code = r"""
139      CHECK-LABEL: func @test_fn() -> i32
140      CHECK: %[[r0:[0-9]+]] = "tf.AddV2"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
141      CHECK: %[[r1:[0-9]+]] = "tf.Sub"(%{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
142      CHECK: return %[[r1]] : tensor<i32>
143    """
144    self._check_code(mlir_code, exp_mlir_code)
145
146  def test_if(self):
147
148    def test_fn(x: core.Tensor) -> int:
149      res = 0
150      if x > 0:
151        res = 1
152      elif x < 0:
153        res = -1
154      else:
155        res = 0
156      return res
157
158    mlir_code = mlir_gen(test_fn)
159    exp_mlir_code = r"""
160      CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> i32
161
162      CHECK: %[[r1:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
163      CHECK-NEXT: %[[r2:[0-9]+]] = "tfp.If"(%[[r1]]) ( {
164        CHECK: return %{{[0-9]+}} : tensor<i32>
165      CHECK-NEXT: },  {
166        CHECK: %[[r3:[0-9]+]] = "tf.Less"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
167        CHECK: %[[r4:[0-9]+]] = "tfp.If"(%[[r3]]) ( {
168          CHECK: %[[r5:[0-9]+]] = "tf.Neg"(%{{[0-9]+}}) : (tensor<i32>) -> tensor<i32>
169          CHECK: return %[[r5]] : tensor<i32>
170        CHECK-NEXT: },  {
171          CHECK: return %{{[0-9]+}} : tensor<i32>
172        CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32>
173        CHECK: return %[[r4]] : tensor<i32>
174      CHECK-NEXT: }) : (tensor<*xi1>) -> tensor<i32>
175      CHECK-NEXT: return %[[r2]] : tensor<i32>
176    """
177    self._check_code(mlir_code, exp_mlir_code)
178
179  def test_while(self):
180
181    def test_fn(x: core.Tensor) -> core.Tensor:
182      s = 0
183      while x > 0:
184        s = s + x
185      return s
186
187    mlir_code = mlir_gen(test_fn)
188    exp_mlir_code = r"""
189      CHECK-LABEL: func @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32>
190
191      CHECK: %[[r1:[0-9]+]] = "tfp.While"(%0) ( {
192      CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>):
193        CHECK: %[[r2:[0-9]+]] = "tf.Greater"(%arg0, %{{[0-9]+}}) : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
194        CHECK-NEXT: return %[[r2]] : tensor<*xi1>
195      CHECK-NEXT: },  {
196      CHECK-NEXT: ^{{[^ ]+}}(%arg1: tensor<i32>):
197        CHECK: %[[r3:[0-9]+]] = "tf.AddV2"(%arg1, %arg0) : (tensor<i32>, tensor<*xi32>) -> tensor<*xi32>
198        CHECK-NEXT: return %[[r3]] : tensor<*xi32>
199      CHECK-NEXT: }) : (tensor<i32>) -> tensor<i32>
200      CHECK-NEXT: return %[[r1]] : tensor<i32>
201    """
202    self._check_code(mlir_code, exp_mlir_code)
203
204  def test_fibonacci(self):
205
206    def test_fn(x: core.Tensor) -> core.Tensor:
207      res, idx = 0, 2
208      a, b = 0, 1
209      if x == 0 or x == 1:
210        res = x
211      else:
212        while idx <= x:
213          res = a + b
214          a = b
215          b = res
216          idx = idx + 1
217      return res
218
219    mlir_code = mlir_gen(test_fn)
220    exp_mlir_code = r"""
221      CHECK-LABEL: @test_fn(%arg0: tensor<*xi32>) -> tensor<*xi32>
222      CHECK: %[[r5:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
223      CHECK: %[[r7:[0-9]+]] = "tf.Equal"(%arg0, %{{[0-9]+}}) {incompatible_shape_error = true} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi1>
224      CHECK: %[[r8:[0-9]+]] = "tfp.Or"(%[[r5]], %[[r7]]) : (tensor<*xi1>, tensor<*xi1>) -> tensor<*xi1>
225
226      CHECK: %[[r9:[0-9]+]]:4 = "tfp.If"(%[[r8]]) ( {
227        CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>
228        CHECK-NEXT: },  {
229        CHECK-NEXT: %[[r10:[0-9]+]]:4 = "tfp.While"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) ( {
230          CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
231          CHECK-NEXT: %[[r11:[0-9]+]] = "tf.LessEqual"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<{{(\*x)?}}i32>, tensor<{{(\*x)?}}i32>) -> tensor<*xi1>
232          CHECK-NEXT: return %[[r11]] : tensor<*xi1>
233        CHECK-NEXT: },  {
234          CHECK-NEXT: ^{{[^ ]*}}(%arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>, %arg4: tensor<i32>):
235          CHECK-NEXT: %[[r12:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %arg{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
236          CHECK: %[[r13:[0-9]+]] = "tf.AddV2"(%arg{{[0-9]+}}, %{{[0-9]+}}) : (tensor<i32>, tensor<i32>) -> tensor<i32>
237          CHECK-NEXT: return %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
238        CHECK-NEXT: }) : (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
239        CHECK-NEXT: return %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}}, %[[r10]]#{{[0-9]+}} : tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>
240      CHECK-NEXT: }) : (tensor<*xi1>) -> (tensor<i32>, tensor<i32>, tensor<i32>, tensor<i32>)
241      CHECK-NEXT: return %[[r9]]#{{[0-9]+}} : tensor<i32>
242    """
243    self._check_code(mlir_code, exp_mlir_code)
244
245
246if __name__ == '__main__':
247  test.main()
248