1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef ANDROID_ML_NN_RUNTIME_MEMORY_H 18 #define ANDROID_ML_NN_RUNTIME_MEMORY_H 19 20 #include "NeuralNetworks.h" 21 #include "Utils.h" 22 23 #include <cutils/native_handle.h> 24 #include <sys/mman.h> 25 #include <unordered_map> 26 27 namespace android { 28 namespace nn { 29 30 class ModelBuilder; 31 32 // Represents a memory region. 33 class Memory { 34 public: Memory()35 Memory() {} ~Memory()36 virtual ~Memory() {} 37 38 // Disallow copy semantics to ensure the runtime object can only be freed 39 // once. Copy semantics could be enabled if some sort of reference counting 40 // or deep-copy system for runtime objects is added later. 41 Memory(const Memory&) = delete; 42 Memory& operator=(const Memory&) = delete; 43 44 // Creates a shared memory object of the size specified in bytes. 45 int create(uint32_t size); 46 getHidlMemory()47 hardware::hidl_memory getHidlMemory() const { return mHidlMemory; } 48 49 // Returns a pointer to the underlying memory of this memory object. getPointer(uint8_t ** buffer)50 virtual int getPointer(uint8_t** buffer) const { 51 *buffer = static_cast<uint8_t*>(static_cast<void*>(mMemory->getPointer())); 52 return ANEURALNETWORKS_NO_ERROR; 53 } 54 55 virtual bool validateSize(uint32_t offset, uint32_t length) const; 56 protected: 57 // The hidl_memory handle for this shared memory. We will pass this value when 58 // communicating with the drivers. 59 hardware::hidl_memory mHidlMemory; 60 sp<IMemory> mMemory; 61 }; 62 63 class MemoryFd : public Memory { 64 public: MemoryFd()65 MemoryFd() {} 66 ~MemoryFd(); 67 68 // Disallow copy semantics to ensure the runtime object can only be freed 69 // once. Copy semantics could be enabled if some sort of reference counting 70 // or deep-copy system for runtime objects is added later. 71 MemoryFd(const MemoryFd&) = delete; 72 MemoryFd& operator=(const MemoryFd&) = delete; 73 74 // Create the native_handle based on input size, prot, and fd. 75 // Existing native_handle will be deleted, and mHidlMemory will wrap 76 // the newly created native_handle. 77 int set(size_t size, int prot, int fd, size_t offset); 78 79 int getPointer(uint8_t** buffer) const override; 80 81 private: 82 native_handle_t* mHandle = nullptr; 83 mutable uint8_t* mMapping = nullptr; 84 }; 85 86 // A utility class to accumulate mulitple Memory objects and assign each 87 // a distinct index number, starting with 0. 88 // 89 // The user of this class is responsible for avoiding concurrent calls 90 // to this class from multiple threads. 91 class MemoryTracker { 92 private: 93 // The vector of Memory pointers we are building. 94 std::vector<const Memory*> mMemories; 95 // A faster way to see if we already have a memory than doing find(). 96 std::unordered_map<const Memory*, uint32_t> mKnown; 97 98 public: 99 // Adds the memory, if it does not already exists. Returns its index. 100 // The memories should survive the tracker. 101 uint32_t add(const Memory* memory); 102 // Returns the number of memories contained. size()103 uint32_t size() const { return static_cast<uint32_t>(mKnown.size()); } 104 // Returns the ith memory. 105 const Memory* operator[](size_t i) const { return mMemories[i]; } 106 // Iteration begin()107 decltype(mMemories.begin()) begin() { return mMemories.begin(); } end()108 decltype(mMemories.end()) end() { return mMemories.end(); } 109 }; 110 111 } // namespace nn 112 } // namespace android 113 114 #endif // ANDROID_ML_NN_RUNTIME_MEMORY_H 115