1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/ipc/host_impl.h"
18 
19 #include <inttypes.h>
20 
21 #include <algorithm>
22 #include <utility>
23 
24 #include "perfetto/base/task_runner.h"
25 #include "perfetto/ext/base/utils.h"
26 #include "perfetto/ext/ipc/service.h"
27 #include "perfetto/ext/ipc/service_descriptor.h"
28 
29 #include "protos/perfetto/ipc/wire_protocol.gen.h"
30 
31 // TODO(primiano): put limits on #connections/uid and req. queue (b/69093705).
32 
33 namespace perfetto {
34 namespace ipc {
35 
36 namespace {
37 
38 constexpr base::SockFamily kHostSockFamily =
39     kUseTCPSocket ? base::SockFamily::kInet : base::SockFamily::kUnix;
40 
GetPosixPeerUid(base::UnixSocket * sock)41 uid_t GetPosixPeerUid(base::UnixSocket* sock) {
42 #if PERFETTO_BUILDFLAG(PERFETTO_OS_WIN)
43   base::ignore_result(sock);
44   // Unsupported. Must be != kInvalidUid or the PacketValidator will fail.
45   return 0;
46 #else
47   return sock->peer_uid_posix();
48 #endif
49 }
50 
51 }  // namespace
52 
53 // static
CreateInstance(const char * socket_name,base::TaskRunner * task_runner)54 std::unique_ptr<Host> Host::CreateInstance(const char* socket_name,
55                                            base::TaskRunner* task_runner) {
56   std::unique_ptr<HostImpl> host(new HostImpl(socket_name, task_runner));
57   if (!host->sock() || !host->sock()->is_listening())
58     return nullptr;
59   return std::unique_ptr<Host>(std::move(host));
60 }
61 
62 // static
CreateInstance(base::ScopedSocketHandle socket_fd,base::TaskRunner * task_runner)63 std::unique_ptr<Host> Host::CreateInstance(base::ScopedSocketHandle socket_fd,
64                                            base::TaskRunner* task_runner) {
65   std::unique_ptr<HostImpl> host(
66       new HostImpl(std::move(socket_fd), task_runner));
67   if (!host->sock() || !host->sock()->is_listening())
68     return nullptr;
69   return std::unique_ptr<Host>(std::move(host));
70 }
71 
HostImpl(base::ScopedSocketHandle socket_fd,base::TaskRunner * task_runner)72 HostImpl::HostImpl(base::ScopedSocketHandle socket_fd,
73                    base::TaskRunner* task_runner)
74     : task_runner_(task_runner), weak_ptr_factory_(this) {
75   PERFETTO_DCHECK_THREAD(thread_checker_);
76   sock_ = base::UnixSocket::Listen(std::move(socket_fd), this, task_runner_,
77                                    kHostSockFamily, base::SockType::kStream);
78 }
79 
HostImpl(const char * socket_name,base::TaskRunner * task_runner)80 HostImpl::HostImpl(const char* socket_name, base::TaskRunner* task_runner)
81     : task_runner_(task_runner), weak_ptr_factory_(this) {
82   PERFETTO_DCHECK_THREAD(thread_checker_);
83   sock_ = base::UnixSocket::Listen(socket_name, this, task_runner_,
84                                    kHostSockFamily, base::SockType::kStream);
85   if (!sock_) {
86     PERFETTO_PLOG("Failed to create %s", socket_name);
87   }
88 }
89 
90 HostImpl::~HostImpl() = default;
91 
ExposeService(std::unique_ptr<Service> service)92 bool HostImpl::ExposeService(std::unique_ptr<Service> service) {
93   PERFETTO_DCHECK_THREAD(thread_checker_);
94   const std::string& service_name = service->GetDescriptor().service_name;
95   if (GetServiceByName(service_name)) {
96     PERFETTO_DLOG("Duplicate ExposeService(): %s", service_name.c_str());
97     return false;
98   }
99   ServiceID sid = ++last_service_id_;
100   ExposedService exposed_service(sid, service_name, std::move(service));
101   services_.emplace(sid, std::move(exposed_service));
102   return true;
103 }
104 
OnNewIncomingConnection(base::UnixSocket *,std::unique_ptr<base::UnixSocket> new_conn)105 void HostImpl::OnNewIncomingConnection(
106     base::UnixSocket*,
107     std::unique_ptr<base::UnixSocket> new_conn) {
108   PERFETTO_DCHECK_THREAD(thread_checker_);
109   std::unique_ptr<ClientConnection> client(new ClientConnection());
110   ClientID client_id = ++last_client_id_;
111   clients_by_socket_[new_conn.get()] = client.get();
112   client->id = client_id;
113   client->sock = std::move(new_conn);
114   // Watchdog is 30 seconds, so set the socket timeout to 10 seconds.
115   client->sock->SetTxTimeout(10000);
116   clients_[client_id] = std::move(client);
117 }
118 
OnDataAvailable(base::UnixSocket * sock)119 void HostImpl::OnDataAvailable(base::UnixSocket* sock) {
120   PERFETTO_DCHECK_THREAD(thread_checker_);
121   auto it = clients_by_socket_.find(sock);
122   if (it == clients_by_socket_.end())
123     return;
124   ClientConnection* client = it->second;
125   BufferedFrameDeserializer& frame_deserializer = client->frame_deserializer;
126 
127   size_t rsize;
128   do {
129     auto buf = frame_deserializer.BeginReceive();
130     base::ScopedFile fd;
131     rsize = client->sock->Receive(buf.data, buf.size, &fd);
132     if (fd) {
133       PERFETTO_DCHECK(!client->received_fd);
134       client->received_fd = std::move(fd);
135     }
136     if (!frame_deserializer.EndReceive(rsize))
137       return OnDisconnect(client->sock.get());
138   } while (rsize > 0);
139 
140   for (;;) {
141     std::unique_ptr<Frame> frame = frame_deserializer.PopNextFrame();
142     if (!frame)
143       break;
144     OnReceivedFrame(client, *frame);
145   }
146 }
147 
OnReceivedFrame(ClientConnection * client,const Frame & req_frame)148 void HostImpl::OnReceivedFrame(ClientConnection* client,
149                                const Frame& req_frame) {
150   if (req_frame.has_msg_bind_service())
151     return OnBindService(client, req_frame);
152   if (req_frame.has_msg_invoke_method())
153     return OnInvokeMethod(client, req_frame);
154 
155   PERFETTO_DLOG("Received invalid RPC frame from client %" PRIu64, client->id);
156   Frame reply_frame;
157   reply_frame.set_request_id(req_frame.request_id());
158   reply_frame.mutable_msg_request_error()->set_error("unknown request");
159   SendFrame(client, reply_frame);
160 }
161 
OnBindService(ClientConnection * client,const Frame & req_frame)162 void HostImpl::OnBindService(ClientConnection* client, const Frame& req_frame) {
163   // Binding a service doesn't do anything major. It just returns back the
164   // service id and its method map.
165   const Frame::BindService& req = req_frame.msg_bind_service();
166   Frame reply_frame;
167   reply_frame.set_request_id(req_frame.request_id());
168   auto* reply = reply_frame.mutable_msg_bind_service_reply();
169   const ExposedService* service = GetServiceByName(req.service_name());
170   if (service) {
171     reply->set_success(true);
172     reply->set_service_id(service->id);
173     uint32_t method_id = 1;  // method ids start at index 1.
174     for (const auto& desc_method : service->instance->GetDescriptor().methods) {
175       Frame::BindServiceReply::MethodInfo* method_info = reply->add_methods();
176       method_info->set_name(desc_method.name);
177       method_info->set_id(method_id++);
178     }
179   }
180   SendFrame(client, reply_frame);
181 }
182 
OnInvokeMethod(ClientConnection * client,const Frame & req_frame)183 void HostImpl::OnInvokeMethod(ClientConnection* client,
184                               const Frame& req_frame) {
185   const Frame::InvokeMethod& req = req_frame.msg_invoke_method();
186   Frame reply_frame;
187   RequestID request_id = req_frame.request_id();
188   reply_frame.set_request_id(request_id);
189   reply_frame.mutable_msg_invoke_method_reply()->set_success(false);
190   auto svc_it = services_.find(req.service_id());
191   if (svc_it == services_.end())
192     return SendFrame(client, reply_frame);  // |success| == false by default.
193 
194   Service* service = svc_it->second.instance.get();
195   const ServiceDescriptor& svc = service->GetDescriptor();
196   const auto& methods = svc.methods;
197   const uint32_t method_id = req.method_id();
198   if (method_id == 0 || method_id > methods.size())
199     return SendFrame(client, reply_frame);
200 
201   const ServiceDescriptor::Method& method = methods[method_id - 1];
202   std::unique_ptr<ProtoMessage> decoded_req_args(
203       method.request_proto_decoder(req.args_proto()));
204   if (!decoded_req_args)
205     return SendFrame(client, reply_frame);
206 
207   Deferred<ProtoMessage> deferred_reply;
208   base::WeakPtr<HostImpl> host_weak_ptr = weak_ptr_factory_.GetWeakPtr();
209   ClientID client_id = client->id;
210 
211   if (!req.drop_reply()) {
212     deferred_reply.Bind([host_weak_ptr, client_id,
213                          request_id](AsyncResult<ProtoMessage> reply) {
214       if (!host_weak_ptr)
215         return;  // The reply came too late, the HostImpl has gone.
216       host_weak_ptr->ReplyToMethodInvocation(client_id, request_id,
217                                              std::move(reply));
218     });
219   }
220 
221   service->client_info_ =
222       ClientInfo(client->id, GetPosixPeerUid(client->sock.get()));
223   service->received_fd_ = &client->received_fd;
224   method.invoker(service, *decoded_req_args, std::move(deferred_reply));
225   service->received_fd_ = nullptr;
226   service->client_info_ = ClientInfo();
227 }
228 
ReplyToMethodInvocation(ClientID client_id,RequestID request_id,AsyncResult<ProtoMessage> reply)229 void HostImpl::ReplyToMethodInvocation(ClientID client_id,
230                                        RequestID request_id,
231                                        AsyncResult<ProtoMessage> reply) {
232   auto client_iter = clients_.find(client_id);
233   if (client_iter == clients_.end())
234     return;  // client has disconnected by the time we got the async reply.
235 
236   ClientConnection* client = client_iter->second.get();
237   Frame reply_frame;
238   reply_frame.set_request_id(request_id);
239 
240   // TODO(fmayer): add a test to guarantee that the reply is consumed within the
241   // same call stack and not kept around. ConsumerIPCService::OnTraceData()
242   // relies on this behavior.
243   auto* reply_frame_data = reply_frame.mutable_msg_invoke_method_reply();
244   reply_frame_data->set_has_more(reply.has_more());
245   if (reply.success()) {
246     std::string reply_proto = reply->SerializeAsString();
247     reply_frame_data->set_reply_proto(reply_proto);
248     reply_frame_data->set_success(true);
249   }
250   SendFrame(client, reply_frame, reply.fd());
251 }
252 
253 // static
SendFrame(ClientConnection * client,const Frame & frame,int fd)254 void HostImpl::SendFrame(ClientConnection* client, const Frame& frame, int fd) {
255   std::string buf = BufferedFrameDeserializer::Serialize(frame);
256 
257   // When a new Client connects in OnNewClientConnection we set a timeout on
258   // Send (see call to SetTxTimeout).
259   //
260   // The old behaviour was to do a blocking I/O call, which caused crashes from
261   // misbehaving producers (see b/169051440).
262   bool res = client->sock->Send(buf.data(), buf.size(), fd);
263   // If we timeout |res| will be false, but the UnixSocket will have called
264   // UnixSocket::ShutDown() and thus |is_connected()| is false.
265   PERFETTO_CHECK(res || !client->sock->is_connected());
266 }
267 
OnDisconnect(base::UnixSocket * sock)268 void HostImpl::OnDisconnect(base::UnixSocket* sock) {
269   PERFETTO_DCHECK_THREAD(thread_checker_);
270   auto it = clients_by_socket_.find(sock);
271   if (it == clients_by_socket_.end())
272     return;
273   ClientID client_id = it->second->id;
274 
275   ClientInfo client_info(client_id, GetPosixPeerUid(sock));
276   clients_by_socket_.erase(it);
277   PERFETTO_DCHECK(clients_.count(client_id));
278   clients_.erase(client_id);
279 
280   for (const auto& service_it : services_) {
281     Service& service = *service_it.second.instance;
282     service.client_info_ = client_info;
283     service.OnClientDisconnected();
284     service.client_info_ = ClientInfo();
285   }
286 }
287 
GetServiceByName(const std::string & name)288 const HostImpl::ExposedService* HostImpl::GetServiceByName(
289     const std::string& name) {
290   // This could be optimized by using another map<name,ServiceID>. However this
291   // is used only by Bind/ExposeService that are quite rare (once per client
292   // connection and once per service instance), not worth it.
293   for (const auto& it : services_) {
294     if (it.second.name == name)
295       return &it.second;
296   }
297   return nullptr;
298 }
299 
ExposedService(ServiceID id_,const std::string & name_,std::unique_ptr<Service> instance_)300 HostImpl::ExposedService::ExposedService(ServiceID id_,
301                                          const std::string& name_,
302                                          std::unique_ptr<Service> instance_)
303     : id(id_), name(name_), instance(std::move(instance_)) {}
304 
305 HostImpl::ExposedService::ExposedService(ExposedService&&) noexcept = default;
306 HostImpl::ExposedService& HostImpl::ExposedService::operator=(
307     HostImpl::ExposedService&&) = default;
308 HostImpl::ExposedService::~ExposedService() = default;
309 
310 HostImpl::ClientConnection::~ClientConnection() = default;
311 
312 }  // namespace ipc
313 }  // namespace perfetto
314