1 /*
2  * Copyright (C) 2007 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #define TRACE_TAG TRANSPORT
18 
19 #include "sysdeps.h"
20 
21 #include "client/usb.h"
22 
23 #include <memory>
24 
25 #include "sysdeps.h"
26 #include "transport.h"
27 
28 #include <stdio.h>
29 #include <stdlib.h>
30 #include <string.h>
31 
32 #include "adb.h"
33 
34 #if ADB_HOST
35 
36 #if defined(__APPLE__)
37 #define CHECK_PACKET_OVERFLOW 0
38 #else
39 #define CHECK_PACKET_OVERFLOW 1
40 #endif
41 
42 // Call usb_read using a buffer having a multiple of usb_get_max_packet_size() bytes
43 // to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html.
44 static int UsbReadMessage(usb_handle* h, amessage* msg) {
45     D("UsbReadMessage");
46 
47 #if CHECK_PACKET_OVERFLOW
48     size_t usb_packet_size = usb_get_max_packet_size(h);
49     CHECK_GE(usb_packet_size, sizeof(*msg));
50     CHECK_LT(usb_packet_size, 4096ULL);
51 
52     char buffer[4096];
53     int n = usb_read(h, buffer, usb_packet_size);
54     if (n != sizeof(*msg)) {
55         D("usb_read returned unexpected length %d (expected %zu)", n, sizeof(*msg));
56         return -1;
57     }
58     memcpy(msg, buffer, sizeof(*msg));
59     return n;
60 #else
61     return usb_read(h, msg, sizeof(*msg));
62 #endif
63 }
64 
65 // Call usb_read using a buffer having a multiple of usb_get_max_packet_size() bytes
66 // to avoid overflow. See http://libusb.sourceforge.net/api-1.0/packetoverflow.html.
67 static int UsbReadPayload(usb_handle* h, apacket* p) {
68     D("UsbReadPayload(%d)", p->msg.data_length);
69 
70     if (p->msg.data_length > MAX_PAYLOAD) {
71         return -1;
72     }
73 
74 #if CHECK_PACKET_OVERFLOW
75     size_t usb_packet_size = usb_get_max_packet_size(h);
76 
77     // Round the data length up to the nearest packet size boundary.
78     // The device won't send a zero packet for packet size aligned payloads,
79     // so don't read any more packets than needed.
80     size_t len = p->msg.data_length;
81     size_t rem_size = len % usb_packet_size;
82     if (rem_size) {
83         len += usb_packet_size - rem_size;
84     }
85 
86     p->payload.resize(len);
87     int rc = usb_read(h, &p->payload[0], p->payload.size());
88     if (rc != static_cast<int>(p->msg.data_length)) {
89         return -1;
90     }
91 
92     p->payload.resize(rc);
93     return rc;
94 #else
95     p->payload.resize(p->msg.data_length);
96     return usb_read(h, &p->payload[0], p->payload.size());
97 #endif
98 }
99 
100 static int remote_read(apacket* p, usb_handle* usb) {
101     int n = UsbReadMessage(usb, &p->msg);
102     if (n < 0) {
103         D("remote usb: read terminated (message)");
104         return -1;
105     }
106     if (static_cast<size_t>(n) != sizeof(p->msg)) {
107         D("remote usb: read received unexpected header length %d", n);
108         return -1;
109     }
110     if (p->msg.data_length) {
111         n = UsbReadPayload(usb, p);
112         if (n < 0) {
113             D("remote usb: terminated (data)");
114             return -1;
115         }
116         if (static_cast<uint32_t>(n) != p->msg.data_length) {
117             D("remote usb: read payload failed (need %u bytes, give %d bytes), skip it",
118               p->msg.data_length, n);
119             return -1;
120         }
121     }
122     return 0;
123 }
124 
125 #else
126 
127 // On Android devices, we rely on the kernel to provide buffered read.
128 // So we can recover automatically from EOVERFLOW.
129 static int remote_read(apacket* p, usb_handle* usb) {
130     if (usb_read(usb, &p->msg, sizeof(amessage)) != sizeof(amessage)) {
131         PLOG(ERROR) << "remote usb: read terminated (message)";
132         return -1;
133     }
134 
135     if (p->msg.data_length) {
136         if (p->msg.data_length > MAX_PAYLOAD) {
137             PLOG(ERROR) << "remote usb: read overflow (data length = " << p->msg.data_length << ")";
138             return -1;
139         }
140 
141         p->payload.resize(p->msg.data_length);
142         if (usb_read(usb, &p->payload[0], p->payload.size()) !=
143             static_cast<int>(p->payload.size())) {
144             PLOG(ERROR) << "remote usb: terminated (data)";
145             return -1;
146         }
147     }
148 
149     return 0;
150 }
151 #endif
152 
153 UsbConnection::~UsbConnection() {
154     usb_close(handle_);
155 }
156 
157 bool UsbConnection::Read(apacket* packet) {
158     int rc = remote_read(packet, handle_);
159     return rc == 0;
160 }
161 
162 bool UsbConnection::Write(apacket* packet) {
163     int size = packet->msg.data_length;
164 
165     if (usb_write(handle_, &packet->msg, sizeof(packet->msg)) != sizeof(packet->msg)) {
166         PLOG(ERROR) << "remote usb: 1 - write terminated";
167         return false;
168     }
169 
170     if (packet->msg.data_length != 0 && usb_write(handle_, packet->payload.data(), size) != size) {
171         PLOG(ERROR) << "remote usb: 2 - write terminated";
172         return false;
173     }
174 
175     return true;
176 }
177 
178 bool UsbConnection::DoTlsHandshake(RSA* key, std::string* auth_key) {
179     // TODO: support TLS for usb connections
180     LOG(FATAL) << "Not supported yet.";
181     return false;
182 }
183 
184 void UsbConnection::Reset() {
185     usb_reset(handle_);
186     usb_kick(handle_);
187 }
188 
189 void UsbConnection::Close() {
190     usb_kick(handle_);
191 }
192 
193 void init_usb_transport(atransport* t, usb_handle* h) {
194     D("transport: usb");
195     auto connection = std::make_unique<UsbConnection>(h);
196     t->SetConnection(std::make_unique<BlockingConnectionAdapter>(std::move(connection)));
197     t->type = kTransportUsb;
198     t->SetUsbHandle(h);
199 }
200 
201 int is_adb_interface(int usb_class, int usb_subclass, int usb_protocol) {
202     return (usb_class == ADB_CLASS && usb_subclass == ADB_SUBCLASS && usb_protocol == ADB_PROTOCOL);
203 }
204 
205 bool should_use_libusb() {
206 #if !ADB_HOST
207     return false;
208 #else
209     static bool enable = getenv("ADB_LIBUSB") && strcmp(getenv("ADB_LIBUSB"), "1") == 0;
210     return enable;
211 #endif
212 }
213