1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/interface_endpoint_client.h"
6 
7 #include <stdint.h>
8 
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/location.h"
13 #include "base/macros.h"
14 #include "base/memory/ptr_util.h"
15 #include "base/single_thread_task_runner.h"
16 #include "base/stl_util.h"
17 #include "mojo/public/cpp/bindings/associated_group.h"
18 #include "mojo/public/cpp/bindings/associated_group_controller.h"
19 #include "mojo/public/cpp/bindings/interface_endpoint_controller.h"
20 #include "mojo/public/cpp/bindings/sync_call_restrictions.h"
21 
22 namespace mojo {
23 
24 // ----------------------------------------------------------------------------
25 
26 namespace {
27 
DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient> & client,const std::string & message)28 void DCheckIfInvalid(const base::WeakPtr<InterfaceEndpointClient>& client,
29                    const std::string& message) {
30   bool is_valid = client && !client->encountered_error();
31   DCHECK(!is_valid) << message;
32 }
33 
34 // When receiving an incoming message which expects a repsonse,
35 // InterfaceEndpointClient creates a ResponderThunk object and passes it to the
36 // incoming message receiver. When the receiver finishes processing the message,
37 // it can provide a response using this object.
38 class ResponderThunk : public MessageReceiverWithStatus {
39  public:
ResponderThunk(const base::WeakPtr<InterfaceEndpointClient> & endpoint_client,scoped_refptr<base::SingleThreadTaskRunner> runner)40   explicit ResponderThunk(
41       const base::WeakPtr<InterfaceEndpointClient>& endpoint_client,
42       scoped_refptr<base::SingleThreadTaskRunner> runner)
43       : endpoint_client_(endpoint_client),
44         accept_was_invoked_(false),
45         task_runner_(std::move(runner)) {}
~ResponderThunk()46   ~ResponderThunk() override {
47     if (!accept_was_invoked_) {
48       // The Mojo application handled a message that was expecting a response
49       // but did not send a response.
50       // We raise an error to signal the calling application that an error
51       // condition occurred. Without this the calling application would have no
52       // way of knowing it should stop waiting for a response.
53       if (task_runner_->RunsTasksOnCurrentThread()) {
54         // Please note that even if this code is run from a different task
55         // runner on the same thread as |task_runner_|, it is okay to directly
56         // call InterfaceEndpointClient::RaiseError(), because it will raise
57         // error from the correct task runner asynchronously.
58         if (endpoint_client_) {
59           endpoint_client_->RaiseError();
60         }
61       } else {
62         task_runner_->PostTask(
63             FROM_HERE,
64             base::Bind(&InterfaceEndpointClient::RaiseError, endpoint_client_));
65       }
66     }
67   }
68 
69   // MessageReceiver implementation:
Accept(Message * message)70   bool Accept(Message* message) override {
71     DCHECK(task_runner_->RunsTasksOnCurrentThread());
72     accept_was_invoked_ = true;
73     DCHECK(message->has_flag(Message::kFlagIsResponse));
74 
75     bool result = false;
76 
77     if (endpoint_client_)
78       result = endpoint_client_->Accept(message);
79 
80     return result;
81   }
82 
83   // MessageReceiverWithStatus implementation:
IsValid()84   bool IsValid() override {
85     DCHECK(task_runner_->RunsTasksOnCurrentThread());
86     return endpoint_client_ && !endpoint_client_->encountered_error();
87   }
88 
DCheckInvalid(const std::string & message)89   void DCheckInvalid(const std::string& message) override {
90     if (task_runner_->RunsTasksOnCurrentThread()) {
91       DCheckIfInvalid(endpoint_client_, message);
92     } else {
93       task_runner_->PostTask(
94           FROM_HERE, base::Bind(&DCheckIfInvalid, endpoint_client_, message));
95     }
96  }
97 
98  private:
99   base::WeakPtr<InterfaceEndpointClient> endpoint_client_;
100   bool accept_was_invoked_;
101   scoped_refptr<base::SingleThreadTaskRunner> task_runner_;
102 
103   DISALLOW_COPY_AND_ASSIGN(ResponderThunk);
104 };
105 
106 }  // namespace
107 
108 // ----------------------------------------------------------------------------
109 
SyncResponseInfo(bool * in_response_received)110 InterfaceEndpointClient::SyncResponseInfo::SyncResponseInfo(
111     bool* in_response_received)
112     : response_received(in_response_received) {}
113 
~SyncResponseInfo()114 InterfaceEndpointClient::SyncResponseInfo::~SyncResponseInfo() {}
115 
116 // ----------------------------------------------------------------------------
117 
HandleIncomingMessageThunk(InterfaceEndpointClient * owner)118 InterfaceEndpointClient::HandleIncomingMessageThunk::HandleIncomingMessageThunk(
119     InterfaceEndpointClient* owner)
120     : owner_(owner) {}
121 
122 InterfaceEndpointClient::HandleIncomingMessageThunk::
~HandleIncomingMessageThunk()123     ~HandleIncomingMessageThunk() {}
124 
Accept(Message * message)125 bool InterfaceEndpointClient::HandleIncomingMessageThunk::Accept(
126     Message* message) {
127   return owner_->HandleValidatedMessage(message);
128 }
129 
130 // ----------------------------------------------------------------------------
131 
InterfaceEndpointClient(ScopedInterfaceEndpointHandle handle,MessageReceiverWithResponderStatus * receiver,std::unique_ptr<MessageFilter> payload_validator,bool expect_sync_requests,scoped_refptr<base::SingleThreadTaskRunner> runner)132 InterfaceEndpointClient::InterfaceEndpointClient(
133     ScopedInterfaceEndpointHandle handle,
134     MessageReceiverWithResponderStatus* receiver,
135     std::unique_ptr<MessageFilter> payload_validator,
136     bool expect_sync_requests,
137     scoped_refptr<base::SingleThreadTaskRunner> runner)
138     : handle_(std::move(handle)),
139       incoming_receiver_(receiver),
140       payload_validator_(std::move(payload_validator)),
141       thunk_(this),
142       next_request_id_(1),
143       encountered_error_(false),
144       task_runner_(std::move(runner)),
145       weak_ptr_factory_(this) {
146   DCHECK(handle_.is_valid());
147   DCHECK(handle_.is_local());
148 
149   // TODO(yzshen): the way to use validator (or message filter in general)
150   // directly is a little awkward.
151   payload_validator_->set_sink(&thunk_);
152 
153   controller_ = handle_.group_controller()->AttachEndpointClient(
154       handle_, this, task_runner_);
155   if (expect_sync_requests)
156     controller_->AllowWokenUpBySyncWatchOnSameThread();
157 }
158 
~InterfaceEndpointClient()159 InterfaceEndpointClient::~InterfaceEndpointClient() {
160   DCHECK(thread_checker_.CalledOnValidThread());
161 
162   handle_.group_controller()->DetachEndpointClient(handle_);
163 }
164 
associated_group()165 AssociatedGroup* InterfaceEndpointClient::associated_group() {
166   if (!associated_group_)
167     associated_group_ = handle_.group_controller()->CreateAssociatedGroup();
168   return associated_group_.get();
169 }
170 
interface_id() const171 uint32_t InterfaceEndpointClient::interface_id() const {
172   DCHECK(thread_checker_.CalledOnValidThread());
173   return handle_.id();
174 }
175 
PassHandle()176 ScopedInterfaceEndpointHandle InterfaceEndpointClient::PassHandle() {
177   DCHECK(thread_checker_.CalledOnValidThread());
178   DCHECK(!has_pending_responders());
179 
180   if (!handle_.is_valid())
181     return ScopedInterfaceEndpointHandle();
182 
183   controller_ = nullptr;
184   handle_.group_controller()->DetachEndpointClient(handle_);
185 
186   return std::move(handle_);
187 }
188 
RaiseError()189 void InterfaceEndpointClient::RaiseError() {
190   DCHECK(thread_checker_.CalledOnValidThread());
191 
192   handle_.group_controller()->RaiseError();
193 }
194 
Accept(Message * message)195 bool InterfaceEndpointClient::Accept(Message* message) {
196   DCHECK(thread_checker_.CalledOnValidThread());
197   DCHECK(controller_);
198   DCHECK(!message->has_flag(Message::kFlagExpectsResponse));
199 
200   if (encountered_error_)
201     return false;
202 
203   return controller_->SendMessage(message);
204 }
205 
AcceptWithResponder(Message * message,MessageReceiver * responder)206 bool InterfaceEndpointClient::AcceptWithResponder(Message* message,
207                                                   MessageReceiver* responder) {
208   DCHECK(thread_checker_.CalledOnValidThread());
209   DCHECK(controller_);
210   DCHECK(message->has_flag(Message::kFlagExpectsResponse));
211 
212   if (encountered_error_)
213     return false;
214 
215   // Reserve 0 in case we want it to convey special meaning in the future.
216   uint64_t request_id = next_request_id_++;
217   if (request_id == 0)
218     request_id = next_request_id_++;
219 
220   message->set_request_id(request_id);
221 
222   bool is_sync = message->has_flag(Message::kFlagIsSync);
223   if (!controller_->SendMessage(message))
224     return false;
225 
226   if (!is_sync) {
227     // We assume ownership of |responder|.
228     async_responders_[request_id] = base::WrapUnique(responder);
229     return true;
230   }
231 
232   SyncCallRestrictions::AssertSyncCallAllowed();
233 
234   bool response_received = false;
235   std::unique_ptr<MessageReceiver> sync_responder(responder);
236   sync_responses_.insert(std::make_pair(
237       request_id, base::WrapUnique(new SyncResponseInfo(&response_received))));
238 
239   base::WeakPtr<InterfaceEndpointClient> weak_self =
240       weak_ptr_factory_.GetWeakPtr();
241   controller_->SyncWatch(&response_received);
242   // Make sure that this instance hasn't been destroyed.
243   if (weak_self) {
244     DCHECK(ContainsKey(sync_responses_, request_id));
245     auto iter = sync_responses_.find(request_id);
246     DCHECK_EQ(&response_received, iter->second->response_received);
247     if (response_received) {
248       std::unique_ptr<Message> response = std::move(iter->second->response);
249       ignore_result(sync_responder->Accept(response.get()));
250     }
251     sync_responses_.erase(iter);
252   }
253 
254   // Return true means that we take ownership of |responder|.
255   return true;
256 }
257 
HandleIncomingMessage(Message * message)258 bool InterfaceEndpointClient::HandleIncomingMessage(Message* message) {
259   DCHECK(thread_checker_.CalledOnValidThread());
260 
261   return payload_validator_->Accept(message);
262 }
263 
NotifyError()264 void InterfaceEndpointClient::NotifyError() {
265   DCHECK(thread_checker_.CalledOnValidThread());
266 
267   if (encountered_error_)
268     return;
269   encountered_error_ = true;
270   if (!error_handler_.is_null())
271     error_handler_.Run();
272 }
273 
HandleValidatedMessage(Message * message)274 bool InterfaceEndpointClient::HandleValidatedMessage(Message* message) {
275   DCHECK_EQ(handle_.id(), message->interface_id());
276 
277   if (message->has_flag(Message::kFlagExpectsResponse)) {
278     if (!incoming_receiver_)
279       return false;
280 
281     MessageReceiverWithStatus* responder =
282         new ResponderThunk(weak_ptr_factory_.GetWeakPtr(), task_runner_);
283     bool ok = incoming_receiver_->AcceptWithResponder(message, responder);
284     if (!ok)
285       delete responder;
286     return ok;
287   } else if (message->has_flag(Message::kFlagIsResponse)) {
288     uint64_t request_id = message->request_id();
289 
290     if (message->has_flag(Message::kFlagIsSync)) {
291       auto it = sync_responses_.find(request_id);
292       if (it == sync_responses_.end())
293         return false;
294       it->second->response.reset(new Message());
295       message->MoveTo(it->second->response.get());
296       *it->second->response_received = true;
297       return true;
298     }
299 
300     auto it = async_responders_.find(request_id);
301     if (it == async_responders_.end())
302       return false;
303     std::unique_ptr<MessageReceiver> responder = std::move(it->second);
304     async_responders_.erase(it);
305     return responder->Accept(message);
306   } else {
307     if (!incoming_receiver_)
308       return false;
309 
310     return incoming_receiver_->Accept(message);
311   }
312 }
313 
314 }  // namespace mojo
315