1 /* 2 * Copyright (C) 2018 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include <stdint.h> 18 19 #include <deque> 20 #include <memory> 21 #include <mutex> 22 #include <string> 23 #include <thread> 24 25 #include <android-base/logging.h> 26 #include <android-base/stringprintf.h> 27 #include <android-base/thread_annotations.h> 28 29 #include "adb_unique_fd.h" 30 #include "adb_utils.h" 31 #include "sysdeps.h" 32 #include "transport.h" 33 #include "types.h" 34 35 static void CreateWakeFds(unique_fd* read, unique_fd* write) { 36 // TODO: eventfd on linux? 37 int wake_fds[2]; 38 int rc = adb_socketpair(wake_fds); 39 set_file_block_mode(wake_fds[0], false); 40 set_file_block_mode(wake_fds[1], false); 41 CHECK_EQ(0, rc); 42 *read = unique_fd(wake_fds[0]); 43 *write = unique_fd(wake_fds[1]); 44 } 45 46 struct NonblockingFdConnection : public Connection { 47 NonblockingFdConnection(unique_fd fd) : started_(false), fd_(std::move(fd)) { 48 set_file_block_mode(fd_.get(), false); 49 CreateWakeFds(&wake_fd_read_, &wake_fd_write_); 50 } 51 52 void SetRunning(bool value) { 53 std::lock_guard<std::mutex> lock(run_mutex_); 54 running_ = value; 55 } 56 57 bool IsRunning() { 58 std::lock_guard<std::mutex> lock(run_mutex_); 59 return running_; 60 } 61 62 void Run(std::string* error) { 63 SetRunning(true); 64 while (IsRunning()) { 65 adb_pollfd pfds[2] = { 66 {.fd = fd_.get(), .events = POLLIN}, 67 {.fd = wake_fd_read_.get(), .events = POLLIN}, 68 }; 69 70 { 71 std::lock_guard<std::mutex> lock(this->write_mutex_); 72 if (!writable_) { 73 pfds[0].events |= POLLOUT; 74 } 75 } 76 77 int rc = adb_poll(pfds, 2, -1); 78 if (rc == -1) { 79 *error = android::base::StringPrintf("poll failed: %s", strerror(errno)); 80 return; 81 } else if (rc == 0) { 82 LOG(FATAL) << "poll timed out with an infinite timeout?"; 83 } 84 85 if (pfds[0].revents) { 86 if ((pfds[0].revents & POLLOUT)) { 87 std::lock_guard<std::mutex> lock(this->write_mutex_); 88 if (DispatchWrites() == WriteResult::Error) { 89 *error = "write failed"; 90 return; 91 } 92 } 93 94 if (pfds[0].revents & POLLIN) { 95 // TODO: Should we be getting blocks from a free list? 96 auto block = IOVector::block_type(MAX_PAYLOAD); 97 rc = adb_read(fd_.get(), &block[0], block.size()); 98 if (rc == -1) { 99 *error = std::string("read failed: ") + strerror(errno); 100 return; 101 } else if (rc == 0) { 102 *error = "read failed: EOF"; 103 return; 104 } 105 block.resize(rc); 106 read_buffer_.append(std::move(block)); 107 108 if (!read_header_ && read_buffer_.size() >= sizeof(amessage)) { 109 auto header_buf = read_buffer_.take_front(sizeof(amessage)).coalesce(); 110 CHECK_EQ(sizeof(amessage), header_buf.size()); 111 read_header_ = std::make_unique<amessage>(); 112 memcpy(read_header_.get(), header_buf.data(), sizeof(amessage)); 113 } 114 115 if (read_header_ && read_buffer_.size() >= read_header_->data_length) { 116 auto data_chain = read_buffer_.take_front(read_header_->data_length); 117 118 // TODO: Make apacket carry around a IOVector instead of coalescing. 119 auto payload = std::move(data_chain).coalesce(); 120 auto packet = std::make_unique<apacket>(); 121 packet->msg = *read_header_; 122 packet->payload = std::move(payload); 123 read_header_ = nullptr; 124 read_callback_(this, std::move(packet)); 125 } 126 } 127 } 128 129 if (pfds[1].revents) { 130 uint64_t buf; 131 rc = adb_read(wake_fd_read_.get(), &buf, sizeof(buf)); 132 CHECK_EQ(static_cast<int>(sizeof(buf)), rc); 133 134 // We were woken up either to add POLLOUT to our events, or to exit. 135 // Do nothing. 136 } 137 } 138 } 139 140 void Start() override final { 141 if (started_.exchange(true)) { 142 LOG(FATAL) << "Connection started multiple times?"; 143 } 144 145 thread_ = std::thread([this]() { 146 std::string error = "connection closed"; 147 Run(&error); 148 this->error_callback_(this, error); 149 }); 150 } 151 152 void Stop() override final { 153 SetRunning(false); 154 WakeThread(); 155 thread_.join(); 156 } 157 158 bool DoTlsHandshake(RSA* key, std::string* auth_key) override final { 159 LOG(FATAL) << "Not supported yet"; 160 return false; 161 } 162 163 void WakeThread() { 164 uint64_t buf = 0; 165 if (TEMP_FAILURE_RETRY(adb_write(wake_fd_write_.get(), &buf, sizeof(buf))) != sizeof(buf)) { 166 LOG(FATAL) << "failed to wake up thread"; 167 } 168 } 169 170 enum class WriteResult { 171 Error, 172 Completed, 173 TryAgain, 174 }; 175 176 WriteResult DispatchWrites() REQUIRES(write_mutex_) { 177 CHECK(!write_buffer_.empty()); 178 auto iovs = write_buffer_.iovecs(); 179 ssize_t rc = adb_writev(fd_.get(), iovs.data(), iovs.size()); 180 if (rc == -1) { 181 if (errno == EAGAIN || errno == EWOULDBLOCK) { 182 writable_ = false; 183 return WriteResult::TryAgain; 184 } 185 186 return WriteResult::Error; 187 } else if (rc == 0) { 188 errno = 0; 189 return WriteResult::Error; 190 } 191 192 write_buffer_.drop_front(rc); 193 writable_ = write_buffer_.empty(); 194 if (write_buffer_.empty()) { 195 return WriteResult::Completed; 196 } 197 198 // There's data left in the range, which means our write returned early. 199 return WriteResult::TryAgain; 200 } 201 202 bool Write(std::unique_ptr<apacket> packet) final { 203 std::lock_guard<std::mutex> lock(write_mutex_); 204 const char* header_begin = reinterpret_cast<const char*>(&packet->msg); 205 const char* header_end = header_begin + sizeof(packet->msg); 206 auto header_block = IOVector::block_type(header_begin, header_end); 207 write_buffer_.append(std::move(header_block)); 208 if (!packet->payload.empty()) { 209 write_buffer_.append(std::move(packet->payload)); 210 } 211 212 WriteResult result = DispatchWrites(); 213 if (result == WriteResult::TryAgain) { 214 WakeThread(); 215 } 216 return result != WriteResult::Error; 217 } 218 219 std::thread thread_; 220 221 std::atomic<bool> started_; 222 std::mutex run_mutex_; 223 bool running_ GUARDED_BY(run_mutex_); 224 225 std::unique_ptr<amessage> read_header_; 226 IOVector read_buffer_; 227 228 unique_fd fd_; 229 unique_fd wake_fd_read_; 230 unique_fd wake_fd_write_; 231 232 std::mutex write_mutex_; 233 bool writable_ GUARDED_BY(write_mutex_) = true; 234 IOVector write_buffer_ GUARDED_BY(write_mutex_); 235 236 IOVector incoming_queue_; 237 }; 238 239 std::unique_ptr<Connection> Connection::FromFd(unique_fd fd) { 240 return std::make_unique<NonblockingFdConnection>(std::move(fd)); 241 } 242