1 /*
2  * Copyright 2021 HIMSA II K/S - www.himsa.com.
3  * Represented by EHIMA - www.ehima.com
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #include <bluetooth/log.h>
19 
20 #include <algorithm>
21 #include <limits>
22 #include <map>
23 #include <mutex>
24 #include <unordered_set>
25 
26 #include "bta_groups.h"
27 #include "btif/include/btif_profile_storage.h"
28 #include "os/logging/log_adapter.h"
29 #include "stack/include/bt_types.h"
30 #include "types/bluetooth/uuid.h"
31 #include "types/raw_address.h"
32 
33 using bluetooth::Uuid;
34 
35 namespace bluetooth {
36 namespace groups {
37 
38 class DeviceGroupsImpl;
39 DeviceGroupsImpl* instance;
40 std::mutex instance_mutex;
41 static constexpr int kMaxGroupId = 0xEF;
42 
43 class DeviceGroup {
44  public:
DeviceGroup(int group_id,Uuid uuid)45   DeviceGroup(int group_id, Uuid uuid)
46       : group_id_(group_id), group_uuid_(uuid) {}
Add(const RawAddress & addr)47   void Add(const RawAddress& addr) { devices_.insert(addr); }
Remove(const RawAddress & addr)48   void Remove(const RawAddress& addr) { devices_.erase(addr); }
Contains(const RawAddress & addr) const49   bool Contains(const RawAddress& addr) const {
50     return (devices_.count(addr) != 0);
51   }
52 
ForEachDevice(std::function<void (const RawAddress &)> cb) const53   void ForEachDevice(std::function<void(const RawAddress&)> cb) const {
54     for (auto const& addr : devices_) {
55       cb(addr);
56     }
57   }
58 
Size(void) const59   int Size(void) const { return devices_.size(); }
GetGroupId(void) const60   int GetGroupId(void) const { return group_id_; }
GetUuid(void) const61   const Uuid& GetUuid(void) const { return group_uuid_; }
62 
63  private:
64   friend std::ostream& operator<<(std::ostream& out,
65                                   const bluetooth::groups::DeviceGroup& value);
66   int group_id_;
67   Uuid group_uuid_;
68   std::unordered_set<RawAddress> devices_;
69 };
70 
71 class DeviceGroupsImpl : public DeviceGroups {
72   static constexpr uint8_t GROUP_STORAGE_CURRENT_LAYOUT_MAGIC = 0x10;
73   static constexpr size_t GROUP_STORAGE_HEADER_SZ =
74       sizeof(GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) +
75       sizeof(uint8_t); /* num_of_groups */
76   static constexpr size_t GROUP_STORAGE_ENTRY_SZ =
77       sizeof(uint8_t) /* group_id */ + Uuid::kNumBytes128;
78 
79  public:
DeviceGroupsImpl(DeviceGroupsCallbacks * callbacks)80   DeviceGroupsImpl(DeviceGroupsCallbacks* callbacks) {
81     AddCallbacks(callbacks);
82     btif_storage_load_bonded_groups();
83   }
84 
GetGroupId(const RawAddress & addr,Uuid uuid) const85   int GetGroupId(const RawAddress& addr, Uuid uuid) const override {
86     for (const auto& [id, g] : groups_) {
87       if ((g.Contains(addr)) && (uuid == g.GetUuid())) return id;
88     }
89     return kGroupUnknown;
90   }
91 
add_to_group(const RawAddress & addr,DeviceGroup * group)92   void add_to_group(const RawAddress& addr, DeviceGroup* group) {
93     group->Add(addr);
94 
95     bool first_device_in_group = (group->Size() == 1);
96 
97     for (auto c : callbacks_) {
98       if (first_device_in_group) {
99         c->OnGroupAdded(addr, group->GetUuid(), group->GetGroupId());
100       } else {
101         c->OnGroupMemberAdded(addr, group->GetGroupId());
102       }
103     }
104   }
105 
AddDevice(const RawAddress & addr,Uuid uuid,int group_id)106   int AddDevice(const RawAddress& addr, Uuid uuid, int group_id) override {
107     DeviceGroup* group = nullptr;
108 
109     if (group_id == kGroupUnknown) {
110       auto gid = GetGroupId(addr, uuid);
111       if (gid != kGroupUnknown) return gid;
112       group = create_group(uuid);
113     } else {
114       group = get_or_create_group_with_id(group_id, uuid);
115       if (!group) {
116         return kGroupUnknown;
117       }
118     }
119 
120     log::assert_that(group, "assert failed: group");
121 
122     if (group->Contains(addr)) {
123       log::error("device {} already in the group: {}", addr, group_id);
124       return group->GetGroupId();
125     }
126 
127     add_to_group(addr, group);
128 
129     btif_storage_add_groups(addr);
130     return group->GetGroupId();
131   }
132 
RemoveDevice(const RawAddress & addr,int group_id)133   void RemoveDevice(const RawAddress& addr, int group_id) override {
134     int num_of_groups_dev_belongs = 0;
135 
136     /* Remove from all the groups. Usually happens on unbond */
137     for (auto it = groups_.begin(); it != groups_.end();) {
138       auto& [id, g] = *it;
139       if (!g.Contains(addr)) {
140         ++it;
141         continue;
142       }
143 
144       num_of_groups_dev_belongs++;
145 
146       if ((group_id != bluetooth::groups::kGroupUnknown) && (group_id != id)) {
147         ++it;
148         continue;
149       }
150 
151       num_of_groups_dev_belongs--;
152 
153       g.Remove(addr);
154       for (auto c : callbacks_) {
155         c->OnGroupMemberRemoved(addr, id);
156       }
157 
158       if (g.Size() == 0) {
159         for (auto c : callbacks_) {
160           c->OnGroupRemoved(g.GetUuid(), g.GetGroupId());
161         }
162         it = groups_.erase(it);
163       } else {
164         ++it;
165       }
166     }
167 
168     btif_storage_remove_groups(addr);
169     if (num_of_groups_dev_belongs > 0) {
170       btif_storage_add_groups(addr);
171     }
172   }
173 
SerializeGroups(const RawAddress & addr,std::vector<uint8_t> & out) const174   bool SerializeGroups(const RawAddress& addr,
175                        std::vector<uint8_t>& out) const {
176     auto num_groups = std::count_if(
177         groups_.begin(), groups_.end(), [&addr](auto& id_group_pair) {
178           return id_group_pair.second.Contains(addr);
179         });
180     if ((num_groups == 0) || (num_groups > std::numeric_limits<uint8_t>::max()))
181       return false;
182 
183     out.resize(GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ));
184     auto* ptr = out.data();
185 
186     /* header */
187     UINT8_TO_STREAM(ptr, GROUP_STORAGE_CURRENT_LAYOUT_MAGIC);
188     UINT8_TO_STREAM(ptr, num_groups);
189 
190     /* group entries */
191     for (const auto& [id, g] : groups_) {
192       if (g.Contains(addr)) {
193         UINT8_TO_STREAM(ptr, id);
194 
195         Uuid::UUID128Bit uuid128 = g.GetUuid().To128BitLE();
196         memcpy(ptr, uuid128.data(), Uuid::kNumBytes128);
197         ptr += Uuid::kNumBytes128;
198       }
199     }
200 
201     return true;
202   }
203 
DeserializeGroups(const RawAddress & addr,const std::vector<uint8_t> & in)204   void DeserializeGroups(const RawAddress& addr,
205                          const std::vector<uint8_t>& in) {
206     if (in.size() < GROUP_STORAGE_HEADER_SZ + GROUP_STORAGE_ENTRY_SZ) return;
207 
208     auto* ptr = in.data();
209 
210     uint8_t magic;
211     STREAM_TO_UINT8(magic, ptr);
212 
213     if (magic == GROUP_STORAGE_CURRENT_LAYOUT_MAGIC) {
214       uint8_t num_groups;
215       STREAM_TO_UINT8(num_groups, ptr);
216 
217       if (in.size() <
218           GROUP_STORAGE_HEADER_SZ + (num_groups * GROUP_STORAGE_ENTRY_SZ)) {
219         log::error("Invalid persistent storage data");
220         return;
221       }
222 
223       /* group entries */
224       while (num_groups--) {
225         uint8_t id;
226         STREAM_TO_UINT8(id, ptr);
227 
228         Uuid::UUID128Bit uuid128;
229         STREAM_TO_ARRAY(uuid128.data(), ptr, (int)Uuid::kNumBytes128);
230 
231         auto* group =
232             get_or_create_group_with_id(id, Uuid::From128BitLE(uuid128));
233         if (group) add_to_group(addr, group);
234 
235         for (auto c : callbacks_) {
236           c->OnGroupAddFromStorage(addr, Uuid::From128BitLE(uuid128), id);
237         }
238       }
239     }
240   }
241 
AddCallbacks(DeviceGroupsCallbacks * callbacks)242   void AddCallbacks(DeviceGroupsCallbacks* callbacks) {
243     callbacks_.push_back(std::move(callbacks));
244 
245     /* Notify new user about known groups */
246     for (const auto& [id, g] : groups_) {
247       auto group_uuid = g.GetUuid();
248       auto group_id = g.GetGroupId();
249       g.ForEachDevice([&](auto& dev) {
250         callbacks->OnGroupAdded(dev, group_uuid, group_id);
251       });
252     }
253   }
254 
Clear(DeviceGroupsCallbacks * callbacks)255   bool Clear(DeviceGroupsCallbacks* callbacks) {
256     auto it = find_if(callbacks_.begin(), callbacks_.end(),
257                       [callbacks](auto c) { return c == callbacks; });
258 
259     if (it != callbacks_.end()) callbacks_.erase(it);
260 
261     if (callbacks_.size() != 0) {
262       return false;
263     }
264     /* When all clients were unregistered */
265     groups_.clear();
266     return true;
267   }
268 
Dump(int fd)269   void Dump(int fd) {
270     std::stringstream stream;
271 
272     stream << "  Num. registered clients: " << callbacks_.size() << std::endl;
273     stream << "  Groups:\n";
274     for (const auto& kv_pair : groups_) {
275       stream << kv_pair.second << std::endl;
276     }
277 
278     dprintf(fd, "%s", stream.str().c_str());
279   }
280 
281  private:
find_device_group(int group_id)282   DeviceGroup* find_device_group(int group_id) {
283     return groups_.count(group_id) ? &groups_.at(group_id) : nullptr;
284   }
285 
get_or_create_group_with_id(int group_id,Uuid uuid)286   DeviceGroup* get_or_create_group_with_id(int group_id, Uuid uuid) {
287     auto group = find_device_group(group_id);
288     if (group) {
289       if (group->GetUuid() != uuid) {
290         log::error(
291             "group {} exists but for different uuid: {}, user request uuid: {}",
292             group_id, group->GetUuid(), uuid);
293         return nullptr;
294       }
295 
296       log::info("group already exists: {}", group_id);
297       return group;
298     }
299 
300     DeviceGroup new_group(group_id, uuid);
301     groups_.insert({group_id, std::move(new_group)});
302 
303     return &groups_.at(group_id);
304   }
305 
create_group(Uuid & uuid)306   DeviceGroup* create_group(Uuid& uuid) {
307     /* Generate new group id and return empty group */
308     /* Find first free id */
309 
310     int group_id = -1;
311     for (int i = 1; i < kMaxGroupId; i++) {
312       if (groups_.count(i) == 0) {
313         group_id = i;
314         break;
315       }
316     }
317 
318     if (group_id < 0) {
319       log::error("too many groups");
320       return nullptr;
321     }
322 
323     DeviceGroup group(group_id, uuid);
324     groups_.insert({group_id, std::move(group)});
325 
326     return &groups_.at(group_id);
327   }
328 
329   std::map<int, DeviceGroup> groups_;
330   std::list<DeviceGroupsCallbacks*> callbacks_;
331 };
332 
Initialize(DeviceGroupsCallbacks * callbacks)333 void DeviceGroups::Initialize(DeviceGroupsCallbacks* callbacks) {
334   std::scoped_lock<std::mutex> lock(instance_mutex);
335   if (instance == nullptr) {
336     instance = new DeviceGroupsImpl(callbacks);
337     return;
338   }
339 
340   instance->AddCallbacks(callbacks);
341 }
342 
AddFromStorage(const RawAddress & addr,const std::vector<uint8_t> & in)343 void DeviceGroups::AddFromStorage(const RawAddress& addr,
344                                   const std::vector<uint8_t>& in) {
345   if (!instance) {
346     log::error("Not initialized yet");
347     return;
348   }
349 
350   instance->DeserializeGroups(addr, in);
351 }
352 
GetForStorage(const RawAddress & addr,std::vector<uint8_t> & out)353 bool DeviceGroups::GetForStorage(const RawAddress& addr,
354                                  std::vector<uint8_t>& out) {
355   if (!instance) {
356     log::error("Not initialized yet");
357     return false;
358   }
359 
360   return instance->SerializeGroups(addr, out);
361 }
362 
CleanUp(DeviceGroupsCallbacks * callbacks)363 void DeviceGroups::CleanUp(DeviceGroupsCallbacks* callbacks) {
364   std::scoped_lock<std::mutex> lock(instance_mutex);
365   if (!instance) return;
366 
367   if (instance->Clear(callbacks)) {
368     delete (instance);
369     instance = nullptr;
370   }
371 }
372 
operator <<(std::ostream & out,bluetooth::groups::DeviceGroup const & group)373 std::ostream& operator<<(std::ostream& out,
374                          bluetooth::groups::DeviceGroup const& group) {
375   out << "    == Group id: " << group.group_id_ << " == \n"
376       << "      Uuid: " << group.group_uuid_ << std::endl;
377   out << "      Devices:\n";
378   for (auto const& addr : group.devices_) {
379     out << "        " << ADDRESS_TO_LOGGABLE_STR(addr) << std::endl;
380   }
381   return out;
382 }
383 
DebugDump(int fd)384 void DeviceGroups::DebugDump(int fd) {
385   std::scoped_lock<std::mutex> lock(instance_mutex);
386   dprintf(fd, "Device Groups Manager:\n");
387   if (instance)
388     instance->Dump(fd);
389   else
390     dprintf(fd, "  Not initialized \n");
391 }
392 
Get()393 DeviceGroups* DeviceGroups::Get() { return instance; }
394 
395 }  // namespace groups
396 }  // namespace bluetooth
397