1 /*
2 * Copyright (C) 2017 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #define LOG_TAG "OperationsUtils"
18
19 #include "OperationsUtils.h"
20
21 #include <android-base/logging.h>
22
23 #include <vector>
24
25 #include "nnapi/Validation.h"
26
27 namespace android::nn {
28
SameShape(const Shape & in1,const Shape & in2)29 bool SameShape(const Shape& in1, const Shape& in2) {
30 if (in1.type != in2.type || in1.dimensions.size() != in2.dimensions.size()) {
31 return false;
32 }
33 for (size_t i = 0; i < in1.dimensions.size(); i++) {
34 if (in1.dimensions[i] != in2.dimensions[i]) {
35 return false;
36 }
37 }
38 return true;
39 }
40
SetShape(const Shape & in,Shape * out)41 bool SetShape(const Shape& in, Shape* out) {
42 if (in.type != out->type) {
43 return false;
44 }
45 out->dimensions = in.dimensions;
46 return true;
47 }
48
getNumberOfElements(const Shape & shape)49 uint32_t getNumberOfElements(const Shape& shape) {
50 uint32_t count = 1;
51 for (size_t i = 0; i < shape.dimensions.size(); i++) {
52 count *= shape.dimensions[i];
53 }
54 return count;
55 }
56
getNumberOfElements(const Shape & shape,size_t firstAxisInclusive,size_t lastAxisExclusive)57 uint32_t getNumberOfElements(const Shape& shape, size_t firstAxisInclusive,
58 size_t lastAxisExclusive) {
59 CHECK_LE(0u, firstAxisInclusive);
60 CHECK_LE(firstAxisInclusive, lastAxisExclusive);
61 CHECK_LE(lastAxisExclusive, shape.dimensions.size());
62 uint32_t count = 1;
63 for (size_t i = firstAxisInclusive; i < lastAxisExclusive; i++) {
64 count *= shape.dimensions[i];
65 }
66 return count;
67 }
68
getNumberOfDimensions(const Shape & shape)69 uint32_t getNumberOfDimensions(const Shape& shape) {
70 return shape.dimensions.size();
71 }
72
getSizeOfDimension(const Shape & shape,uint32_t dimensionIdx)73 uint32_t getSizeOfDimension(const Shape& shape, uint32_t dimensionIdx) {
74 CHECK(0 <= dimensionIdx && dimensionIdx < shape.dimensions.size());
75 return shape.dimensions[dimensionIdx];
76 }
77
hasKnownRank(const Shape & shape)78 uint32_t hasKnownRank(const Shape& shape) {
79 return !shape.dimensions.empty();
80 }
81
82 } // namespace android::nn
83