1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/reference_util.h"
17 
18 #include <array>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
Array2DF32ToF64(const Array2D<float> & input)35 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
36     const Array2D<float>& input) {
37   auto result =
38       absl::make_unique<Array2D<double>>(input.height(), input.width());
39   for (int64 rowno = 0; rowno < input.height(); ++rowno) {
40     for (int64 colno = 0; colno < input.height(); ++colno) {
41       (*result)(rowno, colno) = input(rowno, colno);
42     }
43   }
44   return result;
45 }
46 
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding)47 /*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
48     const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
49     Padding padding) {
50   return ConvArray3DGeneralDimensionsDilated(
51       lhs, rhs, kernel_stride, padding, 1, 1,
52       XlaBuilder::CreateDefaultConvDimensionNumbers(1));
53 }
54 
55 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding,int64 lhs_dilation,int64 rhs_dilation,const ConvolutionDimensionNumbers & dnums)56 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
57     const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
58     Padding padding, int64 lhs_dilation, int64 rhs_dilation,
59     const ConvolutionDimensionNumbers& dnums) {
60   CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
61   CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
62   CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
63   // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
64   // array by adding a fourth dummy dimension of size 1 without stride, padding
65   // and dilation.
66   Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
67   a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
68     CHECK_EQ(indices[3], 0);
69     *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
70   });
71   Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
72   a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
73     CHECK_EQ(indices[3], 0);
74     *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
75   });
76   // Add a second dummy spatial dimensions.
77   ConvolutionDimensionNumbers dnums2d = dnums;
78   dnums2d.add_input_spatial_dimensions(3);
79   dnums2d.add_kernel_spatial_dimensions(3);
80   dnums2d.add_output_spatial_dimensions(3);
81   std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
82       a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
83       {rhs_dilation, 1}, dnums2d);
84 
85   auto convr3 = absl::make_unique<Array3D<float>>(
86       convr4->planes(), convr4->depth(), convr4->height());
87   convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
88     CHECK_EQ(indices[3], 0);
89     convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
90   });
91   return convr3;
92 }
93 
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding)94 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
95     const Array4D<float>& lhs, const Array4D<float>& rhs,
96     std::pair<int64, int64> kernel_stride, Padding padding) {
97   return ConvArray4DGeneralDimensions(
98       lhs, rhs, kernel_stride, padding,
99       XlaBuilder::CreateDefaultConvDimensionNumbers());
100 }
101 
102 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64,int64> kernel_stride,Padding padding)103 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
104                                     const Array4D<float>& depthwise_weights,
105                                     const Array4D<float>& pointwise_weights,
106                                     std::pair<int64, int64> kernel_stride,
107                                     Padding padding) {
108   const int64 depth_multiplier = depthwise_weights.planes();
109   CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
110 
111   // Combine the two weights by reducing the depth_multiplier, so that we can
112   // apply a single convolution on the combined weights.
113   Array4D<float> weights(pointwise_weights.planes(), input.depth(),
114                          depthwise_weights.height(), depthwise_weights.width());
115   for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
116     for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
117       for (int64 kz = 0; kz < input.depth(); ++kz) {
118         for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
119           float weight = 0.0;
120           for (int64 depth = 0; depth < depth_multiplier; ++depth) {
121             weight +=
122                 depthwise_weights(depth, kz, ky, kx) *
123                 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
124           }
125           weights(out, kz, ky, kx) = weight;
126         }
127       }
128     }
129   }
130 
131   return ConvArray4D(input, weights, kernel_stride, padding);
132 }
133 
WindowCount(int64 unpadded_width,int64 window_len,int64 stride,Padding padding)134 /* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
135                                               int64 window_len, int64 stride,
136                                               Padding padding) {
137   if (padding == Padding::kValid) {
138     return window_util::StridedBound(unpadded_width, window_len, stride);
139   }
140   return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
141 }
142 
143 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)144 ReferenceUtil::ReduceWindow1DGeneric(
145     absl::Span<const float> operand, float init,
146     const std::function<float(float, float)>& reduce_func,
147     absl::Span<const int64> window, absl::Span<const int64> stride,
148     absl::Span<const std::pair<int64, int64>> padding) {
149   CHECK_EQ(window.size(), 1);
150   CHECK_EQ(stride.size(), 1);
151   CHECK_EQ(padding.size(), 1);
152 
153   int64 padded_width = padding[0].first + operand.size() + padding[0].second;
154   int64 stride_amount = stride[0];
155   int64 window_size = window[0];
156   int64 result_size =
157       window_util::StridedBound(padded_width, window_size, stride_amount);
158   int64 pad_low = padding[0].first;
159   auto result = absl::make_unique<std::vector<float>>(result_size);
160 
161   // Do a full 1D reduce window.
162   for (int64 i0 = 0; i0 < result_size; ++i0) {
163     int64 i0_base = i0 * stride_amount - pad_low;
164     float val = init;
165     for (int64 i0_win = 0; i0_win < window_size; ++i0_win) {
166       if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
167         val = reduce_func(val, operand[i0_base + i0_win]);
168       }
169     }
170     (*result)[i0] = val;
171   }
172   return result;
173 }
174 
175 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)176 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
177                                  absl::Span<const int64> window,
178                                  absl::Span<const int64> stride,
179                                  Padding padding) {
180   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
181   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
182   return ReduceWindow1DGeneric(
183       operand, init, add_reduce, window, stride,
184       xla::MakePadding(dim_lengths, window, stride, padding));
185 }
186 
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)187 /* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
188     const Array3D<float>& operand, float init, absl::Span<const int64> window,
189     absl::Span<const int64> stride, Padding padding) {
190   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
191   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
192 
193   std::vector<int64> window_counts(window.size(), 0);
194   std::vector<int64> pad_low(window.size(), 0);
195   for (int64 i = 0; i < window.size(); ++i) {
196     window_counts[i] =
197         WindowCount(dim_lengths[i], window[i], stride[i], padding);
198     pad_low[i] = padding_both[i].first;
199   }
200   auto result = absl::make_unique<Array3D<float>>(
201       window_counts[0], window_counts[1], window_counts[2]);
202 
203   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
204     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
205       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
206         int64 i0_base = i0 * stride[0] - pad_low[0];
207         int64 i1_base = i1 * stride[1] - pad_low[1];
208         int64 i2_base = i2 * stride[2] - pad_low[2];
209 
210         float val = init;
211         for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
212           for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
213             for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
214               if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
215                   i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
216                   i1_base + i1_win < operand.n2() &&
217                   i2_base + i2_win < operand.n3()) {
218                 val += operand(i0_base + i0_win, i1_base + i1_win,
219                                i2_base + i2_win);
220               }
221             }
222           }
223         }
224         (*result)(i0, i1, i2) = val;
225       }
226     }
227   }
228   return result;
229 }
230 
231 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)232 ReferenceUtil::ReduceWindow4DGeneric(
233     const Array4D<float>& operand, float init,
234     const std::function<float(float, float)>& reduce_func,
235     absl::Span<const int64> window, absl::Span<const int64> stride,
236     Padding padding) {
237   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
238                                  operand.n4()};
239   return ReduceWindow4DGeneric(
240       operand, init, reduce_func, window, stride,
241       xla::MakePadding(dim_lengths, window, stride, padding));
242 }
243 
244 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)245 ReferenceUtil::ReduceWindow4DGeneric(
246     const Array4D<float>& operand, float init,
247     const std::function<float(float, float)>& reduce_func,
248     absl::Span<const int64> window, absl::Span<const int64> stride,
249     absl::Span<const std::pair<int64, int64>> padding) {
250   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
251                                  operand.n4()};
252 
253   std::vector<int64> window_counts(window.size(), 0);
254   std::vector<int64> pad_low(window.size(), 0);
255   for (int64 i = 0; i < window.size(); ++i) {
256     int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
257     window_counts[i] =
258         window_util::StridedBound(padded_width, window[i], stride[i]);
259     pad_low[i] = padding[i].first;
260   }
261   auto result = absl::make_unique<Array4D<float>>(
262       window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
263   // Do a full 4D reduce window.
264   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
265     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
266       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
267         for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
268           int64 i0_base = i0 * stride[0] - pad_low[0];
269           int64 i1_base = i1 * stride[1] - pad_low[1];
270           int64 i2_base = i2 * stride[2] - pad_low[2];
271           int64 i3_base = i3 * stride[3] - pad_low[3];
272 
273           float val = init;
274           for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
275             for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
276               for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
277                 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
278                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
279                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
280                       i0_base + i0_win < operand.n1() &&
281                       i1_base + i1_win < operand.n2() &&
282                       i2_base + i2_win < operand.n3() &&
283                       i3_base + i3_win < operand.n4()) {
284                     val = reduce_func(
285                         val, operand(i0_base + i0_win, i1_base + i1_win,
286                                      i2_base + i2_win, i3_base + i3_win));
287                   }
288                 }
289               }
290             }
291           }
292           (*result)(i0, i1, i2, i3) = val;
293         }
294       }
295     }
296   }
297   return result;
298 }
299 
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)300 /* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
301     const Array4D<float>& operand, float init, absl::Span<const int64> window,
302     absl::Span<const int64> stride, Padding padding) {
303   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
304   return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
305                                padding);
306 }
307 
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)308 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
309     const Array4D<float>& input, const Array4D<float>& mean,
310     const Array4D<float>& var, const Array4D<float>& scale,
311     const Array4D<float>& offset, float epsilon) {
312   auto normalized =
313       *MapArray4D(input, mean, [](float a, float b) { return a - b; });
314   normalized = *MapArray4D(normalized, var, [&](float a, float b) {
315     return a / std::sqrt(b + epsilon);
316   });
317   normalized =
318       *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
319   return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
320 }
321 
322 /* static  */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64> window,absl::Span<const int64> stride,bool same_padding)323 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
324                                         const Array4D<float>& source,
325                                         float init,
326                                         absl::Span<const int64> window,
327                                         absl::Span<const int64> stride,
328                                         bool same_padding) {
329   Padding padding = same_padding ? Padding::kSame : Padding::kValid;
330   auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
331                                                   operand.n3(), operand.n4());
332   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
333                                  operand.n4()};
334   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
335   // Fill the output, with the initial value.
336   result->Fill(init);
337 
338   std::vector<int64> window_counts(window.size(), 0);
339   std::vector<int64> pad_low(window.size(), 0);
340   for (int64 i = 0; i < window.size(); ++i) {
341     window_counts[i] =
342         WindowCount(dim_lengths[i], window[i], stride[i], padding);
343     pad_low[i] = padding_both[i].first;
344   }
345   CHECK_EQ(window_counts[0], source.n1());
346   CHECK_EQ(window_counts[1], source.n2());
347   CHECK_EQ(window_counts[2], source.n3());
348   CHECK_EQ(window_counts[3], source.n4());
349 
350   // Do a full 4D select and Scatter.
351   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
352     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
353       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
354         for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
355           // Now we are inside a window and need to find the max and the argmax.
356           int64 i0_base = i0 * stride[0] - pad_low[0];
357           int64 i1_base = i1 * stride[1] - pad_low[1];
358           int64 i2_base = i2 * stride[2] - pad_low[2];
359           int64 i3_base = i3 * stride[3] - pad_low[3];
360           int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
361           int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
362           int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
363           int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
364           float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
365           for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
366             for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
367               for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
368                 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
369                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
370                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
371                       i0_base + i0_win < operand.n1() &&
372                       i1_base + i1_win < operand.n2() &&
373                       i2_base + i2_win < operand.n3() &&
374                       i3_base + i3_win < operand.n4()) {
375                     float tmp = operand(i0_base + i0_win, i1_base + i1_win,
376                                         i2_base + i2_win, i3_base + i3_win);
377                     if (tmp > val) {
378                       val = tmp;
379                       scatter_0 = i0_base + i0_win;
380                       scatter_1 = i1_base + i1_win;
381                       scatter_2 = i2_base + i2_win;
382                       scatter_3 = i3_base + i3_win;
383                     }
384                   }
385                 }
386               }
387             }
388           }
389           (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
390               source(i0, i1, i2, i3);
391         }
392       }
393     }
394   }
395   return result;
396 }
397 
398 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)399 ReferenceUtil::ConvArray4DGeneralDimensions(
400     const Array4D<float>& lhs, const Array4D<float>& rhs,
401     std::pair<int64, int64> kernel_stride, Padding padding,
402     ConvolutionDimensionNumbers dimension_numbers) {
403   return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
404                                              {1, 1}, {1, 1},
405                                              std::move(dimension_numbers));
406 }
407 
408 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,std::pair<int64,int64> lhs_dilation,std::pair<int64,int64> rhs_dilation,ConvolutionDimensionNumbers dnums)409 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
410     const Array4D<float>& lhs, const Array4D<float>& rhs,
411     std::pair<int64, int64> kernel_stride, Padding padding,
412     std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
413     ConvolutionDimensionNumbers dnums) {
414   HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
415   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
416   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
417 
418   std::array<int64, 2> ordered_kernel_strides;
419   std::array<int64, 2> ordered_input_dimensions;
420   std::array<int64, 2> ordered_kernel_dimensions;
421   if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
422     ordered_kernel_strides[0] = kernel_stride.second;
423     ordered_kernel_strides[1] = kernel_stride.first;
424   } else {
425     ordered_kernel_strides[0] = kernel_stride.first;
426     ordered_kernel_strides[1] = kernel_stride.second;
427   }
428 
429   ordered_input_dimensions[0] =
430       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
431   ordered_input_dimensions[1] =
432       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
433   ordered_kernel_dimensions[0] =
434       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
435   ordered_kernel_dimensions[1] =
436       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
437 
438   std::vector<std::pair<int64, int64>> paddings =
439       MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
440                   ordered_kernel_strides, padding);
441   CHECK_EQ(paddings.size(), 2);
442 
443   Window window;
444 
445   WindowDimension dim;
446   dim.set_size(
447       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
448   dim.set_stride(kernel_stride.first);
449   dim.set_padding_low(paddings[0].first);
450   dim.set_padding_high(paddings[0].second);
451   dim.set_window_dilation(rhs_dilation.first);
452   dim.set_base_dilation(lhs_dilation.first);
453   *window.add_dimensions() = dim;
454 
455   WindowDimension dim2;
456   dim2.set_size(
457       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
458   dim2.set_stride(kernel_stride.second);
459   dim2.set_padding_low(paddings[1].first);
460   dim2.set_padding_high(paddings[1].second);
461   dim2.set_window_dilation(rhs_dilation.second);
462   dim2.set_base_dilation(lhs_dilation.second);
463   *window.add_dimensions() = dim2;
464 
465   const Shape& shape =
466       ShapeInference::InferConvolveShape(
467           lhs_literal.shape(), rhs_literal.shape(),
468           /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
469           /*preferred_element_type=*/absl::nullopt)
470           .ConsumeValueOrDie();
471 
472   HloInstruction* lhs_instruction =
473       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
474   HloInstruction* rhs_instruction =
475       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
476 
477   PrecisionConfig precision_config;
478   precision_config.mutable_operand_precision()->Resize(
479       /*new_size=*/2, PrecisionConfig::DEFAULT);
480   b.AddInstruction(HloInstruction::CreateConvolve(
481       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
482       /*batch_group_count=*/1, window, dnums, precision_config));
483   HloModuleConfig config;
484   HloModule module("ReferenceUtil", config);
485   auto computation = module.AddEntryComputation(b.Build());
486 
487   HloEvaluator evaluator;
488   Literal result_literal =
489       evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
490 
491   CHECK_EQ(result_literal.shape().rank(), 4);
492   auto result =
493       absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
494                                         result_literal.shape().dimensions(1),
495                                         result_literal.shape().dimensions(2),
496                                         result_literal.shape().dimensions(3));
497 
498   result->Each([&](absl::Span<const int64> indices, float* value) {
499     *value = result_literal.Get<float>(indices);
500   });
501 
502   return result;
503 }
504 
505 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)506 ReferenceUtil::ReduceToColArray2D(
507     const Array2D<float>& matrix, float init,
508     const std::function<float(float, float)>& reduce_function) {
509   int64 rows = matrix.height();
510   int64 cols = matrix.width();
511   auto result = absl::make_unique<std::vector<float>>();
512   for (int64 i = 0; i < rows; ++i) {
513     float acc = init;
514     for (int64 j = 0; j < cols; ++j) {
515       acc = reduce_function(acc, matrix(i, j));
516     }
517     result->push_back(acc);
518   }
519   return result;
520 }
521 
522 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)523 ReferenceUtil::ReduceToRowArray2D(
524     const Array2D<float>& matrix, float init,
525     const std::function<float(float, float)>& reduce_function) {
526   int64 rows = matrix.height();
527   int64 cols = matrix.width();
528   auto result = absl::make_unique<std::vector<float>>();
529   for (int64 i = 0; i < cols; ++i) {
530     float acc = init;
531     for (int64 j = 0; j < rows; ++j) {
532       acc = reduce_function(acc, matrix(j, i));
533     }
534     result->push_back(acc);
535   }
536   return result;
537 }
538 
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)539 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
540     const Array4D<float>& array, float init, absl::Span<const int64> dims,
541     const std::function<float(float, float)>& reduce_function) {
542   std::vector<float> result;
543   CHECK_EQ(dims.size(), 3);
544   const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
545   CHECK_EQ(dim_set.size(), 3);
546   for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
547        ++a0) {
548     for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
549          ++a1) {
550       for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
551            ++a2) {
552         for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
553              ++a3) {
554           float accumulator = init;
555           for (int64 i0 = 0;
556                i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
557             for (int64 i1 = 0;
558                  i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
559               for (int64 i2 = 0;
560                    i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
561                 for (int64 i3 = 0;
562                      i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
563                      ++i3) {
564                   // Handle zero-sized arrays.
565                   if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
566                       array.n4() > 0) {
567                     accumulator = reduce_function(
568                         accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
569                   }
570                 }
571               }
572             }
573           }
574           result.push_back(accumulator);
575         }
576       }
577     }
578   }
579   return result;
580 }
581 
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64> & bounds,int64 broadcast_from_dim)582 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
583     const std::vector<float>& array, const std::vector<int64>& bounds,
584     int64 broadcast_from_dim) {
585   auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
586                                                   bounds[2], bounds[3]);
587   for (int64 i = 0; i < result->n1(); ++i) {
588     for (int64 j = 0; j < result->n2(); ++j) {
589       for (int64 k = 0; k < result->n3(); ++k) {
590         for (int64 l = 0; l < result->n4(); ++l) {
591           switch (broadcast_from_dim) {
592             case 0:
593               (*result)(i, j, k, l) = array[i];
594               break;
595             case 1:
596               (*result)(i, j, k, l) = array[j];
597               break;
598             case 2:
599               (*result)(i, j, k, l) = array[k];
600               break;
601             case 3:
602               (*result)(i, j, k, l) = array[l];
603               break;
604             default:
605               break;
606           }
607         }
608       }
609     }
610   }
611   return result;
612 }
613 
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)614 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
615     const Array3D<float>& array, float init, absl::Span<const int64> dims,
616     const std::function<float(float, float)>& reduce_function) {
617   CHECK_EQ(dims.size(), 1);
618   int64 rows = dims[0] == 0 ? array.n2() : array.n1();
619   int64 cols = dims[0] == 2 ? array.n2() : array.n3();
620   auto result = absl::make_unique<Array2D<float>>(rows, cols);
621   result->Fill(init);
622   for (int i0 = 0; i0 < array.n1(); ++i0) {
623     for (int i1 = 0; i1 < array.n2(); ++i1) {
624       for (int i2 = 0; i2 < array.n3(); ++i2) {
625         int64 row = dims[0] == 0 ? i1 : i0;
626         int64 col = dims[0] == 2 ? i1 : i2;
627         (*result)(row, col) =
628             reduce_function((*result)(row, col), array(i0, i1, i2));
629       }
630     }
631   }
632   return result;
633 }
634 
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)635 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
636     const Array2D<float>& matrix,
637     const std::function<float(float)>& map_function) {
638   int64 rows = matrix.height();
639   int64 cols = matrix.width();
640   auto result = absl::make_unique<Array2D<float>>(rows, cols);
641   for (int64 i = 0; i < rows; ++i) {
642     for (int64 j = 0; j < cols; ++j) {
643       (*result)(i, j) = map_function(matrix(i, j));
644     }
645   }
646   return result;
647 }
648 
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)649 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
650     const Array2D<float>& lhs, const Array2D<float>& rhs,
651     const std::function<float(float, float)>& map_function) {
652   CHECK_EQ(lhs.height(), rhs.height());
653   CHECK_EQ(lhs.width(), rhs.width());
654   int64 rows = lhs.height();
655   int64 cols = rhs.width();
656   auto result = absl::make_unique<Array2D<float>>(rows, cols);
657   for (int64 i = 0; i < rows; ++i) {
658     for (int64 j = 0; j < cols; ++j) {
659       (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
660     }
661   }
662   return result;
663 }
664 
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64,int64)> & map_function)665 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
666     const Array2D<float>& matrix,
667     const std::function<float(float, int64, int64)>& map_function) {
668   int64 rows = matrix.height();
669   int64 cols = matrix.width();
670   auto result = absl::make_unique<Array2D<float>>(rows, cols);
671   for (int64 i = 0; i < rows; ++i) {
672     for (int64 j = 0; j < cols; ++j) {
673       (*result)(i, j) = map_function(matrix(i, j), i, j);
674     }
675   }
676   return result;
677 }
678 
679 }  // namespace xla
680