1 /*
2  * Copyright (C) 2018 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "CpuOperationUtils.h"
17 #include "IndexedShapeWrapper.h"
18 #include "OperationResolver.h"
19 
20 #include <vector>
21 
22 namespace android {
23 namespace nn {
24 namespace slice {
25 
26 constexpr char kOperationName[] = "SLICE";
27 
28 constexpr uint32_t kNumInputs = 3;
29 constexpr uint32_t kInputTensor = 0;
30 constexpr uint32_t kBeginTensor = 1;
31 constexpr uint32_t kSizeTensor = 2;
32 
33 constexpr uint32_t kNumOutputs = 1;
34 constexpr uint32_t kOutputTensor = 0;
35 
36 namespace {
37 
38 template <typename T>
addVectors(const std::vector<T> & a,const std::vector<T> & b,std::vector<T> * res)39 void addVectors(const std::vector<T>& a, const std::vector<T>& b, std::vector<T>* res) {
40     for (int i = 0; i < res->size(); ++i) {
41         res->at(i) = a[i] + b[i];
42     }
43 }
44 
45 template <typename T>
evalGeneric(const T * inputData,const Shape & inputShape,const int32_t * beginData,const Shape & beginShape,const int32_t * sizeData,const Shape & sizeShape,T * outputData,const Shape & outputShape)46 bool evalGeneric(const T* inputData, const Shape& inputShape, const int32_t* beginData,
47                  const Shape& beginShape, const int32_t* sizeData, const Shape& sizeShape,
48                  T* outputData, const Shape& outputShape) {
49     const int outputSize = getNumberOfElements(outputShape);
50     const IndexedShapeWrapper indexedOutput = IndexedShapeWrapper(outputShape);
51     const IndexedShapeWrapper indexedInput = IndexedShapeWrapper(inputShape);
52     std::vector<uint32_t> outputIndex(getNumberOfDimensions(outputShape), 0);
53     std::vector<uint32_t> beginIndex(getSizeOfDimension(beginShape, 0));
54     std::vector<uint32_t> inputIndex(getNumberOfDimensions(inputShape));
55 
56     for (int i = 0; i < beginIndex.size(); ++i) {
57         beginIndex[i] = static_cast<uint32_t>(beginData[i]);
58     }
59 
60     bool lastIndex = false;
61     uint32_t outputOffset;
62     uint32_t inputOffset;
63 
64     do {
65         addVectors(outputIndex, beginIndex, &inputIndex);
66 
67         NN_RET_CHECK(indexedOutput.indexToFlatIndex(outputIndex, &outputOffset));
68         NN_RET_CHECK(indexedInput.indexToFlatIndex(inputIndex, &inputOffset));
69 
70         outputData[outputOffset] = inputData[inputOffset];
71         NN_RET_CHECK(indexedOutput.nextIndexInplace(&outputIndex, &lastIndex));
72     } while (!lastIndex);
73     return true;
74 }
75 
76 }  // namespace
77 
validate(const IOperationValidationContext * context)78 bool validate(const IOperationValidationContext* context) {
79     NN_RET_CHECK_EQ(context->getNumInputs(), kNumInputs);
80     NN_RET_CHECK_EQ(context->getNumOutputs(), kNumOutputs);
81 
82     const OperandType inputType = context->getInputType(kInputTensor);
83     NN_RET_CHECK(
84             inputType == OperandType::TENSOR_FLOAT16 || inputType == OperandType::TENSOR_FLOAT32 ||
85             inputType == OperandType::TENSOR_INT32 || inputType == OperandType::TENSOR_QUANT8_ASYMM)
86             << "Unsupported tensor type for operation " << kOperationName;
87     NN_RET_CHECK(validateHalVersion(context, HalVersion::V1_2));
88     return validateInputTypes(context,
89                               {inputType, OperandType::TENSOR_INT32, OperandType::TENSOR_INT32}) &&
90            validateOutputTypes(context, {inputType});
91 }
92 
prepare(IOperationExecutionContext * context)93 bool prepare(IOperationExecutionContext* context) {
94     const Shape& inputShape = context->getInputShape(kInputTensor);
95     const int32_t n_dims = getNumberOfDimensions(inputShape);
96     NN_RET_CHECK(n_dims > 0);
97 
98     const Shape& beginShape = context->getInputShape(kBeginTensor);
99     NN_RET_CHECK_EQ(getNumberOfDimensions(beginShape), 1);
100     NN_RET_CHECK_EQ(getSizeOfDimension(beginShape, 0), n_dims);
101 
102     const Shape& sizeShape = context->getInputShape(kSizeTensor);
103     NN_RET_CHECK_EQ(getNumberOfDimensions(sizeShape), 1);
104     NN_RET_CHECK_EQ(getSizeOfDimension(sizeShape, 0), n_dims);
105 
106     const int32_t* beginData = context->getInputBuffer<int32_t>(kBeginTensor);
107     const int32_t* sizeData = context->getInputBuffer<int32_t>(kSizeTensor);
108 
109     Shape outputShape = context->getOutputShape(kOutputTensor);
110     outputShape.dimensions.resize(n_dims);
111     for (int i = 0; i < n_dims; ++i) {
112         const int32_t sliceBegin = beginData[i];
113         int32_t sliceSize = sizeData[i];
114         if (sliceSize == -1) {
115             sliceSize = getSizeOfDimension(inputShape, i) - sliceBegin;
116         }
117         NN_RET_CHECK_LE(beginData[i], getSizeOfDimension(inputShape, i));
118         NN_RET_CHECK_GE(sliceSize, 0);
119         NN_RET_CHECK_LE(sliceBegin + sliceSize, getSizeOfDimension(inputShape, i));
120         outputShape.dimensions[i] = sliceSize;
121     }
122     return context->setOutputShape(kOutputTensor, outputShape);
123 }
124 
execute(IOperationExecutionContext * context)125 bool execute(IOperationExecutionContext* context) {
126     // Bypass execution in the case of zero-sized input.
127     if (getNumberOfElements(context->getOutputShape(kOutputTensor)) == 0) return true;
128     switch (context->getInputType(kInputTensor)) {
129         case OperandType::TENSOR_FLOAT16:
130             return evalGeneric(context->getInputBuffer<_Float16>(kInputTensor),
131                                context->getInputShape(kInputTensor),
132                                context->getInputBuffer<int32_t>(kBeginTensor),
133                                context->getInputShape(kBeginTensor),
134                                context->getInputBuffer<int32_t>(kSizeTensor),
135                                context->getInputShape(kSizeTensor),
136                                context->getOutputBuffer<_Float16>(kOutputTensor),
137                                context->getOutputShape(kOutputTensor));
138         case OperandType::TENSOR_FLOAT32:
139             return evalGeneric(context->getInputBuffer<float>(kInputTensor),
140                                context->getInputShape(kInputTensor),
141                                context->getInputBuffer<int32_t>(kBeginTensor),
142                                context->getInputShape(kBeginTensor),
143                                context->getInputBuffer<int32_t>(kSizeTensor),
144                                context->getInputShape(kSizeTensor),
145                                context->getOutputBuffer<float>(kOutputTensor),
146                                context->getOutputShape(kOutputTensor));
147         case OperandType::TENSOR_INT32:
148             return evalGeneric(context->getInputBuffer<int32_t>(kInputTensor),
149                                context->getInputShape(kInputTensor),
150                                context->getInputBuffer<int32_t>(kBeginTensor),
151                                context->getInputShape(kBeginTensor),
152                                context->getInputBuffer<int32_t>(kSizeTensor),
153                                context->getInputShape(kSizeTensor),
154                                context->getOutputBuffer<int32_t>(kOutputTensor),
155                                context->getOutputShape(kOutputTensor));
156         case OperandType::TENSOR_QUANT8_ASYMM:
157             return evalGeneric(context->getInputBuffer<uint8_t>(kInputTensor),
158                                context->getInputShape(kInputTensor),
159                                context->getInputBuffer<int32_t>(kBeginTensor),
160                                context->getInputShape(kBeginTensor),
161                                context->getInputBuffer<int32_t>(kSizeTensor),
162                                context->getInputShape(kSizeTensor),
163                                context->getOutputBuffer<uint8_t>(kOutputTensor),
164                                context->getOutputShape(kOutputTensor));
165         default:
166             NN_RET_CHECK_FAIL() << "Unsupported tensor type for operation " << kOperationName;
167     }
168 }
169 
170 }  // namespace slice
171 
172 NN_REGISTER_OPERATION(SLICE, slice::kOperationName, slice::validate, slice::prepare, slice::execute,
173                       .allowZeroSizedInput = true);
174 
175 }  // namespace nn
176 }  // namespace android
177