1 /*
2  * Copyright (C) 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef BERBERIS_RUNTIME_PRIMITIVES_TABLE_OF_TABLES_H_
18 #define BERBERIS_RUNTIME_PRIMITIVES_TABLE_OF_TABLES_H_
19 
20 #include <sys/mman.h>
21 
22 #include <atomic>
23 #include <cstdint>
24 #include <mutex>
25 
26 #include "berberis/base/logging.h"
27 #include "berberis/base/memfd_backed_mmap.h"
28 #include "berberis/base/mmap.h"
29 
30 namespace berberis {
31 
32 template <typename Key, typename T>
33 class TableOfTables {
34  public:
TableOfTables(T default_value)35   explicit TableOfTables(T default_value) : default_value_(default_value) {
36     static_assert(sizeof(T) == sizeof(uintptr_t));
37     default_table_ = static_cast<decltype(default_table_)>(CreateMemfdBackedMapOrDie(
38         GetOrAllocDefaultMemfdUnsafe(), kChildTableBytes, kMemfdRegionSize));
39 
40     int main_memfd =
41         CreateAndFillMemfd("main", kMemfdRegionSize, reinterpret_cast<uintptr_t>(default_table_));
42     main_table_ = static_cast<decltype(main_table_)>(
43         CreateMemfdBackedMapOrDie(main_memfd, kTableSize * sizeof(T*), kMemfdRegionSize));
44     close(main_memfd);
45 
46     // The default table is read-only.
47     MprotectOrDie(default_table_, kChildTableBytes, PROT_READ);
48   }
49 
~TableOfTables()50   ~TableOfTables() {
51     for (size_t i = 0; i < kTableSize; ++i) {
52       if (main_table_[i] != default_table_) {
53         MunmapOrDie(main_table_[i], kChildTableBytes);
54       }
55     }
56 
57     MunmapOrDie(main_table_, kTableSize * sizeof(T*));
58     MunmapOrDie(default_table_, kChildTableBytes);
59     CloseDefaultMemfdUnsafe();
60   }
61 
Put(Key key,T value)62   /*may_discard*/ std::atomic<T>* Put(Key key, T value) {
63     SplitKey split_key(key);
64 
65     AllocateIfNecessary(split_key.high);
66 
67     main_table_[split_key.high][split_key.low] = value;
68     return &main_table_[split_key.high][split_key.low];
69   }
70 
Get(Key key)71   [[nodiscard]] T Get(Key key) const {
72     SplitKey split_key(key);
73     return main_table_[split_key.high][split_key.low];
74   }
75 
76   // This function returns a value address.
77   //
78   // Note that since this function has additional checks and may
79   // result in memory allocation, it is considerably slower than Get().
GetPointer(Key key)80   [[nodiscard]] std::atomic<T>* GetPointer(Key key) {
81     SplitKey split_key(key);
82 
83     AllocateIfNecessary(split_key.high);
84 
85     return &main_table_[split_key.high][split_key.low];
86   }
87 
main_table()88   [[nodiscard]] const std::atomic<std::atomic<T>*>* main_table() const { return main_table_; }
89 
CloseDefaultMemfdUnsafe()90   void CloseDefaultMemfdUnsafe() {
91     if (default_memfd_ == -1) {
92       return;
93     }
94     close(default_memfd_);
95     default_memfd_ = -1;
96   }
97 
98  private:
99   struct SplitKey {
SplitKeySplitKey100     explicit SplitKey(Key key) : low(key & (kTableSize - 1)), high(key >> kTableBits) {
101       CHECK_EQ(high & ~(kTableSize - 1), 0);
102     }
103 
104     const uint32_t low;
105     const uint32_t high;
106     static_assert(sizeof(Key) <= sizeof(low) * 2);
107   };
108 
GetOrAllocDefaultMemfdUnsafe()109   int GetOrAllocDefaultMemfdUnsafe() {
110     if (default_memfd_ == -1) {
111       default_memfd_ = CreateAndFillMemfd(
112           "child", kMemfdRegionSize, reinterpret_cast<uintptr_t>(default_value_));
113     }
114     return default_memfd_;
115   }
116 
117   // TODO(b/191390557): Inlining this function breaks app execution. Need to figure out
118   // the root cause and remove noinline.
AllocateIfNecessary(uint32_t high_word)119   void __attribute__((noinline)) AllocateIfNecessary(uint32_t high_word) {
120     // Fast fallback to avoid expensive mutex lock.
121     if (main_table_[high_word] != default_table_) {
122       return;
123     }
124 
125     std::lock_guard<std::mutex> lock(mutex_);
126     // Check again since the value could have been modified by other threads.
127     if (main_table_[high_word] == default_table_) {
128       auto tmp = static_cast<std::atomic<T>*>(CreateMemfdBackedMapOrDie(
129           GetOrAllocDefaultMemfdUnsafe(), kChildTableBytes, kMemfdRegionSize));
130       // Use fence to make sure the allocated table has been fully initialized
131       // before main_table_ is updated to point to it.
132       std::atomic_thread_fence(std::memory_order_release);
133       main_table_[high_word] = tmp;
134     }
135   }
136 
137 #if defined(__LP64__) && defined(BERBERIS_GUEST_LP64)
138   // On 64-bit architectures the effective pointer bits are limited to 48
139   // which makes it possible to split tables into 2^24 + 2^24.
140   static constexpr size_t kTableBits = 24;
141   // Use a 16Mb memfd region to fill the main/default table.
142   // Linux has a limited number of maps (sysctl vm.max_map_count).
143   // A larger region size allows us to stay within the limit.
144   static constexpr size_t kMemfdRegionSize = 1 << 24;
145   static_assert(sizeof(Key) == 8);
146 #elif !defined(BERBERIS_GUEST_LP64)
147   static constexpr size_t kTableBits = 16;
148   // Use a 64k memfd region to fill the main/default table.
149   // Linux has a limited number of maps (sysctl vm.max_map_count).
150   // A larger region size allows us to stay within the limit.
151   static constexpr size_t kMemfdRegionSize = 1 << 16;
152   static_assert(sizeof(Key) == 4);
153 #else
154 #error "Unsupported combination of a 32-bit host with a 64-bit guest"
155 #endif
156   static constexpr size_t kTableSize = 1 << kTableBits;
157   static constexpr size_t kChildTableBytes = kTableSize * sizeof(T);
158   std::mutex mutex_;
159   std::atomic<std::atomic<T>*>* main_table_;
160   std::atomic<T>* default_table_;
161   int default_memfd_{-1};
162   T default_value_;
163 };
164 
165 }  // namespace berberis
166 
167 #endif  // BERBERIS_RUNTIME_PRIMITIVES_TABLE_OF_TABLES_H_
168