1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cmath>
17 #include <memory>
18 #include <vector>
19
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/test_macros.h"
41 #include "tensorflow/compiler/xla/tests/test_utils.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/math/math_util.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/test.h"
47 #include "tensorflow/core/platform/types.h"
48
49 namespace xla {
50 namespace {
51
52 class BatchNormalizationTest
53 : public ClientLibraryTestBase,
54 public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> {
55 protected:
BatchNormalizationTest()56 BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
57 mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam());
58
59 Array2D<float> pz({
60 // z0 z1
61 {-1.0f, 4.1f}, // p0
62 {2.0f, 4.1f}, // p1
63 {5.0f, 4.4f}, // p2
64 });
65 input_array_.FillWithPZ(pz);
66 input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
67 CHECK_EQ(kSamples, input_array_.planes());
68 CHECK_EQ(kZ, input_array_.depth());
69 CHECK_EQ(kY, input_array_.height());
70 CHECK_EQ(kY, input_array_.width());
71 }
72
CheckShape(XlaBuilder * b,const XlaOp & operand,const Shape & expected_shape) const73 XlaOp CheckShape(XlaBuilder* b, const XlaOp& operand,
74 const Shape& expected_shape) const {
75 Shape actual_shape = b->GetShape(operand).ConsumeValueOrDie();
76 CHECK(ShapeUtil::Equal(expected_shape, actual_shape))
77 << "want " << ShapeUtil::HumanString(expected_shape) << " got "
78 << ShapeUtil::HumanString(actual_shape);
79 return operand;
80 }
81
82 static constexpr int64 kSamples = 3;
83 static constexpr int64 kX = 1;
84 static constexpr int64 kY = 1;
85 static constexpr int64 kZ = 2;
86
87 Array4D<float> input_array_;
88 Literal input_literal_;
89 const ErrorSpec error_spec_{0.001, 0.001};
90 };
91
92 // If testing the GPU backend, run the tests twice, with and without cudnn
93 // batchnorm. Otherwise, just run the tests once -- the value of this flag
94 // doesn't matter.
95 #ifdef XLA_TEST_BACKEND_GPU
96 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
97 ::testing::Bool());
98 #else
99 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
100 ::testing::Values(false));
101 #endif
102
XLA_TEST_P(BatchNormalizationTest,SubtractInZ)103 XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
104 XlaBuilder builder("subtract_in_z_one_sample");
105 auto x = ConstantLiteral(&builder, input_literal_);
106 auto y = ConstantR1<float>(&builder, {3.14, 4.25});
107 Sub(x, y, /*broadcast_dimensions=*/{1});
108
109 Array4D<float> expected(kSamples, kZ, kY, kX);
110 Array2D<float> pz({
111 {-1.0f - 3.14f, 4.1f - 4.25f}, // p0
112 {2.0f - 3.14f, 4.1f - 4.25f}, // p1
113 {5.0f - 3.14f, 4.4f - 4.25f}, // p2
114 });
115 expected.FillWithPZ(pz);
116 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
117 }
118
XLA_TEST_P(BatchNormalizationTest,SquareTesseractElementwise)119 XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
120 XlaBuilder builder("square_tesseract_elementwise");
121 auto x = ConstantLiteral(&builder, input_literal_);
122 Square(x);
123
124 using tensorflow::MathUtil;
125
126 Array4D<float> expected(kSamples, kZ, kY, kX);
127 Array2D<float> expected_pz({
128 {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
129 {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
130 {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
131 });
132 expected.FillWithPZ(expected_pz);
133 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
134 }
135
XLA_TEST_P(BatchNormalizationTest,SumToZ)136 XLA_TEST_P(BatchNormalizationTest, SumToZ) {
137 XlaBuilder builder("sum_to_z");
138 auto input_activations = ConstantLiteral(&builder, input_literal_);
139 XlaComputation add = CreateScalarAddComputation(F32, &builder);
140 // Reduce all but the Z dimension.
141 Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
142
143 std::vector<float> expected = {6, 12.6};
144 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
145 }
146
XLA_TEST_P(BatchNormalizationTest,SquareAndReduce)147 XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
148 XlaBuilder builder("square_and_reduce");
149 auto input_activations = ConstantLiteral(&builder, input_literal_);
150 auto set_means = ConstantR1<float>(&builder, {2.f, 4.2f});
151 auto activation_deviations = Sub(input_activations, set_means,
152 /*broadcast_dimensions=*/{1});
153 XlaComputation add = CreateScalarAddComputation(F32, &builder);
154 auto dev_squares = Square(activation_deviations);
155 Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
156
157 std::vector<float> expected = {18, 0.06};
158 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
159 }
160
XLA_TEST_P(BatchNormalizationTest,VarianceToStddev)161 XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
162 XlaBuilder builder("variance_to_stddev");
163 auto variance = ConstantR1<float>(&builder, {6.f, .02f});
164 Sqrt(variance);
165
166 std::vector<float> expected = {2.44948974f, 0.14142136f};
167 ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
168 }
169
170 // Compare against a forward batch normalization example in the NN spec
171 // reference.
XLA_TEST_P(BatchNormalizationTest,SpecComparisonForward)172 XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
173 XlaBuilder builder("batch_normalize_per_spec");
174 auto input_activations =
175 CheckShape(&builder, ConstantLiteral(&builder, input_literal_),
176 ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
177 auto gamma = ConstantR1<float>(&builder, {1.0, 1.0});
178 auto beta = ConstantR1<float>(&builder, {0.0, 0.0});
179 XlaComputation add = CreateScalarAddComputation(F32, &builder);
180 // Reduce all dimensions except dimension 1.
181 Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
182 auto sum = CheckShape(
183 &builder,
184 Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
185 /*dimensions_to_reduce=*/{0, 2, 3}),
186 TwoElementVectorF32);
187 auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
188 auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
189 auto count =
190 ConstantR0<float>(&builder, ShapeUtil::ElementsIn(input_shape) /
191 ShapeUtil::ElementsIn(sum_shape));
192 auto set_means = Div(sum, count);
193
194 const float kEpsilon = 1e-9f;
195 auto epsilon = ConstantR0<float>(&builder, kEpsilon);
196 auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
197 auto activation_deviations = Sub(input_activations, set_means,
198 /*broadcast_dimensions=*/{1});
199 auto dev_squares = Square(activation_deviations);
200 auto sum_of_squares =
201 CheckShape(&builder,
202 Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
203 /*dimensions_to_reduce=*/{0, 2, 3}),
204 TwoElementVectorF32);
205 auto variance = Div(sum_of_squares, count);
206 auto standard_deviation = Sqrt(variance);
207 auto standard_deviation_above_epsilon =
208 CheckShape(&builder, Gt(standard_deviation, epsilon),
209 ShapeUtil::MakeShape(PRED, {2}));
210 auto gt_eps =
211 Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
212 auto normalization_factors = Reciprocal(gt_eps);
213 auto normalized_input_activations =
214 Mul(activation_deviations, normalization_factors,
215 /*broadcast_dimensions=*/{1});
216 /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma,
217 /*broadcast_dimensions=*/{1}),
218 beta, /*broadcast_dimensions=*/{1});
219
220 Array4D<float> expected(kSamples, kZ, kY, kX);
221 Array2D<float> pz({
222 {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
223 {0.f, -.1f / std::sqrt(.02f)},
224 {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
225 });
226 expected.FillWithPZ(pz);
227
228 ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
229 }
230
XLA_TEST_P(BatchNormalizationTest,BasicTraining)231 XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
232 const int kFeatureIndex = 3;
233 XlaBuilder builder(TestName());
234
235 auto operand = ConstantR4FromArray4D<float>(
236 &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
237
238 auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
239
240 auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
241
242 BatchNormTraining(operand, scale, offset,
243 /*epsilon=*/0.001, kFeatureIndex);
244
245 auto expected = LiteralUtil::MakeTupleFromSlices(
246 {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
247 {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
248 LiteralUtil::CreateR1<float>({4, 5}),
249 LiteralUtil::CreateR1<float>({5, 5})});
250
251 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
252 }
253
XLA_TEST_P(BatchNormalizationTest,BasicTrainingOnDimension2)254 XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
255 const int kFeatureIndex = 2;
256 XlaBuilder builder(TestName());
257
258 auto operand = ConstantR4FromArray4D<float>(
259 &builder,
260 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
261
262 auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
263
264 auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
265
266 BatchNormTraining(operand, scale, offset,
267 /*epsilon=*/0.001, kFeatureIndex);
268
269 auto expected = LiteralUtil::MakeTupleFromSlices(
270 {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
271 {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
272 LiteralUtil::CreateR1<float>({4, 5}),
273 LiteralUtil::CreateR1<float>({5, 5})});
274
275 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
276 }
277
XLA_TEST_P(BatchNormalizationTest,TrainingWithFeatureOnLowDimension)278 XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
279 // Use 0 dimension as feature, tests layout analyzer.
280 const int kFeatureIndex = 0;
281 XlaBuilder builder(TestName());
282
283 XlaOp h0;
284 auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
285 /*parameter_number=*/0, "operand",
286 &builder, &h0);
287 XlaOp h1;
288 auto scale =
289 CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
290 /*parameter_number=*/1, "scale", &builder, &h1);
291 XlaOp h2;
292 auto offset =
293 CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
294 /*parameter_number=*/2, "offset", &builder, &h2);
295
296 BatchNormTraining(h0, h1, h2,
297 /*epsilon=*/1, kFeatureIndex);
298
299 auto expected = LiteralUtil::MakeTupleFromSlices(
300 {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
301 LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
302 LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
303
304 ComputeAndCompareTuple(&builder, expected,
305 {operand.get(), scale.get(), offset.get()},
306 ErrorSpec(0.1));
307 }
308
XLA_TEST_P(BatchNormalizationTest,LargeEpsilonTest)309 XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
310 // Test the correctness of choosing a large epsilon value.
311 const int kFeatureIndex = 2;
312 XlaBuilder builder(TestName());
313
314 XlaOp h0;
315 auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
316 /*parameter_number=*/0, "operand",
317 &builder, &h0);
318 XlaOp h1;
319 auto scale =
320 CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
321 /*parameter_number=*/1, "scale", &builder, &h1);
322 XlaOp h2;
323 auto offset =
324 CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
325 /*parameter_number=*/2, "offset", &builder, &h2);
326
327 // var = 125, mean = 15, epsilon = -100
328 BatchNormTraining(h0, h1, h2,
329 /*epsilon=*/-100, kFeatureIndex);
330
331 auto expected = LiteralUtil::MakeTupleFromSlices(
332 {LiteralUtil::CreateR3FromArray3D<float>(
333 {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
334 LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
335 LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
336
337 ComputeAndCompareTuple(&builder, expected,
338 {operand.get(), scale.get(), offset.get()},
339 ErrorSpec(0.1));
340 }
341
XLA_TEST_P(BatchNormalizationTest,BatchNormGradBasic)342 XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
343 const int kFeatureIndex = 2;
344 XlaBuilder builder(TestName());
345
346 auto operand =
347 ConstantR4FromArray4D<float>(&builder, Array4D<float>(2, 2, 2, 1, 0.0f));
348
349 auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
350
351 auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
352
353 auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
354
355 auto grad_output = ConstantR4FromArray4D<float>(
356 &builder,
357 {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
358
359 BatchNormGrad(operand, scale, mean, var, grad_output,
360 /*epsilon=*/0.0, kFeatureIndex);
361
362 auto expected = LiteralUtil::MakeTupleFromSlices(
363 {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
364 {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
365 LiteralUtil::CreateR1<float>({0, 0}),
366 LiteralUtil::CreateR1<float>({16, 20})});
367
368 ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
369 }
370
371 struct BatchNormTestParam {
372 std::vector<int64> bounds;
373 int64 feature_index;
374 float random_value_mean;
375 float random_value_var;
376 bool use_cudnn_batchnorm;
377
operator <<(::std::ostream & os,const BatchNormTestParam & p)378 friend ::std::ostream& operator<<(::std::ostream& os,
379 const BatchNormTestParam& p) {
380 os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, ";
381 os << "feature_index=" << p.feature_index << ", ";
382 os << "random_value_mean=" << p.random_value_mean << ", ";
383 os << "random_value_var=" << p.random_value_var;
384
385 // Don't print use_cudnn_batchnorm when it's false, because most backends
386 // never set it to true.
387 if (p.use_cudnn_batchnorm) {
388 os << ", use_cudnn_batchnorm=true";
389 }
390 return os;
391 }
392 };
393
394 // Tests to test the fused operation of BatchNorm.
395 class BatchNormTestManySizes
396 : public ClientLibraryTestBase,
397 public ::testing::WithParamInterface<BatchNormTestParam> {
398 public:
BatchNormTestManySizes()399 BatchNormTestManySizes() {
400 mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(
401 GetParam().use_cudnn_batchnorm);
402 }
403 };
404
BuildBatchNormTestParams()405 std::vector<BatchNormTestParam> BuildBatchNormTestParams() {
406 std::vector<BatchNormTestParam> params;
407
408 auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index,
409 float random_value_mean, float random_value_var) {
410 BatchNormTestParam p{bounds, feature_index, random_value_mean,
411 random_value_var, /*use_cudnn_batchnorm=*/false};
412 params.push_back(p);
413
414 // If testing the GPU backend, also run with cudnn batchnorm enabled.
415 #ifdef XLA_TEST_BACKEND_GPU
416 p.use_cudnn_batchnorm = true;
417 params.push_back(p);
418 #endif
419 };
420
421 add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
422 add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
423
424 add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
425 add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
426 add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
427 add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
428 add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
429 add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
430 add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
431 add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
432
433 add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
434 add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
435
436 // Feature on low dimension to trigger relayout, check that internal logical
437 // to physical dimension calculation is correct after relayout.
438 add_testcase({1, 2, 3, 4}, 0, 100, 100);
439
440 // Zero-sized tensor.
441 add_testcase({1, 0, 100, 42}, 0, 100, 100);
442
443 return params;
444 }
445
446 INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
447 ::testing::ValuesIn(BuildBatchNormTestParams()));
448
XLA_TEST_P(BatchNormTestManySizes,RandomizedTrainingTests)449 XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
450 float epsilon = 0.001;
451 XlaBuilder builder(TestName());
452 const std::vector<int64>& bounds = GetParam().bounds;
453 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
454 input_array.FillRandom(GetParam().random_value_var,
455 GetParam().random_value_mean);
456
457 const int64 feature_index = GetParam().feature_index;
458 const int64 num_elements_per_feature =
459 Product(bounds) / bounds[feature_index];
460 const int64 feature_bound = bounds[feature_index];
461 std::vector<float> offset(feature_bound, 1);
462 std::vector<float> scale(feature_bound, 2);
463
464 auto input_squared =
465 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
466 std::vector<int64> reduce_dims;
467 for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
468 if (i != feature_index) {
469 reduce_dims.push_back(i);
470 }
471 }
472
473 auto sum =
474 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
475 [](float a, float b) { return a + b; });
476
477 auto sum_squared =
478 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
479 [](float a, float b) { return a + b; });
480
481 std::vector<float> mean(feature_bound);
482
483 for (int64 i = 0; i < feature_bound; ++i) {
484 mean[i] = sum[i] / num_elements_per_feature;
485 }
486
487 std::vector<float> mean_square(feature_bound);
488 for (int64 i = 0; i < feature_bound; ++i) {
489 mean_square[i] = mean[i] * mean[i];
490 }
491
492 std::vector<float> square_mean(feature_bound);
493 for (int64 i = 0; i < feature_bound; ++i) {
494 square_mean[i] = sum_squared[i] / num_elements_per_feature;
495 }
496
497 std::vector<float> var(feature_bound);
498 for (int64 i = 0; i < feature_bound; ++i) {
499 var[i] = square_mean[i] - mean_square[i];
500 }
501
502 Array4D<float> mean4D =
503 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
504 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
505 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
506 auto offset4D =
507 *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
508
509 auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
510 scale4D, offset4D, epsilon);
511
512 auto expected_normalized =
513 LiteralUtil::CreateR4FromArray4D<float>(normalized);
514
515 auto offset_literal = LiteralUtil::CreateR1<float>(offset);
516 auto scale_literal = LiteralUtil::CreateR1<float>(scale);
517 auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
518
519 auto input_activations =
520 Parameter(&builder, 0, input_literal.shape(), "input");
521 auto scale_activations =
522 Parameter(&builder, 1, scale_literal.shape(), "offset");
523 auto offset_activations =
524 Parameter(&builder, 2, offset_literal.shape(), "scale");
525
526 auto expected = LiteralUtil::MakeTupleFromSlices(
527 {expected_normalized, LiteralUtil::CreateR1<float>(mean),
528 LiteralUtil::CreateR1<float>(var)});
529
530 std::unique_ptr<GlobalData> input_data =
531 client_->TransferToServer(input_literal).ConsumeValueOrDie();
532 std::unique_ptr<GlobalData> scale_data =
533 client_->TransferToServer(scale_literal).ConsumeValueOrDie();
534 std::unique_ptr<GlobalData> offset_data =
535 client_->TransferToServer(offset_literal).ConsumeValueOrDie();
536
537 BatchNormTraining(input_activations, scale_activations, offset_activations,
538 epsilon, feature_index);
539
540 // Run all HLO passes during this test. In particular, ClientLibraryTestBase
541 // disables constant folding, but we want it enabled for our zero-sized tensor
542 // testcase.
543 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
544 ComputeAndCompareTuple(
545 &builder, expected,
546 {input_data.get(), scale_data.get(), offset_data.get()},
547 ErrorSpec(0.01, 1));
548 }
549
XLA_TEST_P(BatchNormTestManySizes,RandomizedInferencingTests)550 XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
551 float epsilon = 0.001;
552 XlaBuilder builder(TestName());
553 const std::vector<int64>& bounds = GetParam().bounds;
554 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
555 input_array.FillRandom(GetParam().random_value_var,
556 GetParam().random_value_mean);
557
558 const int64 feature_index = GetParam().feature_index;
559 const int64 num_elements_per_feature =
560 Product(bounds) / bounds[feature_index];
561 const int64 feature_bound = bounds[feature_index];
562 std::vector<float> offset(feature_bound, 1);
563 std::vector<float> scale(feature_bound, 2);
564
565 auto input_squared =
566 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
567 std::vector<int64> reduce_dims;
568 for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
569 if (i != feature_index) {
570 reduce_dims.push_back(i);
571 }
572 }
573
574 auto sum =
575 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
576 [](float a, float b) { return a + b; });
577
578 auto sum_squared =
579 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
580 [](float a, float b) { return a + b; });
581
582 std::vector<float> mean(feature_bound);
583
584 for (int64 i = 0; i < feature_bound; ++i) {
585 mean[i] = sum[i] / num_elements_per_feature;
586 }
587
588 std::vector<float> mean_square(feature_bound);
589 for (int64 i = 0; i < feature_bound; ++i) {
590 mean_square[i] = mean[i] * mean[i];
591 }
592
593 std::vector<float> square_mean(feature_bound);
594 for (int64 i = 0; i < feature_bound; ++i) {
595 square_mean[i] = sum_squared[i] / num_elements_per_feature;
596 }
597
598 std::vector<float> var(feature_bound);
599 for (int64 i = 0; i < feature_bound; ++i) {
600 var[i] = square_mean[i] - mean_square[i];
601 }
602
603 Array4D<float> mean4D =
604 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
605 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
606 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
607 auto offset4D =
608 *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
609
610 auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
611 scale4D, offset4D, epsilon);
612
613 auto offset_literal = LiteralUtil::CreateR1<float>(offset);
614 auto scale_literal = LiteralUtil::CreateR1<float>(scale);
615 auto mean_literal = LiteralUtil::CreateR1<float>(mean);
616 auto var_literal = LiteralUtil::CreateR1<float>(var);
617 auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
618
619 auto input_activations =
620 Parameter(&builder, 0, input_literal.shape(), "input");
621 auto scale_activations =
622 Parameter(&builder, 1, scale_literal.shape(), "offset");
623 auto offset_activations =
624 Parameter(&builder, 2, offset_literal.shape(), "scale");
625 auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
626 auto variance_activations =
627 Parameter(&builder, 4, var_literal.shape(), "variance");
628
629 Array4D<float> expected = normalized;
630
631 std::unique_ptr<GlobalData> input_data =
632 client_->TransferToServer(input_literal).ConsumeValueOrDie();
633 std::unique_ptr<GlobalData> scale_data =
634 client_->TransferToServer(scale_literal).ConsumeValueOrDie();
635 std::unique_ptr<GlobalData> offset_data =
636 client_->TransferToServer(offset_literal).ConsumeValueOrDie();
637 std::unique_ptr<GlobalData> mean_data =
638 client_->TransferToServer(mean_literal).ConsumeValueOrDie();
639 std::unique_ptr<GlobalData> variance_data =
640 client_->TransferToServer(var_literal).ConsumeValueOrDie();
641
642 BatchNormInference(input_activations, scale_activations, offset_activations,
643 mean_activations, variance_activations, epsilon,
644 feature_index);
645
646 // Run all HLO passes during this test. In particular, ClientLibraryTestBase
647 // disables constant folding, but we want it enabled for our zero-sized tensor
648 // testcase.
649 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
650
651 ComputeAndCompareR4<float>(
652 &builder, expected,
653 {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
654 variance_data.get()},
655 ErrorSpec(0.01, 1));
656 }
657
XLA_TEST_P(BatchNormTestManySizes,RandomizedGradTests)658 XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
659 float epsilon = 0.001;
660 XlaBuilder builder(TestName());
661 const std::vector<int64>& bounds = GetParam().bounds;
662 Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
663 input_array.FillRandom(GetParam().random_value_var,
664 GetParam().random_value_mean);
665
666 Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
667 grad_output_array.FillRandom(GetParam().random_value_var,
668 GetParam().random_value_mean);
669
670 const int64 feature_index = GetParam().feature_index;
671 const int64 num_elements_per_feature =
672 Product(bounds) / bounds[feature_index];
673 const int64 feature_bound = bounds[feature_index];
674 std::vector<float> scale(feature_bound, 2);
675
676 auto input_squared =
677 ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
678 std::vector<int64> reduce_dims;
679 for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
680 if (i != feature_index) {
681 reduce_dims.push_back(i);
682 }
683 }
684
685 auto sum =
686 ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
687 [](float a, float b) { return a + b; });
688
689 auto sum_squared =
690 ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
691 [](float a, float b) { return a + b; });
692
693 std::vector<float> mean(feature_bound);
694
695 for (int64 i = 0; i < feature_bound; ++i) {
696 if (num_elements_per_feature > 0) {
697 mean[i] = sum[i] / num_elements_per_feature;
698 } else {
699 mean[i] = 0;
700 }
701 }
702
703 std::vector<float> mean_square(feature_bound);
704 for (int64 i = 0; i < feature_bound; ++i) {
705 mean_square[i] = mean[i] * mean[i];
706 }
707
708 std::vector<float> square_mean(feature_bound);
709 for (int64 i = 0; i < feature_bound; ++i) {
710 if (num_elements_per_feature > 0) {
711 square_mean[i] = sum_squared[i] / num_elements_per_feature;
712 } else {
713 square_mean[i] = 0;
714 }
715 }
716
717 std::vector<float> var(feature_bound);
718 for (int64 i = 0; i < feature_bound; ++i) {
719 var[i] = square_mean[i] - mean_square[i];
720 }
721
722 Array4D<float> mean4D =
723 *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
724 auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
725 auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
726
727 auto var_add_epsilon = *ReferenceUtil::MapArray4D(
728 var4D, [epsilon](float a) { return a + epsilon; });
729
730 auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
731 var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
732
733 auto grad_output_times_var =
734 *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
735 [](float a, float b) { return a * b; });
736
737 auto activation_shifted = *ReferenceUtil::MapArray4D(
738 input_array, mean4D, [](float a, float b) { return a - b; });
739
740 auto activation_shifted_times_grad_output =
741 *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
742 [](float a, float b) { return a * b; });
743
744 auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
745 activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
746 [](float a, float b) { return a * b; });
747
748 auto grad_scale = ReferenceUtil::Reduce4DTo1D(
749 grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
750 [](float a, float b) { return a + b; });
751
752 auto grad_offset =
753 ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
754 [](float a, float b) { return a + b; });
755
756 auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
757 scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
758
759 auto I1 = *ReferenceUtil::MapArray4D(
760 grad_output_array, [&](float a) { return num_elements_per_feature * a; });
761
762 auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
763
764 // I3 = sum(output_grad * (activation - mean(activation)))
765 auto I3 = *ReferenceUtil::Broadcast1DTo4D(
766 ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
767 /*init=*/0.0f, reduce_dims,
768 [](float a, float b) { return a + b; }),
769 bounds, feature_index);
770
771 // I4 = (activation - mean(activation)) *
772 // sum(output_grad * (activation - mean(activation)))
773 auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
774 [](float a, float b) { return a * b; });
775
776 // I5 = (activation - mean(activation)) *
777 // sum(output_grad * (activation - mean(activation))) / (variance +
778 // epsilon))
779 auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
780 [](float a, float b) { return a / b; });
781
782 auto grad_activation = *ReferenceUtil::MapArray4D(
783 I1, I2, [](float a, float b) { return a - b; });
784
785 grad_activation = *ReferenceUtil::MapArray4D(
786 grad_activation, I5, [](float a, float b) { return a - b; });
787
788 grad_activation = *ReferenceUtil::MapArray4D(
789 grad_activation, scale4D, [](float a, float b) { return a * b; });
790
791 grad_activation = *ReferenceUtil::MapArray4D(
792 grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
793 if (num_elements_per_feature > 0) {
794 return a * b / num_elements_per_feature;
795 }
796 return 0.f;
797 });
798
799 auto expected_grad_activation =
800 LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
801
802 auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
803 auto scale_literal = LiteralUtil::CreateR1<float>(scale);
804 auto mean_literal = LiteralUtil::CreateR1<float>(mean);
805 auto var_literal = LiteralUtil::CreateR1<float>(var);
806 auto grad_output_literal =
807 LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
808
809 auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
810 auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
811 auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
812 auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
813 auto grad_output_parameter =
814 Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
815
816 std::unique_ptr<GlobalData> input_data =
817 client_->TransferToServer(input_literal).ConsumeValueOrDie();
818 std::unique_ptr<GlobalData> scale_data =
819 client_->TransferToServer(scale_literal).ConsumeValueOrDie();
820 std::unique_ptr<GlobalData> mean_data =
821 client_->TransferToServer(mean_literal).ConsumeValueOrDie();
822 std::unique_ptr<GlobalData> var_data =
823 client_->TransferToServer(var_literal).ConsumeValueOrDie();
824 std::unique_ptr<GlobalData> grad_output_data =
825 client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
826
827 BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
828 grad_output_parameter, epsilon, feature_index);
829
830 auto expected = LiteralUtil::MakeTupleFromSlices(
831 {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
832 LiteralUtil::CreateR1<float>(grad_offset)});
833
834 // Run all HLO passes during this test. In particular, ClientLibraryTestBase
835 // disables constant folding, but we want it enabled for our zero-sized tensor
836 // testcase.
837 execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
838
839 ComputeAndCompareTuple(&builder, expected,
840 {input_data.get(), scale_data.get(), mean_data.get(),
841 var_data.get(), grad_output_data.get()},
842 ErrorSpec(0.01, 1));
843 }
844
845 } // namespace
846 } // namespace xla
847