1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/coreml/builders/util.h"
16 
17 #include <vector>
18 
19 #include "tensorflow/lite/delegates/coreml/builders/op_validator.h"
20 #include "tensorflow/lite/kernels/kernel_util.h"
21 
22 namespace tflite {
23 namespace delegates {
24 namespace coreml {
25 namespace {
Get4DShape(const TfLiteTensor * tensor,std::vector<int> * shape)26 void Get4DShape(const TfLiteTensor* tensor, std::vector<int>* shape) {
27   const int rank = tensor->dims->size;
28   shape->resize(4);
29   for (int i = 0; i < 4 - rank; i++) {
30     (*shape)[i] = 1;
31   }
32   for (int i = 4 - rank; i < 4; ++i) {
33     (*shape)[i] = tensor->dims->data[i - (4 - rank)];
34   }
35 }
36 }  // namespace
37 
38 // Determines if two tensor shapes are broadcastable. See comment of
39 // IsBinaryOpSupported for more info.
IsBroadcastable(const TfLiteTensor * input_0,const TfLiteTensor * input_1)40 bool IsBroadcastable(const TfLiteTensor* input_0, const TfLiteTensor* input_1) {
41   std::vector<int> shape_0;
42   std::vector<int> shape_1;
43   Get4DShape(input_0, &shape_0);
44   Get4DShape(input_1, &shape_1);
45   const int B_0 = shape_0[0];
46   const int B_1 = shape_1[0];
47   const int H_0 = shape_0[1];
48   const int H_1 = shape_1[1];
49   const int W_0 = shape_0[2];
50   const int W_1 = shape_1[2];
51   const int C_0 = shape_0[3];
52   const int C_1 = shape_1[3];
53 
54   // TFL tensor has [B, H, W, C] format.
55   // comparing B: shape[0], (H, W): (shape[1], shape[2]), C: shape[3].
56 
57   // When B is different, it's not supported unless
58   // one of the tensor is size 1 constant tensor.
59   if (B_0 != B_1) {
60     if (!((IsConstantTensor(input_0) && NumElements(input_0) == 1) ||
61           (IsConstantTensor(input_1) && NumElements(input_1) == 1)))
62       return false;
63   }
64 
65   // When (H, W) are different, one of the (H, W) should be (1, 1).
66   if (H_0 != H_1 || W_0 != W_1) {
67     if (!((H_0 == 1 && W_0 == 1) || (H_1 == 1 && W_1 == 1))) {
68       return false;
69     }
70   }
71 
72   // When C is different, one of the C should be 1.
73   if (C_0 != C_1) {
74     if (C_0 != 1 && C_1 != 1) return false;
75   }
76   return true;
77 }
78 
IsBinaryOpSupported(const TfLiteRegistration * registration,const TfLiteNode * node,TfLiteContext * context)79 bool IsBinaryOpSupported(const TfLiteRegistration* registration,
80                          const TfLiteNode* node, TfLiteContext* context) {
81   return IsBroadcastable(GetInput(context, node, 0),
82                          GetInput(context, node, 1));
83 }
84 
85 }  // namespace coreml
86 }  // namespace delegates
87 }  // namespace tflite
88