1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <utility>
18 
19 #include "absl/memory/memory.h"
20 #include "tensorflow/compiler/xla/literal.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_module.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
26 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
27 #include "tensorflow/compiler/xla/tests/test_macros.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/test.h"
30 
31 namespace xla {
32 namespace {
33 
34 class BroadcastTest : public HloTestBase {};
35 
XLA_TEST_F(BroadcastTest,BroadcastScalarToScalar)36 XLA_TEST_F(BroadcastTest, BroadcastScalarToScalar) {
37   // Test degenerate case of broadcasting a scalar into a scalar.
38   auto builder = HloComputation::Builder(TestName());
39   auto input = builder.AddInstruction(
40       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
41   builder.AddInstruction(HloInstruction::CreateBroadcast(
42       ShapeUtil::MakeShape(F32, {}), input, {}));
43 
44   // Create HLO module, compile, and execute.
45   auto hlo_module = CreateNewUnverifiedModule();
46   hlo_module->AddEntryComputation(builder.Build());
47   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
48 
49   EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR0<float>(42.0), result,
50                                     error_spec_));
51 }
52 
XLA_TEST_F(BroadcastTest,BroadcastScalarTo2D)53 XLA_TEST_F(BroadcastTest, BroadcastScalarTo2D) {
54   auto builder = HloComputation::Builder(TestName());
55   auto input = builder.AddInstruction(
56       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(42.0)));
57   builder.AddInstruction(HloInstruction::CreateBroadcast(
58       ShapeUtil::MakeShape(F32, {2, 2}), input, {}));
59 
60   // Create HLO module, compile, and execute.
61   auto hlo_module = CreateNewUnverifiedModule();
62   hlo_module->AddEntryComputation(builder.Build());
63   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
64 
65   EXPECT_TRUE(LiteralTestUtil::Near(
66       LiteralUtil::CreateR2<float>({{42.0, 42.0}, {42.0, 42.0}}), result,
67       error_spec_));
68 }
69 
XLA_TEST_F(BroadcastTest,BroadcastVectorTo2D)70 XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) {
71   auto builder = HloComputation::Builder(TestName());
72   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
73       LiteralUtil::CreateR1<float>({1.0, 2.0, 3.0})));
74 
75   // Broadcast vector in both dimension 0 and dimension 1. Join them in a tuple
76   // to enable testing of the results.
77   auto element1 = builder.AddInstruction(HloInstruction::CreateBroadcast(
78       ShapeUtil::MakeShape(F32, {3, 2}), input, {0}));
79   auto element2 = builder.AddInstruction(HloInstruction::CreateBroadcast(
80       ShapeUtil::MakeShape(F32, {2, 3}), input, {1}));
81   builder.AddInstruction(HloInstruction::CreateTuple({element1, element2}));
82 
83   // Create HLO module, compile, and execute.
84   auto hlo_module = CreateNewUnverifiedModule();
85   hlo_module->AddEntryComputation(builder.Build());
86   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
87 
88   EXPECT_TRUE(LiteralTestUtil::Near(
89       LiteralUtil::CreateR2<float>({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}),
90       LiteralSlice(result, {0}), error_spec_));
91 
92   EXPECT_TRUE(LiteralTestUtil::Near(
93       LiteralUtil::CreateR2<float>({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}),
94       LiteralSlice(result, {1}), error_spec_));
95 }
96 
XLA_TEST_F(BroadcastTest,Broadcast2DTo2D)97 XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) {
98   auto builder = HloComputation::Builder(TestName());
99   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
100       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
101   builder.AddInstruction(HloInstruction::CreateBroadcast(
102       ShapeUtil::MakeShape(F32, {2, 2}), input, {0, 1}));
103 
104   // Create HLO module, compile, and execute.
105   auto hlo_module = CreateNewUnverifiedModule();
106   hlo_module->AddEntryComputation(builder.Build());
107   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
108 
109   EXPECT_TRUE(LiteralTestUtil::Near(
110       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}}), result,
111       error_spec_));
112 }
113 
XLA_TEST_F(BroadcastTest,Broadcast2DTo2DTranspose)114 XLA_TEST_F(BroadcastTest, Broadcast2DTo2DTranspose) {
115   // Degenerately broadcasting a shape into a shape of the same rank reorders
116   // the dimensions, ie transpose.
117   auto builder = HloComputation::Builder(TestName());
118   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
119       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
120   builder.AddInstruction(HloInstruction::CreateBroadcast(
121       ShapeUtil::MakeShape(F32, {2, 2}), input, {1, 0}));
122 
123   // Create HLO module, compile, and execute.
124   auto hlo_module = CreateNewUnverifiedModule();
125   hlo_module->AddEntryComputation(builder.Build());
126   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
127 
128   EXPECT_TRUE(LiteralTestUtil::Near(
129       LiteralUtil::CreateR2<float>({{1.0, 3.0}, {2.0, 4.0}}), result,
130       error_spec_));
131 }
132 
XLA_TEST_F(BroadcastTest,Broadcast2DTo3D)133 XLA_TEST_F(BroadcastTest, Broadcast2DTo3D) {
134   auto builder = HloComputation::Builder(TestName());
135   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
136       LiteralUtil::CreateR2<float>({{1.0, 2.0}, {3.0, 4.0}})));
137   builder.AddInstruction(HloInstruction::CreateBroadcast(
138       ShapeUtil::MakeShape(F32, {2, 3, 2}), input, {0, 2}));
139 
140   // Create HLO module, compile, and execute.
141   auto hlo_module = CreateNewUnverifiedModule();
142   hlo_module->AddEntryComputation(builder.Build());
143   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
144 
145   EXPECT_TRUE(LiteralTestUtil::Near(
146       LiteralUtil::CreateR3<float>({{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}},
147                                     {{3.0, 4.0}, {3.0, 4.0}, {3.0, 4.0}}}),
148       result, error_spec_));
149 }
150 
TEST_F(BroadcastTest,Broadcast_R1_2_To_R4_2x2x3x3)151 TEST_F(BroadcastTest, Broadcast_R1_2_To_R4_2x2x3x3) {
152   auto builder = HloComputation::Builder(TestName());
153   auto input = builder.AddInstruction(
154       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>({1.0, 2.0})));
155 
156   // Broadcast vector in dimension 1.
157   builder.AddInstruction(HloInstruction::CreateBroadcast(
158       ShapeUtil::MakeShape(F32, {2, 2, 3, 3}), input, {1}));
159 
160   // Create HLO module, compile, and execute.
161   auto hlo_module = CreateNewUnverifiedModule();
162   hlo_module->AddEntryComputation(builder.Build());
163   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
164 
165   Array4D<float> expected(2, 2, 3, 3);
166   Array2D<float> pz({{1, 2}, {1, 2}});
167   expected.FillWithPZ(pz);
168 
169   EXPECT_TRUE(LiteralTestUtil::Near(
170       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
171 }
172 
TEST_F(BroadcastTest,Broadcast_R1_1025_To_R4_3x3x3x1025)173 TEST_F(BroadcastTest, Broadcast_R1_1025_To_R4_3x3x3x1025) {
174   auto builder = HloComputation::Builder(TestName());
175   std::vector<float> input_data(1025);
176   int64 r1_size = input_data.size();
177   std::iota(input_data.begin(), input_data.end(), 0.0f);
178   auto input = builder.AddInstruction(
179       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(input_data)));
180 
181   // Broadcast vector in dimension 3.
182   builder.AddInstruction(HloInstruction::CreateBroadcast(
183       ShapeUtil::MakeShape(F32, {3, 3, 3, r1_size}), input, {3}));
184 
185   // Create HLO module, compile, and execute.
186   auto hlo_module = CreateNewUnverifiedModule();
187   hlo_module->AddEntryComputation(builder.Build());
188   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
189 
190   Array4D<float> expected(3, 3, 3, 1025);
191   Array2D<float> yx(3, r1_size);
192   for (int64 y = 0; y < 3; ++y) {
193     for (int64 x = 0; x < r1_size; ++x) {
194       yx(y, x) = input_data[x];
195     }
196   }
197   expected.FillWithYX(yx);
198 
199   EXPECT_TRUE(LiteralTestUtil::Near(
200       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
201 }
202 
XLA_TEST_F(BroadcastTest,Broadcast_R1_64_To_R4_32x64x7x7)203 XLA_TEST_F(BroadcastTest, Broadcast_R1_64_To_R4_32x64x7x7) {
204   auto builder = HloComputation::Builder(TestName());
205   Array4D<float> r4_array(32, 64, 7, 7);
206   r4_array.Fill(42.0);
207   std::vector<float> r1_array(64, 42.0);
208 
209   auto input = builder.AddInstruction(
210       HloInstruction::CreateConstant(LiteralUtil::CreateR1<float>(r1_array)));
211 
212   // Broadcast vector in dimension 1.
213   builder.AddInstruction(HloInstruction::CreateBroadcast(
214       ShapeUtil::MakeShape(F32, {32, 64, 7, 7}), input, {1}));
215 
216   // Create HLO module, compile, and execute.
217   auto hlo_module = CreateNewUnverifiedModule();
218   hlo_module->AddEntryComputation(builder.Build());
219   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
220 
221   EXPECT_TRUE(LiteralTestUtil::Near(LiteralUtil::CreateR4FromArray4D(r4_array),
222                                     result, error_spec_));
223 }
224 
TEST_F(BroadcastTest,Broadcast_R0_to_R4_64x64x3x3)225 TEST_F(BroadcastTest, Broadcast_R0_to_R4_64x64x3x3) {
226   auto builder = HloComputation::Builder(TestName());
227   auto input = builder.AddInstruction(
228       HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0f)));
229   builder.AddInstruction(HloInstruction::CreateBroadcast(
230       ShapeUtil::MakeShape(F32, {64, 64, 3, 3}), input, {}));
231 
232   // Create HLO module, compile, and execute.
233   auto hlo_module = CreateNewUnverifiedModule();
234   hlo_module->AddEntryComputation(builder.Build());
235   LOG(INFO) << hlo_module->ToString();
236   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
237 
238   Array4D<float> expected(64, 64, 3, 3);
239   expected.Fill(1.0f);
240 
241   EXPECT_TRUE(LiteralTestUtil::Near(
242       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
243 }
244 
TEST_F(BroadcastTest,Broadcast_R2_2x2_To_R4_3x3x2x2)245 TEST_F(BroadcastTest, Broadcast_R2_2x2_To_R4_3x3x2x2) {
246   auto builder = HloComputation::Builder(TestName());
247   Array2D<float> to_broadcast({{1.0f, 2.0f}, {3.0f, 4.0f}});
248   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
249       LiteralUtil::CreateR2FromArray2D<float>(to_broadcast)));
250 
251   // Broadcast vector in dimensions 2 and 3.
252   builder.AddInstruction(HloInstruction::CreateBroadcast(
253       ShapeUtil::MakeShape(F32, {3, 3, 2, 2}), input, {2, 3}));
254 
255   // Create HLO module, compile, and execute.
256   auto hlo_module = CreateNewUnverifiedModule();
257   hlo_module->AddEntryComputation(builder.Build());
258   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
259 
260   Array4D<float> expected(3, 3, 2, 2);
261   expected.FillWithYX(to_broadcast);
262 
263   EXPECT_TRUE(LiteralTestUtil::Near(
264       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
265 }
266 
TEST_F(BroadcastTest,Broadcast_R3_2x3x4_to_R4_2x3x4x5)267 TEST_F(BroadcastTest, Broadcast_R3_2x3x4_to_R4_2x3x4x5) {
268   auto builder = HloComputation::Builder(TestName());
269   Array3D<float> input_vals(2, 3, 4);
270   input_vals.FillRandom(1.0);
271 
272   Array4D<float> expected(2, 3, 4, 5);
273   for (int i = 0; i < 2; ++i) {
274     for (int j = 0; j < 3; ++j) {
275       for (int k = 0; k < 4; ++k) {
276         for (int m = 0; m < 5; ++m) {
277           expected(i, j, k, m) = input_vals(i, j, k);
278         }
279       }
280     }
281   }
282   auto input = builder.AddInstruction(HloInstruction::CreateConstant(
283       LiteralUtil::CreateR3FromArray3D<float>(input_vals)));
284 
285   // Broadcast vector in dimensions 2 and 3.
286   builder.AddInstruction(HloInstruction::CreateBroadcast(
287       ShapeUtil::MakeShape(F32, {2, 3, 4, 5}), input, {0, 1, 2}));
288 
289   // Create HLO module, compile, and execute.
290   auto hlo_module = CreateNewUnverifiedModule();
291   hlo_module->AddEntryComputation(builder.Build());
292   auto result = ExecuteAndTransfer(std::move(hlo_module), {});
293 
294   EXPECT_TRUE(LiteralTestUtil::Near(
295       LiteralUtil::CreateR4FromArray4D<float>(expected), result, error_spec_));
296 }
297 
298 }  // namespace
299 }  // namespace xla
300