1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/map_inliner.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/memory/memory.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/service/hlo_computation.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/test.h"
29 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
30 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
31 #include "tensorflow/compiler/xla/xla_data.pb.h"
32
33 namespace op = xla::testing::opcode_matchers;
34
35 namespace xla {
36 namespace {
37
38 using MapInlinerTest = HloTestBase;
39
40 // Test that `map` with `max` is transformed to `max`
TEST_F(MapInlinerTest,MapMax)41 TEST_F(MapInlinerTest, MapMax) {
42 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
43
44 auto max_builder = HloComputation::Builder(TestName());
45 auto param1 = max_builder.AddInstruction(
46 HloInstruction::CreateParameter(0, r0f32, "x"));
47 auto param2 = max_builder.AddInstruction(
48 HloInstruction::CreateParameter(1, r0f32, "y"));
49 max_builder.AddInstruction(HloInstruction::CreateBinary(
50 param1->shape(), HloOpcode::kMaximum, param1, param2));
51 auto max_f32 = max_builder.Build();
52
53 auto builder = HloComputation::Builder("MapMaxFunction");
54 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
55 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
56 auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
57 LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
58 builder.AddInstruction(
59 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
60
61 auto computation = builder.Build();
62 auto hlo_module = CreateNewVerifiedModule();
63 hlo_module->AddEmbeddedComputation(std::move(max_f32));
64 hlo_module->AddEntryComputation(std::move(computation));
65
66 MapInliner inliner;
67 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
68 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
69 op::Maximum(lhs, rhs));
70
71 // Verify execution on CPU.
72 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
73 auto expected = LiteralUtil::CreateR1<float>({4, 3, 3, 4});
74 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
75 }
76
77 // Test that `constant` function is changed to `broadcast`.
TEST_F(MapInlinerTest,MapConstant)78 TEST_F(MapInlinerTest, MapConstant) {
79 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
80
81 auto const2_builder = HloComputation::Builder(TestName());
82 auto param1 = const2_builder.AddInstruction(
83 HloInstruction::CreateParameter(0, r0f32, "x"));
84 (void)param1;
85 const2_builder.AddInstruction(
86 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(2.0f)));
87 auto const2_f32 = const2_builder.Build();
88
89 auto builder = HloComputation::Builder("MapConstFunction");
90 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
91 LiteralUtil::CreateR2<float>({{1, 2, 3, 4}, {5, 6, 7, 8}})));
92 builder.AddInstruction(
93 HloInstruction::CreateMap(lhs->shape(), {lhs}, const2_f32.get()));
94
95 auto computation = builder.Build();
96 auto hlo_module = CreateNewVerifiedModule();
97 hlo_module->AddEmbeddedComputation(std::move(const2_f32));
98 hlo_module->AddEntryComputation(std::move(computation));
99 HloInstruction* root = hlo_module->entry_computation()->root_instruction();
100 MapInliner inliner;
101 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
102 root = hlo_module->entry_computation()->root_instruction();
103 EXPECT_THAT(root, op::Broadcast(op::Constant()));
104
105 // Verify execution on CPU.
106 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
107 auto expected = LiteralUtil::CreateR2<float>({{2, 2, 2, 2}, {2, 2, 2, 2}});
108 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
109 }
110
TEST_F(MapInlinerTest,MapSubtractOppositeOrder)111 TEST_F(MapInlinerTest, MapSubtractOppositeOrder) {
112 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
113
114 // Note that the parameter ordinals are in the opposite order to their
115 // position as operands
116 auto max_builder = HloComputation::Builder(TestName());
117 auto param1 = max_builder.AddInstruction(
118 HloInstruction::CreateParameter(1, r0f32, "x"));
119 auto param2 = max_builder.AddInstruction(
120 HloInstruction::CreateParameter(0, r0f32, "y"));
121 max_builder.AddInstruction(HloInstruction::CreateBinary(
122 param1->shape(), HloOpcode::kSubtract, param1, param2));
123 auto max_f32 = max_builder.Build();
124
125 auto builder = HloComputation::Builder("MapSubFunction");
126 auto lhs = builder.AddInstruction(HloInstruction::CreateConstant(
127 LiteralUtil::CreateR1<float>({1, 2, 3, 4})));
128 auto rhs = builder.AddInstruction(HloInstruction::CreateConstant(
129 LiteralUtil::CreateR1<float>({4, 3, 2, 1})));
130 builder.AddInstruction(
131 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get()));
132
133 auto computation = builder.Build();
134 auto hlo_module = CreateNewVerifiedModule();
135 hlo_module->AddEmbeddedComputation(std::move(max_f32));
136 hlo_module->AddEntryComputation(std::move(computation));
137
138 MapInliner inliner;
139 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
140 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(),
141 op::Subtract(rhs, lhs));
142
143 // Verify execution on CPU.
144 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
145 auto expected = LiteralUtil::CreateR1<float>({3, 1, -1, -3});
146 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
147 }
148
TEST_F(MapInlinerTest,MapParameter)149 TEST_F(MapInlinerTest, MapParameter) {
150 Shape r0f32 = ShapeUtil::MakeShape(F32, {});
151
152 auto param_builder = HloComputation::Builder(TestName());
153 param_builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32, "p0"));
154 param_builder.AddInstruction(HloInstruction::CreateParameter(1, r0f32, "p1"));
155 auto param_f32 = param_builder.Build();
156
157 auto builder = HloComputation::Builder("MapParamFunction");
158 auto lhs = builder.AddInstruction(
159 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1)));
160 auto rhs = builder.AddInstruction(
161 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(4)));
162 builder.AddInstruction(
163 HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, param_f32.get()));
164
165 auto computation = builder.Build();
166 auto hlo_module = CreateNewVerifiedModule();
167 hlo_module->AddEmbeddedComputation(std::move(param_f32));
168 hlo_module->AddEntryComputation(std::move(computation));
169
170 MapInliner inliner;
171 EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie());
172 EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), rhs);
173
174 // Verify execution on CPU.
175 auto result = ExecuteAndTransfer(hlo_module->Clone(), {});
176 auto expected = LiteralUtil::CreateR0<float>(4);
177 EXPECT_TRUE(LiteralTestUtil::Equal(result, expected));
178 }
179
180 } // namespace
181 } // namespace xla
182