1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/common/operations.h"
17
18 #include <algorithm>
19 #include <cstdint>
20 #include <set>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/lite/delegates/gpu/common/shape.h"
27 #include "tensorflow/lite/delegates/gpu/common/status.h"
28 #include "tensorflow/lite/delegates/gpu/common/tensor.h"
29
30 namespace tflite {
31 namespace gpu {
32
operator =(const Padding2D & value)33 Padding2D& Padding2D::operator=(const Padding2D& value) {
34 prepended = value.prepended;
35 appended = value.appended;
36 return *this;
37 }
38
operator ==(const Padding2D & value)39 bool Padding2D::operator==(const Padding2D& value) {
40 return this->prepended == value.prepended && this->appended == value.appended;
41 }
42
operator !=(const Padding2D & value)43 bool Padding2D::operator!=(const Padding2D& value) { return !(*this == value); }
44
operator -(const Padding2D & value)45 Padding2D& Padding2D::operator-(const Padding2D& value) {
46 prepended.h -= value.prepended.h;
47 prepended.w -= value.prepended.w;
48 appended.h -= value.appended.h;
49 appended.w -= value.appended.w;
50 return *this;
51 }
52
operator =(const Padding3D & value)53 Padding3D& Padding3D::operator=(const Padding3D& value) {
54 prepended = value.prepended;
55 appended = value.appended;
56 return *this;
57 }
58
operator ==(const Padding3D & value)59 bool Padding3D::operator==(const Padding3D& value) {
60 return this->prepended == value.prepended && this->appended == value.appended;
61 }
62
operator !=(const Padding3D & value)63 bool Padding3D::operator!=(const Padding3D& value) { return !(*this == value); }
64
operator -(const Padding3D & value)65 Padding3D& Padding3D::operator-(const Padding3D& value) {
66 prepended.h -= value.prepended.h;
67 prepended.w -= value.prepended.w;
68 prepended.d -= value.prepended.d;
69 appended.h -= value.appended.h;
70 appended.w -= value.appended.w;
71 appended.d -= value.appended.d;
72 return *this;
73 }
74
ToString(enum OperationType op)75 std::string ToString(enum OperationType op) {
76 switch (op) {
77 case OperationType::ABS:
78 return "abs";
79 case OperationType::ADD:
80 return "add";
81 case OperationType::BATCH_NORMALIZATION:
82 return "batch_normalization";
83 case OperationType::BATCH_TO_SPACE:
84 return "batch_to_space";
85 case OperationType::BATCHED_MATMUL:
86 return "batched_matmul";
87 case OperationType::CONCAT:
88 return "concat";
89 case OperationType::CONSTANT:
90 return "const";
91 case OperationType::CONVOLUTION_2D:
92 return "convolution_2d";
93 case OperationType::CONVOLUTION_TRANSPOSED:
94 return "convolution_transposed";
95 case OperationType::COPY:
96 return "copy";
97 case OperationType::COS:
98 return "cos";
99 case OperationType::DEPTHWISE_CONVOLUTION:
100 return "depthwise_convolution";
101 case OperationType::DIV:
102 return "div";
103 case OperationType::ELU:
104 return "elu";
105 case OperationType::EQUAL:
106 return "equal";
107 case OperationType::EXP:
108 return "exp";
109 case OperationType::FULLY_CONNECTED:
110 return "fully_connected";
111 case OperationType::GREATER:
112 return "greater";
113 case OperationType::GREATER_EQUAL:
114 return "greater_equal";
115 case OperationType::HARD_SWISH:
116 return "hard_swish";
117 case OperationType::LESS:
118 return "less";
119 case OperationType::LESS_EQUAL:
120 return "less_equal";
121 case OperationType::LOG:
122 return "log";
123 case OperationType::LSTM:
124 return "lstm";
125 case OperationType::MAXIMUM:
126 return "maximum";
127 case OperationType::MAX_UNPOOLING_2D:
128 return "max_unpooling";
129 case OperationType::MEAN:
130 return "mean";
131 case OperationType::MEAN_STDDEV_NORMALIZATION:
132 return "mean_stddev_normalization";
133 case OperationType::MINIMUM:
134 return "minimum";
135 case OperationType::MUL:
136 return "mul";
137 case OperationType::NEG:
138 return "neg";
139 case OperationType::NOT_EQUAL:
140 return "not_equal";
141 case OperationType::PAD:
142 return "pad";
143 case OperationType::POOLING_2D:
144 return "pooling_2d";
145 case OperationType::POW:
146 return "pow";
147 case OperationType::PRELU:
148 return "prelu";
149 case OperationType::QUANTIZE_AND_DEQUANTIZE:
150 return "quantize_and_dequantize";
151 case OperationType::REDUCE_MAXIMUM:
152 return "reduce_maximum";
153 case OperationType::REDUCE_MINIMUM:
154 return "reduce_minimum";
155 case OperationType::REDUCE_PRODUCT:
156 return "reduce_product";
157 case OperationType::REDUCE_SUM:
158 return "reduce_sum";
159 case OperationType::RELU:
160 return "relu";
161 case OperationType::RESHAPE:
162 return "reshape";
163 case OperationType::RESIZE:
164 return "resize";
165 case OperationType::RSQRT:
166 return "rsqrt";
167 case OperationType::SIGMOID:
168 return "sigmoid";
169 case OperationType::SIN:
170 return "sin";
171 case OperationType::SLICE:
172 return "slice";
173 case OperationType::SOFTMAX:
174 return "softmax";
175 case OperationType::SPACE_TO_BATCH:
176 return "space_to_batch";
177 case OperationType::SPACE_TO_DEPTH:
178 return "space_to_depth";
179 case OperationType::SPLIT:
180 return "split";
181 case OperationType::SQRT:
182 return "sqrt";
183 case OperationType::SQUARE:
184 return "square";
185 case OperationType::SQUARED_DIFF:
186 return "squared_diff";
187 case OperationType::SUB:
188 return "subtract";
189 case OperationType::TANH:
190 return "tanh";
191 case OperationType::TRANSPOSE:
192 return "transpose";
193 case OperationType::UNKNOWN:
194 return "unknown_operation";
195 }
196 }
197
OperationTypeFromString(const std::string & name)198 OperationType OperationTypeFromString(const std::string& name) {
199 static const auto operations =
200 new absl::flat_hash_map<std::string, OperationType>({
201 {"abs", OperationType::ABS},
202 {"add", OperationType::ADD},
203 {"batch_normalization", OperationType::BATCH_NORMALIZATION},
204 {"batched_matmul", OperationType::BATCHED_MATMUL},
205 {"concat", OperationType::CONCAT},
206 {"const", OperationType::CONSTANT},
207 {"convolution_2d", OperationType::CONVOLUTION_2D},
208 {"convolution_transposed", OperationType::CONVOLUTION_TRANSPOSED},
209 {"copy", OperationType::COPY},
210 {"cos", OperationType::COS},
211 {"depthwise_convolution", OperationType::DEPTHWISE_CONVOLUTION},
212 {"div", OperationType::DIV},
213 {"elu", OperationType::ELU},
214 {"equal", OperationType::EQUAL},
215 {"exp", OperationType::EXP},
216 {"fully_connected", OperationType::FULLY_CONNECTED},
217 {"greater", OperationType::GREATER},
218 {"greater_equal", OperationType::GREATER_EQUAL},
219 {"hard_swish", OperationType::HARD_SWISH},
220 {"less", OperationType::LESS},
221 {"less_equal", OperationType::LESS_EQUAL},
222 {"log", OperationType::LOG},
223 {"lstm", OperationType::LSTM},
224 {"maximum", OperationType::MAXIMUM},
225 {"max_unpooling", OperationType::MAX_UNPOOLING_2D},
226 {"mean", OperationType::MEAN},
227 {"mean_stddev_normalization",
228 OperationType::MEAN_STDDEV_NORMALIZATION},
229 {"minimum", OperationType::MINIMUM},
230 {"mul", OperationType::MUL},
231 {"neg", OperationType::NEG},
232 {"not_equal", OperationType::NOT_EQUAL},
233 {"pad", OperationType::PAD},
234 {"pooling_2d", OperationType::POOLING_2D},
235 {"pow", OperationType::POW},
236 {"prelu", OperationType::PRELU},
237 {"quantize_and_dequantize", OperationType::QUANTIZE_AND_DEQUANTIZE},
238 {"reduce_maximum", OperationType::REDUCE_MAXIMUM},
239 {"reduce_minimum", OperationType::REDUCE_MINIMUM},
240 {"reduce_product", OperationType::REDUCE_PRODUCT},
241 {"reduce_sum", OperationType::REDUCE_SUM},
242 {"relu", OperationType::RELU},
243 {"resize", OperationType::RESIZE},
244 {"reshape", OperationType::RESHAPE},
245 {"rsqrt", OperationType::RSQRT},
246 {"sigmoid", OperationType::SIGMOID},
247 {"sin", OperationType::SIN},
248 {"slice", OperationType::SLICE},
249 {"softmax", OperationType::SOFTMAX},
250 {"space_to_depth", OperationType::SPACE_TO_DEPTH},
251 {"split", OperationType::SPLIT},
252 {"sqrt", OperationType::SQRT},
253 {"square", OperationType::SQUARE},
254 {"squared_diff", OperationType::SQUARED_DIFF},
255 {"subtract", OperationType::SUB},
256 {"tanh", OperationType::TANH},
257 {"transpose", OperationType::TRANSPOSE},
258 });
259 auto op = operations->find(name);
260 return op == operations->end() ? OperationType::UNKNOWN : op->second;
261 }
262
263 namespace {
264
265 template <typename T>
DivideRoundUp(T n,T divisor)266 T DivideRoundUp(T n, T divisor) {
267 return (n - 1) / divisor + 1;
268 }
269
CalculateOutputSizeBeforeStrides(int32_t input,int32_t kernel,int32_t padding,int32_t dilation)270 int32_t CalculateOutputSizeBeforeStrides(int32_t input, int32_t kernel,
271 int32_t padding, int32_t dilation) {
272 const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
273 return input + padding - dilated_kernel + 1;
274 }
275
276 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Convolution2DAttributes & attr)277 int32_t CalculateOutputWithoutStrides(const BHWC& input,
278 const Convolution2DAttributes& attr) {
279 return CalculateOutputSizeBeforeStrides(
280 input.get<T>(), attr.weights.shape.get<T>(),
281 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
282 attr.dilations.get<T>());
283 }
284
285 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Convolution3DAttributes & attr)286 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
287 const Convolution3DAttributes& attr) {
288 return CalculateOutputSizeBeforeStrides(
289 input.get<T>(), attr.weights.shape.get<T>(),
290 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
291 attr.dilations.get<T>());
292 }
293
294 template <Axis T>
CalculateOutputWithoutStrides(const BHWC & input,const Pooling2DAttributes & attr)295 int32_t CalculateOutputWithoutStrides(const BHWC& input,
296 const Pooling2DAttributes& attr) {
297 return CalculateOutputSizeBeforeStrides(
298 input.get<T>(), attr.kernel.get<T>(),
299 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
300 /*dilation=*/1);
301 }
302
303 template <Axis T>
CalculateOutputWithoutStrides(const BHWDC & input,const Pooling3DAttributes & attr)304 int32_t CalculateOutputWithoutStrides(const BHWDC& input,
305 const Pooling3DAttributes& attr) {
306 return CalculateOutputSizeBeforeStrides(
307 input.get<T>(), attr.kernel.get<T>(),
308 attr.padding.prepended.get<T>() + attr.padding.appended.get<T>(),
309 /*dilation=*/1);
310 }
311
312 template <Axis T>
CalculateOutput(const BHWC & input,const ConvolutionTransposedAttributes & attr)313 int32_t CalculateOutput(const BHWC& input,
314 const ConvolutionTransposedAttributes& attr) {
315 return (input.get<T>() - 1) * attr.stride.get<T>() -
316 (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
317 attr.weights.shape.get<T>() + attr.adjacent.get<T>();
318 }
319
320 template <Axis T>
CalculateOutput(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)321 int32_t CalculateOutput(const BHWDC& input,
322 const ConvolutionTransposed3DAttributes& attr) {
323 return (input.get<T>() - 1) * attr.stride.get<T>() -
324 (attr.padding.prepended.get<T>() + attr.padding.appended.get<T>()) +
325 attr.weights.shape.get<T>();
326 }
327
StridedSize(int32_t size,int32_t stride)328 inline int32_t StridedSize(int32_t size, int32_t stride) {
329 return stride == 0 ? -1 : DivideRoundUp(size, stride);
330 }
331
332 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWC & input,const AttrT & attr)333 int32_t CalculateOutput(const BHWC& input, const AttrT& attr) {
334 return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
335 attr.strides.template get<AxisT>());
336 }
337
338 template <Axis AxisT, typename AttrT>
CalculateOutput(const BHWDC & input,const AttrT & attr)339 int32_t CalculateOutput(const BHWDC& input, const AttrT& attr) {
340 return StridedSize(CalculateOutputWithoutStrides<AxisT>(input, attr),
341 attr.strides.template get<AxisT>());
342 }
343
CalculateSamePadding(int32_t input,int32_t kernel,int32_t dilation,int32_t stride)344 int32_t CalculateSamePadding(int32_t input, int32_t kernel, int32_t dilation,
345 int32_t stride) {
346 const int32_t dilated_kernel = (kernel - 1) * dilation + 1;
347 return std::max(0, dilated_kernel - (input - 1) % stride - 1);
348 }
349
350 // Returns a padding that should be present to make sure image size stays
351 // the same.
352 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)353 int32_t CalculateSamePadding(const BHWC& input,
354 const Convolution2DAttributes& attr) {
355 return CalculateSamePadding(
356 input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
357 attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
358 }
359
360 // Returns a padding that should be present to make sure image size stays
361 // the same.
362 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)363 int32_t CalculateSamePadding(const BHWDC& input,
364 const Convolution3DAttributes& attr) {
365 return CalculateSamePadding(
366 input.get<AxisT>(), attr.weights.shape.get<AxisT>(),
367 attr.dilations.get<AxisT>(), attr.strides.get<AxisT>());
368 }
369
370 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)371 int32_t CalculateSamePadding(const BHWC& input,
372 const ConvolutionTransposedAttributes& attr) {
373 return CalculateSamePadding(input.get<AxisT>(),
374 attr.weights.shape.get<AxisT>(),
375 /*dilation=*/1, attr.stride.get<AxisT>());
376 }
377
378 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)379 int32_t CalculateSamePadding(const BHWDC& input,
380 const ConvolutionTransposed3DAttributes& attr) {
381 return CalculateSamePadding(input.get<AxisT>(),
382 attr.weights.shape.get<AxisT>(),
383 /*dilation=*/1, attr.stride.get<AxisT>());
384 }
385
386 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)387 int32_t CalculateSamePadding(const BHWC& input,
388 const Pooling2DAttributes& attr) {
389 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
390 /*dilation=*/1, attr.strides.get<AxisT>());
391 }
392
393 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)394 int32_t CalculateSamePadding(const BHWDC& input,
395 const Pooling3DAttributes& attr) {
396 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
397 /*dilation=*/1, attr.strides.get<AxisT>());
398 }
399
400 template <Axis AxisT>
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)401 int32_t CalculateSamePadding(const BHWC& input,
402 const MaxUnpooling2DAttributes& attr) {
403 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
404 /*dilation=*/1, attr.strides.get<AxisT>());
405 }
406
407 template <Axis AxisT>
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)408 int32_t CalculateSamePadding(const BHWDC& input,
409 const MaxUnpooling3DAttributes& attr) {
410 return CalculateSamePadding(input.get<AxisT>(), attr.kernel.get<AxisT>(),
411 /*dilation=*/1, attr.strides.get<AxisT>());
412 }
413
MakeSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)414 Padding2D MakeSamePadding(const BHWC& input,
415 const ConvolutionTransposedAttributes& attr) {
416 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
417 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
418 Padding2D padding;
419 padding.prepended = HW(padding_height / 2, padding_width / 2);
420 padding.appended = HW(padding_height - padding_height / 2,
421 padding_width - padding_width / 2);
422 return padding;
423 }
424
MakeSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)425 Padding3D MakeSamePadding(const BHWDC& input,
426 const ConvolutionTransposed3DAttributes& attr) {
427 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
428 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
429 int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
430 Padding3D padding;
431 padding.prepended =
432 HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
433 padding.appended =
434 HWD(padding_height - padding_height / 2,
435 padding_width - padding_width / 2, padding_depth - padding_depth / 2);
436 return padding;
437 }
438
439 // If padding depends on input, convert it into fixed padding.
440 template <class AttrT>
MakeSamePadding(const BHWC & input,const AttrT & attr)441 Padding2D MakeSamePadding(const BHWC& input, const AttrT& attr) {
442 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
443 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
444 Padding2D padding;
445 padding.prepended = HW(padding_height / 2, padding_width / 2);
446 padding.appended = HW(padding_height - padding_height / 2,
447 padding_width - padding_width / 2);
448 return padding;
449 }
450
451 // If padding depends on input, convert it into fixed padding.
452 template <class AttrT>
MakeSamePadding(const BHWDC & input,const AttrT & attr)453 Padding3D MakeSamePadding(const BHWDC& input, const AttrT& attr) {
454 int32_t padding_height = CalculateSamePadding<Axis::HEIGHT>(input, attr);
455 int32_t padding_width = CalculateSamePadding<Axis::WIDTH>(input, attr);
456 int32_t padding_depth = CalculateSamePadding<Axis::DEPTH>(input, attr);
457 Padding3D padding;
458 padding.prepended =
459 HWD(padding_height / 2, padding_width / 2, padding_depth / 2);
460 padding.appended =
461 HWD(padding_height - padding_height / 2,
462 padding_width - padding_width / 2, padding_depth - padding_depth / 2);
463 return padding;
464 }
465
466 } // namespace
467
CalculateOutputShape(const BHWC & input,const MaxUnpooling2DAttributes & attr)468 BHWC CalculateOutputShape(const BHWC& input,
469 const MaxUnpooling2DAttributes& attr) {
470 return BHWC(input.b,
471 input.h * attr.strides.h - attr.padding.prepended.h -
472 attr.padding.appended.h,
473 input.w * attr.strides.w - attr.padding.prepended.w -
474 attr.padding.appended.w,
475 input.c);
476 }
477
CalculateOutputShape(const BHWDC & input,const MaxUnpooling3DAttributes & attr)478 BHWDC CalculateOutputShape(const BHWDC& input,
479 const MaxUnpooling3DAttributes& attr) {
480 return BHWDC(input.b,
481 input.h * attr.strides.h - attr.padding.prepended.h -
482 attr.padding.appended.h,
483 input.w * attr.strides.w - attr.padding.prepended.w -
484 attr.padding.appended.w,
485 input.d * attr.strides.d - attr.padding.prepended.d -
486 attr.padding.appended.d,
487 input.c);
488 }
489
CalculateOutputShape(const BHWC & input,const Pooling2DAttributes & attr)490 BHWC CalculateOutputShape(const BHWC& input, const Pooling2DAttributes& attr) {
491 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
492 CalculateOutput<Axis::WIDTH>(input, attr), input.c);
493 }
494
CalculateOutputShape(const BHWDC & input,const Pooling3DAttributes & attr)495 BHWDC CalculateOutputShape(const BHWDC& input,
496 const Pooling3DAttributes& attr) {
497 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
498 CalculateOutput<Axis::WIDTH>(input, attr),
499 CalculateOutput<Axis::DEPTH>(input, attr), input.c);
500 }
501
CalculateOutputShape(const BHWC & input,const Convolution2DAttributes & attr)502 BHWC CalculateOutputShape(const BHWC& input,
503 const Convolution2DAttributes& attr) {
504 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
505 CalculateOutput<Axis::WIDTH>(input, attr),
506 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
507 }
508
CalculateOutputShape(const BHWDC & input,const Convolution3DAttributes & attr)509 BHWDC CalculateOutputShape(const BHWDC& input,
510 const Convolution3DAttributes& attr) {
511 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
512 CalculateOutput<Axis::WIDTH>(input, attr),
513 CalculateOutput<Axis::DEPTH>(input, attr),
514 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
515 }
516
CalculateOutputShape(const BHWC & input,const ConvolutionTransposedAttributes & attr)517 BHWC CalculateOutputShape(const BHWC& input,
518 const ConvolutionTransposedAttributes& attr) {
519 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
520 CalculateOutput<Axis::WIDTH>(input, attr),
521 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
522 }
523
CalculateOutputShape(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)524 BHWDC CalculateOutputShape(const BHWDC& input,
525 const ConvolutionTransposed3DAttributes& attr) {
526 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
527 CalculateOutput<Axis::WIDTH>(input, attr),
528 CalculateOutput<Axis::DEPTH>(input, attr),
529 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>());
530 }
531
CalculateOutputShape(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)532 BHWC CalculateOutputShape(const BHWC& input,
533 const DepthwiseConvolution2DAttributes& attr) {
534 return BHWC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
535 CalculateOutput<Axis::WIDTH>(input, attr),
536 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
537 attr.weights.shape.get<Axis::INPUT_CHANNELS>());
538 }
539
CalculateOutputShape(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)540 BHWDC CalculateOutputShape(const BHWDC& input,
541 const DepthwiseConvolution3DAttributes& attr) {
542 return BHWDC(input.b, CalculateOutput<Axis::HEIGHT>(input, attr),
543 CalculateOutput<Axis::WIDTH>(input, attr),
544 CalculateOutput<Axis::DEPTH>(input, attr),
545 attr.weights.shape.get<Axis::OUTPUT_CHANNELS>() *
546 attr.weights.shape.get<Axis::INPUT_CHANNELS>());
547 }
548
CalculateOutputShape(const BHWC & input,const SliceAttributes & attr)549 BHWC CalculateOutputShape(const BHWC& input, const SliceAttributes& attr) {
550 return BHWC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
551 StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
552 StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
553 StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
554 }
555
CalculateOutputShape(const BHWDC & input,const Slice3DAttributes & attr)556 BHWDC CalculateOutputShape(const BHWDC& input, const Slice3DAttributes& attr) {
557 return BHWDC(StridedSize(attr.ends.b - attr.starts.b, attr.strides.b),
558 StridedSize(attr.ends.h - attr.starts.h, attr.strides.h),
559 StridedSize(attr.ends.w - attr.starts.w, attr.strides.w),
560 StridedSize(attr.ends.d - attr.starts.d, attr.strides.d),
561 StridedSize(attr.ends.c - attr.starts.c, attr.strides.c));
562 }
563
CalculateOutputShape(const BHWC & input,const PadAttributes & attr)564 BHWC CalculateOutputShape(const BHWC& input, const PadAttributes& attr) {
565 return BHWC(attr.appended.b + attr.prepended.b + input.b,
566 attr.appended.h + attr.prepended.h + input.h,
567 attr.appended.w + attr.prepended.w + input.w,
568 attr.appended.c + attr.prepended.c + input.c);
569 }
570
CalculateOutputShape(const BHWDC & input,const Pad3DAttributes & attr)571 BHWDC CalculateOutputShape(const BHWDC& input, const Pad3DAttributes& attr) {
572 return BHWDC(attr.appended.b + attr.prepended.b + input.b,
573 attr.appended.h + attr.prepended.h + input.h,
574 attr.appended.w + attr.prepended.w + input.w,
575 attr.appended.d + attr.prepended.d + input.d,
576 attr.appended.c + attr.prepended.c + input.c);
577 }
578
CalculateOutputShape(const BHWC & input,const FullyConnectedAttributes & attr)579 BHWC CalculateOutputShape(const BHWC& input,
580 const FullyConnectedAttributes& attr) {
581 return BHWC(input.b, 1, 1, attr.weights.shape.o);
582 }
583
CalculateOutputShape(const BHWC & input,const MeanAttributes & attr)584 BHWC CalculateOutputShape(const BHWC& input, const MeanAttributes& attr) {
585 const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
586 const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
587 const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
588 const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
589 return BHWC(b, h, w, c);
590 }
591
CalculateOutputShape(const BHWDC & input,const MeanAttributes & attr)592 BHWDC CalculateOutputShape(const BHWDC& input, const MeanAttributes& attr) {
593 const int b = attr.dims.find(Axis::BATCH) == attr.dims.end() ? input.b : 1;
594 const int h = attr.dims.find(Axis::HEIGHT) == attr.dims.end() ? input.h : 1;
595 const int w = attr.dims.find(Axis::WIDTH) == attr.dims.end() ? input.w : 1;
596 const int d = attr.dims.find(Axis::DEPTH) == attr.dims.end() ? input.d : 1;
597 const int c = attr.dims.find(Axis::CHANNELS) == attr.dims.end() ? input.c : 1;
598 return BHWDC(b, h, w, d, c);
599 }
600
CalculateOutputShape(const std::vector<BHWC> & input,const ConcatAttributes & attr,BHWC * output_shape)601 absl::Status CalculateOutputShape(const std::vector<BHWC>& input,
602 const ConcatAttributes& attr,
603 BHWC* output_shape) {
604 BHWC new_shape = input[0];
605 switch (attr.axis) {
606 case Axis::CHANNELS:
607 for (int i = 1; i < input.size(); i++) {
608 if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
609 input[i].b != new_shape.b) {
610 return absl::InvalidArgumentError(
611 "Height, Width and Batch must be the same when concatenating "
612 "by channels axis");
613 }
614 new_shape.c += input[i].c;
615 }
616 break;
617 case Axis::HEIGHT:
618 for (int i = 1; i < input.size(); i++) {
619 if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
620 input[i].b != new_shape.b) {
621 return absl::InvalidArgumentError(
622 "Channels, Width and Batch must be the same when concatenating "
623 "by height axis");
624 }
625 new_shape.h += input[i].h;
626 }
627 break;
628 case Axis::WIDTH:
629 for (int i = 1; i < input.size(); i++) {
630 if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
631 input[i].b != new_shape.b) {
632 return absl::InvalidArgumentError(
633 "Height, Channels and Batch must be the same when concatenating "
634 "by width axis");
635 }
636 new_shape.w += input[i].w;
637 }
638 break;
639 case Axis::BATCH:
640 for (int i = 1; i < input.size(); i++) {
641 if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
642 input[i].w != new_shape.w) {
643 return absl::InvalidArgumentError(
644 "Width, Height and Channels must be the same when concatenating "
645 "by batch axis");
646 }
647 new_shape.b += input[i].b;
648 }
649 break;
650 default:
651 return absl::InvalidArgumentError("Invalid axis");
652 break;
653 }
654 *output_shape = new_shape;
655 return absl::OkStatus();
656 }
657
CalculateOutputShape(const std::vector<BHWDC> & input,const ConcatAttributes & attr,BHWDC * output_shape)658 absl::Status CalculateOutputShape(const std::vector<BHWDC>& input,
659 const ConcatAttributes& attr,
660 BHWDC* output_shape) {
661 BHWDC new_shape = input[0];
662 switch (attr.axis) {
663 case Axis::CHANNELS:
664 for (int i = 1; i < input.size(); ++i) {
665 if (input[i].h != new_shape.h || input[i].w != new_shape.w ||
666 input[i].d != new_shape.d || input[i].b != new_shape.b) {
667 return absl::InvalidArgumentError(
668 "Height, Width, Batch and Depth must be the same when "
669 "concatenating "
670 "by channels axis");
671 }
672 new_shape.c += input[i].c;
673 }
674 break;
675 case Axis::HEIGHT:
676 for (int i = 1; i < input.size(); ++i) {
677 if (input[i].w != new_shape.w || input[i].c != new_shape.c ||
678 input[i].d != new_shape.d || input[i].b != new_shape.b) {
679 return absl::InvalidArgumentError(
680 "Width, Depth, Batch and Channels must be the same when "
681 "concatenating "
682 "by height axis");
683 }
684 new_shape.h += input[i].h;
685 }
686 break;
687 case Axis::WIDTH:
688 for (int i = 1; i < input.size(); ++i) {
689 if (input[i].h != new_shape.h || input[i].c != new_shape.c ||
690 input[i].d != new_shape.d || input[i].b != new_shape.b) {
691 return absl::InvalidArgumentError(
692 "Height, Depth, Batch and Channels must be the same when "
693 "concatenating "
694 "by width axis");
695 }
696 new_shape.w += input[i].w;
697 }
698 break;
699 case Axis::DEPTH:
700 for (int i = 1; i < input.size(); ++i) {
701 if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
702 input[i].c != new_shape.c || input[i].b != new_shape.b) {
703 return absl::InvalidArgumentError(
704 "Width, Height, Batch and Channels must be the same when "
705 "concatenating "
706 "by depth axis");
707 }
708 new_shape.d += input[i].d;
709 }
710 break;
711 case Axis::BATCH:
712 for (int i = 1; i < input.size(); ++i) {
713 if (input[i].w != new_shape.w || input[i].h != new_shape.h ||
714 input[i].c != new_shape.c || input[i].d != new_shape.d) {
715 return absl::InvalidArgumentError(
716 "Width, Height, Depth and Channels must be the same when "
717 "concatenating "
718 "by batch axis");
719 }
720 new_shape.b += input[i].b;
721 }
722 break;
723 default:
724 return absl::InvalidArgumentError("Invalid axis");
725 }
726 *output_shape = new_shape;
727 return absl::OkStatus();
728 }
729
CalculateSamePadding(const BHWC & input,const Convolution2DAttributes & attr)730 Padding2D CalculateSamePadding(const BHWC& input,
731 const Convolution2DAttributes& attr) {
732 return MakeSamePadding(input, attr);
733 }
734
CalculateSamePadding(const BHWDC & input,const Convolution3DAttributes & attr)735 Padding3D CalculateSamePadding(const BHWDC& input,
736 const Convolution3DAttributes& attr) {
737 return MakeSamePadding(input, attr);
738 }
739
CalculateSamePadding(const BHWC & input,const ConvolutionTransposedAttributes & attr)740 Padding2D CalculateSamePadding(const BHWC& input,
741 const ConvolutionTransposedAttributes& attr) {
742 return MakeSamePadding(input, attr);
743 }
744
CalculateSamePadding(const BHWDC & input,const ConvolutionTransposed3DAttributes & attr)745 Padding3D CalculateSamePadding(const BHWDC& input,
746 const ConvolutionTransposed3DAttributes& attr) {
747 return MakeSamePadding(input, attr);
748 }
749
CalculateSamePadding(const BHWC & input,const DepthwiseConvolution2DAttributes & attr)750 Padding2D CalculateSamePadding(const BHWC& input,
751 const DepthwiseConvolution2DAttributes& attr) {
752 return MakeSamePadding(input, attr);
753 }
754
CalculateSamePadding(const BHWDC & input,const DepthwiseConvolution3DAttributes & attr)755 Padding3D CalculateSamePadding(const BHWDC& input,
756 const DepthwiseConvolution3DAttributes& attr) {
757 return MakeSamePadding(input, attr);
758 }
759
CalculateSamePadding(const BHWC & input,const Pooling2DAttributes & attr)760 Padding2D CalculateSamePadding(const BHWC& input,
761 const Pooling2DAttributes& attr) {
762 return MakeSamePadding(input, attr);
763 }
764
CalculateSamePadding(const BHWDC & input,const Pooling3DAttributes & attr)765 Padding3D CalculateSamePadding(const BHWDC& input,
766 const Pooling3DAttributes& attr) {
767 return MakeSamePadding(input, attr);
768 }
769
CalculateSamePadding(const BHWC & input,const MaxUnpooling2DAttributes & attr)770 Padding2D CalculateSamePadding(const BHWC& input,
771 const MaxUnpooling2DAttributes& attr) {
772 return MakeSamePadding(input, attr);
773 }
774
CalculateSamePadding(const BHWDC & input,const MaxUnpooling3DAttributes & attr)775 Padding3D CalculateSamePadding(const BHWDC& input,
776 const MaxUnpooling3DAttributes& attr) {
777 return MakeSamePadding(input, attr);
778 }
779
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize2DAttributes & attr)780 float CalculateResizeScale(int32_t input_size, int32_t output_size,
781 const Resize2DAttributes& attr) {
782 return attr.align_corners && input_size > 1 && output_size > 1
783 ? static_cast<float>(input_size - 1) / (output_size - 1)
784 : static_cast<float>(input_size) / output_size;
785 }
786
CalculateResizeScale(int32_t input_size,int32_t output_size,const Resize3DAttributes & attr)787 float CalculateResizeScale(int32_t input_size, int32_t output_size,
788 const Resize3DAttributes& attr) {
789 return attr.align_corners && input_size > 1 && output_size > 1
790 ? static_cast<float>(input_size - 1) / (output_size - 1)
791 : static_cast<float>(input_size) / output_size;
792 }
793
CalculateOutputShape(const BHWC & input,const Resize2DAttributes & attr)794 BHWC CalculateOutputShape(const BHWC& input, const Resize2DAttributes& attr) {
795 return BHWC(input.b, attr.new_shape.h, attr.new_shape.w, input.c);
796 }
797
CalculateOutputShape(const BHWDC & input,const Resize3DAttributes & attr)798 BHWDC CalculateOutputShape(const BHWDC& input, const Resize3DAttributes& attr) {
799 return BHWDC(input.b, attr.new_shape.h, attr.new_shape.w, attr.new_shape.d,
800 input.c);
801 }
802
CalculateOutputShape(const BHWC & input,const TransposeAttributes & attr)803 BHWC CalculateOutputShape(const BHWC& input, const TransposeAttributes& attr) {
804 return BHWC(input.get(attr.perm.b), input.get(attr.perm.h),
805 input.get(attr.perm.w), input.get(attr.perm.c));
806 }
807
CalculateOutputShape(const BHWDC & input,const Transpose3DAttributes & attr)808 BHWDC CalculateOutputShape(const BHWDC& input,
809 const Transpose3DAttributes& attr) {
810 return BHWDC(input.get(attr.perm.b), input.get(attr.perm.h),
811 input.get(attr.perm.w), input.get(attr.perm.d),
812 input.get(attr.perm.c));
813 }
814
815 } // namespace gpu
816 } // namespace tflite
817