1 /*
2 * Copyright 2021 HIMSA II K/S - www.himsa.com.
3 * Represented by EHIMA - www.ehima.com
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 #include <bluetooth/log.h>
19
20 #include <algorithm>
21 #include <limits>
22 #include <map>
23 #include <mutex>
24 #include <unordered_set>
25
26 #include "bta_groups.h"
27 #include "btif/include/btif_profile_storage.h"
28 #include "os/logging/log_adapter.h"
29 #include "stack/include/bt_types.h"
30 #include "types/bluetooth/uuid.h"
31 #include "types/raw_address.h"
32
33 using bluetooth::Uuid;
34
35 namespace bluetooth {
36 namespace groups {
37
38 class DeviceGroupsImpl;
39 DeviceGroupsImpl* instance;
40 std::mutex instance_mutex;
41 static constexpr int kMaxGroupId = 0xEF;
42
43 class DeviceGroup {
44 public:
DeviceGroup(int group_id,Uuid uuid)45 DeviceGroup(int group_id, Uuid uuid)
46 : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)47 void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)48 void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const49 bool Contains(const RawAddress& addr) const {
50 return (devices_.count(addr) != 0);
51 }
52
ForEachDevice(std::function<void (const RawAddress &)> cb) const53 void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
54 for (auto const& addr : devices_) {
55 cb(addr);
56 }
57 }
58
Size(void) const59 int Size(void) const { return devices_.size(); }
GetGroupId(void) const60 int GetGroupId(void) const { return group_id_; }
GetUuid(void) const61 const Uuid& GetUuid(void) const { return group_uuid_; }
62
63 private:
64 friend std::ostream& operator<<(std::ostream& out,
65 const bluetooth::groups::DeviceGroup& value);
66 int group_id_;
67 Uuid group_uuid_;
68 std::unordered_set<RawAddress> devices_;
69 };
70
71 class DeviceGroupsImpl : public DeviceGroups {
72 static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
73 static constexpr size_t GROUP_STORAGE_HEADER_SZ =
74 sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) +
75 sizeof(uint8_t); /* num_of_groups */
76 static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
77 sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
78
79 public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)80 DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
81 AddCallbacks(callbacks);
82 btif_storage_load_bonded_groups();
83 }
84
GetGroupId(const RawAddress & addr,Uuid uuid) const85 int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
86 for (const auto& [id, g] : groups_) {
87 if ((g.Contains(addr)) && (uuid == g.GetUuid())) return id;
88 }
89 return kGroupUnknown;
90 }
91
add_to_group(const RawAddress & addr,DeviceGroup * group)92 void add_to_group(const RawAddress& addr, DeviceGroup* group) {
93 group->Add(addr);
94
95 bool first_device_in_group = (group->Size() == 1);
96
97 for (auto c : callbacks_) {
98 if (first_device_in_group) {
99 c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
100 } else {
101 c->OnGroupMemberAdded(addr, group->GetGroupId());
102 }
103 }
104 }
105
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)106 int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
107 DeviceGroup* group = nullptr;
108
109 if (group_id == kGroupUnknown) {
110 auto gid = GetGroupId(addr, uuid);
111 if (gid != kGroupUnknown) return gid;
112 group = create_group(uuid);
113 } else {
114 group = get_or_create_group_with_id(group_id, uuid);
115 if (!group) {
116 return kGroupUnknown;
117 }
118 }
119
120 log::assert_that(group, "assert failed: group");
121
122 if (group->Contains(addr)) {
123 log::error("device {} already in the group: {}", addr, group_id);
124 return group->GetGroupId();
125 }
126
127 add_to_group(addr, group);
128
129 btif_storage_add_groups(addr);
130 return group->GetGroupId();
131 }
132
RemoveDevice(const RawAddress & addr,int group_id)133 void RemoveDevice(const RawAddress& addr, int group_id) override {
134 int num_of_groups_dev_belongs = 0;
135
136 /* Remove from all the groups. Usually happens on unbond */
137 for (auto it = groups_.begin(); it != groups_.end();) {
138 auto& [id, g] = *it;
139 if (!g.Contains(addr)) {
140 ++it;
141 continue;
142 }
143
144 num_of_groups_dev_belongs++;
145
146 if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
147 ++it;
148 continue;
149 }
150
151 num_of_groups_dev_belongs--;
152
153 g.Remove(addr);
154 for (auto c : callbacks_) {
155 c->OnGroupMemberRemoved(addr, id);
156 }
157
158 if (g.Size() == 0) {
159 for (auto c : callbacks_) {
160 c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
161 }
162 it = groups_.erase(it);
163 } else {
164 ++it;
165 }
166 }
167
168 btif_storage_remove_groups(addr);
169 if (num_of_groups_dev_belongs > 0) {
170 btif_storage_add_groups(addr);
171 }
172 }
173
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const174 bool SerializeGroups(const RawAddress& addr,
175 std::vector<uint8_t>& out) const {
176 auto num_groups = std::count_if(
177 groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
178 return id_group_pair.second.Contains(addr);
179 });
180 if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max()))
181 return false;
182
183 out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
184 auto* ptr = out.data();
185
186 /* header */
187 UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
188 UINT8_TO_STREAM(ptr, num_groups);
189
190 /* group entries */
191 for (const auto& [id, g] : groups_) {
192 if (g.Contains(addr)) {
193 UINT8_TO_STREAM(ptr, id);
194
195 Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
196 memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
197 ptr += Uuid::kNumBytes128;
198 }
199 }
200
201 return true;
202 }
203
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)204 void DeserializeGroups(const RawAddress& addr,
205 const std::vector<uint8_t>& in) {
206 if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) return;
207
208 auto* ptr = in.data();
209
210 uint8_t magic;
211 STREAM_TO_UINT8(magic, ptr);
212
213 if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
214 uint8_t num_groups;
215 STREAM_TO_UINT8(num_groups, ptr);
216
217 if (in.size() <
218 GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
219 log::error("Invalid persistent storage data");
220 return;
221 }
222
223 /* group entries */
224 while (num_groups--) {
225 uint8_t id;
226 STREAM_TO_UINT8(id, ptr);
227
228 Uuid::UUID128Bit uuid128;
229 STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
230
231 auto* group =
232 get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
233 if (group) add_to_group(addr, group);
234
235 for (auto c : callbacks_) {
236 c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
237 }
238 }
239 }
240 }
241
AddCallbacks(DeviceGroupsCallbacks * callbacks)242 void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
243 callbacks_.push_back(std::move(callbacks));
244
245 /* Notify new user about known groups */
246 for (const auto& [id, g] : groups_) {
247 auto group_uuid = g.GetUuid();
248 auto group_id = g.GetGroupId();
249 g.ForEachDevice([&](auto& dev) {
250 callbacks->OnGroupAdded(dev, group_uuid, group_id);
251 });
252 }
253 }
254
Clear(DeviceGroupsCallbacks * callbacks)255 bool Clear(DeviceGroupsCallbacks* callbacks) {
256 auto it = find_if(callbacks_.begin(), callbacks_.end(),
257 [callbacks](auto c) { return c == callbacks; });
258
259 if (it != callbacks_.end()) callbacks_.erase(it);
260
261 if (callbacks_.size() != 0) {
262 return false;
263 }
264 /* When all clients were unregistered */
265 groups_.clear();
266 return true;
267 }
268
Dump(int fd)269 void Dump(int fd) {
270 std::stringstream stream;
271
272 stream << " Num. registered clients: " << callbacks_.size() << std::endl;
273 stream << " Groups:\n";
274 for (const auto& kv_pair : groups_) {
275 stream << kv_pair.second << std::endl;
276 }
277
278 dprintf(fd, "%s", stream.str().c_str());
279 }
280
281 private:
find_device_group(int group_id)282 DeviceGroup* find_device_group(int group_id) {
283 return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
284 }
285
get_or_create_group_with_id(int group_id,Uuid uuid)286 DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
287 auto group = find_device_group(group_id);
288 if (group) {
289 if (group->GetUuid() != uuid) {
290 log::error(
291 "group {} exists but for different uuid: {}, user request uuid: {}",
292 group_id, group->GetUuid(), uuid);
293 return nullptr;
294 }
295
296 log::info("group already exists: {}", group_id);
297 return group;
298 }
299
300 DeviceGroup new_group(group_id, uuid);
301 groups_.insert({group_id, std::move(new_group)});
302
303 return &groups_.at(group_id);
304 }
305
create_group(Uuid & uuid)306 DeviceGroup* create_group(Uuid& uuid) {
307 /* Generate new group id and return empty group */
308 /* Find first free id */
309
310 int group_id = -1;
311 for (int i = 1; i < kMaxGroupId; i++) {
312 if (groups_.count(i) == 0) {
313 group_id = i;
314 break;
315 }
316 }
317
318 if (group_id < 0) {
319 log::error("too many groups");
320 return nullptr;
321 }
322
323 DeviceGroup group(group_id, uuid);
324 groups_.insert({group_id, std::move(group)});
325
326 return &groups_.at(group_id);
327 }
328
329 std::map<int, DeviceGroup> groups_;
330 std::list<DeviceGroupsCallbacks*> callbacks_;
331 };
332
Initialize(DeviceGroupsCallbacks * callbacks)333 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
334 std::scoped_lock<std::mutex> lock(instance_mutex);
335 if (instance == nullptr) {
336 instance = new DeviceGroupsImpl(callbacks);
337 return;
338 }
339
340 instance->AddCallbacks(callbacks);
341 }
342
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)343 void DeviceGroups::AddFromStorage(const RawAddress& addr,
344 const std::vector<uint8_t>& in) {
345 if (!instance) {
346 log::error("Not initialized yet");
347 return;
348 }
349
350 instance->DeserializeGroups(addr, in);
351 }
352
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)353 bool DeviceGroups::GetForStorage(const RawAddress& addr,
354 std::vector<uint8_t>& out) {
355 if (!instance) {
356 log::error("Not initialized yet");
357 return false;
358 }
359
360 return instance->SerializeGroups(addr, out);
361 }
362
CleanUp(DeviceGroupsCallbacks * callbacks)363 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
364 std::scoped_lock<std::mutex> lock(instance_mutex);
365 if (!instance) return;
366
367 if (instance->Clear(callbacks)) {
368 delete (instance);
369 instance = nullptr;
370 }
371 }
372
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)373 std::ostream& operator<<(std::ostream& out,
374 bluetooth::groups::DeviceGroup const& group) {
375 out << " == Group id: " << group.group_id_ << " == \n"
376 << " Uuid: " << group.group_uuid_ << std::endl;
377 out << " Devices:\n";
378 for (auto const& addr : group.devices_) {
379 out << " " << ADDRESS_TO_LOGGABLE_STR(addr) << std::endl;
380 }
381 return out;
382 }
383
DebugDump(int fd)384 void DeviceGroups::DebugDump(int fd) {
385 std::scoped_lock<std::mutex> lock(instance_mutex);
386 dprintf(fd, "Device Groups Manager:\n");
387 if (instance)
388 instance->Dump(fd);
389 else
390 dprintf(fd, " Not initialized \n");
391 }
392
Get()393 DeviceGroups* DeviceGroups::Get() { return instance; }
394
395 } // namespace groups
396 } // namespace bluetooth
397