1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define LOG_TAG "OperationsUtils"
18 
19 #include "OperationsUtils.h"
20 
21 #include <android-base/logging.h>
22 
23 #include <vector>
24 
25 #include "nnapi/Validation.h"
26 
27 namespace android::nn {
28 
SameShape(const Shape & in1,const Shape & in2)29 bool SameShape(const Shape& in1, const Shape& in2) {
30     if (in1.type != in2.type || in1.dimensions.size() != in2.dimensions.size()) {
31         return false;
32     }
33     for (size_t i = 0; i < in1.dimensions.size(); i++) {
34         if (in1.dimensions[i] != in2.dimensions[i]) {
35             return false;
36         }
37     }
38     return true;
39 }
40 
SetShape(const Shape & in,Shape * out)41 bool SetShape(const Shape& in, Shape* out) {
42     if (in.type != out->type) {
43         return false;
44     }
45     out->dimensions = in.dimensions;
46     return true;
47 }
48 
getNumberOfElements(const Shape & shape)49 uint32_t getNumberOfElements(const Shape& shape) {
50     uint32_t count = 1;
51     for (size_t i = 0; i < shape.dimensions.size(); i++) {
52         count *= shape.dimensions[i];
53     }
54     return count;
55 }
56 
getNumberOfElements(const Shape & shape,size_t firstAxisInclusive,size_t lastAxisExclusive)57 uint32_t getNumberOfElements(const Shape& shape, size_t firstAxisInclusive,
58                              size_t lastAxisExclusive) {
59     CHECK_LE(0u, firstAxisInclusive);
60     CHECK_LE(firstAxisInclusive, lastAxisExclusive);
61     CHECK_LE(lastAxisExclusive, shape.dimensions.size());
62     uint32_t count = 1;
63     for (size_t i = firstAxisInclusive; i < lastAxisExclusive; i++) {
64         count *= shape.dimensions[i];
65     }
66     return count;
67 }
68 
getNumberOfDimensions(const Shape & shape)69 uint32_t getNumberOfDimensions(const Shape& shape) {
70     return shape.dimensions.size();
71 }
72 
getSizeOfDimension(const Shape & shape,uint32_t dimensionIdx)73 uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx) {
74     CHECK(0 <= dimensionIdx && dimensionIdx < shape.dimensions.size());
75     return shape.dimensions[dimensionIdx];
76 }
77 
hasKnownRank(const Shape & shape)78 uint32_t hasKnownRank(const Shape& shape) {
79     return !shape.dimensions.empty();
80 }
81 
82 }  // namespace android::nn
83