1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/ipc/host_impl.h"
18 
19 #include <inttypes.h>
20 
21 #include <algorithm>
22 #include <utility>
23 
24 #include "perfetto/base/task_runner.h"
25 #include "perfetto/base/utils.h"
26 #include "perfetto/ipc/service.h"
27 #include "perfetto/ipc/service_descriptor.h"
28 
29 #include "src/ipc/wire_protocol.pb.h"
30 
31 // TODO(primiano): put limits on #connections/uid and req. queue (b/69093705).
32 
33 namespace perfetto {
34 namespace ipc {
35 
36 // static
CreateInstance(const char * socket_name,base::TaskRunner * task_runner)37 std::unique_ptr<Host> Host::CreateInstance(const char* socket_name,
38                                            base::TaskRunner* task_runner) {
39   std::unique_ptr<HostImpl> host(new HostImpl(socket_name, task_runner));
40   if (!host->sock() || !host->sock()->is_listening())
41     return nullptr;
42   return std::move(host);
43 }
44 
45 // static
CreateInstance(base::ScopedFile socket_fd,base::TaskRunner * task_runner)46 std::unique_ptr<Host> Host::CreateInstance(base::ScopedFile socket_fd,
47                                            base::TaskRunner* task_runner) {
48   std::unique_ptr<HostImpl> host(
49       new HostImpl(std::move(socket_fd), task_runner));
50   if (!host->sock() || !host->sock()->is_listening())
51     return nullptr;
52   return std::move(host);
53 }
54 
HostImpl(base::ScopedFile socket_fd,base::TaskRunner * task_runner)55 HostImpl::HostImpl(base::ScopedFile socket_fd, base::TaskRunner* task_runner)
56     : task_runner_(task_runner), weak_ptr_factory_(this) {
57   GOOGLE_PROTOBUF_VERIFY_VERSION;
58   PERFETTO_DCHECK_THREAD(thread_checker_);
59   sock_ = base::UnixSocket::Listen(std::move(socket_fd), this, task_runner_);
60 }
61 
HostImpl(const char * socket_name,base::TaskRunner * task_runner)62 HostImpl::HostImpl(const char* socket_name, base::TaskRunner* task_runner)
63     : task_runner_(task_runner), weak_ptr_factory_(this) {
64   GOOGLE_PROTOBUF_VERIFY_VERSION;
65   PERFETTO_DCHECK_THREAD(thread_checker_);
66   sock_ = base::UnixSocket::Listen(socket_name, this, task_runner_);
67 }
68 
69 HostImpl::~HostImpl() = default;
70 
ExposeService(std::unique_ptr<Service> service)71 bool HostImpl::ExposeService(std::unique_ptr<Service> service) {
72   PERFETTO_DCHECK_THREAD(thread_checker_);
73   const std::string& service_name = service->GetDescriptor().service_name;
74   if (GetServiceByName(service_name)) {
75     PERFETTO_DLOG("Duplicate ExposeService(): %s", service_name.c_str());
76     return false;
77   }
78   ServiceID sid = ++last_service_id_;
79   ExposedService exposed_service(sid, service_name, std::move(service));
80   services_.emplace(sid, std::move(exposed_service));
81   return true;
82 }
83 
OnNewIncomingConnection(base::UnixSocket *,std::unique_ptr<base::UnixSocket> new_conn)84 void HostImpl::OnNewIncomingConnection(
85     base::UnixSocket*,
86     std::unique_ptr<base::UnixSocket> new_conn) {
87   PERFETTO_DCHECK_THREAD(thread_checker_);
88   std::unique_ptr<ClientConnection> client(new ClientConnection());
89   ClientID client_id = ++last_client_id_;
90   clients_by_socket_[new_conn.get()] = client.get();
91   client->id = client_id;
92   client->sock = std::move(new_conn);
93   clients_[client_id] = std::move(client);
94 }
95 
OnDataAvailable(base::UnixSocket * sock)96 void HostImpl::OnDataAvailable(base::UnixSocket* sock) {
97   PERFETTO_DCHECK_THREAD(thread_checker_);
98   auto it = clients_by_socket_.find(sock);
99   if (it == clients_by_socket_.end())
100     return;
101   ClientConnection* client = it->second;
102   BufferedFrameDeserializer& frame_deserializer = client->frame_deserializer;
103 
104   size_t rsize;
105   do {
106     auto buf = frame_deserializer.BeginReceive();
107     base::ScopedFile fd;
108     rsize = client->sock->Receive(buf.data, buf.size, &fd);
109     if (fd) {
110       PERFETTO_DCHECK(!client->received_fd);
111       client->received_fd = std::move(fd);
112     }
113     if (!frame_deserializer.EndReceive(rsize))
114       return OnDisconnect(client->sock.get());
115   } while (rsize > 0);
116 
117   for (;;) {
118     std::unique_ptr<Frame> frame = frame_deserializer.PopNextFrame();
119     if (!frame)
120       break;
121     OnReceivedFrame(client, *frame);
122   }
123 }
124 
OnReceivedFrame(ClientConnection * client,const Frame & req_frame)125 void HostImpl::OnReceivedFrame(ClientConnection* client,
126                                const Frame& req_frame) {
127   if (req_frame.msg_case() == Frame::kMsgBindService)
128     return OnBindService(client, req_frame);
129   if (req_frame.msg_case() == Frame::kMsgInvokeMethod)
130     return OnInvokeMethod(client, req_frame);
131 
132   PERFETTO_DLOG("Received invalid RPC frame %u from client %" PRIu64,
133                 req_frame.msg_case(), client->id);
134   Frame reply_frame;
135   reply_frame.set_request_id(req_frame.request_id());
136   reply_frame.mutable_msg_request_error()->set_error("unknown request");
137   SendFrame(client, reply_frame);
138 }
139 
OnBindService(ClientConnection * client,const Frame & req_frame)140 void HostImpl::OnBindService(ClientConnection* client, const Frame& req_frame) {
141   // Binding a service doesn't do anything major. It just returns back the
142   // service id and its method map.
143   const Frame::BindService& req = req_frame.msg_bind_service();
144   Frame reply_frame;
145   reply_frame.set_request_id(req_frame.request_id());
146   auto* reply = reply_frame.mutable_msg_bind_service_reply();
147   const ExposedService* service = GetServiceByName(req.service_name());
148   if (service) {
149     reply->set_success(true);
150     reply->set_service_id(service->id);
151     uint32_t method_id = 1;  // method ids start at index 1.
152     for (const auto& desc_method : service->instance->GetDescriptor().methods) {
153       Frame::BindServiceReply::MethodInfo* method_info = reply->add_methods();
154       method_info->set_name(desc_method.name);
155       method_info->set_id(method_id++);
156     }
157   }
158   SendFrame(client, reply_frame);
159 }
160 
OnInvokeMethod(ClientConnection * client,const Frame & req_frame)161 void HostImpl::OnInvokeMethod(ClientConnection* client,
162                               const Frame& req_frame) {
163   const Frame::InvokeMethod& req = req_frame.msg_invoke_method();
164   Frame reply_frame;
165   RequestID request_id = req_frame.request_id();
166   reply_frame.set_request_id(request_id);
167   reply_frame.mutable_msg_invoke_method_reply()->set_success(false);
168   auto svc_it = services_.find(req.service_id());
169   if (svc_it == services_.end())
170     return SendFrame(client, reply_frame);  // |success| == false by default.
171 
172   Service* service = svc_it->second.instance.get();
173   const ServiceDescriptor& svc = service->GetDescriptor();
174   const auto& methods = svc.methods;
175   const uint32_t method_id = req.method_id();
176   if (method_id == 0 || method_id > methods.size())
177     return SendFrame(client, reply_frame);
178 
179   const ServiceDescriptor::Method& method = methods[method_id - 1];
180   std::unique_ptr<ProtoMessage> decoded_req_args(
181       method.request_proto_decoder(req.args_proto()));
182   if (!decoded_req_args)
183     return SendFrame(client, reply_frame);
184 
185   Deferred<ProtoMessage> deferred_reply;
186   base::WeakPtr<HostImpl> host_weak_ptr = weak_ptr_factory_.GetWeakPtr();
187   ClientID client_id = client->id;
188 
189   if (!req.drop_reply()) {
190     deferred_reply.Bind([host_weak_ptr, client_id,
191                          request_id](AsyncResult<ProtoMessage> reply) {
192       if (!host_weak_ptr)
193         return;  // The reply came too late, the HostImpl has gone.
194       host_weak_ptr->ReplyToMethodInvocation(client_id, request_id,
195                                              std::move(reply));
196     });
197   }
198 
199   service->client_info_ = ClientInfo(client->id, client->sock->peer_uid());
200   service->received_fd_ = &client->received_fd;
201   method.invoker(service, *decoded_req_args, std::move(deferred_reply));
202   service->received_fd_ = nullptr;
203   service->client_info_ = ClientInfo();
204 }
205 
ReplyToMethodInvocation(ClientID client_id,RequestID request_id,AsyncResult<ProtoMessage> reply)206 void HostImpl::ReplyToMethodInvocation(ClientID client_id,
207                                        RequestID request_id,
208                                        AsyncResult<ProtoMessage> reply) {
209   auto client_iter = clients_.find(client_id);
210   if (client_iter == clients_.end())
211     return;  // client has disconnected by the time we got the async reply.
212 
213   ClientConnection* client = client_iter->second.get();
214   Frame reply_frame;
215   reply_frame.set_request_id(request_id);
216 
217   // TODO(fmayer): add a test to guarantee that the reply is consumed within the
218   // same call stack and not kept around. ConsumerIPCService::OnTraceData()
219   // relies on this behavior.
220   auto* reply_frame_data = reply_frame.mutable_msg_invoke_method_reply();
221   reply_frame_data->set_has_more(reply.has_more());
222   if (reply.success()) {
223     std::string reply_proto;
224     if (reply->SerializeToString(&reply_proto)) {
225       reply_frame_data->set_reply_proto(reply_proto);
226       reply_frame_data->set_success(true);
227     }
228   }
229   SendFrame(client, reply_frame, reply.fd());
230 }
231 
232 // static
SendFrame(ClientConnection * client,const Frame & frame,int fd)233 void HostImpl::SendFrame(ClientConnection* client, const Frame& frame, int fd) {
234   std::string buf = BufferedFrameDeserializer::Serialize(frame);
235 
236   // TODO(primiano): this should do non-blocking I/O. But then what if the
237   // socket buffer is full? We might want to either drop the request or throttle
238   // the send and PostTask the reply later? Right now we are making Send()
239   // blocking as a workaround. Propagate bakpressure to the caller instead.
240   bool res = client->sock->Send(buf.data(), buf.size(), fd,
241                                 base::UnixSocket::BlockingMode::kBlocking);
242   PERFETTO_CHECK(res || !client->sock->is_connected());
243 }
244 
OnDisconnect(base::UnixSocket * sock)245 void HostImpl::OnDisconnect(base::UnixSocket* sock) {
246   PERFETTO_DCHECK_THREAD(thread_checker_);
247   auto it = clients_by_socket_.find(sock);
248   if (it == clients_by_socket_.end())
249     return;
250   ClientID client_id = it->second->id;
251   ClientInfo client_info(client_id, sock->peer_uid());
252   clients_by_socket_.erase(it);
253   PERFETTO_DCHECK(clients_.count(client_id));
254   clients_.erase(client_id);
255 
256   for (const auto& service_it : services_) {
257     Service& service = *service_it.second.instance;
258     service.client_info_ = client_info;
259     service.OnClientDisconnected();
260     service.client_info_ = ClientInfo();
261   }
262 }
263 
GetServiceByName(const std::string & name)264 const HostImpl::ExposedService* HostImpl::GetServiceByName(
265     const std::string& name) {
266   // This could be optimized by using another map<name,ServiceID>. However this
267   // is used only by Bind/ExposeService that are quite rare (once per client
268   // connection and once per service instance), not worth it.
269   for (const auto& it : services_) {
270     if (it.second.name == name)
271       return &it.second;
272   }
273   return nullptr;
274 }
275 
ExposedService(ServiceID id_,const std::string & name_,std::unique_ptr<Service> instance_)276 HostImpl::ExposedService::ExposedService(ServiceID id_,
277                                          const std::string& name_,
278                                          std::unique_ptr<Service> instance_)
279     : id(id_), name(name_), instance(std::move(instance_)) {}
280 
281 HostImpl::ExposedService::ExposedService(ExposedService&&) noexcept = default;
282 HostImpl::ExposedService& HostImpl::ExposedService::operator=(
283     HostImpl::ExposedService&&) = default;
284 HostImpl::ExposedService::~ExposedService() = default;
285 
286 HostImpl::ClientConnection::~ClientConnection() = default;
287 
288 }  // namespace ipc
289 }  // namespace perfetto
290