1 /*
2  * Copyright 2021 HIMSA II K/S - www.himsa.com.
3  * Represented by EHIMA - www.ehima.com
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 #pragma once
19 
20 #include <base/strings/string_number_conversions.h>
21 #include <bluetooth/log.h>
22 
23 #include <algorithm>
24 #include <map>
25 #include <vector>
26 
27 #include "bta_csis_api.h"
28 #include "bta_gatt_api.h"
29 #include "bta_groups.h"
30 #include "btif/include/btif_storage.h"
31 #include "common/init_flags.h"
32 #include "common/strings.h"
33 #include "crypto_toolbox/crypto_toolbox.h"
34 #include "gap_api.h"
35 
36 // Uncomment to debug SIRK calculations
37 // #define CSIS_DEBUG
38 
39 namespace bluetooth {
40 namespace csis {
41 
42 using bluetooth::csis::CsisLockCb;
43 
44 // CSIP additions
45 /* Generic UUID is used when CSIS is not included in any context */
46 static const bluetooth::Uuid kCsisServiceUuid =
47     bluetooth::Uuid::From16Bit(0x1846);
48 static const bluetooth::Uuid kCsisSirkUuid = bluetooth::Uuid::From16Bit(0x2B84);
49 static const bluetooth::Uuid kCsisSizeUuid = bluetooth::Uuid::From16Bit(0x2B85);
50 static const bluetooth::Uuid kCsisLockUuid = bluetooth::Uuid::From16Bit(0x2B86);
51 static const bluetooth::Uuid kCsisRankUuid = bluetooth::Uuid::From16Bit(0x2B87);
52 
53 static constexpr uint8_t kCsisErrorCodeLockDenied = 0x80;
54 static constexpr uint8_t kCsisErrorCodeReleaseNotAllowed = 0x81;
55 static constexpr uint8_t kCsisErrorCodeInvalidValue = 0x82;
56 static constexpr uint8_t kCsisErrorCodeLockAccessSirkRejected = 0x83;
57 static constexpr uint8_t kCsisErrorCodeLockOobSirkOnly = 0x84;
58 static constexpr uint8_t kCsisErrorCodeLockAlreadyGranted = 0x85;
59 
60 static constexpr uint8_t kCsisSirkTypeEncrypted = 0x00;
61 static constexpr uint8_t kCsisSirkCharLen = 17;
62 
63 struct hdl_pair {
hdl_pairhdl_pair64   hdl_pair() {}
hdl_pairhdl_pair65   hdl_pair(uint16_t val_hdl, uint16_t ccc_hdl)
66       : val_hdl(val_hdl), ccc_hdl(ccc_hdl) {}
67 
68   uint16_t val_hdl;
69   uint16_t ccc_hdl;
70 };
71 
72 /* CSIS Types */
73 static constexpr uint8_t kDefaultScanDurationS = 5;
74 static constexpr uint8_t kDefaultCsisSetSize = 1;
75 static constexpr uint8_t kUnknownRank = 0xff;
76 
77 /* Enums */
78 enum class CsisLockState : uint8_t {
79   CSIS_STATE_UNSET = 0x00,
80   CSIS_STATE_UNLOCKED,
81   CSIS_STATE_LOCKED
82 };
83 
84 enum class CsisDiscoveryState : uint8_t {
85   CSIS_DISCOVERY_IDLE = 0x00,
86   CSIS_DISCOVERY_ONGOING,
87   CSIS_DISCOVERY_COMPLETED,
88 };
89 
90 class GattServiceDevice {
91  public:
92   RawAddress addr;
93   /*
94    * We are making active attempt to connect to this device, 'direct connect'.
95    */
96   bool connecting_actively = false;
97 
98   uint16_t conn_id = GATT_INVALID_CONN_ID;
99   uint16_t service_handle = GAP_INVALID_HANDLE;
100   bool is_gatt_service_valid = false;
101 
GattServiceDevice(const RawAddress & addr,bool first_connection)102   GattServiceDevice(const RawAddress& addr, bool first_connection)
103       : addr(addr) {}
104 
GattServiceDevice()105   GattServiceDevice() : GattServiceDevice(RawAddress::kEmpty, false) {}
106 
IsConnected()107   bool IsConnected() const { return (conn_id != GATT_INVALID_CONN_ID); }
108 
109   class MatchAddress {
110    private:
111     RawAddress addr;
112 
113    public:
MatchAddress(const RawAddress & addr)114     MatchAddress(const RawAddress& addr) : addr(addr) {}
operator()115     bool operator()(const std::shared_ptr<GattServiceDevice>& other) const {
116       return (addr == other->addr);
117     }
118   };
119 
120   class MatchConnId {
121    private:
122     uint16_t conn_id;
123 
124    public:
MatchConnId(uint16_t conn_id)125     MatchConnId(uint16_t conn_id) : conn_id(conn_id) {}
operator()126     bool operator()(const std::shared_ptr<GattServiceDevice>& other) const {
127       return (conn_id == other->conn_id);
128     }
129   };
130 };
131 
132 /*
133  * CSIS instance represents single CSIS service on the remote device
134  * along with the handle in database and specific data to control CSIS like:
135  * rank, lock state.
136  *
137  * It also inclues UUID of the primary service which includes that CSIS
138  * instance. If this is 0x0000 it means CSIS is per device and not for specific
139  * service.
140  */
141 class CsisInstance {
142  public:
143   bluetooth::Uuid coordinated_service = bluetooth::groups::kGenericContextUuid;
144 
145   struct SvcData {
146     uint16_t start_handle;
147     uint16_t end_handle;
148     struct hdl_pair sirk_handle;
149     struct hdl_pair lock_handle;
150     uint16_t rank_handle;
151     struct hdl_pair size_handle;
152   } svc_data = {
153       GAP_INVALID_HANDLE,
154       GAP_INVALID_HANDLE,
155       {GAP_INVALID_HANDLE, GAP_INVALID_HANDLE},
156       {GAP_INVALID_HANDLE, GAP_INVALID_HANDLE},
157       GAP_INVALID_HANDLE,
158       {GAP_INVALID_HANDLE, GAP_INVALID_HANDLE},
159   };
160 
CsisInstance(uint16_t start_handle,uint16_t end_handle,const bluetooth::Uuid & uuid)161   CsisInstance(uint16_t start_handle, uint16_t end_handle,
162                const bluetooth::Uuid& uuid)
163       : coordinated_service(uuid),
164         group_id_(bluetooth::groups::kGroupUnknown),
165         rank_(kUnknownRank),
166         lock_state_(CsisLockState::CSIS_STATE_UNSET) {
167     svc_data.start_handle = start_handle;
168     svc_data.end_handle = end_handle;
169   }
170 
SetLockState(CsisLockState state)171   void SetLockState(CsisLockState state) {
172     log::debug("current lock state: {}, new lock state: {}",
173                static_cast<int>(lock_state_), static_cast<int>(state));
174     lock_state_ = state;
175   }
GetLockState(void)176   CsisLockState GetLockState(void) const { return lock_state_; }
GetRank(void)177   uint8_t GetRank(void) const { return rank_; }
SetRank(uint8_t rank)178   void SetRank(uint8_t rank) {
179     log::debug("current rank: {}, new rank: {}", static_cast<int>(rank_),
180                static_cast<int>(rank));
181     rank_ = rank;
182   }
183 
SetGroupId(int group_id)184   void SetGroupId(int group_id) {
185     log::info("set group id: {}, instance handle: 0x{:04x}", group_id,
186               svc_data.start_handle);
187     group_id_ = group_id;
188   }
189 
GetGroupId(void)190   int GetGroupId(void) const { return group_id_; }
191 
HasSameUuid(const CsisInstance & csis_instance)192   bool HasSameUuid(const CsisInstance& csis_instance) const {
193     return (csis_instance.coordinated_service == coordinated_service);
194   }
195 
GetUuid(void)196   const bluetooth::Uuid& GetUuid(void) const { return coordinated_service; }
IsForUuid(const bluetooth::Uuid & uuid)197   bool IsForUuid(const bluetooth::Uuid& uuid) const {
198     return coordinated_service == uuid;
199   }
200 
201  private:
202   int group_id_;
203   uint8_t rank_;
204   CsisLockState lock_state_;
205 };
206 
207 /*
208  * Csis Device represents remote device and its all CSIS instances.
209  * It can happen that device can have more than one CSIS service instance
210  * if those instances are included in other services. In this way, coordinated
211  * set is within the context of the primary service which includes the instance.
212  *
213  * CsisDevice contains vector of the instances.
214  */
215 class CsisDevice : public GattServiceDevice {
216  public:
217   using GattServiceDevice::GattServiceDevice;
218 
ClearSvcData()219   void ClearSvcData() {
220     GattServiceDevice::service_handle = GAP_INVALID_HANDLE;
221     GattServiceDevice::is_gatt_service_valid = false;
222 
223     csis_instances_.clear();
224   }
225 
FindValueHandleByCccHandle(uint16_t ccc_handle)226   uint16_t FindValueHandleByCccHandle(uint16_t ccc_handle) {
227     uint16_t val_handle = 0;
228     for (const auto& [_, inst] : csis_instances_) {
229       if (inst->svc_data.sirk_handle.ccc_hdl == ccc_handle) {
230         val_handle = inst->svc_data.sirk_handle.val_hdl;
231       } else if (inst->svc_data.lock_handle.ccc_hdl == ccc_handle) {
232         val_handle = inst->svc_data.lock_handle.val_hdl;
233       } else if (inst->svc_data.size_handle.ccc_hdl == ccc_handle) {
234         val_handle = inst->svc_data.size_handle.val_hdl;
235       }
236       if (val_handle) {
237         break;
238       }
239     }
240     return val_handle;
241   }
242 
GetCsisInstanceByOwningHandle(uint16_t handle)243   std::shared_ptr<CsisInstance> GetCsisInstanceByOwningHandle(uint16_t handle) {
244     uint16_t hdl = 0;
245     for (const auto& [h, inst] : csis_instances_) {
246       if (handle >= inst->svc_data.start_handle &&
247           handle <= inst->svc_data.end_handle) {
248         hdl = h;
249         log::verbose("found 0x{:04x}", hdl);
250         break;
251       }
252     }
253     return (hdl > 0) ? csis_instances_.at(hdl) : nullptr;
254   }
255 
GetCsisInstanceByGroupId(int group_id)256   std::shared_ptr<CsisInstance> GetCsisInstanceByGroupId(int group_id) {
257     uint16_t hdl = 0;
258     for (const auto& [handle, inst] : csis_instances_) {
259       if (inst->GetGroupId() == group_id) {
260         hdl = handle;
261         break;
262       }
263     }
264     return (hdl > 0) ? csis_instances_.at(hdl) : nullptr;
265   }
266 
SetCsisInstance(uint16_t handle,std::shared_ptr<CsisInstance> csis_instance)267   void SetCsisInstance(uint16_t handle,
268                        std::shared_ptr<CsisInstance> csis_instance) {
269     if (csis_instances_.count(handle)) {
270       log::debug("instance is already here: {}",
271                  csis_instance->GetUuid().ToString());
272       return;
273     }
274 
275     csis_instances_.insert({handle, csis_instance});
276     log::debug("instance added: 0x{:04x}, device {}", handle, addr);
277   }
278 
RemoveCsisInstance(int group_id)279   void RemoveCsisInstance(int group_id) {
280     for (auto it = csis_instances_.begin(); it != csis_instances_.end(); it++) {
281       if (it->second->GetGroupId() == group_id) {
282         csis_instances_.erase(it);
283         return;
284       }
285     }
286   }
287 
GetNumberOfCsisInstances(void)288   int GetNumberOfCsisInstances(void) { return csis_instances_.size(); }
289 
ForEachCsisInstance(std::function<void (const std::shared_ptr<CsisInstance> &)> cb)290   void ForEachCsisInstance(
291       std::function<void(const std::shared_ptr<CsisInstance>&)> cb) {
292     for (auto const& kv_pair : csis_instances_) {
293       cb(kv_pair.second);
294     }
295   }
296 
SetExpectedGroupIdMember(int group_id)297   void SetExpectedGroupIdMember(int group_id) {
298     log::info("Expected Group ID: {}, for member: {} is set", group_id, addr);
299     expected_group_id_member_ = group_id;
300   }
301 
SetPairingSirkReadFlag(bool flag)302   void SetPairingSirkReadFlag(bool flag) {
303     log::info("Pairing flag for Group ID: {}, member: {} is set to {}",
304               expected_group_id_member_, addr, flag);
305     pairing_sirk_read_flag_ = flag;
306   }
307 
GetExpectedGroupIdMember()308   inline int GetExpectedGroupIdMember() { return expected_group_id_member_; }
GetPairingSirkReadFlag()309   inline bool GetPairingSirkReadFlag() { return pairing_sirk_read_flag_; }
310 
311  private:
312   /* Instances per start handle  */
313   std::map<uint16_t, std::shared_ptr<CsisInstance>> csis_instances_;
314   int expected_group_id_member_ = bluetooth::groups::kGroupUnknown;
315   bool pairing_sirk_read_flag_ = false;
316 };
317 
318 /*
319  * CSIS group gathers devices which belongs to specific group.
320  * It also contains methond to decode encrypted SIRK and also to
321  * resolve PRSI in order to find out if device belongs to given group
322  */
323 class CsisGroup {
324  public:
CsisGroup(int group_id,const bluetooth::Uuid & uuid)325   CsisGroup(int group_id, const bluetooth::Uuid& uuid)
326       : group_id_(group_id),
327         size_(kDefaultCsisSetSize),
328         uuid_(uuid),
329         member_discovery_state_(CsisDiscoveryState::CSIS_DISCOVERY_IDLE),
330         lock_state_(CsisLockState::CSIS_STATE_UNSET),
331         target_lock_state_(CsisLockState::CSIS_STATE_UNSET),
332         lock_transition_cnt_(0) {
333     devices_.clear();
334     BTIF_STORAGE_FILL_PROPERTY(&model_name, BT_PROPERTY_REMOTE_MODEL_NUM,
335                                sizeof(model_name_val), &model_name_val);
336   }
337 
338   bt_property_t model_name;
339   bt_bdname_t model_name_val = {0};
340 
AddDevice(std::shared_ptr<CsisDevice> csis_device)341   void AddDevice(std::shared_ptr<CsisDevice> csis_device) {
342     auto it = find_if(devices_.begin(), devices_.end(),
343                       CsisDevice::MatchAddress(csis_device->addr));
344     if (it != devices_.end()) return;
345 
346     devices_.push_back(std::move(csis_device));
347   }
348 
RemoveDevice(const RawAddress & bd_addr)349   void RemoveDevice(const RawAddress& bd_addr) {
350     auto it = find_if(devices_.begin(), devices_.end(),
351                       CsisDevice::MatchAddress(bd_addr));
352     if (it != devices_.end()) devices_.erase(it);
353   }
354 
GetCurrentSize(void)355   int GetCurrentSize(void) const { return devices_.size(); }
GetUuid()356   bluetooth::Uuid GetUuid() const { return uuid_; }
SetUuid(const bluetooth::Uuid & uuid)357   void SetUuid(const bluetooth::Uuid& uuid) { uuid_ = uuid; }
GetGroupId(void)358   int GetGroupId(void) const { return group_id_; }
GetDesiredSize(void)359   int GetDesiredSize(void) const { return size_; }
SetDesiredSize(int size)360   void SetDesiredSize(int size) { size_ = size; }
IsGroupComplete(void)361   bool IsGroupComplete(void) const { return size_ == (int)devices_.size(); }
IsEmpty(void)362   bool IsEmpty(void) const { return devices_.empty(); }
363 
IsDeviceInTheGroup(std::shared_ptr<CsisDevice> & csis_device)364   bool IsDeviceInTheGroup(std::shared_ptr<CsisDevice>& csis_device) {
365     auto it = find_if(devices_.begin(), devices_.end(),
366                       CsisDevice::MatchAddress(csis_device->addr));
367     return (it != devices_.end());
368   }
IsRsiMatching(const RawAddress & rsi)369   bool IsRsiMatching(const RawAddress& rsi) const {
370     return is_rsi_match_sirk(rsi, GetSirk());
371   }
IsSirkBelongsToGroup(Octet16 sirk)372   bool IsSirkBelongsToGroup(Octet16 sirk) const {
373     return (sirk_available_ && sirk_ == sirk);
374   }
GetSirk(void)375   Octet16 GetSirk(void) const { return sirk_; }
SetSirk(Octet16 & sirk)376   void SetSirk(Octet16& sirk) {
377     if (sirk_available_) {
378       log::debug("Updating SIRK");
379     }
380     sirk_available_ = true;
381     sirk_ = sirk;
382   }
383 
GetNumOfConnectedDevices(void)384   int GetNumOfConnectedDevices(void) {
385     return std::count_if(devices_.begin(), devices_.end(),
386                          [](auto& d) { return d->IsConnected(); });
387   }
388 
GetDiscoveryState(void)389   CsisDiscoveryState GetDiscoveryState(void) const {
390     return member_discovery_state_;
391   }
SetDiscoveryState(CsisDiscoveryState state)392   void SetDiscoveryState(CsisDiscoveryState state) {
393     log::debug("current discovery state: {}, new discovery state: {}",
394                static_cast<int>(member_discovery_state_),
395                static_cast<int>(state));
396     member_discovery_state_ = state;
397   }
398 
SetCurrentLockState(CsisLockState state)399   void SetCurrentLockState(CsisLockState state) { lock_state_ = state; }
400 
401   void SetTargetLockState(CsisLockState state,
402                           CsisLockCb cb = base::DoNothing()) {
403     target_lock_state_ = state;
404     cb_ = std::move(cb);
405     switch (state) {
406       case CsisLockState::CSIS_STATE_LOCKED:
407         lock_transition_cnt_ = GetNumOfConnectedDevices();
408         break;
409       case CsisLockState::CSIS_STATE_UNLOCKED:
410       case CsisLockState::CSIS_STATE_UNSET:
411         lock_transition_cnt_ = 0;
412         break;
413     }
414   }
415 
GetLockCb(void)416   CsisLockCb GetLockCb(void) { return std::move(cb_); }
417 
GetCurrentLockState(void)418   CsisLockState GetCurrentLockState(void) const { return lock_state_; }
GetTargetLockState(void)419   CsisLockState GetTargetLockState(void) const { return target_lock_state_; }
420 
IsAvailableForCsisLockOperation(void)421   bool IsAvailableForCsisLockOperation(void) {
422     int id = group_id_;
423     int number_of_connected = 0;
424     auto iter = std::find_if(
425         devices_.begin(), devices_.end(), [id, &number_of_connected](auto& d) {
426           if (!d->IsConnected()) {
427             log::debug("Device {} is not connected in group {}", d->addr, id);
428             return false;
429           }
430           auto inst = d->GetCsisInstanceByGroupId(id);
431           if (!inst) {
432             log::debug("Instance not available for group {}", id);
433             return false;
434           }
435           number_of_connected++;
436           log::debug("Device {},  lock state: {}", d->addr,
437                      (int)inst->GetLockState());
438           return inst->GetLockState() == CsisLockState::CSIS_STATE_LOCKED;
439         });
440 
441     log::debug("Locked set: {}, number of connected {}", iter != devices_.end(),
442                number_of_connected);
443     /* If there is no locked device, we are good to go */
444     if (iter != devices_.end()) {
445       log::warn("Device {} is locked", (*iter)->addr);
446       return false;
447     }
448 
449     return (number_of_connected > 0);
450   }
451 
SortByCsisRank(void)452   void SortByCsisRank(void) {
453     int id = group_id_;
454     std::sort(devices_.begin(), devices_.end(), [id](auto& dev1, auto& dev2) {
455       auto inst1 = dev1->GetCsisInstanceByGroupId(id);
456       auto inst2 = dev2->GetCsisInstanceByGroupId(id);
457       if (!inst1 || !inst2) {
458         /* One of the device is not connected */
459         log::debug("Device  {} is not connected.",
460                    inst1 == nullptr ? ADDRESS_TO_LOGGABLE_CSTR(dev1->addr)
461                                     : ADDRESS_TO_LOGGABLE_CSTR(dev2->addr));
462         return dev1->IsConnected();
463       }
464       return (inst1->GetRank() < inst2->GetRank());
465     });
466   }
467 
GetFirstDevice(void)468   std::shared_ptr<CsisDevice> GetFirstDevice(void) {
469     return (devices_.front());
470   }
GetLastDevice(void)471   std::shared_ptr<CsisDevice> GetLastDevice(void) { return (devices_.back()); }
GetNextDevice(std::shared_ptr<CsisDevice> & device)472   std::shared_ptr<CsisDevice> GetNextDevice(
473       std::shared_ptr<CsisDevice>& device) {
474     auto iter = std::find_if(devices_.begin(), devices_.end(),
475                              CsisDevice::MatchAddress(device->addr));
476 
477     /* If reference device not found */
478     if (iter == devices_.end()) return nullptr;
479 
480     iter++;
481     /* If reference device is last in group */
482     if (iter == devices_.end()) return nullptr;
483 
484     return (*iter);
485   }
GetPrevDevice(std::shared_ptr<CsisDevice> & device)486   std::shared_ptr<CsisDevice> GetPrevDevice(
487       std::shared_ptr<CsisDevice>& device) {
488     auto iter = std::find_if(devices_.rbegin(), devices_.rend(),
489                              CsisDevice::MatchAddress(device->addr));
490 
491     /* If reference device not found */
492     if (iter == devices_.rend()) return nullptr;
493 
494     iter++;
495 
496     if (iter == devices_.rend()) return nullptr;
497     return (*iter);
498   }
499 
GetLockTransitionCnt(void)500   int GetLockTransitionCnt(void) const { return lock_transition_cnt_; }
UpdateLockTransitionCnt(int i)501   int UpdateLockTransitionCnt(int i) {
502     lock_transition_cnt_ += i;
503     return lock_transition_cnt_;
504   }
505 
506   /* Return true if given Autoset Private Address |srpa| matches Set Identity
507    * Resolving Key |sirk| */
is_rsi_match_sirk(const RawAddress & rsi,const Octet16 & sirk)508   static bool is_rsi_match_sirk(const RawAddress& rsi, const Octet16& sirk) {
509     /* use the 3 MSB of bd address as prand */
510     Octet16 rand{};
511     rand[0] = rsi.address[2];
512     rand[1] = rsi.address[1];
513     rand[2] = rsi.address[0];
514 #ifdef CSIS_DEBUG
515     log::info("Prand {}", base::HexEncode(rand.data(), 3));
516     log::info("SIRK {}", base::HexEncode(sirk.data(), 16));
517 #endif
518 
519     /* generate X = E irk(R0, R1, R2) and R is random address 3 LSO */
520     Octet16 x = crypto_toolbox::aes_128(sirk, rand);
521 
522 #ifdef CSIS_DEBUG
523     log::info("X {}", base::HexEncode(x.data(), 16));
524 #endif
525 
526     rand[0] = rsi.address[5];
527     rand[1] = rsi.address[4];
528     rand[2] = rsi.address[3];
529 
530 #ifdef CSIS_DEBUG
531     log::info("Hash {}", base::HexEncode(rand.data(), 3));
532 #endif
533 
534     if (memcmp(x.data(), &rand[0], 3) == 0) {
535       // match
536       return true;
537     }
538     // not a match
539     return false;
540   }
541 
542  private:
543   int group_id_;
544   Octet16 sirk_ = {0};
545   bool sirk_available_ = false;
546   int size_;
547   bluetooth::Uuid uuid_;
548 
549   std::vector<std::shared_ptr<CsisDevice>> devices_;
550   CsisDiscoveryState member_discovery_state_;
551 
552   CsisLockState lock_state_;
553   CsisLockState target_lock_state_;
554   int lock_transition_cnt_;
555 
556   CsisLockCb cb_;
557 };
558 
559 }  // namespace csis
560 }  // namespace bluetooth
561