1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <utility>
18
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/test.h"
30
31 namespace xla {
32 namespace {
33
34 class BroadcastTest : public HloTestBase {};
35
XLA_TEST_F(BroadcastTest,BroadcastScalarToScalar)36 XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
37 // Test degenerate case of broadcasting a scalar into a scalar.
38 auto builder = HloComputation::Builder(TestName());
39 auto input = builder.AddInstruction(
40 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
41 builder.AddInstruction(HloInstruction::CreateBroadcast(
42 ShapeUtil::MakeShape(F32, {}), input, {}));
43
44 // Create HLO module, compile, and execute.
45 auto hlo_module = CreateNewUnverifiedModule();
46 hlo_module->AddEntryComputation(builder.Build());
47 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
48
49 EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
50 error_spec_));
51 }
52
XLA_TEST_F(BroadcastTest,BroadcastScalarTo2D)53 XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
54 auto builder = HloComputation::Builder(TestName());
55 auto input = builder.AddInstruction(
56 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
57 builder.AddInstruction(HloInstruction::CreateBroadcast(
58 ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
59
60 // Create HLO module, compile, and execute.
61 auto hlo_module = CreateNewUnverifiedModule();
62 hlo_module->AddEntryComputation(builder.Build());
63 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
64
65 EXPECT_TRUE(LiteralTestUtil::Near(
66 LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
67 error_spec_));
68 }
69
XLA_TEST_F(BroadcastTest,BroadcastVectorTo2D)70 XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
71 auto builder = HloComputation::Builder(TestName());
72 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
73 LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
74
75 // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
76 // to enable testing of the results.
77 auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
78 ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
79 auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
80 ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
81 builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
82
83 // Create HLO module, compile, and execute.
84 auto hlo_module = CreateNewUnverifiedModule();
85 hlo_module->AddEntryComputation(builder.Build());
86 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
87
88 EXPECT_TRUE(LiteralTestUtil::Near(
89 LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
90 LiteralSlice(result, {0}), error_spec_));
91
92 EXPECT_TRUE(LiteralTestUtil::Near(
93 LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
94 LiteralSlice(result, {1}), error_spec_));
95 }
96
XLA_TEST_F(BroadcastTest,Broadcast2DTo2D)97 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
98 auto builder = HloComputation::Builder(TestName());
99 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
100 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
101 builder.AddInstruction(HloInstruction::CreateBroadcast(
102 ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
103
104 // Create HLO module, compile, and execute.
105 auto hlo_module = CreateNewUnverifiedModule();
106 hlo_module->AddEntryComputation(builder.Build());
107 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
108
109 EXPECT_TRUE(LiteralTestUtil::Near(
110 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
111 error_spec_));
112 }
113
XLA_TEST_F(BroadcastTest,Broadcast2DTo2DTranspose)114 XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
115 // Degenerately broadcasting a shape into a shape of the same rank reorders
116 // the dimensions, ie transpose.
117 auto builder = HloComputation::Builder(TestName());
118 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
119 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
120 builder.AddInstruction(HloInstruction::CreateBroadcast(
121 ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
122
123 // Create HLO module, compile, and execute.
124 auto hlo_module = CreateNewUnverifiedModule();
125 hlo_module->AddEntryComputation(builder.Build());
126 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
127
128 EXPECT_TRUE(LiteralTestUtil::Near(
129 LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
130 error_spec_));
131 }
132
XLA_TEST_F(BroadcastTest,Broadcast2DTo3D)133 XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
134 auto builder = HloComputation::Builder(TestName());
135 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
136 LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
137 builder.AddInstruction(HloInstruction::CreateBroadcast(
138 ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
139
140 // Create HLO module, compile, and execute.
141 auto hlo_module = CreateNewUnverifiedModule();
142 hlo_module->AddEntryComputation(builder.Build());
143 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
144
145 EXPECT_TRUE(LiteralTestUtil::Near(
146 LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
147 {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
148 result, error_spec_));
149 }
150
TEST_F(BroadcastTest,Broadcast_R1_2_To_R4_2x2x3x3)151 TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
152 auto builder = HloComputation::Builder(TestName());
153 auto input = builder.AddInstruction(
154 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
155
156 // Broadcast vector in dimension 1.
157 builder.AddInstruction(HloInstruction::CreateBroadcast(
158 ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
159
160 // Create HLO module, compile, and execute.
161 auto hlo_module = CreateNewUnverifiedModule();
162 hlo_module->AddEntryComputation(builder.Build());
163 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
164
165 Array4D<float> expected(2, 2, 3, 3);
166 Array2D<float> pz({{1, 2}, {1, 2}});
167 expected.FillWithPZ(pz);
168
169 EXPECT_TRUE(LiteralTestUtil::Near(
170 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
171 }
172
TEST_F(BroadcastTest,Broadcast_R1_1025_To_R4_3x3x3x1025)173 TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
174 auto builder = HloComputation::Builder(TestName());
175 std::vector<float> input_data(1025);
176 int64 r1_size = input_data.size();
177 std::iota(input_data.begin(), input_data.end(), 0.0f);
178 auto input = builder.AddInstruction(
179 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
180
181 // Broadcast vector in dimension 3.
182 builder.AddInstruction(HloInstruction::CreateBroadcast(
183 ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
184
185 // Create HLO module, compile, and execute.
186 auto hlo_module = CreateNewUnverifiedModule();
187 hlo_module->AddEntryComputation(builder.Build());
188 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
189
190 Array4D<float> expected(3, 3, 3, 1025);
191 Array2D<float> yx(3, r1_size);
192 for (int64 y = 0; y < 3; ++y) {
193 for (int64 x = 0; x < r1_size; ++x) {
194 yx(y, x) = input_data[x];
195 }
196 }
197 expected.FillWithYX(yx);
198
199 EXPECT_TRUE(LiteralTestUtil::Near(
200 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
201 }
202
XLA_TEST_F(BroadcastTest,Broadcast_R1_64_To_R4_32x64x7x7)203 XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
204 auto builder = HloComputation::Builder(TestName());
205 Array4D<float> r4_array(32, 64, 7, 7);
206 r4_array.Fill(42.0);
207 std::vector<float> r1_array(64, 42.0);
208
209 auto input = builder.AddInstruction(
210 HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
211
212 // Broadcast vector in dimension 1.
213 builder.AddInstruction(HloInstruction::CreateBroadcast(
214 ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
215
216 // Create HLO module, compile, and execute.
217 auto hlo_module = CreateNewUnverifiedModule();
218 hlo_module->AddEntryComputation(builder.Build());
219 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
220
221 EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
222 result, error_spec_));
223 }
224
TEST_F(BroadcastTest,Broadcast_R0_to_R4_64x64x3x3)225 TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
226 auto builder = HloComputation::Builder(TestName());
227 auto input = builder.AddInstruction(
228 HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
229 builder.AddInstruction(HloInstruction::CreateBroadcast(
230 ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
231
232 // Create HLO module, compile, and execute.
233 auto hlo_module = CreateNewUnverifiedModule();
234 hlo_module->AddEntryComputation(builder.Build());
235 LOG(INFO) << hlo_module->ToString();
236 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
237
238 Array4D<float> expected(64, 64, 3, 3);
239 expected.Fill(1.0f);
240
241 EXPECT_TRUE(LiteralTestUtil::Near(
242 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
243 }
244
TEST_F(BroadcastTest,Broadcast_R2_2x2_To_R4_3x3x2x2)245 TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
246 auto builder = HloComputation::Builder(TestName());
247 Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
248 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
249 LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
250
251 // Broadcast vector in dimensions 2 and 3.
252 builder.AddInstruction(HloInstruction::CreateBroadcast(
253 ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
254
255 // Create HLO module, compile, and execute.
256 auto hlo_module = CreateNewUnverifiedModule();
257 hlo_module->AddEntryComputation(builder.Build());
258 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
259
260 Array4D<float> expected(3, 3, 2, 2);
261 expected.FillWithYX(to_broadcast);
262
263 EXPECT_TRUE(LiteralTestUtil::Near(
264 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
265 }
266
TEST_F(BroadcastTest,Broadcast_R3_2x3x4_to_R4_2x3x4x5)267 TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
268 auto builder = HloComputation::Builder(TestName());
269 Array3D<float> input_vals(2, 3, 4);
270 input_vals.FillRandom(1.0);
271
272 Array4D<float> expected(2, 3, 4, 5);
273 for (int i = 0; i < 2; ++i) {
274 for (int j = 0; j < 3; ++j) {
275 for (int k = 0; k < 4; ++k) {
276 for (int m = 0; m < 5; ++m) {
277 expected(i, j, k, m) = input_vals(i, j, k);
278 }
279 }
280 }
281 }
282 auto input = builder.AddInstruction(HloInstruction::CreateConstant(
283 LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
284
285 // Broadcast vector in dimensions 2 and 3.
286 builder.AddInstruction(HloInstruction::CreateBroadcast(
287 ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
288
289 // Create HLO module, compile, and execute.
290 auto hlo_module = CreateNewUnverifiedModule();
291 hlo_module->AddEntryComputation(builder.Build());
292 auto result = ExecuteAndTransfer(std::move(hlo_module), {});
293
294 EXPECT_TRUE(LiteralTestUtil::Near(
295 LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
296 }
297
298 } // namespace
299 } // namespace xla
300