1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/delegates/gpu/common/operations.h"
17 
18 #include <algorithm>
19 #include <cstdint>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29 
30 namespace tflite {
31 namespace gpu {
32 
operator =(const Padding2D & value)33 Padding2D& Padding2D::operator=(const Padding2D& value) {
34   prepended = value.prepended;
35   appended = value.appended;
36   return *this;
37 }
38 
operator ==(const Padding2D & value)39 bool Padding2D::operator==(const Padding2D& value) {
40   return this->prepended == value.prepended && this->appended == value.appended;
41 }
42 
operator !=(const Padding2D & value)43 bool Padding2D::operator!=(const Padding2D& value) { return !(*this == value); }
44 
operator -(const Padding2D & value)45 Padding2D& Padding2D::operator-(const Padding2D& value) {
46   prepended.h -= value.prepended.h;
47   prepended.w -= value.prepended.w;
48   appended.h -= value.appended.h;
49   appended.w -= value.appended.w;
50   return *this;
51 }
52 
operator =(const Padding3D & value)53 Padding3D& Padding3D::operator=(const Padding3D& value) {
54   prepended = value.prepended;
55   appended = value.appended;
56   return *this;
57 }
58 
operator ==(const Padding3D & value)59 bool Padding3D::operator==(const Padding3D& value) {
60   return this->prepended == value.prepended && this->appended == value.appended;
61 }
62 
operator !=(const Padding3D & value)63 bool Padding3D::operator!=(const Padding3D& value) { return !(*this == value); }
64 
operator -(const Padding3D & value)65 Padding3D& Padding3D::operator-(const Padding3D& value) {
66   prepended.h -= value.prepended.h;
67   prepended.w -= value.prepended.w;
68   prepended.d -= value.prepended.d;
69   appended.h -= value.appended.h;
70   appended.w -= value.appended.w;
71   appended.d -= value.appended.d;
72   return *this;
73 }
74 
ToString(enum OperationType op)75 std::string ToString(enum OperationType op) {
76   switch (op) {
77     case OperationType::ABS:
78       return "abs";
79     case OperationType::ADD:
80       return "add";
81     case OperationType::BATCH_NORMALIZATION:
82       return "batch_normalization";
83     case OperationType::BATCH_TO_SPACE:
84       return "batch_to_space";
85     case OperationType::BATCHED_MATMUL:
86       return "batched_matmul";
87     case OperationType::CONCAT:
88       return "concat";
89     case OperationType::CONSTANT:
90       return "const";
91     case OperationType::CONVOLUTION_2D:
92       return "convolution_2d";
93     case OperationType::CONVOLUTION_TRANSPOSED:
94       return "convolution_transposed";
95     case OperationType::COPY:
96       return "copy";
97     case OperationType::COS:
98       return "cos";
99     case OperationType::DEPTHWISE_CONVOLUTION:
100       return "depthwise_convolution";
101     case OperationType::DIV:
102       return "div";
103     case OperationType::ELU:
104       return "elu";
105     case OperationType::EQUAL:
106       return "equal";
107     case OperationType::EXP:
108       return "exp";
109     case OperationType::FULLY_CONNECTED:
110       return "fully_connected";
111     case OperationType::GREATER:
112       return "greater";
113     case OperationType::GREATER_EQUAL:
114       return "greater_equal";
115     case OperationType::HARD_SWISH:
116       return "hard_swish";
117     case OperationType::LESS:
118       return "less";
119     case OperationType::LESS_EQUAL:
120       return "less_equal";
121     case OperationType::LOG:
122       return "log";
123     case OperationType::LSTM:
124       return "lstm";
125     case OperationType::MAXIMUM:
126       return "maximum";
127     case OperationType::MAX_UNPOOLING_2D:
128       return "max_unpooling";
129     case OperationType::MEAN:
130       return "mean";
131     case OperationType::MEAN_STDDEV_NORMALIZATION:
132       return "mean_stddev_normalization";
133     case OperationType::MINIMUM:
134       return "minimum";
135     case OperationType::MUL:
136       return "mul";
137     case OperationType::NEG:
138       return "neg";
139     case OperationType::NOT_EQUAL:
140       return "not_equal";
141     case OperationType::PAD:
142       return "pad";
143     case OperationType::POOLING_2D:
144       return "pooling_2d";
145     case OperationType::POW:
146       return "pow";
147     case OperationType::PRELU:
148       return "prelu";
149     case OperationType::QUANTIZE_AND_DEQUANTIZE:
150       return "quantize_and_dequantize";
151     case OperationType::REDUCE_MAXIMUM:
152       return "reduce_maximum";
153     case OperationType::REDUCE_MINIMUM:
154       return "reduce_minimum";
155     case OperationType::REDUCE_PRODUCT:
156       return "reduce_product";
157     case OperationType::REDUCE_SUM:
158       return "reduce_sum";
159     case OperationType::RELU:
160       return "relu";
161     case OperationType::RESHAPE:
162       return "reshape";
163     case OperationType::RESIZE:
164       return "resize";
165     case OperationType::RSQRT:
166       return "rsqrt";
167     case OperationType::SIGMOID:
168       return "sigmoid";
169     case OperationType::SIN:
170       return "sin";
171     case OperationType::SLICE:
172       return "slice";
173     case OperationType::SOFTMAX:
174       return "softmax";
175     case OperationType::SPACE_TO_BATCH:
176       return "space_to_batch";
177     case OperationType::SPACE_TO_DEPTH:
178       return "space_to_depth";
179     case OperationType::SPLIT:
180       return "split";
181     case OperationType::SQRT:
182       return "sqrt";
183     case OperationType::SQUARE:
184       return "square";
185     case OperationType::SQUARED_DIFF:
186       return "squared_diff";
187     case OperationType::SUB:
188       return "subtract";
189     case OperationType::TANH:
190       return "tanh";
191     case OperationType::TRANSPOSE:
192       return "transpose";
193     case OperationType::UNKNOWN:
194       return "unknown_operation";
195   }
196 }
197 
OperationTypeFromString(const std::string & name)198 OperationType OperationTypeFromString(const std::string& name) {
199   static const auto operations =
200       new absl::flat_hash_map<std::string, OperationType>({
201           {"abs", OperationType::ABS},
202           {"add", OperationType::ADD},
203           {"batch_normalization", OperationType::BATCH_NORMALIZATION},
204           {"batched_matmul", OperationType::BATCHED_MATMUL},
205           {"concat", OperationType::CONCAT},
206           {"const", OperationType::CONSTANT},
207           {"convolution_2d", OperationType::CONVOLUTION_2D},
208           {"convolution_transposed", OperationType::CONVOLUTION_TRANSPOSED},
209           {"copy", OperationType::COPY},
210           {"cos", OperationType::COS},
211           {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
212           {"div", OperationType::DIV},
213           {"elu", OperationType::ELU},
214           {"equal", OperationType::EQUAL},
215           {"exp", OperationType::EXP},
216           {"fully_connected", OperationType::FULLY_CONNECTED},
217           {"greater", OperationType::GREATER},
218           {"greater_equal", OperationType::GREATER_EQUAL},
219           {"hard_swish", OperationType::HARD_SWISH},
220           {"less", OperationType::LESS},
221           {"less_equal", OperationType::LESS_EQUAL},
222           {"log", OperationType::LOG},
223           {"lstm", OperationType::LSTM},
224           {"maximum", OperationType::MAXIMUM},
225           {"max_unpooling", OperationType::MAX_UNPOOLING_2D},
226           {"mean", OperationType::MEAN},
227           {"mean_stddev_normalization",
228            OperationType::MEAN_STDDEV_NORMALIZATION},
229           {"minimum", OperationType::MINIMUM},
230           {"mul", OperationType::MUL},
231           {"neg", OperationType::NEG},
232           {"not_equal", OperationType::NOT_EQUAL},
233           {"pad", OperationType::PAD},
234           {"pooling_2d", OperationType::POOLING_2D},
235           {"pow", OperationType::POW},
236           {"prelu", OperationType::PRELU},
237           {"quantize_and_dequantize", OperationType::QUANTIZE_AND_DEQUANTIZE},
238           {"reduce_maximum", OperationType::REDUCE_MAXIMUM},
239           {"reduce_minimum", OperationType::REDUCE_MINIMUM},
240           {"reduce_product", OperationType::REDUCE_PRODUCT},
241           {"reduce_sum", OperationType::REDUCE_SUM},
242           {"relu", OperationType::RELU},
243           {"resize", OperationType::RESIZE},
244           {"reshape", OperationType::RESHAPE},
245           {"rsqrt", OperationType::RSQRT},
246           {"sigmoid", OperationType::SIGMOID},
247           {"sin", OperationType::SIN},
248           {"slice", OperationType::SLICE},
249           {"softmax", OperationType::SOFTMAX},
250           {"space_to_depth", OperationType::SPACE_TO_DEPTH},
251           {"split", OperationType::SPLIT},
252           {"sqrt", OperationType::SQRT},
253           {"square", OperationType::SQUARE},
254           {"squared_diff", OperationType::SQUARED_DIFF},
255           {"subtract", OperationType::SUB},
256           {"tanh", OperationType::TANH},
257           {"transpose", OperationType::TRANSPOSE},
258       });
259   auto op = operations->find(name);
260   return op == operations->end() ? OperationType::UNKNOWN : op->second;
261 }
262 
263 namespace {
264 
265 template <typename T>
DivideRoundUp(T n,T divisor)266 T DivideRoundUp(T n, T divisor) {
267   return (n - 1) / divisor + 1;
268 }
269 
CalculateOutputSizeBeforeStrides(int32_t input,int32_t kernel,int32_t padding,int32_t dilation)270 int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel,
271                                          int32_t padding, int32_t dilation) {
272   const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
273   return input + padding - dilated_kernel + 1;
274 }
275 
276 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Convolution2DAttributes & attr)277 int32_t CalculateOutputWithoutStrides(const BHWC& input,
278                                       const Convolution2DAttributes& attr) {
279   return CalculateOutputSizeBeforeStrides(
280       input.get<T>(), attr.weights.shape.get<T>(),
281       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
282       attr.dilations.get<T>());
283 }
284 
285 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Convolution3DAttributes & attr)286 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
287                                       const Convolution3DAttributes& attr) {
288   return CalculateOutputSizeBeforeStrides(
289       input.get<T>(), attr.weights.shape.get<T>(),
290       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
291       attr.dilations.get<T>());
292 }
293 
294 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Pooling2DAttributes & attr)295 int32_t CalculateOutputWithoutStrides(const BHWC& input,
296                                       const Pooling2DAttributes& attr) {
297   return CalculateOutputSizeBeforeStrides(
298       input.get<T>(), attr.kernel.get<T>(),
299       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
300       /*dilation=*/1);
301 }
302 
303 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Pooling3DAttributes & attr)304 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
305                                       const Pooling3DAttributes& attr) {
306   return CalculateOutputSizeBeforeStrides(
307       input.get<T>(), attr.kernel.get<T>(),
308       attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
309       /*dilation=*/1);
310 }
311 
312 template <Axis T>
CalculateOutput(const BHWC & input,const ConvolutionTransposedAttributes & attr)313 int32_t CalculateOutput(const BHWC& input,
314                         const ConvolutionTransposedAttributes& attr) {
315   return (input.get<T>() - 1) * attr.stride.get<T>() -
316          (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
317          attr.weights.shape.get<T>() + attr.adjacent.get<T>();
318 }
319 
320 template <Axis T>
CalculateOutput(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)321 int32_t CalculateOutput(const BHWDC& input,
322                         const ConvolutionTransposed3DAttributes& attr) {
323   return (input.get<T>() - 1) * attr.stride.get<T>() -
324          (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
325          attr.weights.shape.get<T>();
326 }
327 
StridedSize(int32_t size,int32_t stride)328 inline int32_t StridedSize(int32_t size, int32_t stride) {
329   return stride == 0 ? -1 : DivideRoundUp(size, stride);
330 }
331 
332 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWC & input,const AttrT & attr)333 int32_t CalculateOutput(const BHWC& input, const AttrT& attr) {
334   return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
335                      attr.strides.template get<AxisT>());
336 }
337 
338 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWDC & input,const AttrT & attr)339 int32_t CalculateOutput(const BHWDC& input, const AttrT& attr) {
340   return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
341                      attr.strides.template get<AxisT>());
342 }
343 
CalculateSamePadding(int32_t input,int32_t kernel,int32_t dilation,int32_t stride)344 int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation,
345                              int32_t stride) {
346   const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
347   return std::max(0, dilated_kernel - (input - 1) % stride - 1);
348 }
349 
350 // Returns a padding that should be present to make sure image size stays
351 // the same.
352 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)353 int32_t CalculateSamePadding(const BHWC& input,
354                              const Convolution2DAttributes& attr) {
355   return CalculateSamePadding(
356       input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
357       attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
358 }
359 
360 // Returns a padding that should be present to make sure image size stays
361 // the same.
362 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)363 int32_t CalculateSamePadding(const BHWDC& input,
364                              const Convolution3DAttributes& attr) {
365   return CalculateSamePadding(
366       input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
367       attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
368 }
369 
370 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)371 int32_t CalculateSamePadding(const BHWC& input,
372                              const ConvolutionTransposedAttributes& attr) {
373   return CalculateSamePadding(input.get<AxisT>(),
374                               attr.weights.shape.get<AxisT>(),
375                               /*dilation=*/1, attr.stride.get<AxisT>());
376 }
377 
378 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)379 int32_t CalculateSamePadding(const BHWDC& input,
380                              const ConvolutionTransposed3DAttributes& attr) {
381   return CalculateSamePadding(input.get<AxisT>(),
382                               attr.weights.shape.get<AxisT>(),
383                               /*dilation=*/1, attr.stride.get<AxisT>());
384 }
385 
386 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)387 int32_t CalculateSamePadding(const BHWC& input,
388                              const Pooling2DAttributes& attr) {
389   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
390                               /*dilation=*/1, attr.strides.get<AxisT>());
391 }
392 
393 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)394 int32_t CalculateSamePadding(const BHWDC& input,
395                              const Pooling3DAttributes& attr) {
396   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
397                               /*dilation=*/1, attr.strides.get<AxisT>());
398 }
399 
400 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)401 int32_t CalculateSamePadding(const BHWC& input,
402                              const MaxUnpooling2DAttributes& attr) {
403   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
404                               /*dilation=*/1, attr.strides.get<AxisT>());
405 }
406 
407 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)408 int32_t CalculateSamePadding(const BHWDC& input,
409                              const MaxUnpooling3DAttributes& attr) {
410   return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
411                               /*dilation=*/1, attr.strides.get<AxisT>());
412 }
413 
MakeSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)414 Padding2D MakeSamePadding(const BHWC& input,
415                           const ConvolutionTransposedAttributes& attr) {
416   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
417   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
418   Padding2D padding;
419   padding.prepended = HW(padding_height / 2, padding_width / 2);
420   padding.appended = HW(padding_height - padding_height / 2,
421                         padding_width - padding_width / 2);
422   return padding;
423 }
424 
MakeSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)425 Padding3D MakeSamePadding(const BHWDC& input,
426                           const ConvolutionTransposed3DAttributes& attr) {
427   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
428   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
429   int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
430   Padding3D padding;
431   padding.prepended =
432       HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
433   padding.appended =
434       HWD(padding_height - padding_height / 2,
435           padding_width - padding_width / 2, padding_depth - padding_depth / 2);
436   return padding;
437 }
438 
439 // If padding depends on input, convert it into fixed padding.
440 template <class AttrT>
MakeSamePadding(const BHWC & input,const AttrT & attr)441 Padding2D MakeSamePadding(const BHWC& input, const AttrT& attr) {
442   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
443   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
444   Padding2D padding;
445   padding.prepended = HW(padding_height / 2, padding_width / 2);
446   padding.appended = HW(padding_height - padding_height / 2,
447                         padding_width - padding_width / 2);
448   return padding;
449 }
450 
451 // If padding depends on input, convert it into fixed padding.
452 template <class AttrT>
MakeSamePadding(const BHWDC & input,const AttrT & attr)453 Padding3D MakeSamePadding(const BHWDC& input, const AttrT& attr) {
454   int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
455   int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
456   int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
457   Padding3D padding;
458   padding.prepended =
459       HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
460   padding.appended =
461       HWD(padding_height - padding_height / 2,
462           padding_width - padding_width / 2, padding_depth - padding_depth / 2);
463   return padding;
464 }
465 
466 }  // namespace
467 
CalculateOutputShape(const BHWC & input,const MaxUnpooling2DAttributes & attr)468 BHWC CalculateOutputShape(const BHWC& input,
469                           const MaxUnpooling2DAttributes& attr) {
470   return BHWC(input.b,
471               input.h * attr.strides.h - attr.padding.prepended.h -
472                   attr.padding.appended.h,
473               input.w * attr.strides.w - attr.padding.prepended.w -
474                   attr.padding.appended.w,
475               input.c);
476 }
477 
CalculateOutputShape(const BHWDC & input,const MaxUnpooling3DAttributes & attr)478 BHWDC CalculateOutputShape(const BHWDC& input,
479                            const MaxUnpooling3DAttributes& attr) {
480   return BHWDC(input.b,
481                input.h * attr.strides.h - attr.padding.prepended.h -
482                    attr.padding.appended.h,
483                input.w * attr.strides.w - attr.padding.prepended.w -
484                    attr.padding.appended.w,
485                input.d * attr.strides.d - attr.padding.prepended.d -
486                    attr.padding.appended.d,
487                input.c);
488 }
489 
CalculateOutputShape(const BHWC & input,const Pooling2DAttributes & attr)490 BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr) {
491   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
492               CalculateOutput<Axis::WIDTH>(input, attr), input.c);
493 }
494 
CalculateOutputShape(const BHWDC & input,const Pooling3DAttributes & attr)495 BHWDC CalculateOutputShape(const BHWDC& input,
496                            const Pooling3DAttributes& attr) {
497   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
498                CalculateOutput<Axis::WIDTH>(input, attr),
499                CalculateOutput<Axis::DEPTH>(input, attr), input.c);
500 }
501 
CalculateOutputShape(const BHWC & input,const Convolution2DAttributes & attr)502 BHWC CalculateOutputShape(const BHWC& input,
503                           const Convolution2DAttributes& attr) {
504   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
505               CalculateOutput<Axis::WIDTH>(input, attr),
506               attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
507 }
508 
CalculateOutputShape(const BHWDC & input,const Convolution3DAttributes & attr)509 BHWDC CalculateOutputShape(const BHWDC& input,
510                            const Convolution3DAttributes& attr) {
511   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
512                CalculateOutput<Axis::WIDTH>(input, attr),
513                CalculateOutput<Axis::DEPTH>(input, attr),
514                attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
515 }
516 
CalculateOutputShape(const BHWC & input,const ConvolutionTransposedAttributes & attr)517 BHWC CalculateOutputShape(const BHWC& input,
518                           const ConvolutionTransposedAttributes& attr) {
519   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
520               CalculateOutput<Axis::WIDTH>(input, attr),
521               attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
522 }
523 
CalculateOutputShape(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)524 BHWDC CalculateOutputShape(const BHWDC& input,
525                            const ConvolutionTransposed3DAttributes& attr) {
526   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
527                CalculateOutput<Axis::WIDTH>(input, attr),
528                CalculateOutput<Axis::DEPTH>(input, attr),
529                attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
530 }
531 
CalculateOutputShape(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)532 BHWC CalculateOutputShape(const BHWC& input,
533                           const DepthwiseConvolution2DAttributes& attr) {
534   return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
535               CalculateOutput<Axis::WIDTH>(input, attr),
536               attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
537                   attr.weights.shape.get<Axis::INPUT_CHANNELS>());
538 }
539 
CalculateOutputShape(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)540 BHWDC CalculateOutputShape(const BHWDC& input,
541                            const DepthwiseConvolution3DAttributes& attr) {
542   return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
543                CalculateOutput<Axis::WIDTH>(input, attr),
544                CalculateOutput<Axis::DEPTH>(input, attr),
545                attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
546                    attr.weights.shape.get<Axis::INPUT_CHANNELS>());
547 }
548 
CalculateOutputShape(const BHWC & input,const SliceAttributes & attr)549 BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
550   return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
551               StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
552               StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
553               StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
554 }
555 
CalculateOutputShape(const BHWDC & input,const Slice3DAttributes & attr)556 BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr) {
557   return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
558                StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
559                StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
560                StridedSize(attr.ends.d - attr.starts.d, attr.strides.d),
561                StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
562 }
563 
CalculateOutputShape(const BHWC & input,const PadAttributes & attr)564 BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) {
565   return BHWC(attr.appended.b + attr.prepended.b + input.b,
566               attr.appended.h + attr.prepended.h + input.h,
567               attr.appended.w + attr.prepended.w + input.w,
568               attr.appended.c + attr.prepended.c + input.c);
569 }
570 
CalculateOutputShape(const BHWDC & input,const Pad3DAttributes & attr)571 BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr) {
572   return BHWDC(attr.appended.b + attr.prepended.b + input.b,
573                attr.appended.h + attr.prepended.h + input.h,
574                attr.appended.w + attr.prepended.w + input.w,
575                attr.appended.d + attr.prepended.d + input.d,
576                attr.appended.c + attr.prepended.c + input.c);
577 }
578 
CalculateOutputShape(const BHWC & input,const FullyConnectedAttributes & attr)579 BHWC CalculateOutputShape(const BHWC& input,
580                           const FullyConnectedAttributes& attr) {
581   return BHWC(input.b, 1, 1, attr.weights.shape.o);
582 }
583 
CalculateOutputShape(const BHWC & input,const MeanAttributes & attr)584 BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr) {
585   const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
586   const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
587   const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
588   const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
589   return BHWC(b, h, w, c);
590 }
591 
CalculateOutputShape(const BHWDC & input,const MeanAttributes & attr)592 BHWDC CalculateOutputShape(const BHWDC& input, const MeanAttributes& attr) {
593   const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
594   const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
595   const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
596   const int d = attr.dims.find(Axis::DEPTH) == attr.dims.end() ? input.d : 1;
597   const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
598   return BHWDC(b, h, w, d, c);
599 }
600 
CalculateOutputShape(const std::vector<BHWC> & input,const ConcatAttributes & attr,BHWC * output_shape)601 absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
602                                   const ConcatAttributes& attr,
603                                   BHWC* output_shape) {
604   BHWC new_shape = input[0];
605   switch (attr.axis) {
606     case Axis::CHANNELS:
607       for (int i = 1; i < input.size(); i++) {
608         if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
609             input[i].b != new_shape.b) {
610           return absl::InvalidArgumentError(
611               "Height, Width and Batch must be the same when concatenating "
612               "by channels axis");
613         }
614         new_shape.c += input[i].c;
615       }
616       break;
617     case Axis::HEIGHT:
618       for (int i = 1; i < input.size(); i++) {
619         if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
620             input[i].b != new_shape.b) {
621           return absl::InvalidArgumentError(
622               "Channels, Width and Batch must be the same when concatenating "
623               "by height axis");
624         }
625         new_shape.h += input[i].h;
626       }
627       break;
628     case Axis::WIDTH:
629       for (int i = 1; i < input.size(); i++) {
630         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
631             input[i].b != new_shape.b) {
632           return absl::InvalidArgumentError(
633               "Height, Channels and Batch must be the same when concatenating "
634               "by width axis");
635         }
636         new_shape.w += input[i].w;
637       }
638       break;
639     case Axis::BATCH:
640       for (int i = 1; i < input.size(); i++) {
641         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
642             input[i].w != new_shape.w) {
643           return absl::InvalidArgumentError(
644               "Width, Height and Channels must be the same when concatenating "
645               "by batch axis");
646         }
647         new_shape.b += input[i].b;
648       }
649       break;
650     default:
651       return absl::InvalidArgumentError("Invalid axis");
652       break;
653   }
654   *output_shape = new_shape;
655   return absl::OkStatus();
656 }
657 
CalculateOutputShape(const std::vector<BHWDC> & input,const ConcatAttributes & attr,BHWDC * output_shape)658 absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
659                                   const ConcatAttributes& attr,
660                                   BHWDC* output_shape) {
661   BHWDC new_shape = input[0];
662   switch (attr.axis) {
663     case Axis::CHANNELS:
664       for (int i = 1; i < input.size(); ++i) {
665         if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
666             input[i].d != new_shape.d || input[i].b != new_shape.b) {
667           return absl::InvalidArgumentError(
668               "Height, Width, Batch and Depth must be the same when "
669               "concatenating "
670               "by channels axis");
671         }
672         new_shape.c += input[i].c;
673       }
674       break;
675     case Axis::HEIGHT:
676       for (int i = 1; i < input.size(); ++i) {
677         if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
678             input[i].d != new_shape.d || input[i].b != new_shape.b) {
679           return absl::InvalidArgumentError(
680               "Width, Depth, Batch and Channels must be the same when "
681               "concatenating "
682               "by height axis");
683         }
684         new_shape.h += input[i].h;
685       }
686       break;
687     case Axis::WIDTH:
688       for (int i = 1; i < input.size(); ++i) {
689         if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
690             input[i].d != new_shape.d || input[i].b != new_shape.b) {
691           return absl::InvalidArgumentError(
692               "Height, Depth, Batch and Channels must be the same when "
693               "concatenating "
694               "by width axis");
695         }
696         new_shape.w += input[i].w;
697       }
698       break;
699     case Axis::DEPTH:
700       for (int i = 1; i < input.size(); ++i) {
701         if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
702             input[i].c != new_shape.c || input[i].b != new_shape.b) {
703           return absl::InvalidArgumentError(
704               "Width, Height, Batch and Channels must be the same when "
705               "concatenating "
706               "by depth axis");
707         }
708         new_shape.d += input[i].d;
709       }
710       break;
711     case Axis::BATCH:
712       for (int i = 1; i < input.size(); ++i) {
713         if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
714             input[i].c != new_shape.c || input[i].d != new_shape.d) {
715           return absl::InvalidArgumentError(
716               "Width, Height, Depth and Channels must be the same when "
717               "concatenating "
718               "by batch axis");
719         }
720         new_shape.b += input[i].b;
721       }
722       break;
723     default:
724       return absl::InvalidArgumentError("Invalid axis");
725   }
726   *output_shape = new_shape;
727   return absl::OkStatus();
728 }
729 
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)730 Padding2D CalculateSamePadding(const BHWC& input,
731                                const Convolution2DAttributes& attr) {
732   return MakeSamePadding(input, attr);
733 }
734 
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)735 Padding3D CalculateSamePadding(const BHWDC& input,
736                                const Convolution3DAttributes& attr) {
737   return MakeSamePadding(input, attr);
738 }
739 
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)740 Padding2D CalculateSamePadding(const BHWC& input,
741                                const ConvolutionTransposedAttributes& attr) {
742   return MakeSamePadding(input, attr);
743 }
744 
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)745 Padding3D CalculateSamePadding(const BHWDC& input,
746                                const ConvolutionTransposed3DAttributes& attr) {
747   return MakeSamePadding(input, attr);
748 }
749 
CalculateSamePadding(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)750 Padding2D CalculateSamePadding(const BHWC& input,
751                                const DepthwiseConvolution2DAttributes& attr) {
752   return MakeSamePadding(input, attr);
753 }
754 
CalculateSamePadding(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)755 Padding3D CalculateSamePadding(const BHWDC& input,
756                                const DepthwiseConvolution3DAttributes& attr) {
757   return MakeSamePadding(input, attr);
758 }
759 
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)760 Padding2D CalculateSamePadding(const BHWC& input,
761                                const Pooling2DAttributes& attr) {
762   return MakeSamePadding(input, attr);
763 }
764 
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)765 Padding3D CalculateSamePadding(const BHWDC& input,
766                                const Pooling3DAttributes& attr) {
767   return MakeSamePadding(input, attr);
768 }
769 
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)770 Padding2D CalculateSamePadding(const BHWC& input,
771                                const MaxUnpooling2DAttributes& attr) {
772   return MakeSamePadding(input, attr);
773 }
774 
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)775 Padding3D CalculateSamePadding(const BHWDC& input,
776                                const MaxUnpooling3DAttributes& attr) {
777   return MakeSamePadding(input, attr);
778 }
779 
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize2DAttributes & attr)780 float CalculateResizeScale(int32_t input_size, int32_t output_size,
781                            const Resize2DAttributes& attr) {
782   return attr.align_corners && input_size > 1 && output_size > 1
783              ? static_cast<float>(input_size - 1) / (output_size - 1)
784              : static_cast<float>(input_size) / output_size;
785 }
786 
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize3DAttributes & attr)787 float CalculateResizeScale(int32_t input_size, int32_t output_size,
788                            const Resize3DAttributes& attr) {
789   return attr.align_corners && input_size > 1 && output_size > 1
790              ? static_cast<float>(input_size - 1) / (output_size - 1)
791              : static_cast<float>(input_size) / output_size;
792 }
793 
CalculateOutputShape(const BHWC & input,const Resize2DAttributes & attr)794 BHWC CalculateOutputShape(const BHWC& input, const Resize2DAttributes& attr) {
795   return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
796 }
797 
CalculateOutputShape(const BHWDC & input,const Resize3DAttributes & attr)798 BHWDC CalculateOutputShape(const BHWDC& input, const Resize3DAttributes& attr) {
799   return BHWDC(input.b, attr.new_shape.h, attr.new_shape.w, attr.new_shape.d,
800                input.c);
801 }
802 
CalculateOutputShape(const BHWC & input,const TransposeAttributes & attr)803 BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr) {
804   return BHWC(input.get(attr.perm.b), input.get(attr.perm.h),
805               input.get(attr.perm.w), input.get(attr.perm.c));
806 }
807 
CalculateOutputShape(const BHWDC & input,const Transpose3DAttributes & attr)808 BHWDC CalculateOutputShape(const BHWDC& input,
809                            const Transpose3DAttributes& attr) {
810   return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h),
811                input.get(attr.perm.w), input.get(attr.perm.d),
812                input.get(attr.perm.c));
813 }
814 
815 }  // namespace gpu
816 }  // namespace tflite
817