1 // Copyright 2015 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include <arpa/inet.h>
16 #include <map>
17 #include <netdb.h>
18 #include <string>
19 #include <sys/socket.h>
20 #include <sys/types.h>
21 #include <unistd.h>
22 
23 #include <base/bind.h>
24 #include <base/bind_helpers.h>
25 #include <base/files/file_util.h>
26 #include <base/message_loop/message_loop.h>
27 #include <base/strings/stringprintf.h>
28 #include <brillo/bind_lambda.h>
29 #include <brillo/streams/file_stream.h>
30 #include <brillo/streams/tls_stream.h>
31 
32 #include "buffet/socket_stream.h"
33 #include "buffet/weave_error_conversion.h"
34 
35 namespace buffet {
36 
37 using weave::provider::Network;
38 
39 namespace {
40 
GetIPAddress(const sockaddr * sa)41 std::string GetIPAddress(const sockaddr* sa) {
42   std::string addr;
43   char str[INET6_ADDRSTRLEN] = {};
44   switch (sa->sa_family) {
45     case AF_INET:
46       if (inet_ntop(AF_INET,
47                     &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str,
48                     sizeof(str))) {
49         addr = str;
50       }
51       break;
52 
53     case AF_INET6:
54       if (inet_ntop(AF_INET6,
55                     &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr),
56                     str, sizeof(str))) {
57         addr = str;
58       }
59       break;
60   }
61   if (addr.empty())
62     addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family);
63   return addr;
64 }
65 
ConnectSocket(const std::string & host,uint16_t port)66 int ConnectSocket(const std::string& host, uint16_t port) {
67   std::string service = std::to_string(port);
68   addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
69   addrinfo* result = nullptr;
70   if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
71     PLOG(WARNING) << "Failed to resolve host name: " << host;
72     return -1;
73   }
74 
75   int socket_fd = -1;
76   for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
77     socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol);
78     if (socket_fd < 0)
79       continue;
80 
81     std::string addr = GetIPAddress(info->ai_addr);
82     LOG(INFO) << "Connecting to address: " << addr;
83     if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0)
84       break;  // Success.
85 
86     PLOG(WARNING) << "Failed to connect to address: " << addr;
87     close(socket_fd);
88     socket_fd = -1;
89   }
90 
91   freeaddrinfo(result);
92   return socket_fd;
93 }
94 
OnSuccess(const Network::OpenSslSocketCallback & callback,brillo::StreamPtr tls_stream)95 void OnSuccess(const Network::OpenSslSocketCallback& callback,
96                brillo::StreamPtr tls_stream) {
97   callback.Run(
98       std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}},
99       nullptr);
100 }
101 
OnError(const weave::DoneCallback & callback,const brillo::Error * brillo_error)102 void OnError(const weave::DoneCallback& callback,
103              const brillo::Error* brillo_error) {
104   weave::ErrorPtr error;
105   ConvertError(*brillo_error, &error);
106   callback.Run(std::move(error));
107 }
108 
109 }  // namespace
110 
Read(void * buffer,size_t size_to_read,const ReadCallback & callback)111 void SocketStream::Read(void* buffer,
112                         size_t size_to_read,
113                         const ReadCallback& callback) {
114   brillo::ErrorPtr brillo_error;
115   if (!ptr_->ReadAsync(
116           buffer, size_to_read,
117           base::Bind([](const ReadCallback& callback,
118                         size_t size) { callback.Run(size, nullptr); },
119                      callback),
120           base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) {
121     weave::ErrorPtr error;
122     ConvertError(*brillo_error, &error);
123     base::MessageLoop::current()->PostTask(
124         FROM_HERE, base::Bind(callback, 0, base::Passed(&error)));
125   }
126 }
127 
Write(const void * buffer,size_t size_to_write,const WriteCallback & callback)128 void SocketStream::Write(const void* buffer,
129                          size_t size_to_write,
130                          const WriteCallback& callback) {
131   brillo::ErrorPtr brillo_error;
132   if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr),
133                            base::Bind(&OnError, callback), &brillo_error)) {
134     weave::ErrorPtr error;
135     ConvertError(*brillo_error, &error);
136     base::MessageLoop::current()->PostTask(
137         FROM_HERE, base::Bind(callback, base::Passed(&error)));
138   }
139 }
140 
CancelPendingOperations()141 void SocketStream::CancelPendingOperations() {
142   ptr_->CancelPendingAsyncOperations();
143 }
144 
ConnectBlocking(const std::string & host,uint16_t port)145 std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking(
146     const std::string& host,
147     uint16_t port) {
148   int socket_fd = ConnectSocket(host, port);
149   if (socket_fd <= 0)
150     return nullptr;
151 
152   auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr);
153   if (ptr_)
154     return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}};
155 
156   close(socket_fd);
157   return nullptr;
158 }
159 
TlsConnect(std::unique_ptr<Stream> socket,const std::string & host,const Network::OpenSslSocketCallback & callback)160 void SocketStream::TlsConnect(std::unique_ptr<Stream> socket,
161                               const std::string& host,
162                               const Network::OpenSslSocketCallback& callback) {
163   SocketStream* stream = static_cast<SocketStream*>(socket.get());
164   brillo::TlsStream::Connect(
165       std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback),
166       base::Bind(&OnError, base::Bind(callback, nullptr)));
167 }
168 
169 }  // namespace buffet
170