1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
6 
7 #include <stdint.h>
8 
9 #include "base/bind.h"
10 #include "base/location.h"
11 #include "base/logging.h"
12 #include "base/macros.h"
13 #include "base/memory/ptr_util.h"
14 #include "base/sequenced_task_runner.h"
15 #include "base/stl_util.h"
16 #include "mojo/public/cpp/bindings/associated_group.h"
17 #include "mojo/public/cpp/bindings/associated_group_controller.h"
18 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
19 #include "mojo/public/cpp/bindings/lib/task_runner_helper.h"
20 #include "mojo/public/cpp/bindings/lib/validation_util.h"
21 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
22 
23 namespace mojo {
24 
25 // ----------------------------------------------------------------------------
26 
27 namespace {
28 
DetermineIfEndpointIsConnected(const base::WeakPtr<InterfaceEndpointClient> & client,base::OnceCallback<void (bool)> callback)29 void DetermineIfEndpointIsConnected(
30     const base::WeakPtr<InterfaceEndpointClient>& client,
31     base::OnceCallback<void(bool)> callback) {
32   std::move(callback).Run(client && !client->encountered_error());
33 }
34 
35 // When receiving an incoming message which expects a repsonse,
36 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
37 // incoming message receiver. When the receiver finishes processing the message,
38 // it can provide a response using this object.
39 class ResponderThunk : public MessageReceiverWithStatus {
40  public:
ResponderThunk(const base::WeakPtr<InterfaceEndpointClient> & endpoint_client,scoped_refptr<base::SequencedTaskRunner> runner)41   explicit ResponderThunk(
42       const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
43       scoped_refptr<base::SequencedTaskRunner> runner)
44       : endpoint_client_(endpoint_client),
45         accept_was_invoked_(false),
46         task_runner_(std::move(runner)) {}
~ResponderThunk()47   ~ResponderThunk() override {
48     if (!accept_was_invoked_) {
49       // The Service handled a message that was expecting a response
50       // but did not send a response.
51       // We raise an error to signal the calling application that an error
52       // condition occurred. Without this the calling application would have no
53       // way of knowing it should stop waiting for a response.
54       if (task_runner_->RunsTasksInCurrentSequence()) {
55         // Please note that even if this code is run from a different task
56         // runner on the same thread as |task_runner_|, it is okay to directly
57         // call InterfaceEndpointClient::RaiseError(), because it will raise
58         // error from the correct task runner asynchronously.
59         if (endpoint_client_) {
60           endpoint_client_->RaiseError();
61         }
62       } else {
63         task_runner_->PostTask(
64             FROM_HERE,
65             base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
66       }
67     }
68   }
69 
70   // MessageReceiver implementation:
PrefersSerializedMessages()71   bool PrefersSerializedMessages() override {
72     return endpoint_client_ && endpoint_client_->PrefersSerializedMessages();
73   }
74 
Accept(Message * message)75   bool Accept(Message* message) override {
76     DCHECK(task_runner_->RunsTasksInCurrentSequence());
77     accept_was_invoked_ = true;
78     DCHECK(message->has_flag(Message::kFlagIsResponse));
79 
80     bool result = false;
81 
82     if (endpoint_client_)
83       result = endpoint_client_->Accept(message);
84 
85     return result;
86   }
87 
88   // MessageReceiverWithStatus implementation:
IsConnected()89   bool IsConnected() override {
90     DCHECK(task_runner_->RunsTasksInCurrentSequence());
91     return endpoint_client_ && !endpoint_client_->encountered_error();
92   }
93 
IsConnectedAsync(base::OnceCallback<void (bool)> callback)94   void IsConnectedAsync(base::OnceCallback<void(bool)> callback) override {
95     if (task_runner_->RunsTasksInCurrentSequence()) {
96       DetermineIfEndpointIsConnected(endpoint_client_, std::move(callback));
97     } else {
98       task_runner_->PostTask(
99           FROM_HERE, base::BindOnce(&DetermineIfEndpointIsConnected,
100                                     endpoint_client_, std::move(callback)));
101     }
102   }
103 
104  private:
105   base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
106   bool accept_was_invoked_;
107   scoped_refptr<base::SequencedTaskRunner> task_runner_;
108 
109   DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
110 };
111 
112 }  // namespace
113 
114 // ----------------------------------------------------------------------------
115 
SyncResponseInfo(bool * in_response_received)116 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
117     bool* in_response_received)
118     : response_received(in_response_received) {}
119 
~SyncResponseInfo()120 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
121 
122 // ----------------------------------------------------------------------------
123 
HandleIncomingMessageThunk(InterfaceEndpointClient * owner)124 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
125     InterfaceEndpointClient* owner)
126     : owner_(owner) {}
127 
128 InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk()129     ~HandleIncomingMessageThunk() {}
130 
Accept(Message * message)131 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
132     Message* message) {
133   return owner_->HandleValidatedMessage(message);
134 }
135 
136 // ----------------------------------------------------------------------------
137 
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,MessageReceiverWithResponderStatus * receiver,std::unique_ptr<MessageReceiver> payload_validator,bool expect_sync_requests,scoped_refptr<base::SequencedTaskRunner> runner,uint32_t interface_version)138 InterfaceEndpointClient::InterfaceEndpointClient(
139     ScopedInterfaceEndpointHandle handle,
140     MessageReceiverWithResponderStatus* receiver,
141     std::unique_ptr<MessageReceiver> payload_validator,
142     bool expect_sync_requests,
143     scoped_refptr<base::SequencedTaskRunner> runner,
144     uint32_t interface_version)
145     : expect_sync_requests_(expect_sync_requests),
146       handle_(std::move(handle)),
147       incoming_receiver_(receiver),
148       thunk_(this),
149       filters_(&thunk_),
150       task_runner_(std::move(runner)),
151       control_message_proxy_(this),
152       control_message_handler_(interface_version),
153       weak_ptr_factory_(this) {
154   DCHECK(handle_.is_valid());
155 
156   // TODO(yzshen): the way to use validator (or message filter in general)
157   // directly is a little awkward.
158   if (payload_validator)
159     filters_.Append(std::move(payload_validator));
160 
161   if (handle_.pending_association()) {
162     handle_.SetAssociationEventHandler(base::Bind(
163         &InterfaceEndpointClient::OnAssociationEvent, base::Unretained(this)));
164   } else {
165     InitControllerIfNecessary();
166   }
167 }
168 
~InterfaceEndpointClient()169 InterfaceEndpointClient::~InterfaceEndpointClient() {
170   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
171   if (controller_)
172     handle_.group_controller()->DetachEndpointClient(handle_);
173 }
174 
associated_group()175 AssociatedGroup* InterfaceEndpointClient::associated_group() {
176   if (!associated_group_)
177     associated_group_ = std::make_unique<AssociatedGroup>(handle_);
178   return associated_group_.get();
179 }
180 
PassHandle()181 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
182   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
183   DCHECK(!has_pending_responders());
184 
185   if (!handle_.is_valid())
186     return ScopedInterfaceEndpointHandle();
187 
188   handle_.SetAssociationEventHandler(
189       ScopedInterfaceEndpointHandle::AssociationEventCallback());
190 
191   if (controller_) {
192     controller_ = nullptr;
193     handle_.group_controller()->DetachEndpointClient(handle_);
194   }
195 
196   return std::move(handle_);
197 }
198 
AddFilter(std::unique_ptr<MessageReceiver> filter)199 void InterfaceEndpointClient::AddFilter(
200     std::unique_ptr<MessageReceiver> filter) {
201   filters_.Append(std::move(filter));
202 }
203 
RaiseError()204 void InterfaceEndpointClient::RaiseError() {
205   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
206 
207   if (!handle_.pending_association())
208     handle_.group_controller()->RaiseError();
209 }
210 
CloseWithReason(uint32_t custom_reason,const std::string & description)211 void InterfaceEndpointClient::CloseWithReason(uint32_t custom_reason,
212                                               const std::string& description) {
213   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
214 
215   auto handle = PassHandle();
216   handle.ResetWithReason(custom_reason, description);
217 }
218 
PrefersSerializedMessages()219 bool InterfaceEndpointClient::PrefersSerializedMessages() {
220   auto* controller = handle_.group_controller();
221   return controller && controller->PrefersSerializedMessages();
222 }
223 
Accept(Message * message)224 bool InterfaceEndpointClient::Accept(Message* message) {
225   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
226   DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
227   DCHECK(!handle_.pending_association());
228 
229   // This has to been done even if connection error has occurred. For example,
230   // the message contains a pending associated request. The user may try to use
231   // the corresponding associated interface pointer after sending this message.
232   // That associated interface pointer has to join an associated group in order
233   // to work properly.
234   if (!message->associated_endpoint_handles()->empty())
235     message->SerializeAssociatedEndpointHandles(handle_.group_controller());
236 
237   if (encountered_error_)
238     return false;
239 
240   InitControllerIfNecessary();
241 
242   return controller_->SendMessage(message);
243 }
244 
AcceptWithResponder(Message * message,std::unique_ptr<MessageReceiver> responder)245 bool InterfaceEndpointClient::AcceptWithResponder(
246     Message* message,
247     std::unique_ptr<MessageReceiver> responder) {
248   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
249   DCHECK(message->has_flag(Message::kFlagExpectsResponse));
250   DCHECK(!handle_.pending_association());
251 
252   // Please see comments in Accept().
253   if (!message->associated_endpoint_handles()->empty())
254     message->SerializeAssociatedEndpointHandles(handle_.group_controller());
255 
256   if (encountered_error_)
257     return false;
258 
259   InitControllerIfNecessary();
260 
261   // Reserve 0 in case we want it to convey special meaning in the future.
262   uint64_t request_id = next_request_id_++;
263   if (request_id == 0)
264     request_id = next_request_id_++;
265 
266   message->set_request_id(request_id);
267 
268   bool is_sync = message->has_flag(Message::kFlagIsSync);
269   if (!controller_->SendMessage(message))
270     return false;
271 
272   if (!is_sync) {
273     async_responders_[request_id] = std::move(responder);
274     return true;
275   }
276 
277   SyncCallRestrictions::AssertSyncCallAllowed();
278 
279   bool response_received = false;
280   sync_responses_.insert(std::make_pair(
281       request_id, std::make_unique<SyncResponseInfo>(&response_received)));
282 
283   base::WeakPtr<InterfaceEndpointClient> weak_self =
284       weak_ptr_factory_.GetWeakPtr();
285   controller_->SyncWatch(&response_received);
286   // Make sure that this instance hasn't been destroyed.
287   if (weak_self) {
288     DCHECK(base::ContainsKey(sync_responses_, request_id));
289     auto iter = sync_responses_.find(request_id);
290     DCHECK_EQ(&response_received, iter->second->response_received);
291     if (response_received) {
292       ignore_result(responder->Accept(&iter->second->response));
293     } else {
294       DVLOG(1) << "Mojo sync call returns without receiving a response. "
295                << "Typcially it is because the interface has been "
296                << "disconnected.";
297     }
298     sync_responses_.erase(iter);
299   }
300 
301   return true;
302 }
303 
HandleIncomingMessage(Message * message)304 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
305   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
306   return filters_.Accept(message);
307 }
308 
NotifyError(const base::Optional<DisconnectReason> & reason)309 void InterfaceEndpointClient::NotifyError(
310     const base::Optional<DisconnectReason>& reason) {
311   DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_);
312 
313   if (encountered_error_)
314     return;
315   encountered_error_ = true;
316 
317   // Response callbacks may hold on to resource, and there's no need to keep
318   // them alive any longer. Note that it's allowed that a pending response
319   // callback may own this endpoint, so we simply move the responders onto the
320   // stack here and let them be destroyed when the stack unwinds.
321   AsyncResponderMap responders = std::move(async_responders_);
322 
323   control_message_proxy_.OnConnectionError();
324 
325   if (error_handler_) {
326     std::move(error_handler_).Run();
327   } else if (error_with_reason_handler_) {
328     if (reason) {
329       std::move(error_with_reason_handler_)
330           .Run(reason->custom_reason, reason->description);
331     } else {
332       std::move(error_with_reason_handler_).Run(0, std::string());
333     }
334   }
335 }
336 
QueryVersion(const base::Callback<void (uint32_t)> & callback)337 void InterfaceEndpointClient::QueryVersion(
338     const base::Callback<void(uint32_t)>& callback) {
339   control_message_proxy_.QueryVersion(callback);
340 }
341 
RequireVersion(uint32_t version)342 void InterfaceEndpointClient::RequireVersion(uint32_t version) {
343   control_message_proxy_.RequireVersion(version);
344 }
345 
FlushForTesting()346 void InterfaceEndpointClient::FlushForTesting() {
347   control_message_proxy_.FlushForTesting();
348 }
349 
InitControllerIfNecessary()350 void InterfaceEndpointClient::InitControllerIfNecessary() {
351   if (controller_ || handle_.pending_association())
352     return;
353 
354   controller_ = handle_.group_controller()->AttachEndpointClient(handle_, this,
355                                                                  task_runner_);
356   if (expect_sync_requests_)
357     controller_->AllowWokenUpBySyncWatchOnSameThread();
358 }
359 
OnAssociationEvent(ScopedInterfaceEndpointHandle::AssociationEvent event)360 void InterfaceEndpointClient::OnAssociationEvent(
361     ScopedInterfaceEndpointHandle::AssociationEvent event) {
362   if (event == ScopedInterfaceEndpointHandle::ASSOCIATED) {
363     InitControllerIfNecessary();
364   } else if (event ==
365              ScopedInterfaceEndpointHandle::PEER_CLOSED_BEFORE_ASSOCIATION) {
366     task_runner_->PostTask(FROM_HERE,
367                            base::Bind(&InterfaceEndpointClient::NotifyError,
368                                       weak_ptr_factory_.GetWeakPtr(),
369                                       handle_.disconnect_reason()));
370   }
371 }
372 
HandleValidatedMessage(Message * message)373 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
374   DCHECK_EQ(handle_.id(), message->interface_id());
375 
376   if (encountered_error_) {
377     // This message is received after error has been encountered. For associated
378     // interfaces, this means the remote side sends a
379     // PeerAssociatedEndpointClosed event but continues to send more messages
380     // for the same interface. Close the pipe because this shouldn't happen.
381     DVLOG(1) << "A message is received for an interface after it has been "
382              << "disconnected. Closing the pipe.";
383     return false;
384   }
385 
386   if (message->has_flag(Message::kFlagExpectsResponse)) {
387     std::unique_ptr<MessageReceiverWithStatus> responder =
388         std::make_unique<ResponderThunk>(weak_ptr_factory_.GetWeakPtr(),
389                                          task_runner_);
390     if (mojo::internal::ControlMessageHandler::IsControlMessage(message)) {
391       return control_message_handler_.AcceptWithResponder(message,
392                                                           std::move(responder));
393     } else {
394       return incoming_receiver_->AcceptWithResponder(message,
395                                                      std::move(responder));
396     }
397   } else if (message->has_flag(Message::kFlagIsResponse)) {
398     uint64_t request_id = message->request_id();
399 
400     if (message->has_flag(Message::kFlagIsSync)) {
401       auto it = sync_responses_.find(request_id);
402       if (it == sync_responses_.end())
403         return false;
404       it->second->response = std::move(*message);
405       *it->second->response_received = true;
406       return true;
407     }
408 
409     auto it = async_responders_.find(request_id);
410     if (it == async_responders_.end())
411       return false;
412     std::unique_ptr<MessageReceiver> responder = std::move(it->second);
413     async_responders_.erase(it);
414     return responder->Accept(message);
415   } else {
416     if (mojo::internal::ControlMessageHandler::IsControlMessage(message))
417       return control_message_handler_.Accept(message);
418 
419     return incoming_receiver_->Accept(message);
420   }
421 }
422 
423 }  // namespace mojo
424