1 /*
2  * Copyright 2022 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "hci/acl_manager.h"
18 
19 #include <gmock/gmock.h>
20 #include <gtest/gtest.h>
21 
22 #include <chrono>
23 #include <deque>
24 #include <future>
25 #include <list>
26 #include <map>
27 
28 #include "common/bind.h"
29 #include "common/init_flags.h"
30 #include "hci/address.h"
31 #include "hci/address_with_type.h"
32 #include "hci/class_of_device.h"
33 #include "hci/controller.h"
34 #include "hci/hci_layer.h"
35 #include "hci/hci_layer_fake.h"
36 #include "os/thread.h"
37 #include "packet/raw_builder.h"
38 
39 using namespace std::chrono_literals;
40 
41 namespace bluetooth {
42 namespace hci {
43 namespace acl_manager {
44 namespace {
45 
46 using common::BidiQueue;
47 using common::BidiQueueEnd;
48 using packet::kLittleEndian;
49 using packet::PacketView;
50 using packet::RawBuilder;
51 
52 namespace {
53 constexpr char kLocalRandomAddressString[] = "D0:05:04:03:02:01";
54 constexpr char kRemotePublicDeviceStringA[] = "11:A2:A3:A4:A5:A6";
55 constexpr char kRemotePublicDeviceStringB[] = "11:B2:B3:B4:B5:B6";
56 constexpr uint16_t kHciHandleA = 123;
57 constexpr uint16_t kHciHandleB = 456;
58 
59 constexpr auto kMinimumRotationTime = std::chrono::milliseconds(7 * 60 * 1000);
60 constexpr auto kMaximumRotationTime = std::chrono::milliseconds(15 * 60 * 1000);
61 
62 const AddressWithType empty_address_with_type = hci::AddressWithType();
63 
64 struct {
65   Address address;
66   ClassOfDevice class_of_device;
67   const uint16_t handle;
68 } remote_device[2] = {
69     {.address = {}, .class_of_device = {}, .handle = kHciHandleA},
70     {.address = {}, .class_of_device = {}, .handle = kHciHandleB},
71 };
72 }  // namespace
73 
NextPayload(uint16_t handle)74 std::unique_ptr<BasePacketBuilder> NextPayload(uint16_t handle) {
75   static uint32_t packet_number = 1;
76   auto payload = std::make_unique<RawBuilder>();
77   payload->AddOctets2(6);  // L2CAP PDU size
78   payload->AddOctets2(2);  // L2CAP CID
79   payload->AddOctets2(handle);
80   payload->AddOctets4(packet_number++);
81   return std::move(payload);
82 }
83 
84 class TestController : public Controller {
85  public:
GetAclPacketLength() const86   uint16_t GetAclPacketLength() const override {
87     return acl_buffer_length_;
88   }
89 
GetNumAclPacketBuffers() const90   uint16_t GetNumAclPacketBuffers() const override {
91     return total_acl_buffers_;
92   }
93 
IsSupported(bluetooth::hci::OpCode) const94   bool IsSupported(bluetooth::hci::OpCode /* op_code */) const override {
95     return false;
96   }
97 
GetLeBufferSize() const98   LeBufferSize GetLeBufferSize() const override {
99     LeBufferSize le_buffer_size;
100     le_buffer_size.total_num_le_packets_ = 2;
101     le_buffer_size.le_data_packet_length_ = 32;
102     return le_buffer_size;
103   }
104 
105  protected:
Start()106   void Start() override {}
Stop()107   void Stop() override {}
ListDependencies(ModuleList *) const108   void ListDependencies(ModuleList* /* list */) const {}
109 
110  private:
111   uint16_t acl_buffer_length_ = 1024;
112   uint16_t total_acl_buffers_ = 2;
113   common::ContextualCallback<void(uint16_t /* handle */, uint16_t /* packets */)> acl_cb_;
114 };
115 
116 class MockConnectionCallback : public ConnectionCallbacks {
117  public:
OnConnectSuccess(std::unique_ptr<ClassicAclConnection> connection)118   void OnConnectSuccess(std::unique_ptr<ClassicAclConnection> connection) override {
119     // Convert to std::shared_ptr during push_back()
120     connections_.push_back(std::move(connection));
121     if (is_promise_set_) {
122       is_promise_set_ = false;
123       connection_promise_.set_value(connections_.back());
124     }
125   }
126   MOCK_METHOD(void, OnConnectRequest, (Address, ClassOfDevice), (override));
127   MOCK_METHOD(void, OnConnectFail, (Address, ErrorCode reason, bool locally_initiated), (override));
128 
NumberOfConnections() const129   size_t NumberOfConnections() const {
130     return connections_.size();
131   }
132 
133  private:
134   friend class AclManagerWithCallbacksTest;
135   friend class AclManagerNoCallbacksTest;
136 
137   std::deque<std::shared_ptr<ClassicAclConnection>> connections_;
138   std::promise<std::shared_ptr<ClassicAclConnection>> connection_promise_;
139   bool is_promise_set_{false};
140 };
141 
142 class MockLeConnectionCallbacks : public LeConnectionCallbacks {
143  public:
OnLeConnectSuccess(AddressWithType,std::unique_ptr<LeAclConnection> connection)144   void OnLeConnectSuccess(
145       AddressWithType /* address_with_type */,
146       std::unique_ptr<LeAclConnection> connection) override {
147     le_connections_.push_back(std::move(connection));
148     if (le_connection_promise_ != nullptr) {
149       std::promise<void>* prom = le_connection_promise_.release();
150       prom->set_value();
151       delete prom;
152     }
153   }
154   MOCK_METHOD(void, OnLeConnectFail, (AddressWithType, ErrorCode reason), (override));
155 
156   std::deque<std::shared_ptr<LeAclConnection>> le_connections_;
157   std::unique_ptr<std::promise<void>> le_connection_promise_;
158 };
159 
160 class AclManagerBaseTest : public ::testing::Test {
161  protected:
SetUp()162   void SetUp() override {
163     common::InitFlags::SetAllForTesting();
164     test_hci_layer_ = new HciLayerFake;  // Ownership is transferred to registry
165     test_controller_ = new TestController;
166     fake_registry_.InjectTestModule(&HciLayer::Factory, test_hci_layer_);
167     fake_registry_.InjectTestModule(&Controller::Factory, test_controller_);
168     client_handler_ = fake_registry_.GetTestModuleHandler(&HciLayer::Factory);
169     ASSERT_NE(client_handler_, nullptr);
170     fake_registry_.Start<AclManager>(&thread_);
171   }
172 
TearDown()173   void TearDown() override {
174     fake_registry_.SynchronizeModuleHandler(&AclManager::Factory, std::chrono::milliseconds(20));
175     fake_registry_.StopAll();
176   }
177 
sync_client_handler()178   void sync_client_handler() {
179     std::promise<void> promise;
180     auto future = promise.get_future();
181     client_handler_->Post(common::BindOnce(&std::promise<void>::set_value, common::Unretained(&promise)));
182     auto future_status = future.wait_for(std::chrono::seconds(1));
183     ASSERT_EQ(future_status, std::future_status::ready);
184   }
185 
186   HciLayerFake* test_hci_layer_ = nullptr;
187   TestController* test_controller_ = nullptr;
188 
189   TestModuleRegistry fake_registry_;
190   os::Thread& thread_ = fake_registry_.GetTestThread();
191   AclManager* acl_manager_ = nullptr;
192   os::Handler* client_handler_ = nullptr;
193 };
194 
195 class AclManagerNoCallbacksTest : public AclManagerBaseTest {
196  protected:
SetUp()197   void SetUp() override {
198     AclManagerBaseTest::SetUp();
199 
200     acl_manager_ = static_cast<AclManager*>(fake_registry_.GetModuleUnderTest(&AclManager::Factory));
201 
202     local_address_with_type_ = AddressWithType(
203         Address::FromString(kLocalRandomAddressString).value(), hci::AddressType::RANDOM_DEVICE_ADDRESS);
204 
205     acl_manager_->SetPrivacyPolicyForInitiatorAddress(
206         LeAddressManager::AddressPolicy::USE_STATIC_ADDRESS,
207         local_address_with_type_,
208         kMinimumRotationTime,
209         kMaximumRotationTime);
210 
211     auto command = test_hci_layer_->GetCommand();
212     ASSERT_TRUE(command.IsValid());
213     ASSERT_EQ(OpCode::LE_SET_RANDOM_ADDRESS, command.GetOpCode());
214   }
215 
TearDown()216   void TearDown() override {
217     AclManagerBaseTest::TearDown();
218   }
219 
220   AddressWithType local_address_with_type_;
221   const bool use_accept_list_ = true;  // gd currently only supports connect list
222 
SendAclData(uint16_t handle,AclConnection::QueueUpEnd * queue_end)223   void SendAclData(uint16_t handle, AclConnection::QueueUpEnd* queue_end) {
224     std::promise<void> promise;
225     auto future = promise.get_future();
226     queue_end->RegisterEnqueue(
227         client_handler_,
228         common::Bind(
229             [](decltype(queue_end) queue_end, uint16_t handle, std::promise<void> promise) {
230               queue_end->UnregisterEnqueue();
231               promise.set_value();
232               return NextPayload(handle);
233             },
234             queue_end,
235             handle,
236             common::Passed(std::move(promise))));
237     auto status = future.wait_for(2s);
238     ASSERT_EQ(status, std::future_status::ready);
239   }
240 };
241 
242 class AclManagerWithCallbacksTest : public AclManagerNoCallbacksTest {
243  protected:
SetUp()244   void SetUp() override {
245     AclManagerNoCallbacksTest::SetUp();
246     acl_manager_->RegisterCallbacks(&mock_connection_callbacks_, client_handler_);
247     acl_manager_->RegisterLeCallbacks(&mock_le_connection_callbacks_, client_handler_);
248   }
249 
TearDown()250   void TearDown() override {
251     fake_registry_.SynchronizeModuleHandler(&HciLayer::Factory, std::chrono::milliseconds(20));
252     fake_registry_.SynchronizeModuleHandler(&AclManager::Factory, std::chrono::milliseconds(20));
253     fake_registry_.SynchronizeModuleHandler(&HciLayer::Factory, std::chrono::milliseconds(20));
254     {
255       std::promise<void> promise;
256       auto future = promise.get_future();
257       acl_manager_->UnregisterLeCallbacks(&mock_le_connection_callbacks_, std::move(promise));
258       future.wait_for(2s);
259     }
260     {
261       std::promise<void> promise;
262       auto future = promise.get_future();
263       acl_manager_->UnregisterCallbacks(&mock_connection_callbacks_, std::move(promise));
264       future.wait_for(2s);
265     }
266 
267     mock_connection_callbacks_.connections_.clear();
268     mock_le_connection_callbacks_.le_connections_.clear();
269 
270     AclManagerNoCallbacksTest::TearDown();
271   }
272 
GetConnectionFuture()273   std::future<std::shared_ptr<ClassicAclConnection>> GetConnectionFuture() {
274     // Run on main thread
275     mock_connection_callbacks_.connection_promise_ = std::promise<std::shared_ptr<ClassicAclConnection>>();
276     mock_connection_callbacks_.is_promise_set_ = true;
277     return mock_connection_callbacks_.connection_promise_.get_future();
278   }
279 
GetLeConnectionFuture()280   std::future<void> GetLeConnectionFuture() {
281     mock_le_connection_callbacks_.le_connection_promise_ = std::make_unique<std::promise<void>>();
282     return mock_le_connection_callbacks_.le_connection_promise_->get_future();
283   }
284 
GetLastConnection()285   std::shared_ptr<ClassicAclConnection> GetLastConnection() {
286     return mock_connection_callbacks_.connections_.back();
287   }
288 
NumberOfConnections()289   size_t NumberOfConnections() {
290     return mock_connection_callbacks_.connections_.size();
291   }
292 
GetLastLeConnection()293   std::shared_ptr<LeAclConnection> GetLastLeConnection() {
294     return mock_le_connection_callbacks_.le_connections_.back();
295   }
296 
NumberOfLeConnections()297   size_t NumberOfLeConnections() {
298     return mock_le_connection_callbacks_.le_connections_.size();
299   }
300 
301   MockConnectionCallback mock_connection_callbacks_;
302   MockLeConnectionCallbacks mock_le_connection_callbacks_;
303 };
304 
305 class AclManagerWithConnectionTest : public AclManagerWithCallbacksTest {
306  protected:
SetUp()307   void SetUp() override {
308     AclManagerWithCallbacksTest::SetUp();
309 
310     handle_ = 0x123;
311     Address::FromString("A1:A2:A3:A4:A5:A6", remote);
312 
313     acl_manager_->CreateConnection(remote);
314 
315     // Wait for the connection request
316     auto last_command = test_hci_layer_->GetCommand(OpCode::CREATE_CONNECTION);
317 
318     EXPECT_CALL(mock_connection_management_callbacks_, OnRoleChange(hci::ErrorCode::SUCCESS, Role::CENTRAL));
319 
320     auto first_connection = GetConnectionFuture();
321     test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
322         ErrorCode::SUCCESS, handle_, remote, LinkType::ACL, Enable::DISABLED));
323 
324     auto first_connection_status = first_connection.wait_for(2s);
325     ASSERT_EQ(first_connection_status, std::future_status::ready);
326 
327     connection_ = GetLastConnection();
328     connection_->RegisterCallbacks(&mock_connection_management_callbacks_, client_handler_);
329   }
330 
TearDown()331   void TearDown() override {
332     fake_registry_.SynchronizeModuleHandler(&HciLayer::Factory, std::chrono::milliseconds(20));
333     fake_registry_.SynchronizeModuleHandler(&AclManager::Factory, std::chrono::milliseconds(20));
334     fake_registry_.StopAll();
335   }
336 
337   uint16_t handle_;
338   Address remote;
339   std::shared_ptr<ClassicAclConnection> connection_;
340 
341   class MockConnectionManagementCallbacks : public ConnectionManagementCallbacks {
342    public:
343     MOCK_METHOD1(OnConnectionPacketTypeChanged, void(uint16_t packet_type));
344     MOCK_METHOD1(OnAuthenticationComplete, void(hci::ErrorCode hci_status));
345     MOCK_METHOD1(OnEncryptionChange, void(EncryptionEnabled enabled));
346     MOCK_METHOD0(OnChangeConnectionLinkKeyComplete, void());
347     MOCK_METHOD1(OnReadClockOffsetComplete, void(uint16_t clock_offse));
348     MOCK_METHOD3(OnModeChange, void(ErrorCode status, Mode current_mode, uint16_t interval));
349     MOCK_METHOD5(
350         OnSniffSubrating,
351         void(
352             ErrorCode status,
353             uint16_t maximum_transmit_latency,
354             uint16_t maximum_receive_latency,
355             uint16_t minimum_remote_timeout,
356             uint16_t minimum_local_timeout));
357     MOCK_METHOD5(
358         OnQosSetupComplete,
359         void(
360             ServiceType service_type,
361             uint32_t token_rate,
362             uint32_t peak_bandwidth,
363             uint32_t latency,
364             uint32_t delay_variation));
365     MOCK_METHOD6(
366         OnFlowSpecificationComplete,
367         void(
368             FlowDirection flow_direction,
369             ServiceType service_type,
370             uint32_t token_rate,
371             uint32_t token_bucket_size,
372             uint32_t peak_bandwidth,
373             uint32_t access_latency));
374     MOCK_METHOD0(OnFlushOccurred, void());
375     MOCK_METHOD1(OnRoleDiscoveryComplete, void(Role current_role));
376     MOCK_METHOD1(OnReadLinkPolicySettingsComplete, void(uint16_t link_policy_settings));
377     MOCK_METHOD1(OnReadAutomaticFlushTimeoutComplete, void(uint16_t flush_timeout));
378     MOCK_METHOD1(OnReadTransmitPowerLevelComplete, void(uint8_t transmit_power_level));
379     MOCK_METHOD1(OnReadLinkSupervisionTimeoutComplete, void(uint16_t link_supervision_timeout));
380     MOCK_METHOD1(OnReadFailedContactCounterComplete, void(uint16_t failed_contact_counter));
381     MOCK_METHOD1(OnReadLinkQualityComplete, void(uint8_t link_quality));
382     MOCK_METHOD2(OnReadAfhChannelMapComplete, void(AfhMode afh_mode, std::array<uint8_t, 10> afh_channel_map));
383     MOCK_METHOD1(OnReadRssiComplete, void(uint8_t rssi));
384     MOCK_METHOD2(OnReadClockComplete, void(uint32_t clock, uint16_t accuracy));
385     MOCK_METHOD1(OnCentralLinkKeyComplete, void(KeyFlag flag));
386     MOCK_METHOD2(OnRoleChange, void(ErrorCode hci_status, Role new_role));
387     MOCK_METHOD1(OnDisconnection, void(ErrorCode reason));
388     MOCK_METHOD4(
389         OnReadRemoteVersionInformationComplete,
390         void(hci::ErrorCode hci_status, uint8_t lmp_version, uint16_t manufacturer_name, uint16_t sub_version));
391     MOCK_METHOD1(OnReadRemoteSupportedFeaturesComplete, void(uint64_t features));
392     MOCK_METHOD3(
393         OnReadRemoteExtendedFeaturesComplete, void(uint8_t page_number, uint8_t max_page_number, uint64_t features));
394   } mock_connection_management_callbacks_;
395 };
396 
TEST_F(AclManagerWithCallbacksTest,startup_teardown)397 TEST_F(AclManagerWithCallbacksTest, startup_teardown) {}
398 
399 class AclManagerWithLeConnectionTest : public AclManagerWithCallbacksTest {
400  protected:
SetUp()401   void SetUp() override {
402     AclManagerWithCallbacksTest::SetUp();
403 
404     Address remote_public_address = Address::FromString(kRemotePublicDeviceStringA).value();
405     remote_with_type_ = AddressWithType(remote_public_address, AddressType::PUBLIC_DEVICE_ADDRESS);
406     acl_manager_->CreateLeConnection(remote_with_type_, true);
407     test_hci_layer_->GetCommand(OpCode::LE_ADD_DEVICE_TO_FILTER_ACCEPT_LIST);
408     test_hci_layer_->IncomingEvent(
409         LeAddDeviceToFilterAcceptListCompleteBuilder::Create(0x01, ErrorCode::SUCCESS));
410     auto packet = test_hci_layer_->GetCommand(OpCode::LE_CREATE_CONNECTION);
411     auto le_connection_management_command_view =
412         LeConnectionManagementCommandView::Create(AclCommandView::Create(packet));
413     auto command_view = LeCreateConnectionView::Create(le_connection_management_command_view);
414     ASSERT_TRUE(command_view.IsValid());
415     if (use_accept_list_) {
416       ASSERT_EQ(command_view.GetPeerAddress(), empty_address_with_type.GetAddress());
417       ASSERT_EQ(command_view.GetPeerAddressType(), empty_address_with_type.GetAddressType());
418     } else {
419       ASSERT_EQ(command_view.GetPeerAddress(), remote_public_address);
420       ASSERT_EQ(command_view.GetPeerAddressType(), AddressType::PUBLIC_DEVICE_ADDRESS);
421     }
422 
423     test_hci_layer_->IncomingEvent(
424         LeCreateConnectionStatusBuilder::Create(ErrorCode::SUCCESS, 0x01));
425 
426     auto first_connection = GetLeConnectionFuture();
427 
428     test_hci_layer_->IncomingLeMetaEvent(LeConnectionCompleteBuilder::Create(
429         ErrorCode::SUCCESS,
430         handle_,
431         Role::PERIPHERAL,
432         AddressType::PUBLIC_DEVICE_ADDRESS,
433         remote_public_address,
434         0x0100,
435         0x0010,
436         0x0C80,
437         ClockAccuracy::PPM_30));
438 
439     test_hci_layer_->GetCommand(OpCode::LE_REMOVE_DEVICE_FROM_FILTER_ACCEPT_LIST);
440     test_hci_layer_->IncomingEvent(
441         LeRemoveDeviceFromFilterAcceptListCompleteBuilder::Create(0x01, ErrorCode::SUCCESS));
442 
443     auto first_connection_status = first_connection.wait_for(2s);
444     ASSERT_EQ(first_connection_status, std::future_status::ready);
445 
446     connection_ = GetLastLeConnection();
447   }
448 
TearDown()449   void TearDown() override {
450     fake_registry_.SynchronizeModuleHandler(&HciLayer::Factory, std::chrono::milliseconds(20));
451     fake_registry_.SynchronizeModuleHandler(&AclManager::Factory, std::chrono::milliseconds(20));
452     fake_registry_.StopAll();
453   }
454 
sync_client_handler()455   void sync_client_handler() {
456     std::promise<void> promise;
457     auto future = promise.get_future();
458     client_handler_->Post(common::BindOnce(&std::promise<void>::set_value, common::Unretained(&promise)));
459     auto future_status = future.wait_for(std::chrono::seconds(1));
460     ASSERT_EQ(future_status, std::future_status::ready);
461   }
462 
463   uint16_t handle_ = 0x123;
464   std::shared_ptr<LeAclConnection> connection_;
465   AddressWithType remote_with_type_;
466 
467   class MockLeConnectionManagementCallbacks : public LeConnectionManagementCallbacks {
468    public:
469     MOCK_METHOD1(OnDisconnection, void(ErrorCode reason));
470     MOCK_METHOD4(
471         OnConnectionUpdate,
472         void(
473             hci::ErrorCode hci_status,
474             uint16_t connection_interval,
475             uint16_t connection_latency,
476             uint16_t supervision_timeout));
477     MOCK_METHOD4(OnDataLengthChange, void(uint16_t tx_octets, uint16_t tx_time, uint16_t rx_octets, uint16_t rx_time));
478     MOCK_METHOD4(
479         OnReadRemoteVersionInformationComplete,
480         void(hci::ErrorCode hci_status, uint8_t version, uint16_t manufacturer_name, uint16_t sub_version));
481     MOCK_METHOD2(OnLeReadRemoteFeaturesComplete, void(hci::ErrorCode hci_status, uint64_t features));
482     MOCK_METHOD3(OnPhyUpdate, void(hci::ErrorCode hci_status, uint8_t tx_phy, uint8_t rx_phy));
483     MOCK_METHOD5(
484         OnLeSubrateChange,
485         void(
486             hci::ErrorCode hci_status,
487             uint16_t subrate_factor,
488             uint16_t peripheral_latency,
489             uint16_t continuation_number,
490             uint16_t supervision_timeout));
491   } mock_le_connection_management_callbacks_;
492 };
493 
494 class AclManagerWithResolvableAddressTest : public AclManagerWithCallbacksTest {
495  protected:
SetUp()496   void SetUp() override {
497     test_hci_layer_ = new HciLayerFake;  // Ownership is transferred to registry
498     test_controller_ = new TestController;
499     fake_registry_.InjectTestModule(&HciLayer::Factory, test_hci_layer_);
500     fake_registry_.InjectTestModule(&Controller::Factory, test_controller_);
501     client_handler_ = fake_registry_.GetTestModuleHandler(&HciLayer::Factory);
502     ASSERT_NE(client_handler_, nullptr);
503     fake_registry_.Start<AclManager>(&thread_);
504     acl_manager_ = static_cast<AclManager*>(fake_registry_.GetModuleUnderTest(&AclManager::Factory));
505     hci::Address address;
506     Address::FromString("D0:05:04:03:02:01", address);
507     hci::AddressWithType address_with_type(address, hci::AddressType::RANDOM_DEVICE_ADDRESS);
508     acl_manager_->RegisterCallbacks(&mock_connection_callbacks_, client_handler_);
509     acl_manager_->RegisterLeCallbacks(&mock_le_connection_callbacks_, client_handler_);
510     auto minimum_rotation_time = std::chrono::milliseconds(7 * 60 * 1000);
511     auto maximum_rotation_time = std::chrono::milliseconds(15 * 60 * 1000);
512     acl_manager_->SetPrivacyPolicyForInitiatorAddress(
513         LeAddressManager::AddressPolicy::USE_RESOLVABLE_ADDRESS,
514         address_with_type,
515         minimum_rotation_time,
516         maximum_rotation_time);
517 
518     test_hci_layer_->GetCommand(OpCode::LE_SET_RANDOM_ADDRESS);
519     test_hci_layer_->IncomingEvent(
520         LeSetRandomAddressCompleteBuilder::Create(0x01, ErrorCode::SUCCESS));
521   }
522 };
523 
TEST_F(AclManagerNoCallbacksTest,unregister_classic_before_connection_request)524 TEST_F(AclManagerNoCallbacksTest, unregister_classic_before_connection_request) {
525   ClassOfDevice class_of_device;
526 
527   MockConnectionCallback mock_connection_callbacks_;
528 
529   acl_manager_->RegisterCallbacks(&mock_connection_callbacks_, client_handler_);
530 
531   // Unregister callbacks before receiving connection request
532   auto promise = std::promise<void>();
533   auto future = promise.get_future();
534   acl_manager_->UnregisterCallbacks(&mock_connection_callbacks_, std::move(promise));
535   future.get();
536 
537   // Inject peer sending connection request
538   test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
539       local_address_with_type_.GetAddress(), class_of_device, ConnectionRequestLinkType::ACL));
540   sync_client_handler();
541 
542   // There should be no connections
543   ASSERT_EQ(0UL, mock_connection_callbacks_.NumberOfConnections());
544 
545   auto command = test_hci_layer_->GetCommand(OpCode::REJECT_CONNECTION_REQUEST);
546 }
547 
TEST_F(AclManagerWithCallbacksTest,two_remote_connection_requests_ABAB)548 TEST_F(AclManagerWithCallbacksTest, two_remote_connection_requests_ABAB) {
549   Address::FromString(kRemotePublicDeviceStringA, remote_device[0].address);
550   Address::FromString(kRemotePublicDeviceStringB, remote_device[1].address);
551 
552   {
553     // Device A sends connection request
554     test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
555         remote_device[0].address,
556         remote_device[0].class_of_device,
557         ConnectionRequestLinkType::ACL));
558     sync_client_handler();
559     // Verify we accept this connection
560     auto command = test_hci_layer_->GetCommand(OpCode::ACCEPT_CONNECTION_REQUEST);
561   }
562 
563   {
564     // Device B sends connection request
565     test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
566         remote_device[1].address,
567         remote_device[1].class_of_device,
568         ConnectionRequestLinkType::ACL));
569     sync_client_handler();
570     // Verify we accept this connection
571     auto command = test_hci_layer_->GetCommand(OpCode::ACCEPT_CONNECTION_REQUEST);
572   }
573 
574   ASSERT_EQ(0UL, NumberOfConnections());
575 
576   {
577     // Device A completes first connection
578     auto future = GetConnectionFuture();
579     test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
580         ErrorCode::SUCCESS,
581         remote_device[0].handle,
582         remote_device[0].address,
583         LinkType::ACL,
584         Enable::DISABLED));
585     ASSERT_EQ(std::future_status::ready, future.wait_for(2s)) << "Timeout waiting for first connection complete";
586     ASSERT_EQ(1UL, NumberOfConnections());
587     auto connection = future.get();
588     ASSERT_EQ(connection->GetAddress(), remote_device[0].address) << "First connection remote address mismatch";
589   }
590 
591   {
592     // Device B completes second connection
593     auto future = GetConnectionFuture();
594     test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
595         ErrorCode::SUCCESS,
596         remote_device[1].handle,
597         remote_device[1].address,
598         LinkType::ACL,
599         Enable::DISABLED));
600     ASSERT_EQ(std::future_status::ready, future.wait_for(2s)) << "Timeout waiting for second connection complete";
601     ASSERT_EQ(2UL, NumberOfConnections());
602     auto connection = future.get();
603     ASSERT_EQ(connection->GetAddress(), remote_device[1].address) << "Second connection remote address mismatch";
604   }
605 }
606 
TEST_F(AclManagerWithCallbacksTest,two_remote_connection_requests_ABBA)607 TEST_F(AclManagerWithCallbacksTest, two_remote_connection_requests_ABBA) {
608   Address::FromString(kRemotePublicDeviceStringA, remote_device[0].address);
609   Address::FromString(kRemotePublicDeviceStringB, remote_device[1].address);
610 
611   {
612     // Device A sends connection request
613     test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
614         remote_device[0].address,
615         remote_device[0].class_of_device,
616         ConnectionRequestLinkType::ACL));
617     sync_client_handler();
618     // Verify we accept this connection
619     auto command = test_hci_layer_->GetCommand(OpCode::ACCEPT_CONNECTION_REQUEST);
620   }
621 
622   {
623     // Device B sends connection request
624     test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
625         remote_device[1].address,
626         remote_device[1].class_of_device,
627         ConnectionRequestLinkType::ACL));
628     sync_client_handler();
629     // Verify we accept this connection
630     auto command = test_hci_layer_->GetCommand(OpCode::ACCEPT_CONNECTION_REQUEST);
631   }
632 
633   ASSERT_EQ(0UL, NumberOfConnections());
634 
635   {
636     // Device B completes first connection
637     auto future = GetConnectionFuture();
638     test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
639         ErrorCode::SUCCESS,
640         remote_device[1].handle,
641         remote_device[1].address,
642         LinkType::ACL,
643         Enable::DISABLED));
644     ASSERT_EQ(std::future_status::ready, future.wait_for(2s)) << "Timeout waiting for first connection complete";
645     ASSERT_EQ(1UL, NumberOfConnections());
646     auto connection = future.get();
647     ASSERT_EQ(connection->GetAddress(), remote_device[1].address) << "First connection remote address mismatch";
648   }
649 
650   {
651     // Device A completes second connection
652     auto future = GetConnectionFuture();
653     test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
654         ErrorCode::SUCCESS,
655         remote_device[0].handle,
656         remote_device[0].address,
657         LinkType::ACL,
658         Enable::DISABLED));
659     ASSERT_EQ(std::future_status::ready, future.wait_for(2s)) << "Timeout waiting for second connection complete";
660     ASSERT_EQ(2UL, NumberOfConnections());
661     auto connection = future.get();
662     ASSERT_EQ(connection->GetAddress(), remote_device[0].address) << "Second connection remote address mismatch";
663   }
664 }
665 
TEST_F(AclManagerWithCallbacksTest,test_disconnection_after_request)666 TEST_F(AclManagerWithCallbacksTest, test_disconnection_after_request) {
667   Address remote = *Address::FromString("12:34:56:78:9a:bc");
668   EXPECT_CALL(mock_connection_callbacks_, OnConnectRequest).Times(1);
669   test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
670       remote, ClassOfDevice({1, 2, 3}), ConnectionRequestLinkType::ACL));
671   test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
672       ErrorCode::REMOTE_USER_TERMINATED_CONNECTION, 0, remote, LinkType::ACL, Enable::DISABLED));
673 }
674 
TEST_F(AclManagerWithCallbacksTest,test_disconnection_after_request_sync)675 TEST_F(AclManagerWithCallbacksTest, test_disconnection_after_request_sync) {
676   std::promise<void> request_promise;
677   auto request_future = request_promise.get_future();
678 
679   Address remote = *Address::FromString("12:34:56:78:9a:bc");
680   EXPECT_CALL(mock_connection_callbacks_, OnConnectRequest).WillOnce([&request_promise]() {
681     request_promise.set_value();
682   });
683   test_hci_layer_->IncomingEvent(ConnectionRequestBuilder::Create(
684       remote, ClassOfDevice({1, 2, 3}), ConnectionRequestLinkType::ACL));
685   ASSERT_EQ(std::future_status::ready, request_future.wait_for(std::chrono::seconds(1)));
686   test_hci_layer_->IncomingEvent(ConnectionCompleteBuilder::Create(
687       ErrorCode::REMOTE_USER_TERMINATED_CONNECTION, 0, remote, LinkType::ACL, Enable::DISABLED));
688 }
689 
690 }  // namespace
691 }  // namespace acl_manager
692 }  // namespace hci
693 }  // namespace bluetooth
694