1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #pragma once
18 
19 #include <aidl/android/hardware/neuralnetworks/BnPreparedModel.h>
20 #include <android-base/logging.h>
21 
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 
26 #include "ShimDevice.h"
27 #include "SupportLibrary.h"
28 #include "SupportLibraryWrapper.h"
29 
30 namespace aidl::android::hardware::neuralnetworks {
31 
32 class ShimPreparedModel : public BnPreparedModel {
33    public:
ShimPreparedModel(std::shared_ptr<const NnApiSupportLibrary> nnapi,std::shared_ptr<ShimBufferTracker> bufferTracker,::android::nn::sl_wrapper::Compilation compilation,std::vector<::android::nn::sl_wrapper::Model> mainAndReferencedModels,std::vector<std::unique_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,std::vector<uint8_t> copiedOperandValues)34     ShimPreparedModel(std::shared_ptr<const NnApiSupportLibrary> nnapi,
35                       std::shared_ptr<ShimBufferTracker> bufferTracker,
36                       ::android::nn::sl_wrapper::Compilation compilation,
37                       std::vector<::android::nn::sl_wrapper::Model> mainAndReferencedModels,
38                       std::vector<std::unique_ptr<::android::nn::sl_wrapper::Memory>> memoryPools,
39                       std::vector<uint8_t> copiedOperandValues)
40         : mNnapi(nnapi),
41           mBufferTracker(bufferTracker),
42           mCompilation(std::move(compilation)),
43           mMainAndReferencedModels(std::move(mainAndReferencedModels)),
44           mMemoryPools(std::move(memoryPools)),
45           mCopiedOperandValues(std::move(copiedOperandValues)) {
46         CHECK(mMainAndReferencedModels.size() > 0);
47     };
48 
49     ::ndk::ScopedAStatus executeSynchronously(const Request& request, bool measureTiming,
50                                               int64_t deadlineNs, int64_t loopTimeoutDurationNs,
51                                               ExecutionResult* executionResults) override;
52     ::ndk::ScopedAStatus executeFenced(const Request& request,
53                                        const std::vector<::ndk::ScopedFileDescriptor>& waitFor,
54                                        bool measureTiming, int64_t deadlineNs,
55                                        int64_t loopTimeoutDurationNs, int64_t durationNs,
56                                        FencedExecutionResult* fencedExecutionResult) override;
57     ::ndk::ScopedAStatus executeSynchronouslyWithConfig(const Request& request,
58                                                         const ExecutionConfig& config,
59                                                         int64_t deadlineNs,
60                                                         ExecutionResult* executionResult) override;
61     ::ndk::ScopedAStatus executeFencedWithConfig(
62             const Request& request, const std::vector<ndk::ScopedFileDescriptor>& waitFor,
63             const ExecutionConfig& config, int64_t deadlineNs, int64_t durationNs,
64             FencedExecutionResult* executionResult) override;
65 
66     ndk::ScopedAStatus configureExecutionBurst(std::shared_ptr<IBurst>* burst) override;
67     ndk::ScopedAStatus createReusableExecution(const Request& request,
68                                                const ExecutionConfig& config,
69                                                std::shared_ptr<IExecution>* execution) override;
70 
getCompilation()71     const ::android::nn::sl_wrapper::Compilation& getCompilation() const { return mCompilation; }
getMainModel()72     const ::android::nn::sl_wrapper::Model& getMainModel() const {
73         return mMainAndReferencedModels[0];
74     }
75 
76    private:
77     ErrorStatus parseInputs(
78             const Request& request, bool measure, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
79             ::android::nn::sl_wrapper::Execution* execution,
80             std::vector<std::shared_ptr<::android::nn::sl_wrapper::Memory>>* requestMemoryPools,
81             const std::vector<TokenValuePair>& executionHints,
82             const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix);
83 
84     ::ndk::ScopedAStatus executeSynchronouslyCommon(
85             const Request& request, bool measureTiming, int64_t deadlineNs,
86             int64_t loopTimeoutDurationNs, const std::vector<TokenValuePair>& executionHints,
87             const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
88             ExecutionResult* executionResult);
89     ::ndk::ScopedAStatus executeFencedCommon(
90             const Request& request, const std::vector<::ndk::ScopedFileDescriptor>& waitFor,
91             bool measureTiming, int64_t deadlineNs, int64_t loopTimeoutDurationNs,
92             int64_t durationNs, const std::vector<TokenValuePair>& executionHints,
93             const std::vector<ExtensionNameAndPrefix>& extensionNameToPrefix,
94             FencedExecutionResult* fencedExecutionResult);
95 
96     std::shared_ptr<const NnApiSupportLibrary> mNnapi;
97     std::shared_ptr<ShimBufferTracker> mBufferTracker;
98 
99     ::android::nn::sl_wrapper::Compilation mCompilation;
100     std::vector<::android::nn::sl_wrapper::Model> mMainAndReferencedModels;
101     std::vector<std::unique_ptr<::android::nn::sl_wrapper::Memory>> mMemoryPools;
102     std::vector<uint8_t> mCopiedOperandValues;
103 };
104 
105 }  // namespace aidl::android::hardware::neuralnetworks
106