1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/convert_operand_folding.h"
17
18 #include "absl/strings/substitute.h"
19 #include "tensorflow/compiler/xla/primitive_util.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 namespace xla {
23 namespace {
24
25 namespace op = ::xla::testing::opcode_matchers;
26
27 using ConvertOperandFoldingTest = HloTestBase;
28
TEST_F(ConvertOperandFoldingTest,IntegralUpcastConvertFolded)29 TEST_F(ConvertOperandFoldingTest, IntegralUpcastConvertFolded) {
30 absl::string_view module_string = R"(
31 HloModule module
32
33 ENTRY main {
34 p0 = s8[2,3]{1,0} parameter(0)
35 p1 = s16[3,2]{0,1} parameter(1)
36 c0 = s16[2,3]{1,0} convert(p0)
37 c1 = s16[3,2]{0,1} convert(p1)
38 ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
39 rhs_contracting_dims={0}
40 })";
41 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
42 ParseAndReturnVerifiedModule(module_string));
43 TF_ASSERT_OK_AND_ASSIGN(bool folded,
44 ConvertOperandFolding().Run(module.get()));
45 EXPECT_TRUE(folded);
46 EXPECT_THAT(module->entry_computation()->root_instruction(),
47 AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
48 op::Shape("s16[2,2]{1,0}")));
49 }
50
TEST_F(ConvertOperandFoldingTest,FloatingUpcastConvertFolded)51 TEST_F(ConvertOperandFoldingTest, FloatingUpcastConvertFolded) {
52 absl::string_view module_string = R"(
53 HloModule module
54
55 ENTRY main {
56 p0 = f16[2,3]{1,0} parameter(0)
57 p1 = bf16[3,2]{0,1} parameter(1)
58 c0 = f32[2,3]{1,0} convert(p0)
59 c1 = f32[3,2]{0,1} convert(p1)
60 ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
61 rhs_contracting_dims={0}
62 })";
63 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
64 ParseAndReturnVerifiedModule(module_string));
65 TF_ASSERT_OK_AND_ASSIGN(bool folded,
66 ConvertOperandFolding().Run(module.get()));
67 EXPECT_TRUE(folded);
68 EXPECT_THAT(module->entry_computation()->root_instruction(),
69 AllOf(op::Dot(op::Parameter(0), op::Parameter(1)),
70 op::Shape("f32[2,2]{1,0}")));
71 }
72
TEST_F(ConvertOperandFoldingTest,IntegralToFloatingConvertNotFolded)73 TEST_F(ConvertOperandFoldingTest, IntegralToFloatingConvertNotFolded) {
74 absl::string_view module_string = R"(
75 HloModule module
76
77 ENTRY main {
78 p0 = s8[2,3]{1,0} parameter(0)
79 p1 = s16[3,2]{0,1} parameter(1)
80 c0 = f16[2,3]{1,0} convert(p0)
81 c1 = f32[3,2]{0,1} convert(p1)
82 ROOT dot = f32[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
83 rhs_contracting_dims={0}
84 })";
85 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
86 ParseAndReturnVerifiedModule(module_string));
87 TF_ASSERT_OK_AND_ASSIGN(bool folded,
88 ConvertOperandFolding().Run(module.get()));
89 EXPECT_FALSE(folded);
90 EXPECT_THAT(
91 module->entry_computation()->root_instruction(),
92 AllOf(
93 op::Dot(
94 AllOf(op::Convert(op::Parameter(0)), op::Shape("f16[2,3]{1,0}")),
95 AllOf(op::Convert(op::Parameter(1)), op::Shape("f32[3,2]{0,1}"))),
96 op::Shape("f32[2,2]{1,0}")));
97 }
98
TEST_F(ConvertOperandFoldingTest,DowncastConvertNotFolded)99 TEST_F(ConvertOperandFoldingTest, DowncastConvertNotFolded) {
100 absl::string_view module_string = R"(
101 HloModule module
102
103 ENTRY main {
104 p0 = s32[2,3]{1,0} parameter(0)
105 p1 = s16[3,2]{0,1} parameter(1)
106 c0 = s16[2,3]{1,0} convert(p0)
107 c1 = s8[3,2]{0,1} convert(p1)
108 ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
109 rhs_contracting_dims={0}
110 })";
111 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
112 ParseAndReturnVerifiedModule(module_string));
113 TF_ASSERT_OK_AND_ASSIGN(bool folded,
114 ConvertOperandFolding().Run(module.get()));
115 EXPECT_FALSE(folded);
116 EXPECT_THAT(
117 module->entry_computation()->root_instruction(),
118 AllOf(
119 op::Dot(
120 AllOf(op::Convert(op::Parameter(0)), op::Shape("s16[2,3]{1,0}")),
121 AllOf(op::Convert(op::Parameter(1)), op::Shape("s8[3,2]{0,1}"))),
122 op::Shape("s16[2,2]{1,0}")));
123 }
124
TEST_F(ConvertOperandFoldingTest,OneOperandFolded)125 TEST_F(ConvertOperandFoldingTest, OneOperandFolded) {
126 absl::string_view module_string = R"(
127 HloModule module
128
129 ENTRY main {
130 p0 = s8[2,3]{1,0} parameter(0)
131 p1 = s16[3,2]{0,1} parameter(1)
132 c0 = s16[2,3]{1,0} convert(p0)
133 c1 = s8[3,2]{0,1} convert(p1)
134 ROOT dot = s16[2,2]{1,0} dot(c0, c1), lhs_contracting_dims={1},
135 rhs_contracting_dims={0}
136 })";
137 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
138 ParseAndReturnVerifiedModule(module_string));
139 TF_ASSERT_OK_AND_ASSIGN(bool folded,
140 ConvertOperandFolding().Run(module.get()));
141 EXPECT_TRUE(folded);
142 EXPECT_THAT(
143 module->entry_computation()->root_instruction(),
144 AllOf(op::Dot(op::Parameter(0), AllOf(op::Convert(op::Parameter(1)),
145 op::Shape("s8[3,2]{0,1}"))),
146 op::Shape("s16[2,2]{1,0}")));
147 }
148
149 } // namespace
150 } // namespace xla
151