1 // Copyright 2014 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/lib/router.h"
6 
7 #include <stdint.h>
8 
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/logging.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/stl_util.h"
16 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
17 
18 namespace mojo {
19 namespace internal {
20 
21 // ----------------------------------------------------------------------------
22 
23 namespace {
24 
DCheckIfInvalid(const base::WeakPtr<Router> & router,const std::string & message)25 void DCheckIfInvalid(const base::WeakPtr<Router>& router,
26                    const std::string& message) {
27   bool is_valid = router && !router->encountered_error() && router->is_valid();
28   DCHECK(!is_valid) << message;
29 }
30 
31 class ResponderThunk : public MessageReceiverWithStatus {
32  public:
ResponderThunk(const base::WeakPtr<Router> & router,scoped_refptr<base::SingleThreadTaskRunner> runner)33   explicit ResponderThunk(const base::WeakPtr<Router>& router,
34                           scoped_refptr<base::SingleThreadTaskRunner> runner)
35       : router_(router),
36         accept_was_invoked_(false),
37         task_runner_(std::move(runner)) {}
~ResponderThunk()38   ~ResponderThunk() override {
39     if (!accept_was_invoked_) {
40       // The Mojo application handled a message that was expecting a response
41       // but did not send a response.
42       // We raise an error to signal the calling application that an error
43       // condition occurred. Without this the calling application would have no
44       // way of knowing it should stop waiting for a response.
45       if (task_runner_->RunsTasksOnCurrentThread()) {
46         // Please note that even if this code is run from a different task
47         // runner on the same thread as |task_runner_|, it is okay to directly
48         // call Router::RaiseError(), because it will raise error from the
49         // correct task runner asynchronously.
50         if (router_)
51           router_->RaiseError();
52       } else {
53         task_runner_->PostTask(FROM_HERE,
54                                base::Bind(&Router::RaiseError, router_));
55       }
56     }
57   }
58 
59   // MessageReceiver implementation:
Accept(Message * message)60   bool Accept(Message* message) override {
61     DCHECK(task_runner_->RunsTasksOnCurrentThread());
62     accept_was_invoked_ = true;
63     DCHECK(message->has_flag(Message::kFlagIsResponse));
64 
65     bool result = false;
66 
67     if (router_)
68       result = router_->Accept(message);
69 
70     return result;
71   }
72 
73   // MessageReceiverWithStatus implementation:
IsValid()74   bool IsValid() override {
75     DCHECK(task_runner_->RunsTasksOnCurrentThread());
76     return router_ && !router_->encountered_error() && router_->is_valid();
77   }
78 
DCheckInvalid(const std::string & message)79   void DCheckInvalid(const std::string& message) override {
80     if (task_runner_->RunsTasksOnCurrentThread()) {
81       DCheckIfInvalid(router_, message);
82     } else {
83       task_runner_->PostTask(FROM_HERE,
84                              base::Bind(&DCheckIfInvalid, router_, message));
85     }
86   }
87 
88  private:
89   base::WeakPtr<Router> router_;
90   bool accept_was_invoked_;
91   scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
92 };
93 
94 }  // namespace
95 
96 // ----------------------------------------------------------------------------
97 
SyncResponseInfo(bool * in_response_received)98 Router::SyncResponseInfo::SyncResponseInfo(bool* in_response_received)
99     : response_received(in_response_received) {}
100 
~SyncResponseInfo()101 Router::SyncResponseInfo::~SyncResponseInfo() {}
102 
103 // ----------------------------------------------------------------------------
104 
HandleIncomingMessageThunk(Router * router)105 Router::HandleIncomingMessageThunk::HandleIncomingMessageThunk(Router* router)
106     : router_(router) {
107 }
108 
~HandleIncomingMessageThunk()109 Router::HandleIncomingMessageThunk::~HandleIncomingMessageThunk() {
110 }
111 
Accept(Message * message)112 bool Router::HandleIncomingMessageThunk::Accept(Message* message) {
113   return router_->HandleIncomingMessage(message);
114 }
115 
116 // ----------------------------------------------------------------------------
117 
Router(ScopedMessagePipeHandle message_pipe,FilterChain filters,bool expects_sync_requests,scoped_refptr<base::SingleThreadTaskRunner> runner)118 Router::Router(ScopedMessagePipeHandle message_pipe,
119                FilterChain filters,
120                bool expects_sync_requests,
121                scoped_refptr<base::SingleThreadTaskRunner> runner)
122     : thunk_(this),
123       filters_(std::move(filters)),
124       connector_(std::move(message_pipe),
125                  Connector::SINGLE_THREADED_SEND,
126                  std::move(runner)),
127       incoming_receiver_(nullptr),
128       next_request_id_(0),
129       testing_mode_(false),
130       pending_task_for_messages_(false),
131       encountered_error_(false),
132       weak_factory_(this) {
133   filters_.SetSink(&thunk_);
134   if (expects_sync_requests)
135     connector_.AllowWokenUpBySyncWatchOnSameThread();
136   connector_.set_incoming_receiver(filters_.GetHead());
137   connector_.set_connection_error_handler(
138       base::Bind(&Router::OnConnectionError, base::Unretained(this)));
139 }
140 
~Router()141 Router::~Router() {}
142 
Accept(Message * message)143 bool Router::Accept(Message* message) {
144   DCHECK(thread_checker_.CalledOnValidThread());
145   DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
146   return connector_.Accept(message);
147 }
148 
AcceptWithResponder(Message * message,MessageReceiver * responder)149 bool Router::AcceptWithResponder(Message* message, MessageReceiver* responder) {
150   DCHECK(thread_checker_.CalledOnValidThread());
151   DCHECK(message->has_flag(Message::kFlagExpectsResponse));
152 
153   // Reserve 0 in case we want it to convey special meaning in the future.
154   uint64_t request_id = next_request_id_++;
155   if (request_id == 0)
156     request_id = next_request_id_++;
157 
158   bool is_sync = message->has_flag(Message::kFlagIsSync);
159   message->set_request_id(request_id);
160   if (!connector_.Accept(message))
161     return false;
162 
163   if (!is_sync) {
164     // We assume ownership of |responder|.
165     async_responders_[request_id] = base::WrapUnique(responder);
166     return true;
167   }
168 
169   SyncCallRestrictions::AssertSyncCallAllowed();
170 
171   bool response_received = false;
172   std::unique_ptr<MessageReceiver> sync_responder(responder);
173   sync_responses_.insert(std::make_pair(
174       request_id, base::WrapUnique(new SyncResponseInfo(&response_received))));
175 
176   base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr();
177   connector_.SyncWatch(&response_received);
178   // Make sure that this instance hasn't been destroyed.
179   if (weak_self) {
180     DCHECK(ContainsKey(sync_responses_, request_id));
181     auto iter = sync_responses_.find(request_id);
182     DCHECK_EQ(&response_received, iter->second->response_received);
183     if (response_received) {
184       std::unique_ptr<Message> response = std::move(iter->second->response);
185       ignore_result(sync_responder->Accept(response.get()));
186     }
187     sync_responses_.erase(iter);
188   }
189 
190   // Return true means that we take ownership of |responder|.
191   return true;
192 }
193 
EnableTestingMode()194 void Router::EnableTestingMode() {
195   DCHECK(thread_checker_.CalledOnValidThread());
196   testing_mode_ = true;
197   connector_.set_enforce_errors_from_incoming_receiver(false);
198 }
199 
HandleIncomingMessage(Message * message)200 bool Router::HandleIncomingMessage(Message* message) {
201   DCHECK(thread_checker_.CalledOnValidThread());
202 
203   const bool during_sync_call =
204       connector_.during_sync_handle_watcher_callback();
205   if (!message->has_flag(Message::kFlagIsSync) &&
206       (during_sync_call || !pending_messages_.empty())) {
207     std::unique_ptr<Message> pending_message(new Message);
208     message->MoveTo(pending_message.get());
209     pending_messages_.push(std::move(pending_message));
210 
211     if (!pending_task_for_messages_) {
212       pending_task_for_messages_ = true;
213       connector_.task_runner()->PostTask(
214           FROM_HERE, base::Bind(&Router::HandleQueuedMessages,
215                                 weak_factory_.GetWeakPtr()));
216     }
217 
218     return true;
219   }
220 
221   return HandleMessageInternal(message);
222 }
223 
HandleQueuedMessages()224 void Router::HandleQueuedMessages() {
225   DCHECK(thread_checker_.CalledOnValidThread());
226   DCHECK(pending_task_for_messages_);
227 
228   base::WeakPtr<Router> weak_self = weak_factory_.GetWeakPtr();
229   while (!pending_messages_.empty()) {
230     std::unique_ptr<Message> message(std::move(pending_messages_.front()));
231     pending_messages_.pop();
232 
233     bool result = HandleMessageInternal(message.get());
234     if (!weak_self)
235       return;
236 
237     if (!result && !testing_mode_) {
238       connector_.RaiseError();
239       break;
240     }
241   }
242 
243   pending_task_for_messages_ = false;
244 
245   // We may have already seen a connection error from the connector, but
246   // haven't notified the user because we want to process all the queued
247   // messages first. We should do it now.
248   if (connector_.encountered_error() && !encountered_error_)
249     OnConnectionError();
250 }
251 
HandleMessageInternal(Message * message)252 bool Router::HandleMessageInternal(Message* message) {
253   if (message->has_flag(Message::kFlagExpectsResponse)) {
254     if (!incoming_receiver_)
255       return false;
256 
257     MessageReceiverWithStatus* responder = new ResponderThunk(
258         weak_factory_.GetWeakPtr(), connector_.task_runner());
259     bool ok = incoming_receiver_->AcceptWithResponder(message, responder);
260     if (!ok)
261       delete responder;
262     return ok;
263 
264   } else if (message->has_flag(Message::kFlagIsResponse)) {
265     uint64_t request_id = message->request_id();
266 
267     if (message->has_flag(Message::kFlagIsSync)) {
268       auto it = sync_responses_.find(request_id);
269       if (it == sync_responses_.end()) {
270         DCHECK(testing_mode_);
271         return false;
272       }
273       it->second->response.reset(new Message());
274       message->MoveTo(it->second->response.get());
275       *it->second->response_received = true;
276       return true;
277     }
278 
279     auto it = async_responders_.find(request_id);
280     if (it == async_responders_.end()) {
281       DCHECK(testing_mode_);
282       return false;
283     }
284     std::unique_ptr<MessageReceiver> responder = std::move(it->second);
285     async_responders_.erase(it);
286     return responder->Accept(message);
287   } else {
288     if (!incoming_receiver_)
289       return false;
290 
291     return incoming_receiver_->Accept(message);
292   }
293 }
294 
OnConnectionError()295 void Router::OnConnectionError() {
296   if (encountered_error_)
297     return;
298 
299   if (!pending_messages_.empty()) {
300     // After all the pending messages are processed, we will check whether an
301     // error has been encountered and run the user's connection error handler
302     // if necessary.
303     DCHECK(pending_task_for_messages_);
304     return;
305   }
306 
307   if (connector_.during_sync_handle_watcher_callback()) {
308     // We don't want the error handler to reenter an ongoing sync call.
309     connector_.task_runner()->PostTask(
310         FROM_HERE,
311         base::Bind(&Router::OnConnectionError, weak_factory_.GetWeakPtr()));
312     return;
313   }
314 
315   encountered_error_ = true;
316   if (!error_handler_.is_null())
317     error_handler_.Run();
318 }
319 
320 // ----------------------------------------------------------------------------
321 
322 }  // namespace internal
323 }  // namespace mojo
324