1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cmath>
17 #include <memory>
18 #include <vector>
19 
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/compiler/xla/array2d.h"
22 #include "tensorflow/compiler/xla/array4d.h"
23 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
24 #include "tensorflow/compiler/xla/client/lib/math.h"
25 #include "tensorflow/compiler/xla/client/local_client.h"
26 #include "tensorflow/compiler/xla/client/xla_builder.h"
27 #include "tensorflow/compiler/xla/client/xla_computation.h"
28 #include "tensorflow/compiler/xla/literal.h"
29 #include "tensorflow/compiler/xla/reference_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/statusor.h"
35 #include "tensorflow/compiler/xla/test.h"
36 #include "tensorflow/compiler/xla/test_helpers.h"
37 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
38 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
39 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
40 #include "tensorflow/compiler/xla/tests/test_macros.h"
41 #include "tensorflow/compiler/xla/tests/test_utils.h"
42 #include "tensorflow/compiler/xla/util.h"
43 #include "tensorflow/compiler/xla/xla_data.pb.h"
44 #include "tensorflow/core/lib/math/math_util.h"
45 #include "tensorflow/core/platform/logging.h"
46 #include "tensorflow/core/platform/test.h"
47 #include "tensorflow/core/platform/types.h"
48 
49 namespace xla {
50 namespace {
51 
52 class BatchNormalizationTest
53     : public ClientLibraryTestBase,
54       public ::testing::WithParamInterface<bool /*use_cudnn_batchnorm*/> {
55  protected:
BatchNormalizationTest()56   BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) {
57     mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam());
58 
59     Array2D<float> pz({
60         // z0 z1
61         {-1.0f, 4.1f},  // p0
62         {2.0f, 4.1f},   // p1
63         {5.0f, 4.4f},   // p2
64     });
65     input_array_.FillWithPZ(pz);
66     input_literal_ = LiteralUtil::CreateR4FromArray4D(input_array_);
67     CHECK_EQ(kSamples, input_array_.planes());
68     CHECK_EQ(kZ, input_array_.depth());
69     CHECK_EQ(kY, input_array_.height());
70     CHECK_EQ(kY, input_array_.width());
71   }
72 
CheckShape(XlaBuilder * b,const XlaOp & operand,const Shape & expected_shape) const73   XlaOp CheckShape(XlaBuilder* b, const XlaOp& operand,
74                    const Shape& expected_shape) const {
75     Shape actual_shape = b->GetShape(operand).ConsumeValueOrDie();
76     CHECK(ShapeUtil::Equal(expected_shape, actual_shape))
77         << "want " << ShapeUtil::HumanString(expected_shape) << " got "
78         << ShapeUtil::HumanString(actual_shape);
79     return operand;
80   }
81 
82   static constexpr int64 kSamples = 3;
83   static constexpr int64 kX = 1;
84   static constexpr int64 kY = 1;
85   static constexpr int64 kZ = 2;
86 
87   Array4D<float> input_array_;
88   Literal input_literal_;
89   const ErrorSpec error_spec_{0.001, 0.001};
90 };
91 
92 // If testing the GPU backend, run the tests twice, with and without cudnn
93 // batchnorm.  Otherwise, just run the tests once -- the value of this flag
94 // doesn't matter.
95 #ifdef XLA_TEST_BACKEND_GPU
96 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
97                         ::testing::Bool());
98 #else
99 INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest,
100                         ::testing::Values(false));
101 #endif
102 
XLA_TEST_P(BatchNormalizationTest,SubtractInZ)103 XLA_TEST_P(BatchNormalizationTest, SubtractInZ) {
104   XlaBuilder builder("subtract_in_z_one_sample");
105   auto x = ConstantLiteral(&builder, input_literal_);
106   auto y = ConstantR1<float>(&builder, {3.14, 4.25});
107   Sub(x, y, /*broadcast_dimensions=*/{1});
108 
109   Array4D<float> expected(kSamples, kZ, kY, kX);
110   Array2D<float> pz({
111       {-1.0f - 3.14f, 4.1f - 4.25f},  // p0
112       {2.0f - 3.14f, 4.1f - 4.25f},   // p1
113       {5.0f - 3.14f, 4.4f - 4.25f},   // p2
114   });
115   expected.FillWithPZ(pz);
116   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
117 }
118 
XLA_TEST_P(BatchNormalizationTest,SquareTesseractElementwise)119 XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) {
120   XlaBuilder builder("square_tesseract_elementwise");
121   auto x = ConstantLiteral(&builder, input_literal_);
122   Square(x);
123 
124   using tensorflow::MathUtil;
125 
126   Array4D<float> expected(kSamples, kZ, kY, kX);
127   Array2D<float> expected_pz({
128       {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)},
129       {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)},
130       {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)},
131   });
132   expected.FillWithPZ(expected_pz);
133   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
134 }
135 
XLA_TEST_P(BatchNormalizationTest,SumToZ)136 XLA_TEST_P(BatchNormalizationTest, SumToZ) {
137   XlaBuilder builder("sum_to_z");
138   auto input_activations = ConstantLiteral(&builder, input_literal_);
139   XlaComputation add = CreateScalarAddComputation(F32, &builder);
140   // Reduce all but the Z dimension.
141   Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
142 
143   std::vector<float> expected = {6, 12.6};
144   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
145 }
146 
XLA_TEST_P(BatchNormalizationTest,SquareAndReduce)147 XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) {
148   XlaBuilder builder("square_and_reduce");
149   auto input_activations = ConstantLiteral(&builder, input_literal_);
150   auto set_means = ConstantR1<float>(&builder, {2.f, 4.2f});
151   auto activation_deviations = Sub(input_activations, set_means,
152                                    /*broadcast_dimensions=*/{1});
153   XlaComputation add = CreateScalarAddComputation(F32, &builder);
154   auto dev_squares = Square(activation_deviations);
155   Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add, {0, 2, 3});
156 
157   std::vector<float> expected = {18, 0.06};
158   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
159 }
160 
XLA_TEST_P(BatchNormalizationTest,VarianceToStddev)161 XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) {
162   XlaBuilder builder("variance_to_stddev");
163   auto variance = ConstantR1<float>(&builder, {6.f, .02f});
164   Sqrt(variance);
165 
166   std::vector<float> expected = {2.44948974f, 0.14142136f};
167   ComputeAndCompareR1<float>(&builder, expected, {}, error_spec_);
168 }
169 
170 // Compare against a forward batch normalization example in the NN spec
171 // reference.
XLA_TEST_P(BatchNormalizationTest,SpecComparisonForward)172 XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) {
173   XlaBuilder builder("batch_normalize_per_spec");
174   auto input_activations =
175       CheckShape(&builder, ConstantLiteral(&builder, input_literal_),
176                  ShapeUtil::MakeShape(F32, {3, 2, 1, 1}));
177   auto gamma = ConstantR1<float>(&builder, {1.0, 1.0});
178   auto beta = ConstantR1<float>(&builder, {0.0, 0.0});
179   XlaComputation add = CreateScalarAddComputation(F32, &builder);
180   // Reduce all dimensions except dimension 1.
181   Shape TwoElementVectorF32 = ShapeUtil::MakeShape(F32, {2});
182   auto sum = CheckShape(
183       &builder,
184       Reduce(input_activations, ConstantR0<float>(&builder, 0.0f), add,
185              /*dimensions_to_reduce=*/{0, 2, 3}),
186       TwoElementVectorF32);
187   auto input_shape = builder.GetShape(input_activations).ConsumeValueOrDie();
188   auto sum_shape = builder.GetShape(sum).ConsumeValueOrDie();
189   auto count =
190       ConstantR0<float>(&builder, ShapeUtil::ElementsIn(input_shape) /
191                                       ShapeUtil::ElementsIn(sum_shape));
192   auto set_means = Div(sum, count);
193 
194   const float kEpsilon = 1e-9f;
195   auto epsilon = ConstantR0<float>(&builder, kEpsilon);
196   auto epsilon2 = ConstantR1<float>(&builder, {kEpsilon, kEpsilon});
197   auto activation_deviations = Sub(input_activations, set_means,
198                                    /*broadcast_dimensions=*/{1});
199   auto dev_squares = Square(activation_deviations);
200   auto sum_of_squares =
201       CheckShape(&builder,
202                  Reduce(dev_squares, ConstantR0<float>(&builder, 0.0f), add,
203                         /*dimensions_to_reduce=*/{0, 2, 3}),
204                  TwoElementVectorF32);
205   auto variance = Div(sum_of_squares, count);
206   auto standard_deviation = Sqrt(variance);
207   auto standard_deviation_above_epsilon =
208       CheckShape(&builder, Gt(standard_deviation, epsilon),
209                  ShapeUtil::MakeShape(PRED, {2}));
210   auto gt_eps =
211       Select(standard_deviation_above_epsilon, standard_deviation, epsilon2);
212   auto normalization_factors = Reciprocal(gt_eps);
213   auto normalized_input_activations =
214       Mul(activation_deviations, normalization_factors,
215           /*broadcast_dimensions=*/{1});
216   /* auto output_activations = */ Add(Mul(normalized_input_activations, gamma,
217                                           /*broadcast_dimensions=*/{1}),
218                                       beta, /*broadcast_dimensions=*/{1});
219 
220   Array4D<float> expected(kSamples, kZ, kY, kX);
221   Array2D<float> pz({
222       {-3.f / std::sqrt(6.f), -.1f / std::sqrt(.02f)},
223       {0.f, -.1f / std::sqrt(.02f)},
224       {3.f / std::sqrt(6.f), .2f / std::sqrt(.02f)},
225   });
226   expected.FillWithPZ(pz);
227 
228   ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
229 }
230 
XLA_TEST_P(BatchNormalizationTest,BasicTraining)231 XLA_TEST_P(BatchNormalizationTest, BasicTraining) {
232   const int kFeatureIndex = 3;
233   XlaBuilder builder(TestName());
234 
235   auto operand = ConstantR4FromArray4D<float>(
236       &builder, {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}});
237 
238   auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
239 
240   auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
241 
242   BatchNormTraining(operand, scale, offset,
243                     /*epsilon=*/0.001, kFeatureIndex);
244 
245   auto expected = LiteralUtil::MakeTupleFromSlices(
246       {LiteralUtil::CreateR4<float>({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}},
247                                      {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}),
248        LiteralUtil::CreateR1<float>({4, 5}),
249        LiteralUtil::CreateR1<float>({5, 5})});
250 
251   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
252 }
253 
XLA_TEST_P(BatchNormalizationTest,BasicTrainingOnDimension2)254 XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnDimension2) {
255   const int kFeatureIndex = 2;
256   XlaBuilder builder(TestName());
257 
258   auto operand = ConstantR4FromArray4D<float>(
259       &builder,
260       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
261 
262   auto scale = ConstantR1<float>(&builder, {2.0f, 3.0f});
263 
264   auto offset = ConstantR1<float>(&builder, {1.0f, 2.0f});
265 
266   BatchNormTraining(operand, scale, offset,
267                     /*epsilon=*/0.001, kFeatureIndex);
268 
269   auto expected = LiteralUtil::MakeTupleFromSlices(
270       {LiteralUtil::CreateR4<float>({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}},
271                                      {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}),
272        LiteralUtil::CreateR1<float>({4, 5}),
273        LiteralUtil::CreateR1<float>({5, 5})});
274 
275   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
276 }
277 
XLA_TEST_P(BatchNormalizationTest,TrainingWithFeatureOnLowDimension)278 XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) {
279   // Use 0 dimension as feature, tests layout analyzer.
280   const int kFeatureIndex = 0;
281   XlaBuilder builder(TestName());
282 
283   XlaOp h0;
284   auto operand = CreateR3Parameter<float>(Array3D<float>(260, 2, 2, 1.0f),
285                                           /*parameter_number=*/0, "operand",
286                                           &builder, &h0);
287   XlaOp h1;
288   auto scale =
289       CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
290                                /*parameter_number=*/1, "scale", &builder, &h1);
291   XlaOp h2;
292   auto offset =
293       CreateR1Parameter<float>(std::vector<float>(260, 1.0f),
294                                /*parameter_number=*/2, "offset", &builder, &h2);
295 
296   BatchNormTraining(h0, h1, h2,
297                     /*epsilon=*/1, kFeatureIndex);
298 
299   auto expected = LiteralUtil::MakeTupleFromSlices(
300       {LiteralUtil::CreateR3FromArray3D<float>(Array3D<float>(260, 2, 2, 1.0f)),
301        LiteralUtil::CreateR1<float>(std::vector<float>(260, 1.0f)),
302        LiteralUtil::CreateR1<float>(std::vector<float>(260, 0.0f))});
303 
304   ComputeAndCompareTuple(&builder, expected,
305                          {operand.get(), scale.get(), offset.get()},
306                          ErrorSpec(0.1));
307 }
308 
XLA_TEST_P(BatchNormalizationTest,LargeEpsilonTest)309 XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) {
310   // Test the correctness of choosing a large epsilon value.
311   const int kFeatureIndex = 2;
312   XlaBuilder builder(TestName());
313 
314   XlaOp h0;
315   auto operand = CreateR3Parameter<float>({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}},
316                                           /*parameter_number=*/0, "operand",
317                                           &builder, &h0);
318   XlaOp h1;
319   auto scale =
320       CreateR1Parameter<float>(std::vector<float>(1, 1.0f),
321                                /*parameter_number=*/1, "scale", &builder, &h1);
322   XlaOp h2;
323   auto offset =
324       CreateR1Parameter<float>(std::vector<float>(1, 0.0f),
325                                /*parameter_number=*/2, "offset", &builder, &h2);
326 
327   // var = 125, mean = 15, epsilon = -100
328   BatchNormTraining(h0, h1, h2,
329                     /*epsilon=*/-100, kFeatureIndex);
330 
331   auto expected = LiteralUtil::MakeTupleFromSlices(
332       {LiteralUtil::CreateR3FromArray3D<float>(
333            {{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}),
334        LiteralUtil::CreateR1<float>(std::vector<float>(1, 15.0f)),
335        LiteralUtil::CreateR1<float>(std::vector<float>(1, 125.0f))});
336 
337   ComputeAndCompareTuple(&builder, expected,
338                          {operand.get(), scale.get(), offset.get()},
339                          ErrorSpec(0.1));
340 }
341 
XLA_TEST_P(BatchNormalizationTest,BatchNormGradBasic)342 XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) {
343   const int kFeatureIndex = 2;
344   XlaBuilder builder(TestName());
345 
346   auto operand =
347       ConstantR4FromArray4D<float>(&builder, Array4D<float>(2, 2, 2, 1, 0.0f));
348 
349   auto scale = ConstantR1<float>(&builder, {1.0f, 1.0f});
350 
351   auto mean = ConstantR1<float>(&builder, {0.0f, 0.0f});
352 
353   auto var = ConstantR1<float>(&builder, {1.0f, 1.0f});
354 
355   auto grad_output = ConstantR4FromArray4D<float>(
356       &builder,
357       {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}});
358 
359   BatchNormGrad(operand, scale, mean, var, grad_output,
360                 /*epsilon=*/0.0, kFeatureIndex);
361 
362   auto expected = LiteralUtil::MakeTupleFromSlices(
363       {LiteralUtil::CreateR4<float>({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}},
364                                      {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}),
365        LiteralUtil::CreateR1<float>({0, 0}),
366        LiteralUtil::CreateR1<float>({16, 20})});
367 
368   ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1));
369 }
370 
371 struct BatchNormTestParam {
372   std::vector<int64> bounds;
373   int64 feature_index;
374   float random_value_mean;
375   float random_value_var;
376   bool use_cudnn_batchnorm;
377 
operator <<(::std::ostream & os,const BatchNormTestParam & p)378   friend ::std::ostream& operator<<(::std::ostream& os,
379                                     const BatchNormTestParam& p) {
380     os << "bounds={" << absl::StrJoin(p.bounds, ", ") << "}, ";
381     os << "feature_index=" << p.feature_index << ", ";
382     os << "random_value_mean=" << p.random_value_mean << ", ";
383     os << "random_value_var=" << p.random_value_var;
384 
385     // Don't print use_cudnn_batchnorm when it's false, because most backends
386     // never set it to true.
387     if (p.use_cudnn_batchnorm) {
388       os << ", use_cudnn_batchnorm=true";
389     }
390     return os;
391   }
392 };
393 
394 // Tests to test the fused operation of BatchNorm.
395 class BatchNormTestManySizes
396     : public ClientLibraryTestBase,
397       public ::testing::WithParamInterface<BatchNormTestParam> {
398  public:
BatchNormTestManySizes()399   BatchNormTestManySizes() {
400     mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(
401         GetParam().use_cudnn_batchnorm);
402   }
403 };
404 
BuildBatchNormTestParams()405 std::vector<BatchNormTestParam> BuildBatchNormTestParams() {
406   std::vector<BatchNormTestParam> params;
407 
408   auto add_testcase = [&](std::vector<int64> bounds, int64 feature_index,
409                           float random_value_mean, float random_value_var) {
410     BatchNormTestParam p{bounds, feature_index, random_value_mean,
411                          random_value_var, /*use_cudnn_batchnorm=*/false};
412     params.push_back(p);
413 
414     // If testing the GPU backend, also run with cudnn batchnorm enabled.
415 #ifdef XLA_TEST_BACKEND_GPU
416     p.use_cudnn_batchnorm = true;
417     params.push_back(p);
418 #endif
419   };
420 
421   add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f);
422   add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f);
423 
424   add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f);
425   add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f);
426   add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f);
427   add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f);
428   add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f);
429   add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f);
430   add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f);
431   add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f);
432 
433   add_testcase({24, 129, 1, 2}, 2, 10000, 10000);
434   add_testcase({24, 129, 1, 2}, 3, 10000, 10000);
435 
436   // Feature on low dimension to trigger relayout, check that internal logical
437   // to physical dimension calculation is correct after relayout.
438   add_testcase({1, 2, 3, 4}, 0, 100, 100);
439 
440   // Zero-sized tensor.
441   add_testcase({1, 0, 100, 42}, 0, 100, 100);
442 
443   return params;
444 }
445 
446 INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes,
447                         ::testing::ValuesIn(BuildBatchNormTestParams()));
448 
XLA_TEST_P(BatchNormTestManySizes,RandomizedTrainingTests)449 XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) {
450   float epsilon = 0.001;
451   XlaBuilder builder(TestName());
452   const std::vector<int64>& bounds = GetParam().bounds;
453   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
454   input_array.FillRandom(GetParam().random_value_var,
455                          GetParam().random_value_mean);
456 
457   const int64 feature_index = GetParam().feature_index;
458   const int64 num_elements_per_feature =
459       Product(bounds) / bounds[feature_index];
460   const int64 feature_bound = bounds[feature_index];
461   std::vector<float> offset(feature_bound, 1);
462   std::vector<float> scale(feature_bound, 2);
463 
464   auto input_squared =
465       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
466   std::vector<int64> reduce_dims;
467   for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
468     if (i != feature_index) {
469       reduce_dims.push_back(i);
470     }
471   }
472 
473   auto sum =
474       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
475                                   [](float a, float b) { return a + b; });
476 
477   auto sum_squared =
478       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
479                                   [](float a, float b) { return a + b; });
480 
481   std::vector<float> mean(feature_bound);
482 
483   for (int64 i = 0; i < feature_bound; ++i) {
484     mean[i] = sum[i] / num_elements_per_feature;
485   }
486 
487   std::vector<float> mean_square(feature_bound);
488   for (int64 i = 0; i < feature_bound; ++i) {
489     mean_square[i] = mean[i] * mean[i];
490   }
491 
492   std::vector<float> square_mean(feature_bound);
493   for (int64 i = 0; i < feature_bound; ++i) {
494     square_mean[i] = sum_squared[i] / num_elements_per_feature;
495   }
496 
497   std::vector<float> var(feature_bound);
498   for (int64 i = 0; i < feature_bound; ++i) {
499     var[i] = square_mean[i] - mean_square[i];
500   }
501 
502   Array4D<float> mean4D =
503       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
504   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
505   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
506   auto offset4D =
507       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
508 
509   auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
510                                                 scale4D, offset4D, epsilon);
511 
512   auto expected_normalized =
513       LiteralUtil::CreateR4FromArray4D<float>(normalized);
514 
515   auto offset_literal = LiteralUtil::CreateR1<float>(offset);
516   auto scale_literal = LiteralUtil::CreateR1<float>(scale);
517   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
518 
519   auto input_activations =
520       Parameter(&builder, 0, input_literal.shape(), "input");
521   auto scale_activations =
522       Parameter(&builder, 1, scale_literal.shape(), "offset");
523   auto offset_activations =
524       Parameter(&builder, 2, offset_literal.shape(), "scale");
525 
526   auto expected = LiteralUtil::MakeTupleFromSlices(
527       {expected_normalized, LiteralUtil::CreateR1<float>(mean),
528        LiteralUtil::CreateR1<float>(var)});
529 
530   std::unique_ptr<GlobalData> input_data =
531       client_->TransferToServer(input_literal).ConsumeValueOrDie();
532   std::unique_ptr<GlobalData> scale_data =
533       client_->TransferToServer(scale_literal).ConsumeValueOrDie();
534   std::unique_ptr<GlobalData> offset_data =
535       client_->TransferToServer(offset_literal).ConsumeValueOrDie();
536 
537   BatchNormTraining(input_activations, scale_activations, offset_activations,
538                     epsilon, feature_index);
539 
540   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
541   // disables constant folding, but we want it enabled for our zero-sized tensor
542   // testcase.
543   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
544   ComputeAndCompareTuple(
545       &builder, expected,
546       {input_data.get(), scale_data.get(), offset_data.get()},
547       ErrorSpec(0.01, 1));
548 }
549 
XLA_TEST_P(BatchNormTestManySizes,RandomizedInferencingTests)550 XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) {
551   float epsilon = 0.001;
552   XlaBuilder builder(TestName());
553   const std::vector<int64>& bounds = GetParam().bounds;
554   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
555   input_array.FillRandom(GetParam().random_value_var,
556                          GetParam().random_value_mean);
557 
558   const int64 feature_index = GetParam().feature_index;
559   const int64 num_elements_per_feature =
560       Product(bounds) / bounds[feature_index];
561   const int64 feature_bound = bounds[feature_index];
562   std::vector<float> offset(feature_bound, 1);
563   std::vector<float> scale(feature_bound, 2);
564 
565   auto input_squared =
566       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
567   std::vector<int64> reduce_dims;
568   for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
569     if (i != feature_index) {
570       reduce_dims.push_back(i);
571     }
572   }
573 
574   auto sum =
575       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
576                                   [](float a, float b) { return a + b; });
577 
578   auto sum_squared =
579       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
580                                   [](float a, float b) { return a + b; });
581 
582   std::vector<float> mean(feature_bound);
583 
584   for (int64 i = 0; i < feature_bound; ++i) {
585     mean[i] = sum[i] / num_elements_per_feature;
586   }
587 
588   std::vector<float> mean_square(feature_bound);
589   for (int64 i = 0; i < feature_bound; ++i) {
590     mean_square[i] = mean[i] * mean[i];
591   }
592 
593   std::vector<float> square_mean(feature_bound);
594   for (int64 i = 0; i < feature_bound; ++i) {
595     square_mean[i] = sum_squared[i] / num_elements_per_feature;
596   }
597 
598   std::vector<float> var(feature_bound);
599   for (int64 i = 0; i < feature_bound; ++i) {
600     var[i] = square_mean[i] - mean_square[i];
601   }
602 
603   Array4D<float> mean4D =
604       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
605   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
606   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
607   auto offset4D =
608       *ReferenceUtil::Broadcast1DTo4D(offset, bounds, feature_index);
609 
610   auto normalized = *ReferenceUtil::BatchNorm4D(input_array, mean4D, var4D,
611                                                 scale4D, offset4D, epsilon);
612 
613   auto offset_literal = LiteralUtil::CreateR1<float>(offset);
614   auto scale_literal = LiteralUtil::CreateR1<float>(scale);
615   auto mean_literal = LiteralUtil::CreateR1<float>(mean);
616   auto var_literal = LiteralUtil::CreateR1<float>(var);
617   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
618 
619   auto input_activations =
620       Parameter(&builder, 0, input_literal.shape(), "input");
621   auto scale_activations =
622       Parameter(&builder, 1, scale_literal.shape(), "offset");
623   auto offset_activations =
624       Parameter(&builder, 2, offset_literal.shape(), "scale");
625   auto mean_activations = Parameter(&builder, 3, mean_literal.shape(), "mean");
626   auto variance_activations =
627       Parameter(&builder, 4, var_literal.shape(), "variance");
628 
629   Array4D<float> expected = normalized;
630 
631   std::unique_ptr<GlobalData> input_data =
632       client_->TransferToServer(input_literal).ConsumeValueOrDie();
633   std::unique_ptr<GlobalData> scale_data =
634       client_->TransferToServer(scale_literal).ConsumeValueOrDie();
635   std::unique_ptr<GlobalData> offset_data =
636       client_->TransferToServer(offset_literal).ConsumeValueOrDie();
637   std::unique_ptr<GlobalData> mean_data =
638       client_->TransferToServer(mean_literal).ConsumeValueOrDie();
639   std::unique_ptr<GlobalData> variance_data =
640       client_->TransferToServer(var_literal).ConsumeValueOrDie();
641 
642   BatchNormInference(input_activations, scale_activations, offset_activations,
643                      mean_activations, variance_activations, epsilon,
644                      feature_index);
645 
646   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
647   // disables constant folding, but we want it enabled for our zero-sized tensor
648   // testcase.
649   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
650 
651   ComputeAndCompareR4<float>(
652       &builder, expected,
653       {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(),
654        variance_data.get()},
655       ErrorSpec(0.01, 1));
656 }
657 
XLA_TEST_P(BatchNormTestManySizes,RandomizedGradTests)658 XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) {
659   float epsilon = 0.001;
660   XlaBuilder builder(TestName());
661   const std::vector<int64>& bounds = GetParam().bounds;
662   Array4D<float> input_array(bounds[0], bounds[1], bounds[2], bounds[3]);
663   input_array.FillRandom(GetParam().random_value_var,
664                          GetParam().random_value_mean);
665 
666   Array4D<float> grad_output_array(bounds[0], bounds[1], bounds[2], bounds[3]);
667   grad_output_array.FillRandom(GetParam().random_value_var,
668                                GetParam().random_value_mean);
669 
670   const int64 feature_index = GetParam().feature_index;
671   const int64 num_elements_per_feature =
672       Product(bounds) / bounds[feature_index];
673   const int64 feature_bound = bounds[feature_index];
674   std::vector<float> scale(feature_bound, 2);
675 
676   auto input_squared =
677       ReferenceUtil::MapArray4D(input_array, [](float a) { return a * a; });
678   std::vector<int64> reduce_dims;
679   for (int64 i = 0; i < static_cast<int64>(bounds.size()); ++i) {
680     if (i != feature_index) {
681       reduce_dims.push_back(i);
682     }
683   }
684 
685   auto sum =
686       ReferenceUtil::Reduce4DTo1D(input_array, /*init=*/0.0f, reduce_dims,
687                                   [](float a, float b) { return a + b; });
688 
689   auto sum_squared =
690       ReferenceUtil::Reduce4DTo1D(*input_squared, /*init=*/0.0f, reduce_dims,
691                                   [](float a, float b) { return a + b; });
692 
693   std::vector<float> mean(feature_bound);
694 
695   for (int64 i = 0; i < feature_bound; ++i) {
696     if (num_elements_per_feature > 0) {
697       mean[i] = sum[i] / num_elements_per_feature;
698     } else {
699       mean[i] = 0;
700     }
701   }
702 
703   std::vector<float> mean_square(feature_bound);
704   for (int64 i = 0; i < feature_bound; ++i) {
705     mean_square[i] = mean[i] * mean[i];
706   }
707 
708   std::vector<float> square_mean(feature_bound);
709   for (int64 i = 0; i < feature_bound; ++i) {
710     if (num_elements_per_feature > 0) {
711       square_mean[i] = sum_squared[i] / num_elements_per_feature;
712     } else {
713       square_mean[i] = 0;
714     }
715   }
716 
717   std::vector<float> var(feature_bound);
718   for (int64 i = 0; i < feature_bound; ++i) {
719     var[i] = square_mean[i] - mean_square[i];
720   }
721 
722   Array4D<float> mean4D =
723       *ReferenceUtil::Broadcast1DTo4D(mean, bounds, feature_index);
724   auto var4D = *ReferenceUtil::Broadcast1DTo4D(var, bounds, feature_index);
725   auto scale4D = *ReferenceUtil::Broadcast1DTo4D(scale, bounds, feature_index);
726 
727   auto var_add_epsilon = *ReferenceUtil::MapArray4D(
728       var4D, [epsilon](float a) { return a + epsilon; });
729 
730   auto rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
731       var_add_epsilon, [](float a) { return 1 / std::sqrt(a); });
732 
733   auto grad_output_times_var =
734       *ReferenceUtil::MapArray4D(grad_output_array, var_add_epsilon,
735                                  [](float a, float b) { return a * b; });
736 
737   auto activation_shifted = *ReferenceUtil::MapArray4D(
738       input_array, mean4D, [](float a, float b) { return a - b; });
739 
740   auto activation_shifted_times_grad_output =
741       *ReferenceUtil::MapArray4D(grad_output_array, activation_shifted,
742                                  [](float a, float b) { return a * b; });
743 
744   auto grad_scale_before_reduction = *ReferenceUtil::MapArray4D(
745       activation_shifted_times_grad_output, rsqrt_var_add_epsilon,
746       [](float a, float b) { return a * b; });
747 
748   auto grad_scale = ReferenceUtil::Reduce4DTo1D(
749       grad_scale_before_reduction, /*init=*/0.0f, reduce_dims,
750       [](float a, float b) { return a + b; });
751 
752   auto grad_offset =
753       ReferenceUtil::Reduce4DTo1D(grad_output_array, /*init=*/0.0f, reduce_dims,
754                                   [](float a, float b) { return a + b; });
755 
756   auto scale_times_rsqrt_var_add_epsilon = *ReferenceUtil::MapArray4D(
757       scale4D, rsqrt_var_add_epsilon, [](float a, float b) { return a * b; });
758 
759   auto I1 = *ReferenceUtil::MapArray4D(
760       grad_output_array, [&](float a) { return num_elements_per_feature * a; });
761 
762   auto I2 = *ReferenceUtil::Broadcast1DTo4D(grad_offset, bounds, feature_index);
763 
764   // I3 = sum(output_grad * (activation - mean(activation)))
765   auto I3 = *ReferenceUtil::Broadcast1DTo4D(
766       ReferenceUtil::Reduce4DTo1D(activation_shifted_times_grad_output,
767                                   /*init=*/0.0f, reduce_dims,
768                                   [](float a, float b) { return a + b; }),
769       bounds, feature_index);
770 
771   // I4 = (activation - mean(activation)) *
772   //   sum(output_grad * (activation - mean(activation)))
773   auto I4 = *ReferenceUtil::MapArray4D(I3, activation_shifted,
774                                        [](float a, float b) { return a * b; });
775 
776   // I5 = (activation - mean(activation)) *
777   //   sum(output_grad * (activation - mean(activation))) / (variance +
778   //   epsilon))
779   auto I5 = *ReferenceUtil::MapArray4D(I4, var_add_epsilon,
780                                        [](float a, float b) { return a / b; });
781 
782   auto grad_activation = *ReferenceUtil::MapArray4D(
783       I1, I2, [](float a, float b) { return a - b; });
784 
785   grad_activation = *ReferenceUtil::MapArray4D(
786       grad_activation, I5, [](float a, float b) { return a - b; });
787 
788   grad_activation = *ReferenceUtil::MapArray4D(
789       grad_activation, scale4D, [](float a, float b) { return a * b; });
790 
791   grad_activation = *ReferenceUtil::MapArray4D(
792       grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) {
793         if (num_elements_per_feature > 0) {
794           return a * b / num_elements_per_feature;
795         }
796         return 0.f;
797       });
798 
799   auto expected_grad_activation =
800       LiteralUtil::CreateR4FromArray4D<float>(grad_activation);
801 
802   auto input_literal = LiteralUtil::CreateR4FromArray4D<float>(input_array);
803   auto scale_literal = LiteralUtil::CreateR1<float>(scale);
804   auto mean_literal = LiteralUtil::CreateR1<float>(mean);
805   auto var_literal = LiteralUtil::CreateR1<float>(var);
806   auto grad_output_literal =
807       LiteralUtil::CreateR4FromArray4D<float>(grad_output_array);
808 
809   auto input_parameter = Parameter(&builder, 0, input_literal.shape(), "input");
810   auto scale_parameter = Parameter(&builder, 1, scale_literal.shape(), "scale");
811   auto mean_parameter = Parameter(&builder, 2, mean_literal.shape(), "mean");
812   auto var_parameter = Parameter(&builder, 3, var_literal.shape(), "variance");
813   auto grad_output_parameter =
814       Parameter(&builder, 4, grad_output_literal.shape(), "grad_output");
815 
816   std::unique_ptr<GlobalData> input_data =
817       client_->TransferToServer(input_literal).ConsumeValueOrDie();
818   std::unique_ptr<GlobalData> scale_data =
819       client_->TransferToServer(scale_literal).ConsumeValueOrDie();
820   std::unique_ptr<GlobalData> mean_data =
821       client_->TransferToServer(mean_literal).ConsumeValueOrDie();
822   std::unique_ptr<GlobalData> var_data =
823       client_->TransferToServer(var_literal).ConsumeValueOrDie();
824   std::unique_ptr<GlobalData> grad_output_data =
825       client_->TransferToServer(grad_output_literal).ConsumeValueOrDie();
826 
827   BatchNormGrad(input_parameter, scale_parameter, mean_parameter, var_parameter,
828                 grad_output_parameter, epsilon, feature_index);
829 
830   auto expected = LiteralUtil::MakeTupleFromSlices(
831       {expected_grad_activation, LiteralUtil::CreateR1<float>(grad_scale),
832        LiteralUtil::CreateR1<float>(grad_offset)});
833 
834   // Run all HLO passes during this test.  In particular, ClientLibraryTestBase
835   // disables constant folding, but we want it enabled for our zero-sized tensor
836   // testcase.
837   execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes();
838 
839   ComputeAndCompareTuple(&builder, expected,
840                          {input_data.get(), scale_data.get(), mean_data.get(),
841                           var_data.get(), grad_output_data.get()},
842                          ErrorSpec(0.01, 1));
843 }
844 
845 }  // namespace
846 }  // namespace xla
847