1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/convert_operand_folding.h"
17 
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/compiler/xla/primitive_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 namespace xla {
23 namespace {
24 
25 namespace op = ::xla::testing::opcode_matchers;
26 
27 using ConvertOperandFoldingTest = HloTestBase;
28 
TEST_F(ConvertOperandFoldingTest,IntegralUpcastConvertFolded)29 TEST_F(ConvertOperandFoldingTest, IntegralUpcastConvertFolded) {
30   absl::string_view module_string = R"(
31   HloModule module
32 
33   ENTRY main {
34     p0 = s8[2,3]{1,0} parameter(0)
35     p1 = s16[3,2]{0,1} parameter(1)
36     c0 = s16[2,3]{1,0} convert(p0)
37     c1 = s16[3,2]{0,1} convert(p1)
38     ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
39                                           rhs_contracting_dims={0}
40   })";
41   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
42                           ParseAndReturnVerifiedModule(module_string));
43   TF_ASSERT_OK_AND_ASSIGN(bool folded,
44                           ConvertOperandFolding().Run(module.get()));
45   EXPECT_TRUE(folded);
46   EXPECT_THAT(module->entry_computation()->root_instruction(),
47               AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
48                     op::Shape("s16[2,2]{1,0}")));
49 }
50 
TEST_F(ConvertOperandFoldingTest,FloatingUpcastConvertFolded)51 TEST_F(ConvertOperandFoldingTest, FloatingUpcastConvertFolded) {
52   absl::string_view module_string = R"(
53   HloModule module
54 
55   ENTRY main {
56     p0 = f16[2,3]{1,0} parameter(0)
57     p1 = bf16[3,2]{0,1} parameter(1)
58     c0 = f32[2,3]{1,0} convert(p0)
59     c1 = f32[3,2]{0,1} convert(p1)
60     ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
61                                           rhs_contracting_dims={0}
62   })";
63   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
64                           ParseAndReturnVerifiedModule(module_string));
65   TF_ASSERT_OK_AND_ASSIGN(bool folded,
66                           ConvertOperandFolding().Run(module.get()));
67   EXPECT_TRUE(folded);
68   EXPECT_THAT(module->entry_computation()->root_instruction(),
69               AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
70                     op::Shape("f32[2,2]{1,0}")));
71 }
72 
TEST_F(ConvertOperandFoldingTest,IntegralToFloatingConvertNotFolded)73 TEST_F(ConvertOperandFoldingTest, IntegralToFloatingConvertNotFolded) {
74   absl::string_view module_string = R"(
75   HloModule module
76 
77   ENTRY main {
78     p0 = s8[2,3]{1,0} parameter(0)
79     p1 = s16[3,2]{0,1} parameter(1)
80     c0 = f16[2,3]{1,0} convert(p0)
81     c1 = f32[3,2]{0,1} convert(p1)
82     ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
83                                           rhs_contracting_dims={0}
84   })";
85   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
86                           ParseAndReturnVerifiedModule(module_string));
87   TF_ASSERT_OK_AND_ASSIGN(bool folded,
88                           ConvertOperandFolding().Run(module.get()));
89   EXPECT_FALSE(folded);
90   EXPECT_THAT(
91       module->entry_computation()->root_instruction(),
92       AllOf(
93           op::Dot(
94               AllOf(op::Convert(op::Parameter(0)), op::Shape("f16[2,3]{1,0}")),
95               AllOf(op::Convert(op::Parameter(1)), op::Shape("f32[3,2]{0,1}"))),
96           op::Shape("f32[2,2]{1,0}")));
97 }
98 
TEST_F(ConvertOperandFoldingTest,DowncastConvertNotFolded)99 TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
100   absl::string_view module_string = R"(
101   HloModule module
102 
103   ENTRY main {
104     p0 = s32[2,3]{1,0} parameter(0)
105     p1 = s16[3,2]{0,1} parameter(1)
106     c0 = s16[2,3]{1,0} convert(p0)
107     c1 = s8[3,2]{0,1} convert(p1)
108     ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
109                                           rhs_contracting_dims={0}
110   })";
111   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
112                           ParseAndReturnVerifiedModule(module_string));
113   TF_ASSERT_OK_AND_ASSIGN(bool folded,
114                           ConvertOperandFolding().Run(module.get()));
115   EXPECT_FALSE(folded);
116   EXPECT_THAT(
117       module->entry_computation()->root_instruction(),
118       AllOf(
119           op::Dot(
120               AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{1,0}")),
121               AllOf(op::Convert(op::Parameter(1)), op::Shape("s8[3,2]{0,1}"))),
122           op::Shape("s16[2,2]{1,0}")));
123 }
124 
TEST_F(ConvertOperandFoldingTest,OneOperandFolded)125 TEST_F(ConvertOperandFoldingTest, OneOperandFolded) {
126   absl::string_view module_string = R"(
127   HloModule module
128 
129   ENTRY main {
130     p0 = s8[2,3]{1,0} parameter(0)
131     p1 = s16[3,2]{0,1} parameter(1)
132     c0 = s16[2,3]{1,0} convert(p0)
133     c1 = s8[3,2]{0,1} convert(p1)
134     ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
135                                           rhs_contracting_dims={0}
136   })";
137   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
138                           ParseAndReturnVerifiedModule(module_string));
139   TF_ASSERT_OK_AND_ASSIGN(bool folded,
140                           ConvertOperandFolding().Run(module.get()));
141   EXPECT_TRUE(folded);
142   EXPECT_THAT(
143       module->entry_computation()->root_instruction(),
144       AllOf(op::Dot(op::Parameter(0), AllOf(op::Convert(op::Parameter(1)),
145                                             op::Shape("s8[3,2]{0,1}"))),
146             op::Shape("s16[2,2]{1,0}")));
147 }
148 
149 }  // namespace
150 }  // namespace xla
151