1 /*
2  * Copyright 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "test_channel_transport.h"
18 
19 #include <errno.h>   // for errno, EBADF
20 #include <stddef.h>  // for size_t
21 
22 #include <cstdint>      // for uint8_t
23 #include <cstring>      // for strerror
24 #include <type_traits>  // for remove_extent_t
25 
26 #include "log.h"
27 #include "net/async_data_channel.h"  // for AsyncDataChannel
28 
29 using std::vector;
30 
31 namespace rootcanal {
32 
SetUp(std::shared_ptr<AsyncDataChannelServer> server,ConnectCallback connection_callback)33 bool TestChannelTransport::SetUp(std::shared_ptr<AsyncDataChannelServer> server,
34                                  ConnectCallback connection_callback) {
35   socket_server_ = server;
36   socket_server_->SetOnConnectCallback(connection_callback);
37   socket_server_->StartListening();
38   return socket_server_ != nullptr;
39 }
40 
CleanUp()41 void TestChannelTransport::CleanUp() {
42   socket_server_->StopListening();
43   socket_server_->Close();
44 }
45 
OnCommandReady(AsyncDataChannel * socket,std::function<void (void)> unwatch)46 void TestChannelTransport::OnCommandReady(AsyncDataChannel* socket,
47                                           std::function<void(void)> unwatch) {
48   uint8_t command_name_size = 0;
49   ssize_t bytes_read = socket->Recv(&command_name_size, 1);
50   if (bytes_read != 1) {
51     INFO("Unexpected (command_name_size) bytes_read: {} != {}, {}", bytes_read,
52          1, strerror(errno));
53     socket->Close();
54   }
55   vector<uint8_t> command_name_raw;
56   command_name_raw.resize(command_name_size);
57   bytes_read = socket->Recv(command_name_raw.data(), command_name_size);
58   if (bytes_read != command_name_size) {
59     INFO("Unexpected (command_name) bytes_read: {} != {}, {}", bytes_read,
60          command_name_size, strerror(errno));
61   }
62   std::string command_name(command_name_raw.begin(), command_name_raw.end());
63 
64   if (command_name == "CLOSE_TEST_CHANNEL" || command_name.empty()) {
65     INFO("Test channel closed");
66     unwatch();
67     socket->Close();
68     return;
69   }
70 
71   uint8_t num_args = 0;
72   bytes_read = socket->Recv(&num_args, 1);
73   if (bytes_read != 1) {
74     INFO("Unexpected (num_args) bytes_read: {} != {}, {}", bytes_read, 1,
75          strerror(errno));
76   }
77   vector<std::string> args;
78   for (uint8_t i = 0; i < num_args; ++i) {
79     uint8_t arg_size = 0;
80     bytes_read = socket->Recv(&arg_size, 1);
81     if (bytes_read != 1) {
82       INFO("Unexpected (arg_size) bytes_read: {} != {}, {}", bytes_read, 1,
83            strerror(errno));
84     }
85     vector<uint8_t> arg;
86     arg.resize(arg_size);
87     bytes_read = socket->Recv(arg.data(), arg_size);
88     if (bytes_read != arg_size) {
89       INFO("Unexpected (arg) bytes_read: {} != {}, {}", bytes_read, arg_size,
90            strerror(errno));
91     }
92     args.push_back(std::string(arg.begin(), arg.end()));
93   }
94 
95   command_handler_(command_name, args);
96 }
97 
SendResponse(std::shared_ptr<AsyncDataChannel> socket,const std::string & response)98 void TestChannelTransport::SendResponse(
99     std::shared_ptr<AsyncDataChannel> socket, const std::string& response) {
100   size_t size = response.size();
101   // Cap to 64K
102   if (size > 0xffff) {
103     size = 0xffff;
104   }
105   uint8_t size_buf[4] = {static_cast<uint8_t>(size & 0xff),
106                          static_cast<uint8_t>((size >> 8) & 0xff),
107                          static_cast<uint8_t>((size >> 16) & 0xff),
108                          static_cast<uint8_t>((size >> 24) & 0xff)};
109   ssize_t written = socket->Send(size_buf, 4);
110   if (written == -1 && errno == EBADF) {
111     WARNING("Unable to send a response.  EBADF");
112     return;
113   }
114   ASSERT_LOG(written == 4, "What happened? written = %zd errno = %d", written,
115              errno);
116   written =
117       socket->Send(reinterpret_cast<const uint8_t*>(response.c_str()), size);
118   ASSERT_LOG(written == static_cast<int>(size),
119              "What happened? written = %zd errno = %d", written, errno);
120 }
121 
RegisterCommandHandler(const std::function<void (const std::string &,const std::vector<std::string> &)> & callback)122 void TestChannelTransport::RegisterCommandHandler(
123     const std::function<void(const std::string&,
124                              const std::vector<std::string>&)>& callback) {
125   command_handler_ = callback;
126 }
127 
128 }  // namespace rootcanal
129