1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/libs/utils/tcp_socket.h"
18 
19 #include <netinet/in.h>
20 #include <sys/types.h>
21 #include <sys/socket.h>
22 
23 #include <cerrno>
24 #include <cstring>
25 #include <memory>
26 #include <ostream>
27 #include <string>
28 
29 #include <android-base/logging.h>
30 
31 namespace cuttlefish {
32 
ClientSocket(int port)33 ClientSocket::ClientSocket(int port)
34     : fd_(SharedFD::SocketLocalClient(port, SOCK_STREAM)) {}
35 
RecvAny(std::size_t length)36 Message ClientSocket::RecvAny(std::size_t length) {
37   Message buf(length);
38   auto read_count = fd_->Read(buf.data(), buf.size());
39   if (read_count < 0) {
40     read_count = 0;
41   }
42   buf.resize(read_count);
43   return buf;
44 }
45 
closed() const46 bool ClientSocket::closed() const {
47   std::lock_guard<std::mutex> guard(closed_lock_);
48   return other_side_closed_;
49 }
50 
Recv(std::size_t length)51 Message ClientSocket::Recv(std::size_t length) {
52   Message buf(length);
53   ssize_t total_read = 0;
54   while (total_read < static_cast<ssize_t>(length)) {
55     auto just_read = fd_->Read(&buf[total_read], buf.size() - total_read);
56     if (just_read <= 0) {
57       if (just_read < 0) {
58         LOG(ERROR) << "read() error: " << strerror(errno);
59       }
60       {
61         std::lock_guard<std::mutex> guard(closed_lock_);
62         other_side_closed_ = true;
63       }
64       return Message{};
65     }
66     total_read += just_read;
67   }
68   CHECK(total_read == static_cast<ssize_t>(length));
69   return buf;
70 }
71 
SendNoSignal(const uint8_t * data,std::size_t size)72 ssize_t ClientSocket::SendNoSignal(const uint8_t* data, std::size_t size) {
73   std::lock_guard<std::mutex> lock(send_lock_);
74   ssize_t written{};
75   while (written < static_cast<ssize_t>(size)) {
76     if (!fd_->IsOpen()) {
77       LOG(ERROR) << "fd_ is closed";
78     }
79     auto just_written = fd_->Send(data + written, size - written, MSG_NOSIGNAL);
80     if (just_written <= 0) {
81       LOG(INFO) << "Couldn't write to client: " << strerror(errno);
82       {
83         std::lock_guard<std::mutex> guard(closed_lock_);
84         other_side_closed_ = true;
85       }
86       return just_written;
87     }
88     written += just_written;
89   }
90   return written;
91 }
92 
SendNoSignal(const Message & message)93 ssize_t ClientSocket::SendNoSignal(const Message& message) {
94   return SendNoSignal(&message[0], message.size());
95 }
96 
ServerSocket(int port)97 ServerSocket::ServerSocket(int port)
98     : fd_{SharedFD::SocketLocalServer(port, SOCK_STREAM)} {
99   if (!fd_->IsOpen()) {
100     LOG(FATAL) << "Couldn't open streaming server on port " << port;
101   }
102 }
103 
Accept()104 ClientSocket ServerSocket::Accept() {
105   SharedFD client = SharedFD::Accept(*fd_);
106   if (!client->IsOpen()) {
107     LOG(FATAL) << "Error attempting to accept: " << strerror(errno);
108   }
109   return ClientSocket{client};
110 }
111 
AppendInNetworkByteOrder(Message * msg,const std::uint8_t b)112 void AppendInNetworkByteOrder(Message* msg, const std::uint8_t b) {
113   msg->push_back(b);
114 }
115 
AppendInNetworkByteOrder(Message * msg,const std::uint16_t s)116 void AppendInNetworkByteOrder(Message* msg, const std::uint16_t s) {
117   const std::uint16_t n = htons(s);
118   auto p = reinterpret_cast<const std::uint8_t*>(&n);
119   msg->insert(msg->end(), p, p + sizeof n);
120 }
121 
AppendInNetworkByteOrder(Message * msg,const std::uint32_t w)122 void AppendInNetworkByteOrder(Message* msg, const std::uint32_t w) {
123   const std::uint32_t n = htonl(w);
124   auto p = reinterpret_cast<const std::uint8_t*>(&n);
125   msg->insert(msg->end(), p, p + sizeof n);
126 }
127 
AppendInNetworkByteOrder(Message * msg,const std::int32_t w)128 void AppendInNetworkByteOrder(Message* msg, const std::int32_t w) {
129   std::uint32_t u{};
130   std::memcpy(&u, &w, sizeof u);
131   AppendInNetworkByteOrder(msg, u);
132 }
133 
AppendInNetworkByteOrder(Message * msg,const std::string & str)134 void AppendInNetworkByteOrder(Message* msg, const std::string& str) {
135   msg->insert(msg->end(), str.begin(), str.end());
136 }
137 
138 }
139