1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
18 
19 #include <array>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/padding.h"
31 #include "tensorflow/compiler/xla/util.h"
32 #include "tensorflow/compiler/xla/xla_data.pb.h"
33 #include "tensorflow/core/platform/macros.h"
34 #include "tensorflow/core/platform/types.h"
35 
36 namespace xla {
37 
38 // Utility class for reference implementations of linear algebra routines.
39 class ReferenceUtil {
40  public:
41   // Returns the result of a transpose operation on the input matrix.
42   template <typename T>
TransposeArray2D(const Array2D<T> & operand)43   static std::unique_ptr<Array2D<T>> TransposeArray2D(
44       const Array2D<T>& operand) {
45     auto result =
46         absl::make_unique<Array2D<T>>(operand.width(), operand.height());
47     for (int64 w = 0; w < operand.width(); ++w) {
48       for (int64 h = 0; h < operand.height(); ++h) {
49         (*result)(w, h) = operand(h, w);
50       }
51     }
52 
53     return result;
54   }
55 
56   // Returns the result of a matrix multiply `lhs x rhs`.
57   static std::unique_ptr<Array2D<Eigen::half>> MatmulArray2D(
58       const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs);
59   static std::unique_ptr<Array2D<float>> MatmulArray2D(
60       const Array2D<float>& lhs, const Array2D<float>& rhs);
61   static std::unique_ptr<Array2D<double>> MatmulArray2D(
62       const Array2D<double>& lhs, const Array2D<double>& rhs);
63 
64   // Converts the input operand to use f64 values instead of f32 values.
65   static std::unique_ptr<Array2D<double>> Array2DF32ToF64(
66       const Array2D<float>& input);
67 
68   // Returns the result of a convolution `lhs <conv> rhs`, with the default
69   // convolution dimension numbers returned from
70   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
71   static std::unique_ptr<Array4D<float>> ConvArray4D(
72       const Array4D<float>& lhs, const Array4D<float>& rhs,
73       std::pair<int64, int64> kernel_stride, Padding padding);
74 
75   // Returns the result of a convolution `lhs <conv> rhs`, with the given
76   // convolution dimension numbers.
77   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensions(
78       const Array4D<float>& lhs, const Array4D<float>& rhs,
79       std::pair<int64, int64> kernel_stride, Padding padding,
80       ConvolutionDimensionNumbers dimension_numbers);
81 
82   // Returns the result of a convolution `lhs <conv> rhs`, with the given
83   // dilation factors.
84   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
85       const Array4D<float>& lhs, const Array4D<float>& rhs,
86       std::pair<int64, int64> kernel_stride, Padding padding,
87       std::pair<int64, int64> lhs_dilation,
88       std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums);
89 
90   // Returns the result of a convolution `lhs <conv> rhs`, with the default
91   // convolution dimension numbers returned from
92   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
93   static std::unique_ptr<Array3D<float>> ConvArray3D(const Array3D<float>& lhs,
94                                                      const Array3D<float>& rhs,
95                                                      int64 kernel_stride,
96                                                      Padding padding);
97 
98   // Returns the result of a convolution `lhs <conv> rhs`.
99   static std::unique_ptr<Array3D<float>> ConvArray3DGeneralDimensionsDilated(
100       const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
101       Padding padding, int64 lhs_dilation, int64 rhs_dilation,
102       const ConvolutionDimensionNumbers& dnums);
103 
104   // Returns the result of a separable  convolution with the given parameters.
105   // kernel_stride and padding applies to the depthwise convolution during
106   // the separable convolution. pointwise_weights.depth() must be equal to
107   // input.depth() * depthwise_weights.planes().
108   static std::unique_ptr<Array4D<float>> SeparableConvArray4D(
109       const Array4D<float>& input, const Array4D<float>& depthwise_weights,
110       const Array4D<float>& pointwise_weights,
111       std::pair<int64, int64> kernel_stride, Padding padding);
112 
113   // Returns the result of reducing a matrix to a column vector. init is the
114   // initial value for the reduce operation, and reduce_function is the function
115   // to apply for each reduction step.
116   static std::unique_ptr<std::vector<float>> ReduceToColArray2D(
117       const Array2D<float>& matrix, float init,
118       const std::function<float(float, float)>& reduce_function);
119 
120   // Returns the result of reducing a matrix to a row vector. init is the
121   // initial value for the reduce operation, and reduce_function is the function
122   // to apply for each reduction step.
123   static std::unique_ptr<std::vector<float>> ReduceToRowArray2D(
124       const Array2D<float>& matrix, float init,
125       const std::function<float(float, float)>& reduce_function);
126 
127   // Performs a R2=>R1 reduction by reducing away the dimension specified in
128   // 'dimension_to_reduce'.
129   template <typename T>
ReduceR2ToR1(const Array2D<T> & input,int dimension_to_reduce,T init,const std::function<T (T,T)> & freduce)130   static std::vector<T> ReduceR2ToR1(const Array2D<T>& input,
131                                      int dimension_to_reduce, T init,
132                                      const std::function<T(T, T)>& freduce) {
133     std::vector<T> result(dimension_to_reduce == 0 ? input.n2() : input.n1(),
134                           init);
135     for (int i0 = 0; i0 < input.n1(); ++i0) {
136       for (int i1 = 0; i1 < input.n2(); ++i1) {
137         int output = dimension_to_reduce == 0 ? i1 : i0;
138         result[output] = freduce(result[output], input(i0, i1));
139       }
140     }
141     return result;
142   }
143 
144   // Returns the result of reducing the 4D array to a vector, reducing away
145   // the dimensions specified in dims.
146   static std::vector<float> Reduce4DTo1D(
147       const Array4D<float>& array, float init, absl::Span<const int64> dims,
148       const std::function<float(float, float)>& reduce_function);
149 
150   // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
151   static std::unique_ptr<Array4D<float>> Broadcast1DTo4D(
152       const std::vector<float>& array, const std::vector<int64>& bounds,
153       int64 broadcast_from_dim);
154 
155   // Returns the result of reducing the 3D array to a 2D array, reducing away
156   // the dimensions specified in dims.
157   static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
158       const Array3D<float>& array, float init, absl::Span<const int64> dims,
159       const std::function<float(float, float)>& reduce_function);
160 
161   // Applies map_function to each element in the input (2D array) and returns
162   // the result.
163   static std::unique_ptr<Array2D<float>> MapArray2D(
164       const Array2D<float>& matrix,
165       const std::function<float(float)>& map_function);
166 
167   // Applies map_function to each pair of corresponding elements in the two
168   // inputs arrays and returns the result.
169   static std::unique_ptr<Array2D<float>> MapArray2D(
170       const Array2D<float>& lhs, const Array2D<float>& rhs,
171       const std::function<float(float, float)>& map_function);
172 
173   // Number of windows in a given dimension. Calculation taken from
174   // xla::MakePadding().
175   static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride,
176                            Padding padding);
177 
178   // Windowed reductions with Add as the function to apply.
179   static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
180       absl::Span<const float> operand, float init,
181       absl::Span<const int64> window, absl::Span<const int64> stride,
182       Padding padding);
183   static std::unique_ptr<Array2D<float>> ReduceWindow2DAdd(
184       const Array2D<float>& operand, float init, absl::Span<const int64> window,
185       absl::Span<const int64> stride, Padding padding);
186   static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
187       const Array3D<float>& operand, float init, absl::Span<const int64> window,
188       absl::Span<const int64> stride, Padding padding);
189   static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
190       const Array4D<float>& operand, float init, absl::Span<const int64> window,
191       absl::Span<const int64> stride, Padding padding);
192 
193   // Windowed reductions with a generic reduce function.
194   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
195       absl::Span<const float> operand, float init,
196       const std::function<float(float, float)>& reduce_func,
197       absl::Span<const int64> window, absl::Span<const int64> stride,
198       absl::Span<const std::pair<int64, int64>> padding);
199   static std::unique_ptr<Array2D<float>> ReduceWindow2DGeneric(
200       const Array2D<float>& operand, float init,
201       const std::function<float(float, float)>& reduce_func,
202       absl::Span<const int64> window, absl::Span<const int64> stride,
203       absl::Span<const std::pair<int64, int64>> padding);
204   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
205       const Array4D<float>& operand, float init,
206       const std::function<float(float, float)>& reduce_func,
207       absl::Span<const int64> window, absl::Span<const int64> stride,
208       Padding padding);
209   // With arbitrary padding.
210   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
211       const Array4D<float>& operand, float init,
212       const std::function<float(float, float)>& reduce_func,
213       absl::Span<const int64> window, absl::Span<const int64> stride,
214       absl::Span<const std::pair<int64, int64>> padding);
215 
216   // Batch normalize data.
217   static std::unique_ptr<Array4D<float>> BatchNorm4D(
218       const Array4D<float>& input, const Array4D<float>& mean,
219       const Array4D<float>& var, const Array4D<float>& scale,
220       const Array4D<float>& offset, float epsilon);
221 
222   // Performs select and scatter with Greater Than or equal as the select, plus
223   // as the scatter, and Same Padding.
224   // TODO(b/74533103) Switch tests to evaluator and remove this implementation.
225   static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
226       const Array4D<float>& operand, const Array4D<float>& source, float init,
227       absl::Span<const int64> window, absl::Span<const int64> stride,
228       bool same_padding);
229 
230   // Concatenates the lhs and rhs arrays along the concatenate_dimension.
231   // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
232   // concatenated, so the arrays are stacked on top of each other.
233   template <typename T>
Concat2D(const Array2D<T> & lhs,const Array2D<T> & rhs,int concatenate_dimension)234   static std::unique_ptr<Array2D<T>> Concat2D(const Array2D<T>& lhs,
235                                               const Array2D<T>& rhs,
236                                               int concatenate_dimension) {
237     CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
238     auto result = absl::make_unique<Array2D<T>>(
239         concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
240         concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
241     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
242       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
243         // If we exceed the bounds of the LHS, draw from the RHS, where the
244         // result index is adjusted by the number of values present in the LHS.
245         (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2()
246                                 ? lhs(i0, i1)
247                                 : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
248                                       i1 >= lhs.n2() ? i1 - lhs.n2() : i1);
249       }
250     }
251     return result;
252   }
253 
254   // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs
255   // and rhs must have the same dimensions except for the concatenate dimension.
256   template <typename T>
Concat3D(const Array3D<T> & lhs,const Array3D<T> & rhs,int concatenate_dimension)257   static std::unique_ptr<Array3D<T>> Concat3D(const Array3D<T>& lhs,
258                                               const Array3D<T>& rhs,
259                                               int concatenate_dimension) {
260     CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
261     const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()};
262     const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
263     int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
264     for (int i = 0; i < 3; ++i) {
265       if (i != concatenate_dimension) {
266         out_dims[i] = lhs_dims[i];
267         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
268       } else {
269         out_dims[i] = lhs_dims[i] + rhs_dims[i];
270       }
271     }
272     auto result =
273         absl::make_unique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
274     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
275       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
276         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
277           (*result)(i0, i1, i2) =
278               i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3()
279                   ? lhs(i0, i1, i2)
280                   : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
281                         i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
282                         i2 >= lhs.n3() ? i2 - lhs.n3() : i2);
283         }
284       }
285     }
286     return result;
287   }
288 
289   // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs
290   // and rhs must have the same dimensions except for the concatenate dimension.
291   template <typename T>
Concat4D(const Array4D<T> & lhs,const Array4D<T> & rhs,int concatenate_dimension)292   static std::unique_ptr<Array4D<T>> Concat4D(const Array4D<T>& lhs,
293                                               const Array4D<T>& rhs,
294                                               int concatenate_dimension) {
295     CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
296     const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
297     const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
298     int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
299     for (int i = 0; i < 4; ++i) {
300       if (i != concatenate_dimension) {
301         out_dims[i] = lhs_dims[i];
302         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
303       } else {
304         out_dims[i] = lhs_dims[i] + rhs_dims[i];
305       }
306     }
307     auto result = absl::make_unique<Array4D<T>>(out_dims[0], out_dims[1],
308                                                 out_dims[2], out_dims[3]);
309     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
310       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
311         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
312           for (int64 i3 = 0; i3 < result->n4(); ++i3) {
313             (*result)(i0, i1, i2, i3) =
314                 i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4()
315                     ? lhs(i0, i1, i2, i3)
316                     : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
317                           i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
318                           i2 >= lhs.n3() ? i2 - lhs.n3() : i2,
319                           i3 >= lhs.n4() ? i3 - lhs.n4() : i3);
320           }
321         }
322       }
323     }
324     return result;
325   }
326 
327   // Slices with index clamping
328   template <typename T>
ClampSlice1D(absl::Span<const T> input,int64 start,int64 size)329   static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64 start,
330                                      int64 size) {
331     start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
332     std::vector<T> result;
333     for (int64 i = 0; i < size; ++i) {
334       result.push_back(input[(start + i)]);
335     }
336     return result;
337   }
338 
339   // Slices the input array given starting indices, limit indices, and strides
340   // in each dimension.
341   template <typename T>
Slice2D(const Array2D<T> & input,std::array<int64,2> starts,std::array<int64,2> limits,std::array<int64,2> strides)342   static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
343                                              std::array<int64, 2> starts,
344                                              std::array<int64, 2> limits,
345                                              std::array<int64, 2> strides) {
346     CHECK_LE(starts[0], input.n1());
347     CHECK_LE(starts[1], input.n2());
348     CHECK_LE(limits[0], input.n1());
349     CHECK_LE(limits[1], input.n2());
350     CHECK_GE(strides[0], 1);
351     CHECK_GE(strides[1], 1);
352     auto result = absl::make_unique<Array2D<T>>(
353         CeilOfRatio(limits[0] - starts[0], strides[0]),
354         CeilOfRatio(limits[1] - starts[1], strides[1]));
355     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
356       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
357         (*result)(i0, i1) =
358             input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]);
359       }
360     }
361     return result;
362   }
363 
364   template <typename T>
Slice3D(const Array3D<T> & input,std::array<int64,3> starts,std::array<int64,3> limits,std::array<int64,3> strides)365   static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
366                                              std::array<int64, 3> starts,
367                                              std::array<int64, 3> limits,
368                                              std::array<int64, 3> strides) {
369     CHECK_LE(starts[0], input.n1());
370     CHECK_LE(starts[1], input.n2());
371     CHECK_LE(starts[2], input.n3());
372     CHECK_LE(limits[0], input.n1());
373     CHECK_LE(limits[1], input.n2());
374     CHECK_LE(limits[2], input.n3());
375     CHECK_GE(strides[0], 1);
376     CHECK_GE(strides[1], 1);
377     CHECK_GE(strides[2], 1);
378     auto result = absl::make_unique<Array3D<T>>(
379         CeilOfRatio(limits[0] - starts[0], strides[0]),
380         CeilOfRatio(limits[1] - starts[1], strides[1]),
381         CeilOfRatio(limits[2] - starts[2], strides[2]));
382 
383     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
384       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
385         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
386           (*result)(i0, i1, i2) =
387               input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
388                     starts[2] + i2 * strides[2]);
389         }
390       }
391     }
392     return result;
393   }
394 
395   template <typename T>
Slice4D(const Array4D<T> & input,std::array<int64,4> starts,std::array<int64,4> limits,std::array<int64,4> strides)396   static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
397                                              std::array<int64, 4> starts,
398                                              std::array<int64, 4> limits,
399                                              std::array<int64, 4> strides) {
400     CHECK_LE(starts[0], input.n1());
401     CHECK_LE(starts[1], input.n2());
402     CHECK_LE(starts[2], input.n3());
403     CHECK_LE(starts[3], input.n4());
404     CHECK_LE(limits[0], input.n1());
405     CHECK_LE(limits[1], input.n2());
406     CHECK_LE(limits[2], input.n3());
407     CHECK_LE(limits[3], input.n4());
408     CHECK_GE(strides[0], 1);
409     CHECK_GE(strides[1], 1);
410     CHECK_GE(strides[2], 1);
411     CHECK_GE(strides[3], 1);
412     auto result = absl::make_unique<Array4D<T>>(
413         CeilOfRatio(limits[0] - starts[0], strides[0]),
414         CeilOfRatio(limits[1] - starts[1], strides[1]),
415         CeilOfRatio(limits[2] - starts[2], strides[2]),
416         CeilOfRatio(limits[3] - starts[3], strides[3]));
417     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
418       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
419         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
420           for (int64 i3 = 0; i3 < result->n4(); ++i3) {
421             (*result)(i0, i1, i2, i3) =
422                 input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
423                       starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]);
424           }
425         }
426       }
427     }
428     return result;
429   }
430 
431   // Applies map_function to each element in the input (2D array) and returns
432   // the result.
433   // (row, column) index of each element is also provided as arguments to
434   // map_function.
435   static std::unique_ptr<Array2D<float>> MapWithIndexArray2D(
436       const Array2D<float>& matrix,
437       const std::function<float(float, int64, int64)>& map_function);
438 
439   // Applies map_function to each element in the input (4D array) and returns
440   // the result.
441   template <typename F>
MapArray4D(const Array4D<float> & input,F && map_function)442   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& input,
443                                                     F&& map_function) {
444     return MapWithIndexArray4D(input,
445                                [&](float value, int64, int64, int64, int64) {
446                                  return map_function(value);
447                                });
448   }
449 
450   // Applies map_function to each element in the input (4D array) and returns
451   // the result.
452   // (plane, depth, height, width) index of each element is also provided as
453   // arguments to map_function.
454   template <typename F>
MapWithIndexArray4D(const Array4D<float> & input,F && map_function)455   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
456       const Array4D<float>& input, F&& map_function) {
457     auto result = absl::make_unique<Array4D<float>>(
458         input.planes(), input.depth(), input.height(), input.width());
459     for (int64 plane = 0; plane < input.planes(); ++plane) {
460       for (int64 depth = 0; depth < input.depth(); ++depth) {
461         for (int64 height = 0; height < input.height(); ++height) {
462           for (int64 width = 0; width < input.width(); ++width) {
463             (*result)(plane, depth, height, width) =
464                 map_function(input(plane, depth, height, width), plane, depth,
465                              height, width);
466           }
467         }
468       }
469     }
470     return result;
471   }
472 
473   // Applies map_function to each pair of elements in the input lhs and rhs
474   // (4D array) and returns the result.
475   template <typename F>
MapArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)476   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& lhs,
477                                                     const Array4D<float>& rhs,
478                                                     F&& map_function) {
479     return MapWithIndexArray4D(
480         lhs, rhs, [&](float lhs, float rhs, int64, int64, int64, int64) {
481           return map_function(lhs, rhs);
482         });
483   }
484 
485   // Applies map_function to each pair of element in lhs and rhs (4D array) and
486   // returns the result.
487   // (plane, depth, height, width) index of each element is also provided as
488   // arguments to map_function.
489   template <typename F>
MapWithIndexArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)490   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
491       const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
492     auto result = absl::make_unique<Array4D<float>>(lhs.planes(), lhs.depth(),
493                                                     lhs.height(), lhs.width());
494     for (int64 plane = 0; plane < lhs.planes(); ++plane) {
495       for (int64 depth = 0; depth < lhs.depth(); ++depth) {
496         for (int64 height = 0; height < lhs.height(); ++height) {
497           for (int64 width = 0; width < lhs.width(); ++width) {
498             (*result)(plane, depth, height, width) = map_function(
499                 lhs(plane, depth, height, width),
500                 rhs(plane, depth, height, width), plane, depth, height, width);
501           }
502         }
503       }
504     }
505     return result;
506   }
507 
508   // Returns the result of a 2D pad on an input matrix.
509   template <typename NativeT>
PadArray2D(const Array2D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)510   static std::unique_ptr<Array2D<NativeT>> PadArray2D(
511       const Array2D<NativeT>& operand, const PaddingConfig& padding,
512       const NativeT pad) {
513     int64 in0 = operand.n1();
514     int64 high_padding0 = padding.dimensions(0).edge_padding_high();
515     int64 low_padding0 = padding.dimensions(0).edge_padding_low();
516     int64 interior_padding0 = padding.dimensions(0).interior_padding();
517     int64 out0 =
518         in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
519 
520     int64 in1 = operand.n2();
521     int64 high_padding1 = padding.dimensions(1).edge_padding_high();
522     int64 low_padding1 = padding.dimensions(1).edge_padding_low();
523     int64 interior_padding1 = padding.dimensions(1).interior_padding();
524     int64 out1 =
525         in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
526 
527     auto result = absl::make_unique<Array2D<NativeT>>(out0, out1);
528     result->Fill(pad);
529     int64 o0 = low_padding0;
530     for (int64 i0 = 0; i0 < in0; ++i0) {
531       int64 o1 = low_padding1;
532       for (int64 i1 = 0; i1 < in1; ++i1) {
533         if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
534           (*result)(o0, o1) = operand(i0, i1);
535         }
536         o1 += interior_padding1 + 1;
537       }
538       o0 += interior_padding0 + 1;
539     }
540     return result;
541   }
542 
543   // Returns the result of a 3D pad on an input matrix.
544   template <typename NativeT>
PadArray3D(const Array3D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)545   static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand,
546                                      const PaddingConfig& padding,
547                                      const NativeT pad) {
548     CHECK_EQ(padding.dimensions_size(), 3);
549 
550     const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()};
551     int64 pad_low[3];
552     int64 pad_high[3];
553     int64 pad_interior[3];
554     int64 output_bounds[3];
555     for (int64 i = 0; i < 3; ++i) {
556       pad_low[i] = padding.dimensions(i).edge_padding_low();
557       pad_high[i] = padding.dimensions(i).edge_padding_high();
558       CHECK_LE(0, pad_low[i]);
559       CHECK_LE(0, pad_high[i]);
560       CHECK_LE(0, padding.dimensions(i).interior_padding())
561           << "not implemented";
562       pad_interior[i] = padding.dimensions(i).interior_padding();
563 
564       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
565                          (input_bounds[i] - 1) * pad_interior[i];
566     }
567 
568     Array3D<NativeT> result(output_bounds[0], output_bounds[1],
569                             output_bounds[2]);
570     int indices[] = {0, 0, 0};
571     for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
572       for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
573         for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
574           NativeT* value = &result(indices[0], indices[1], indices[2]);
575           bool value_padded = false;
576           for (int i = 0; i < 3; ++i) {
577             bool in_low_padding = indices[i] < pad_low[i];
578             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
579             if (in_low_padding || in_high_padding) {
580               *value = pad;
581               value_padded = true;
582             }
583             if (pad_interior[i] &&
584                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
585               *value = pad;
586               value_padded = true;
587             }
588           }
589           if (value_padded) {
590             continue;
591           }
592           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
593                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
594                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
595         }
596       }
597     }
598     return result;
599   }
600 
601   // Returns the result of a 4D pad on an input array.
602   template <typename NativeT>
PadArray4D(const Array4D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)603   static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand,
604                                      const PaddingConfig& padding,
605                                      const NativeT pad) {
606     CHECK_EQ(padding.dimensions_size(), 4);
607 
608     const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(),
609                                   operand.n4()};
610     int64 pad_low[4];
611     int64 pad_high[4];
612     int64 pad_interior[4];
613     int64 output_bounds[4];
614     for (int64 i = 0; i < 4; ++i) {
615       pad_low[i] = padding.dimensions(i).edge_padding_low();
616       pad_high[i] = padding.dimensions(i).edge_padding_high();
617       CHECK_LE(0, padding.dimensions(i).interior_padding())
618           << "not implemented";
619       pad_interior[i] = padding.dimensions(i).interior_padding();
620 
621       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
622                          (input_bounds[i] - 1) * pad_interior[i];
623     }
624 
625     Array4D<NativeT> result(output_bounds[0], output_bounds[1],
626                             output_bounds[2], output_bounds[3]);
627     result.Each(
628         [&](absl::Span<const int64> indices, NativeT* value) {
629           for (int i = 0; i < 4; ++i) {
630             bool in_low_padding = indices[i] < pad_low[i];
631             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
632             if (in_low_padding || in_high_padding) {
633               *value = pad;
634               return;
635             }
636             if (pad_interior[i] &&
637                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
638               *value = pad;
639               return;
640             }
641           }
642           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
643                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
644                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
645                            (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
646         });
647     return result;
648   }
649 
650   // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
651   // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
652   //
653   // The given arrays must have the same size and element type, and the return
654   // type of f must be implicitly convertible to the arrays' element type.
655   //
656   // Example usage:
657   //
658   //   Array2D<float> x, y, z = ...;
659   //   std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D(
660   //     [](float a, float b, float c) { return a * b + c; }, x, y, z);
661   //
662   template <typename F, typename T1, typename... Ts>
ApplyElementwise2D(F && f,const Array2D<T1> & array1,const Array2D<Ts> &...arrays)663   static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
664       F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
665     AssertSameSize2D(array1, arrays...);
666     auto result = absl::make_unique<Array2D<T1>>(array1.n1(), array1.n2());
667     for (int64 i = 0; i < array1.n1(); ++i) {
668       for (int64 j = 0; j < array1.n2(); ++j) {
669         (*result)(i, j) = f(array1(i, j), arrays(i, j)...);
670       }
671     }
672     return result;
673   }
674 
675  private:
676   template <typename T1, typename T2, typename... Ts>
AssertSameSize2D(const Array2D<T1> & array1,const Array2D<T2> & array2,const Array2D<Ts> &...arrays)677   static void AssertSameSize2D(const Array2D<T1>& array1,
678                                const Array2D<T2>& array2,
679                                const Array2D<Ts>&... arrays) {
680     static_assert(std::is_same<T1, T2>::value, "Args must be same type.");
681     CHECK_EQ(array1.n1(), array2.n1());
682     CHECK_EQ(array1.n2(), array2.n2());
683     AssertSameSize2D(array2, arrays...);
684   }
685 
686   // Recursive base case for AssertSameSize2D.
687   template <typename Array1>
AssertSameSize2D(const Array1 & array1)688   static void AssertSameSize2D(const Array1& array1) {}
689 
690   TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
691 };
692 
693 }  // namespace xla
694 
695 #endif  // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
696