1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef ANDROID_ML_NN_RUNTIME_MEMORY_H
18 #define ANDROID_ML_NN_RUNTIME_MEMORY_H
19 
20 #include "NeuralNetworks.h"
21 #include "Utils.h"
22 
23 #include <cutils/native_handle.h>
24 #include <sys/mman.h>
25 #include <unordered_map>
26 
27 namespace android {
28 namespace nn {
29 
30 class ModelBuilder;
31 
32 // Represents a memory region.
33 class Memory {
34 public:
Memory()35     Memory() {}
~Memory()36     virtual ~Memory() {}
37 
38     // Disallow copy semantics to ensure the runtime object can only be freed
39     // once. Copy semantics could be enabled if some sort of reference counting
40     // or deep-copy system for runtime objects is added later.
41     Memory(const Memory&) = delete;
42     Memory& operator=(const Memory&) = delete;
43 
44     // Creates a shared memory object of the size specified in bytes.
45     int create(uint32_t size);
46 
getHidlMemory()47     hardware::hidl_memory getHidlMemory() const { return mHidlMemory; }
48 
49     // Returns a pointer to the underlying memory of this memory object.
getPointer(uint8_t ** buffer)50     virtual int getPointer(uint8_t** buffer) const {
51         *buffer = static_cast<uint8_t*>(static_cast<void*>(mMemory->getPointer()));
52         return ANEURALNETWORKS_NO_ERROR;
53     }
54 
55     virtual bool validateSize(uint32_t offset, uint32_t length) const;
56 protected:
57     // The hidl_memory handle for this shared memory.  We will pass this value when
58     // communicating with the drivers.
59     hardware::hidl_memory mHidlMemory;
60     sp<IMemory> mMemory;
61 };
62 
63 class MemoryFd : public Memory {
64 public:
MemoryFd()65     MemoryFd() {}
66     ~MemoryFd();
67 
68     // Disallow copy semantics to ensure the runtime object can only be freed
69     // once. Copy semantics could be enabled if some sort of reference counting
70     // or deep-copy system for runtime objects is added later.
71     MemoryFd(const MemoryFd&) = delete;
72     MemoryFd& operator=(const MemoryFd&) = delete;
73 
74     // Create the native_handle based on input size, prot, and fd.
75     // Existing native_handle will be deleted, and mHidlMemory will wrap
76     // the newly created native_handle.
77     int set(size_t size, int prot, int fd, size_t offset);
78 
79     int getPointer(uint8_t** buffer) const override;
80 
81 private:
82     native_handle_t* mHandle = nullptr;
83     mutable uint8_t* mMapping = nullptr;
84 };
85 
86 // A utility class to accumulate mulitple Memory objects and assign each
87 // a distinct index number, starting with 0.
88 //
89 // The user of this class is responsible for avoiding concurrent calls
90 // to this class from multiple threads.
91 class MemoryTracker {
92 private:
93     // The vector of Memory pointers we are building.
94     std::vector<const Memory*> mMemories;
95     // A faster way to see if we already have a memory than doing find().
96     std::unordered_map<const Memory*, uint32_t> mKnown;
97 
98 public:
99     // Adds the memory, if it does not already exists.  Returns its index.
100     // The memories should survive the tracker.
101     uint32_t add(const Memory* memory);
102     // Returns the number of memories contained.
size()103     uint32_t size() const { return static_cast<uint32_t>(mKnown.size()); }
104     // Returns the ith memory.
105     const Memory* operator[](size_t i) const { return mMemories[i]; }
106     // Iteration
begin()107     decltype(mMemories.begin()) begin() { return mMemories.begin(); }
end()108     decltype(mMemories.end())   end()   { return mMemories.end(); }
109 };
110 
111 }  // namespace nn
112 }  // namespace android
113 
114 #endif  // ANDROID_ML_NN_RUNTIME_MEMORY_H
115