1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_RUNTIME_MEMORY_H
18 #define ANDROID_ML_NN_RUNTIME_MEMORY_H
19 
20 #include "NeuralNetworks.h"
21 #include "Utils.h"
22 
23 #include <cutils/native_handle.h>
24 #include <sys/mman.h>
25 #include <mutex>
26 #include <unordered_map>
27 #include "vndk/hardware_buffer.h"
28 
29 namespace android {
30 namespace nn {
31 
32 class ExecutionBurstController;
33 class ModelBuilder;
34 
35 // Represents a memory region.
36 class Memory {
37    public:
Memory()38     Memory() {}
39     virtual ~Memory();
40 
41     // Disallow copy semantics to ensure the runtime object can only be freed
42     // once. Copy semantics could be enabled if some sort of reference counting
43     // or deep-copy system for runtime objects is added later.
44     Memory(const Memory&) = delete;
45     Memory& operator=(const Memory&) = delete;
46 
47     // Creates a shared memory object of the size specified in bytes.
48     int create(uint32_t size);
49 
getHidlMemory()50     hardware::hidl_memory getHidlMemory() const { return mHidlMemory; }
51 
52     // Returns a pointer to the underlying memory of this memory object.
53     // The function will fail if the memory is not CPU accessible and nullptr
54     // will be returned.
getPointer(uint8_t ** buffer)55     virtual int getPointer(uint8_t** buffer) const {
56         *buffer = static_cast<uint8_t*>(static_cast<void*>(mMemory->getPointer()));
57         if (*buffer == nullptr) {
58             return ANEURALNETWORKS_BAD_DATA;
59         }
60         return ANEURALNETWORKS_NO_ERROR;
61     }
62 
63     virtual bool validateSize(uint32_t offset, uint32_t length) const;
64 
65     // Unique key representing this memory object.
66     intptr_t getKey() const;
67 
68     // Marks a burst object as currently using this memory. When this
69     // memory object is destroyed, it will automatically free this memory from
70     // the bursts' memory cache.
71     void usedBy(const std::shared_ptr<ExecutionBurstController>& burst) const;
72 
73    protected:
74     // The hidl_memory handle for this shared memory.  We will pass this value when
75     // communicating with the drivers.
76     hardware::hidl_memory mHidlMemory;
77     sp<IMemory> mMemory;
78 
79     mutable std::mutex mMutex;
80     // mUsedBy is essentially a set of burst objects which use this Memory
81     // object. However, std::weak_ptr does not have comparison operations nor a
82     // std::hash implementation. This is because it is either a valid pointer
83     // (non-null) if the shared object is still alive, or it is null if the
84     // object has been freed. To circumvent this, mUsedBy is a map with the raw
85     // pointer as the key and the weak_ptr as the value.
86     mutable std::unordered_map<const ExecutionBurstController*,
87                                std::weak_ptr<ExecutionBurstController>>
88             mUsedBy;
89 };
90 
91 class MemoryFd : public Memory {
92    public:
MemoryFd()93     MemoryFd() {}
94     ~MemoryFd() override;
95 
96     // Disallow copy semantics to ensure the runtime object can only be freed
97     // once. Copy semantics could be enabled if some sort of reference counting
98     // or deep-copy system for runtime objects is added later.
99     MemoryFd(const MemoryFd&) = delete;
100     MemoryFd& operator=(const MemoryFd&) = delete;
101 
102     // Create the native_handle based on input size, prot, and fd.
103     // Existing native_handle will be deleted, and mHidlMemory will wrap
104     // the newly created native_handle.
105     int set(size_t size, int prot, int fd, size_t offset);
106 
107     int getPointer(uint8_t** buffer) const override;
108 
109    private:
110     native_handle_t* mHandle = nullptr;
111     mutable uint8_t* mMapping = nullptr;
112 };
113 
114 // TODO(miaowang): move function definitions to Memory.cpp
115 class MemoryAHWB : public Memory {
116    public:
MemoryAHWB()117     MemoryAHWB() {}
~MemoryAHWB()118     ~MemoryAHWB() override{};
119 
120     // Disallow copy semantics to ensure the runtime object can only be freed
121     // once. Copy semantics could be enabled if some sort of reference counting
122     // or deep-copy system for runtime objects is added later.
123     MemoryAHWB(const MemoryAHWB&) = delete;
124     MemoryAHWB& operator=(const MemoryAHWB&) = delete;
125 
126     // Keep track of the provided AHardwareBuffer handle.
set(const AHardwareBuffer * ahwb)127     int set(const AHardwareBuffer* ahwb) {
128         AHardwareBuffer_describe(ahwb, &mBufferDesc);
129         const native_handle_t* handle = AHardwareBuffer_getNativeHandle(ahwb);
130         mHardwareBuffer = ahwb;
131         if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) {
132             mHidlMemory = hidl_memory("hardware_buffer_blob", handle, mBufferDesc.width);
133         } else {
134             // memory size is not used.
135             mHidlMemory = hidl_memory("hardware_buffer", handle, 0);
136         }
137         return ANEURALNETWORKS_NO_ERROR;
138     };
139 
getPointer(uint8_t ** buffer)140     int getPointer(uint8_t** buffer) const override {
141         *buffer = nullptr;
142         return ANEURALNETWORKS_BAD_DATA;
143     };
144 
145     // validateSize should only be called for blob mode AHardwareBuffer.
146     // Calling it on non-blob mode AHardwareBuffer will result in an error.
147     // TODO(miaowang): consider separate blob and non-blob into different classes.
validateSize(uint32_t offset,uint32_t length)148     bool validateSize(uint32_t offset, uint32_t length) const override {
149         if (mHardwareBuffer == nullptr) {
150             LOG(ERROR) << "MemoryAHWB has not been initialized.";
151             return false;
152         }
153         // validateSize should only be called on BLOB mode buffer.
154         if (mBufferDesc.format == AHARDWAREBUFFER_FORMAT_BLOB) {
155             if (offset + length > mBufferDesc.width) {
156                 LOG(ERROR) << "Request size larger than the memory size.";
157                 return false;
158             } else {
159                 return true;
160             }
161         } else {
162             LOG(ERROR) << "Invalid AHARDWAREBUFFER_FORMAT, must be AHARDWAREBUFFER_FORMAT_BLOB.";
163             return false;
164         }
165     }
166 
167    private:
168     const AHardwareBuffer* mHardwareBuffer = nullptr;
169     AHardwareBuffer_Desc mBufferDesc;
170 };
171 
172 // A utility class to accumulate mulitple Memory objects and assign each
173 // a distinct index number, starting with 0.
174 //
175 // The user of this class is responsible for avoiding concurrent calls
176 // to this class from multiple threads.
177 class MemoryTracker {
178    private:
179     // The vector of Memory pointers we are building.
180     std::vector<const Memory*> mMemories;
181     // A faster way to see if we already have a memory than doing find().
182     std::unordered_map<const Memory*, uint32_t> mKnown;
183 
184    public:
185     // Adds the memory, if it does not already exists.  Returns its index.
186     // The memories should survive the tracker.
187     uint32_t add(const Memory* memory);
188     // Returns the number of memories contained.
size()189     uint32_t size() const { return static_cast<uint32_t>(mKnown.size()); }
190     // Returns the ith memory.
191     const Memory* operator[](size_t i) const { return mMemories[i]; }
192     // Iteration
begin()193     decltype(mMemories.begin()) begin() { return mMemories.begin(); }
end()194     decltype(mMemories.end()) end() { return mMemories.end(); }
195 };
196 
197 }  // namespace nn
198 }  // namespace android
199 
200 #endif  // ANDROID_ML_NN_RUNTIME_MEMORY_H
201