1 // Copyright 2015 The Android Open Source Project
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include <arpa/inet.h>
16 #include <map>
17 #include <netdb.h>
18 #include <string>
19 #include <sys/socket.h>
20 #include <sys/types.h>
21 #include <unistd.h>
22
23 #include <base/bind.h>
24 #include <base/bind_helpers.h>
25 #include <base/files/file_util.h>
26 #include <base/message_loop/message_loop.h>
27 #include <base/strings/stringprintf.h>
28 #include <brillo/bind_lambda.h>
29 #include <brillo/streams/file_stream.h>
30 #include <brillo/streams/tls_stream.h>
31
32 #include "buffet/socket_stream.h"
33 #include "buffet/weave_error_conversion.h"
34
35 namespace buffet {
36
37 using weave::provider::Network;
38
39 namespace {
40
GetIPAddress(const sockaddr * sa)41 std::string GetIPAddress(const sockaddr* sa) {
42 std::string addr;
43 char str[INET6_ADDRSTRLEN] = {};
44 switch (sa->sa_family) {
45 case AF_INET:
46 if (inet_ntop(AF_INET,
47 &(reinterpret_cast<const sockaddr_in*>(sa)->sin_addr), str,
48 sizeof(str))) {
49 addr = str;
50 }
51 break;
52
53 case AF_INET6:
54 if (inet_ntop(AF_INET6,
55 &(reinterpret_cast<const sockaddr_in6*>(sa)->sin6_addr),
56 str, sizeof(str))) {
57 addr = str;
58 }
59 break;
60 }
61 if (addr.empty())
62 addr = base::StringPrintf("<Unknown address family: %d>", sa->sa_family);
63 return addr;
64 }
65
ConnectSocket(const std::string & host,uint16_t port)66 int ConnectSocket(const std::string& host, uint16_t port) {
67 std::string service = std::to_string(port);
68 addrinfo hints = {0, AF_UNSPEC, SOCK_STREAM};
69 addrinfo* result = nullptr;
70 if (getaddrinfo(host.c_str(), service.c_str(), &hints, &result)) {
71 PLOG(WARNING) << "Failed to resolve host name: " << host;
72 return -1;
73 }
74
75 int socket_fd = -1;
76 for (const addrinfo* info = result; info != nullptr; info = info->ai_next) {
77 socket_fd = socket(info->ai_family, info->ai_socktype, info->ai_protocol);
78 if (socket_fd < 0)
79 continue;
80
81 std::string addr = GetIPAddress(info->ai_addr);
82 LOG(INFO) << "Connecting to address: " << addr;
83 if (connect(socket_fd, info->ai_addr, info->ai_addrlen) == 0)
84 break; // Success.
85
86 PLOG(WARNING) << "Failed to connect to address: " << addr;
87 close(socket_fd);
88 socket_fd = -1;
89 }
90
91 freeaddrinfo(result);
92 return socket_fd;
93 }
94
OnSuccess(const Network::OpenSslSocketCallback & callback,brillo::StreamPtr tls_stream)95 void OnSuccess(const Network::OpenSslSocketCallback& callback,
96 brillo::StreamPtr tls_stream) {
97 callback.Run(
98 std::unique_ptr<weave::Stream>{new SocketStream{std::move(tls_stream)}},
99 nullptr);
100 }
101
OnError(const weave::DoneCallback & callback,const brillo::Error * brillo_error)102 void OnError(const weave::DoneCallback& callback,
103 const brillo::Error* brillo_error) {
104 weave::ErrorPtr error;
105 ConvertError(*brillo_error, &error);
106 callback.Run(std::move(error));
107 }
108
109 } // namespace
110
Read(void * buffer,size_t size_to_read,const ReadCallback & callback)111 void SocketStream::Read(void* buffer,
112 size_t size_to_read,
113 const ReadCallback& callback) {
114 brillo::ErrorPtr brillo_error;
115 if (!ptr_->ReadAsync(
116 buffer, size_to_read,
117 base::Bind([](const ReadCallback& callback,
118 size_t size) { callback.Run(size, nullptr); },
119 callback),
120 base::Bind(&OnError, base::Bind(callback, 0)), &brillo_error)) {
121 weave::ErrorPtr error;
122 ConvertError(*brillo_error, &error);
123 base::MessageLoop::current()->PostTask(
124 FROM_HERE, base::Bind(callback, 0, base::Passed(&error)));
125 }
126 }
127
Write(const void * buffer,size_t size_to_write,const WriteCallback & callback)128 void SocketStream::Write(const void* buffer,
129 size_t size_to_write,
130 const WriteCallback& callback) {
131 brillo::ErrorPtr brillo_error;
132 if (!ptr_->WriteAllAsync(buffer, size_to_write, base::Bind(callback, nullptr),
133 base::Bind(&OnError, callback), &brillo_error)) {
134 weave::ErrorPtr error;
135 ConvertError(*brillo_error, &error);
136 base::MessageLoop::current()->PostTask(
137 FROM_HERE, base::Bind(callback, base::Passed(&error)));
138 }
139 }
140
CancelPendingOperations()141 void SocketStream::CancelPendingOperations() {
142 ptr_->CancelPendingAsyncOperations();
143 }
144
ConnectBlocking(const std::string & host,uint16_t port)145 std::unique_ptr<weave::Stream> SocketStream::ConnectBlocking(
146 const std::string& host,
147 uint16_t port) {
148 int socket_fd = ConnectSocket(host, port);
149 if (socket_fd <= 0)
150 return nullptr;
151
152 auto ptr_ = brillo::FileStream::FromFileDescriptor(socket_fd, true, nullptr);
153 if (ptr_)
154 return std::unique_ptr<Stream>{new SocketStream{std::move(ptr_)}};
155
156 close(socket_fd);
157 return nullptr;
158 }
159
TlsConnect(std::unique_ptr<Stream> socket,const std::string & host,const Network::OpenSslSocketCallback & callback)160 void SocketStream::TlsConnect(std::unique_ptr<Stream> socket,
161 const std::string& host,
162 const Network::OpenSslSocketCallback& callback) {
163 SocketStream* stream = static_cast<SocketStream*>(socket.get());
164 brillo::TlsStream::Connect(
165 std::move(stream->ptr_), host, base::Bind(&OnSuccess, callback),
166 base::Bind(&OnError, base::Bind(callback, nullptr)));
167 }
168
169 } // namespace buffet
170