1 /*
2  * Copyright (C) 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_BUFFER_TRACKER_H
18 #define ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_BUFFER_TRACKER_H
19 
20 #include <android-base/macros.h>
21 #include <android-base/thread_annotations.h>
22 
23 #include <map>
24 #include <memory>
25 #include <mutex>
26 #include <set>
27 #include <stack>
28 #include <utility>
29 #include <vector>
30 
31 #include "CpuExecutor.h"
32 #include "LegacyUtils.h"
33 #include "nnapi/Types.h"
34 #include "nnapi/Validation.h"
35 
36 namespace android::nn {
37 
38 // This class manages a CPU buffer allocated on heap and provides validation methods.
39 class ManagedBuffer {
40    public:
41     static std::shared_ptr<ManagedBuffer> create(uint32_t size, std::set<PreparedModelRole> roles,
42                                                  const Operand& operand);
43 
44     // Prefer ManagedBuffer::create.
45     ManagedBuffer(std::unique_ptr<uint8_t[]> buffer, uint32_t size,
46                   std::set<PreparedModelRole> roles, const Operand& operand);
47 
createRunTimePoolInfo()48     RunTimePoolInfo createRunTimePoolInfo() const {
49         return RunTimePoolInfo::createFromExistingBuffer(kBuffer.get(), kSize);
50     }
51 
52     // "poolIndex" is the index of this buffer in the request.pools.
53     ErrorStatus validateRequest(uint32_t poolIndex, const Request& request,
54                                 const IPreparedModel* preparedModel) const;
55 
56     // "size" is the byte size of the Memory provided to the copyFrom or copyTo method.
57     ErrorStatus validateCopyFrom(const Dimensions& dimensions, uint32_t size) const;
58     ErrorStatus validateCopyTo(uint32_t size) const;
59 
60     bool updateDimensions(const Dimensions& dimensions);
61     void setInitialized(bool initialized);
62 
63    private:
64     mutable std::mutex mMutex;
65     const std::unique_ptr<uint8_t[]> kBuffer;
66     const uint32_t kSize;
67     const std::set<PreparedModelRole> kRoles;
68     const OperandType kOperandType;
69     const Dimensions kInitialDimensions;
70     Dimensions mUpdatedDimensions GUARDED_BY(mMutex);
71     bool mInitialized GUARDED_BY(mMutex) = false;
72 };
73 
74 // Keep track of all ManagedBuffers and assign each with a unique token.
75 class BufferTracker : public std::enable_shared_from_this<BufferTracker> {
76     DISALLOW_COPY_AND_ASSIGN(BufferTracker);
77 
78    public:
79     // A RAII class to help manage the lifetime of the token.
80     // It is only supposed to be constructed in BufferTracker::add.
81     class Token {
82         DISALLOW_COPY_AND_ASSIGN(Token);
83 
84        public:
Token(Request::MemoryDomainToken token,std::shared_ptr<BufferTracker> tracker)85         Token(Request::MemoryDomainToken token, std::shared_ptr<BufferTracker> tracker)
86             : kToken(token), kBufferTracker(std::move(tracker)) {}
~Token()87         ~Token() { kBufferTracker->free(kToken); }
get()88         Request::MemoryDomainToken get() const { return kToken; }
89 
90        private:
91         const Request::MemoryDomainToken kToken;
92         const std::shared_ptr<BufferTracker> kBufferTracker;
93     };
94 
95     // The factory of BufferTracker. This ensures that the BufferTracker is always managed by a
96     // shared_ptr.
create()97     static std::shared_ptr<BufferTracker> create() { return std::make_shared<BufferTracker>(); }
98 
99     // Prefer BufferTracker::create.
100     BufferTracker();
101 
102     std::unique_ptr<Token> add(std::shared_ptr<ManagedBuffer> buffer);
103     std::shared_ptr<ManagedBuffer> get(Request::MemoryDomainToken token) const;
104 
105    private:
106     void free(Request::MemoryDomainToken token);
107 
108     mutable std::mutex mMutex;
109     std::stack<Request::MemoryDomainToken, std::vector<Request::MemoryDomainToken>> mFreeTokens
110             GUARDED_BY(mMutex);
111 
112     // Since the tokens are allocated in a non-sparse way, we use a vector to represent the mapping.
113     // The index of the vector is the token. When the token gets freed, the corresponding entry is
114     // set to nullptr. mTokenToBuffers[0] is always set to nullptr because 0 is an invalid token.
115     std::vector<std::shared_ptr<ManagedBuffer>> mTokenToBuffers GUARDED_BY(mMutex);
116 };
117 
118 }  // namespace android::nn
119 
120 #endif  // ANDROID_PACKAGES_MODULES_NEURALNETWORKS_COMMON_BUFFER_TRACKER_H
121