1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/reference_util.h"
17 
18 #include <array>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)35 /* static */ std::unique_ptr<Array2D<Eigen::half>> ReferenceUtil::MatmulArray2D(
36     const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
37   return HloEvaluator::MatmulArray2D(lhs, rhs);
38 }
39 
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)40 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
41     const Array2D<float>& lhs, const Array2D<float>& rhs) {
42   return HloEvaluator::MatmulArray2D(lhs, rhs);
43 }
44 
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)45 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
46     const Array2D<double>& lhs, const Array2D<double>& rhs) {
47   return HloEvaluator::MatmulArray2D(lhs, rhs);
48 }
49 
Array2DF32ToF64(const Array2D<float> & input)50 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
51     const Array2D<float>& input) {
52   auto result =
53       absl::make_unique<Array2D<double>>(input.height(), input.width());
54   for (int64 rowno = 0; rowno < input.height(); ++rowno) {
55     for (int64 colno = 0; colno < input.height(); ++colno) {
56       (*result)(rowno, colno) = input(rowno, colno);
57     }
58   }
59   return result;
60 }
61 
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding)62 /*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
63     const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
64     Padding padding) {
65   return ConvArray3DGeneralDimensionsDilated(
66       lhs, rhs, kernel_stride, padding, 1, 1,
67       XlaBuilder::CreateDefaultConvDimensionNumbers(1));
68 }
69 
70 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding,int64 lhs_dilation,int64 rhs_dilation,const ConvolutionDimensionNumbers & dnums)71 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
72     const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
73     Padding padding, int64 lhs_dilation, int64 rhs_dilation,
74     const ConvolutionDimensionNumbers& dnums) {
75   CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
76   CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
77   CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
78   // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
79   // array by adding a fourth dummy dimension of size 1 without stride, padding
80   // and dilation.
81   Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
82   a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
83     CHECK_EQ(indices[3], 0);
84     *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
85   });
86   Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
87   a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
88     CHECK_EQ(indices[3], 0);
89     *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
90   });
91   // Add a second dummy spatial dimensions.
92   ConvolutionDimensionNumbers dnums2d = dnums;
93   dnums2d.add_input_spatial_dimensions(3);
94   dnums2d.add_kernel_spatial_dimensions(3);
95   dnums2d.add_output_spatial_dimensions(3);
96   std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
97       a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
98       {rhs_dilation, 1}, dnums2d);
99 
100   auto convr3 = absl::make_unique<Array3D<float>>(
101       convr4->planes(), convr4->depth(), convr4->height());
102   convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
103     CHECK_EQ(indices[3], 0);
104     convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
105   });
106   return convr3;
107 }
108 
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding)109 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
110     const Array4D<float>& lhs, const Array4D<float>& rhs,
111     std::pair<int64, int64> kernel_stride, Padding padding) {
112   return ConvArray4DGeneralDimensions(
113       lhs, rhs, kernel_stride, padding,
114       XlaBuilder::CreateDefaultConvDimensionNumbers());
115 }
116 
117 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64,int64> kernel_stride,Padding padding)118 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
119                                     const Array4D<float>& depthwise_weights,
120                                     const Array4D<float>& pointwise_weights,
121                                     std::pair<int64, int64> kernel_stride,
122                                     Padding padding) {
123   const int64 depth_multiplier = depthwise_weights.planes();
124   CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
125 
126   // Combine the two weights by reducing the depth_multiplier, so that we can
127   // apply a single convolution on the combined weights.
128   Array4D<float> weights(pointwise_weights.planes(), input.depth(),
129                          depthwise_weights.height(), depthwise_weights.width());
130   for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
131     for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
132       for (int64 kz = 0; kz < input.depth(); ++kz) {
133         for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
134           float weight = 0.0;
135           for (int64 depth = 0; depth < depth_multiplier; ++depth) {
136             weight +=
137                 depthwise_weights(depth, kz, ky, kx) *
138                 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
139           }
140           weights(out, kz, ky, kx) = weight;
141         }
142       }
143     }
144   }
145 
146   return ConvArray4D(input, weights, kernel_stride, padding);
147 }
148 
WindowCount(int64 unpadded_width,int64 window_len,int64 stride,Padding padding)149 /* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
150                                               int64 window_len, int64 stride,
151                                               Padding padding) {
152   if (padding == Padding::kValid) {
153     return window_util::StridedBound(unpadded_width, window_len, stride);
154   }
155   return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
156 }
157 
158 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)159 ReferenceUtil::ReduceWindow1DGeneric(
160     absl::Span<const float> operand, float init,
161     const std::function<float(float, float)>& reduce_func,
162     absl::Span<const int64> window, absl::Span<const int64> stride,
163     absl::Span<const std::pair<int64, int64>> padding) {
164   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
165   std::vector<int64> window_counts(window.size(), 0);
166   std::vector<int64> pad_low(window.size(), 0);
167   for (int64 i = 0; i < window.size(); ++i) {
168     int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
169     window_counts[i] =
170         window_util::StridedBound(padded_width, window[i], stride[i]);
171     pad_low[i] = padding[i].first;
172   }
173   auto result = absl::make_unique<std::vector<float>>(window_counts[0]);
174 
175   // Do a full 1D reduce window.
176   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
177     int64 i0_base = i0 * stride[0] - pad_low[0];
178 
179     float val = init;
180     for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
181       if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) {
182         val = reduce_func(val, operand[i0_base + i0_win]);
183       }
184     }
185     (*result)[i0] = val;
186   }
187   return result;
188 }
189 
190 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)191 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
192                                  absl::Span<const int64> window,
193                                  absl::Span<const int64> stride,
194                                  Padding padding) {
195   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
196   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
197   return ReduceWindow1DGeneric(
198       operand, init, add_reduce, window, stride,
199       xla::MakePadding(dim_lengths, window, stride, padding));
200 }
201 
202 /* static */ std::unique_ptr<Array2D<float>>
ReduceWindow2DGeneric(const Array2D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)203 ReferenceUtil::ReduceWindow2DGeneric(
204     const Array2D<float>& operand, float init,
205     const std::function<float(float, float)>& reduce_func,
206     absl::Span<const int64> window, absl::Span<const int64> stride,
207     absl::Span<const std::pair<int64, int64>> padding) {
208   std::vector<int64> dim_lengths{operand.height(), operand.width()};
209 
210   std::vector<int64> window_counts(window.size(), 0);
211   std::vector<int64> pad_low(window.size(), 0);
212   for (int64 i = 0; i < window.size(); ++i) {
213     int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
214     window_counts[i] =
215         window_util::StridedBound(padded_width, window[i], stride[i]);
216     pad_low[i] = padding[i].first;
217   }
218   auto result =
219       absl::make_unique<Array2D<float>>(window_counts[0], window_counts[1]);
220 
221   // Do a full 2D reduce window.
222   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
223     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
224       int64 i0_base = i0 * stride[0] - pad_low[0];
225       int64 i1_base = i1 * stride[1] - pad_low[1];
226 
227       float val = init;
228       for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
229         for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
230           if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
231               i0_base + i0_win < operand.n1() &&
232               i1_base + i1_win < operand.n2()) {
233             val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win));
234           }
235         }
236       }
237       (*result)(i0, i1) = val;
238     }
239   }
240   return result;
241 }
242 
ReduceWindow2DAdd(const Array2D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)243 /* static  */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
244     const Array2D<float>& operand, float init, absl::Span<const int64> window,
245     absl::Span<const int64> stride, Padding padding) {
246   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
247   std::vector<int64> dim_lengths{operand.height(), operand.width()};
248   return ReduceWindow2DGeneric(
249       operand, init, add_reduce, window, stride,
250       xla::MakePadding(dim_lengths, window, stride, padding));
251 }
252 
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)253 /* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
254     const Array3D<float>& operand, float init, absl::Span<const int64> window,
255     absl::Span<const int64> stride, Padding padding) {
256   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
257   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
258 
259   std::vector<int64> window_counts(window.size(), 0);
260   std::vector<int64> pad_low(window.size(), 0);
261   for (int64 i = 0; i < window.size(); ++i) {
262     window_counts[i] =
263         WindowCount(dim_lengths[i], window[i], stride[i], padding);
264     pad_low[i] = padding_both[i].first;
265   }
266   auto result = absl::make_unique<Array3D<float>>(
267       window_counts[0], window_counts[1], window_counts[2]);
268 
269   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
270     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
271       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
272         int64 i0_base = i0 * stride[0] - pad_low[0];
273         int64 i1_base = i1 * stride[1] - pad_low[1];
274         int64 i2_base = i2 * stride[2] - pad_low[2];
275 
276         float val = init;
277         for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
278           for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
279             for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
280               if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
281                   i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
282                   i1_base + i1_win < operand.n2() &&
283                   i2_base + i2_win < operand.n3()) {
284                 val += operand(i0_base + i0_win, i1_base + i1_win,
285                                i2_base + i2_win);
286               }
287             }
288           }
289         }
290         (*result)(i0, i1, i2) = val;
291       }
292     }
293   }
294   return result;
295 }
296 
297 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)298 ReferenceUtil::ReduceWindow4DGeneric(
299     const Array4D<float>& operand, float init,
300     const std::function<float(float, float)>& reduce_func,
301     absl::Span<const int64> window, absl::Span<const int64> stride,
302     Padding padding) {
303   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
304                                  operand.n4()};
305   return ReduceWindow4DGeneric(
306       operand, init, reduce_func, window, stride,
307       xla::MakePadding(dim_lengths, window, stride, padding));
308 }
309 
310 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)311 ReferenceUtil::ReduceWindow4DGeneric(
312     const Array4D<float>& operand, float init,
313     const std::function<float(float, float)>& reduce_func,
314     absl::Span<const int64> window, absl::Span<const int64> stride,
315     absl::Span<const std::pair<int64, int64>> padding) {
316   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
317                                  operand.n4()};
318 
319   std::vector<int64> window_counts(window.size(), 0);
320   std::vector<int64> pad_low(window.size(), 0);
321   for (int64 i = 0; i < window.size(); ++i) {
322     int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
323     window_counts[i] =
324         window_util::StridedBound(padded_width, window[i], stride[i]);
325     pad_low[i] = padding[i].first;
326   }
327   auto result = absl::make_unique<Array4D<float>>(
328       window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
329   // Do a full 4D reduce window.
330   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
331     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
332       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
333         for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
334           int64 i0_base = i0 * stride[0] - pad_low[0];
335           int64 i1_base = i1 * stride[1] - pad_low[1];
336           int64 i2_base = i2 * stride[2] - pad_low[2];
337           int64 i3_base = i3 * stride[3] - pad_low[3];
338 
339           float val = init;
340           for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
341             for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
342               for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
343                 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
344                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
345                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
346                       i0_base + i0_win < operand.n1() &&
347                       i1_base + i1_win < operand.n2() &&
348                       i2_base + i2_win < operand.n3() &&
349                       i3_base + i3_win < operand.n4()) {
350                     val = reduce_func(
351                         val, operand(i0_base + i0_win, i1_base + i1_win,
352                                      i2_base + i2_win, i3_base + i3_win));
353                   }
354                 }
355               }
356             }
357           }
358           (*result)(i0, i1, i2, i3) = val;
359         }
360       }
361     }
362   }
363   return result;
364 }
365 
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)366 /* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
367     const Array4D<float>& operand, float init, absl::Span<const int64> window,
368     absl::Span<const int64> stride, Padding padding) {
369   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
370   return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
371                                padding);
372 }
373 
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)374 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
375     const Array4D<float>& input, const Array4D<float>& mean,
376     const Array4D<float>& var, const Array4D<float>& scale,
377     const Array4D<float>& offset, float epsilon) {
378   auto normalized =
379       *MapArray4D(input, mean, [](float a, float b) { return a - b; });
380   normalized = *MapArray4D(normalized, var, [&](float a, float b) {
381     return a / std::sqrt(b + epsilon);
382   });
383   normalized =
384       *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
385   return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
386 }
387 
388 /* static  */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64> window,absl::Span<const int64> stride,bool same_padding)389 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
390                                         const Array4D<float>& source,
391                                         float init,
392                                         absl::Span<const int64> window,
393                                         absl::Span<const int64> stride,
394                                         bool same_padding) {
395   Padding padding = same_padding ? Padding::kSame : Padding::kValid;
396   auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
397                                                   operand.n3(), operand.n4());
398   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
399                                  operand.n4()};
400   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
401   // Fill the output, with the initial value.
402   result->Fill(init);
403 
404   std::vector<int64> window_counts(window.size(), 0);
405   std::vector<int64> pad_low(window.size(), 0);
406   for (int64 i = 0; i < window.size(); ++i) {
407     window_counts[i] =
408         WindowCount(dim_lengths[i], window[i], stride[i], padding);
409     pad_low[i] = padding_both[i].first;
410   }
411   CHECK_EQ(window_counts[0], source.n1());
412   CHECK_EQ(window_counts[1], source.n2());
413   CHECK_EQ(window_counts[2], source.n3());
414   CHECK_EQ(window_counts[3], source.n4());
415 
416   // Do a full 4D select and Scatter.
417   for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
418     for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
419       for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
420         for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
421           // Now we are inside a window and need to find the max and the argmax.
422           int64 i0_base = i0 * stride[0] - pad_low[0];
423           int64 i1_base = i1 * stride[1] - pad_low[1];
424           int64 i2_base = i2 * stride[2] - pad_low[2];
425           int64 i3_base = i3 * stride[3] - pad_low[3];
426           int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
427           int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
428           int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
429           int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
430           float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
431           for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
432             for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
433               for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
434                 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
435                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
436                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
437                       i0_base + i0_win < operand.n1() &&
438                       i1_base + i1_win < operand.n2() &&
439                       i2_base + i2_win < operand.n3() &&
440                       i3_base + i3_win < operand.n4()) {
441                     float tmp = operand(i0_base + i0_win, i1_base + i1_win,
442                                         i2_base + i2_win, i3_base + i3_win);
443                     if (tmp > val) {
444                       val = tmp;
445                       scatter_0 = i0_base + i0_win;
446                       scatter_1 = i1_base + i1_win;
447                       scatter_2 = i2_base + i2_win;
448                       scatter_3 = i3_base + i3_win;
449                     }
450                   }
451                 }
452               }
453             }
454           }
455           (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
456               source(i0, i1, i2, i3);
457         }
458       }
459     }
460   }
461   return result;
462 }
463 
464 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)465 ReferenceUtil::ConvArray4DGeneralDimensions(
466     const Array4D<float>& lhs, const Array4D<float>& rhs,
467     std::pair<int64, int64> kernel_stride, Padding padding,
468     ConvolutionDimensionNumbers dimension_numbers) {
469   return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
470                                              {1, 1}, {1, 1},
471                                              std::move(dimension_numbers));
472 }
473 
474 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,std::pair<int64,int64> lhs_dilation,std::pair<int64,int64> rhs_dilation,ConvolutionDimensionNumbers dnums)475 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
476     const Array4D<float>& lhs, const Array4D<float>& rhs,
477     std::pair<int64, int64> kernel_stride, Padding padding,
478     std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
479     ConvolutionDimensionNumbers dnums) {
480   HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
481   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
482   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
483 
484   std::array<int64, 2> ordered_kernel_strides;
485   std::array<int64, 2> ordered_input_dimensions;
486   std::array<int64, 2> ordered_kernel_dimensions;
487   if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
488     ordered_kernel_strides[0] = kernel_stride.second;
489     ordered_kernel_strides[1] = kernel_stride.first;
490   } else {
491     ordered_kernel_strides[0] = kernel_stride.first;
492     ordered_kernel_strides[1] = kernel_stride.second;
493   }
494 
495   ordered_input_dimensions[0] =
496       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
497   ordered_input_dimensions[1] =
498       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
499   ordered_kernel_dimensions[0] =
500       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
501   ordered_kernel_dimensions[1] =
502       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
503 
504   std::vector<std::pair<int64, int64>> paddings =
505       MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
506                   ordered_kernel_strides, padding);
507   CHECK_EQ(paddings.size(), 2);
508 
509   Window window;
510 
511   WindowDimension dim;
512   dim.set_size(
513       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
514   dim.set_stride(kernel_stride.first);
515   dim.set_padding_low(paddings[0].first);
516   dim.set_padding_high(paddings[0].second);
517   dim.set_window_dilation(rhs_dilation.first);
518   dim.set_base_dilation(lhs_dilation.first);
519   *window.add_dimensions() = dim;
520 
521   WindowDimension dim2;
522   dim2.set_size(
523       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
524   dim2.set_stride(kernel_stride.second);
525   dim2.set_padding_low(paddings[1].first);
526   dim2.set_padding_high(paddings[1].second);
527   dim2.set_window_dilation(rhs_dilation.second);
528   dim2.set_base_dilation(lhs_dilation.second);
529   *window.add_dimensions() = dim2;
530 
531   const Shape& shape =
532       ShapeInference::InferConvolveShape(
533           lhs_literal.shape(), rhs_literal.shape(),
534           /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums)
535           .ConsumeValueOrDie();
536 
537   HloInstruction* lhs_instruction =
538       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
539   HloInstruction* rhs_instruction =
540       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
541 
542   PrecisionConfig precision_config;
543   precision_config.mutable_operand_precision()->Resize(
544       /*new_size=*/2, PrecisionConfig::DEFAULT);
545   b.AddInstruction(HloInstruction::CreateConvolve(
546       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
547       /*batch_group_count=*/1, window, dnums, precision_config));
548   HloModuleConfig config;
549   HloModule module("ReferenceUtil", config);
550   auto computation = module.AddEntryComputation(b.Build());
551 
552   HloEvaluator evaluator;
553   Literal result_literal =
554       evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
555 
556   CHECK_EQ(result_literal.shape().rank(), 4);
557   auto result =
558       absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
559                                         result_literal.shape().dimensions(1),
560                                         result_literal.shape().dimensions(2),
561                                         result_literal.shape().dimensions(3));
562 
563   result->Each([&](absl::Span<const int64> indices, float* value) {
564     *value = result_literal.Get<float>(indices);
565   });
566 
567   return result;
568 }
569 
570 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)571 ReferenceUtil::ReduceToColArray2D(
572     const Array2D<float>& matrix, float init,
573     const std::function<float(float, float)>& reduce_function) {
574   int64 rows = matrix.height();
575   int64 cols = matrix.width();
576   auto result = absl::make_unique<std::vector<float>>();
577   for (int64 i = 0; i < rows; ++i) {
578     float acc = init;
579     for (int64 j = 0; j < cols; ++j) {
580       acc = reduce_function(acc, matrix(i, j));
581     }
582     result->push_back(acc);
583   }
584   return result;
585 }
586 
587 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)588 ReferenceUtil::ReduceToRowArray2D(
589     const Array2D<float>& matrix, float init,
590     const std::function<float(float, float)>& reduce_function) {
591   int64 rows = matrix.height();
592   int64 cols = matrix.width();
593   auto result = absl::make_unique<std::vector<float>>();
594   for (int64 i = 0; i < cols; ++i) {
595     float acc = init;
596     for (int64 j = 0; j < rows; ++j) {
597       acc = reduce_function(acc, matrix(j, i));
598     }
599     result->push_back(acc);
600   }
601   return result;
602 }
603 
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)604 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
605     const Array4D<float>& array, float init, absl::Span<const int64> dims,
606     const std::function<float(float, float)>& reduce_function) {
607   std::vector<float> result;
608   CHECK_EQ(dims.size(), 3);
609   const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
610   CHECK_EQ(dim_set.size(), 3);
611   for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
612        ++a0) {
613     for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
614          ++a1) {
615       for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
616            ++a2) {
617         for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
618              ++a3) {
619           float accumulator = init;
620           for (int64 i0 = 0;
621                i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
622             for (int64 i1 = 0;
623                  i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
624               for (int64 i2 = 0;
625                    i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
626                 for (int64 i3 = 0;
627                      i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
628                      ++i3) {
629                   // Handle zero-sized arrays.
630                   if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
631                       array.n4() > 0) {
632                     accumulator = reduce_function(
633                         accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
634                   }
635                 }
636               }
637             }
638           }
639           result.push_back(accumulator);
640         }
641       }
642     }
643   }
644   return result;
645 }
646 
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64> & bounds,int64 broadcast_from_dim)647 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
648     const std::vector<float>& array, const std::vector<int64>& bounds,
649     int64 broadcast_from_dim) {
650   auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
651                                                   bounds[2], bounds[3]);
652   for (int64 i = 0; i < result->n1(); ++i) {
653     for (int64 j = 0; j < result->n2(); ++j) {
654       for (int64 k = 0; k < result->n3(); ++k) {
655         for (int64 l = 0; l < result->n4(); ++l) {
656           switch (broadcast_from_dim) {
657             case 0:
658               (*result)(i, j, k, l) = array[i];
659               break;
660             case 1:
661               (*result)(i, j, k, l) = array[j];
662               break;
663             case 2:
664               (*result)(i, j, k, l) = array[k];
665               break;
666             case 3:
667               (*result)(i, j, k, l) = array[l];
668               break;
669             default:
670               break;
671           }
672         }
673       }
674     }
675   }
676   return result;
677 }
678 
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)679 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
680     const Array3D<float>& array, float init, absl::Span<const int64> dims,
681     const std::function<float(float, float)>& reduce_function) {
682   CHECK_EQ(dims.size(), 1);
683   int64 rows = dims[0] == 0 ? array.n2() : array.n1();
684   int64 cols = dims[0] == 2 ? array.n2() : array.n3();
685   auto result = absl::make_unique<Array2D<float>>(rows, cols);
686   result->Fill(init);
687   for (int i0 = 0; i0 < array.n1(); ++i0) {
688     for (int i1 = 0; i1 < array.n2(); ++i1) {
689       for (int i2 = 0; i2 < array.n3(); ++i2) {
690         int64 row = dims[0] == 0 ? i1 : i0;
691         int64 col = dims[0] == 2 ? i1 : i2;
692         (*result)(row, col) =
693             reduce_function((*result)(row, col), array(i0, i1, i2));
694       }
695     }
696   }
697   return result;
698 }
699 
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)700 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
701     const Array2D<float>& matrix,
702     const std::function<float(float)>& map_function) {
703   int64 rows = matrix.height();
704   int64 cols = matrix.width();
705   auto result = absl::make_unique<Array2D<float>>(rows, cols);
706   for (int64 i = 0; i < rows; ++i) {
707     for (int64 j = 0; j < cols; ++j) {
708       (*result)(i, j) = map_function(matrix(i, j));
709     }
710   }
711   return result;
712 }
713 
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)714 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
715     const Array2D<float>& lhs, const Array2D<float>& rhs,
716     const std::function<float(float, float)>& map_function) {
717   CHECK_EQ(lhs.height(), rhs.height());
718   CHECK_EQ(lhs.width(), rhs.width());
719   int64 rows = lhs.height();
720   int64 cols = rhs.width();
721   auto result = absl::make_unique<Array2D<float>>(rows, cols);
722   for (int64 i = 0; i < rows; ++i) {
723     for (int64 j = 0; j < cols; ++j) {
724       (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
725     }
726   }
727   return result;
728 }
729 
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64,int64)> & map_function)730 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
731     const Array2D<float>& matrix,
732     const std::function<float(float, int64, int64)>& map_function) {
733   int64 rows = matrix.height();
734   int64 cols = matrix.width();
735   auto result = absl::make_unique<Array2D<float>>(rows, cols);
736   for (int64 i = 0; i < rows; ++i) {
737     for (int64 j = 0; j < cols; ++j) {
738       (*result)(i, j) = map_function(matrix(i, j), i, j);
739     }
740   }
741   return result;
742 }
743 
744 }  // namespace xla
745