1 // Copyright 2015 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/lib/control_message_proxy.h"
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <utility>
10 
11 #include "base/bind.h"
12 #include "base/callback_helpers.h"
13 #include "base/macros.h"
14 #include "base/run_loop.h"
15 #include "mojo/public/cpp/bindings/lib/serialization.h"
16 #include "mojo/public/cpp/bindings/lib/validation_util.h"
17 #include "mojo/public/cpp/bindings/message.h"
18 #include "mojo/public/interfaces/bindings/interface_control_messages.mojom.h"
19 
20 namespace mojo {
21 namespace internal {
22 
23 namespace {
24 
ValidateControlResponse(Message * message)25 bool ValidateControlResponse(Message* message) {
26   ValidationContext validation_context(message->payload(),
27                                        message->payload_num_bytes(), 0, 0,
28                                        message, "ControlResponseValidator");
29   if (!ValidateMessageIsResponse(message, &validation_context))
30     return false;
31 
32   switch (message->header()->name) {
33     case interface_control::kRunMessageId:
34       return ValidateMessagePayload<
35           interface_control::internal::RunResponseMessageParams_Data>(
36           message, &validation_context);
37   }
38   return false;
39 }
40 
41 using RunCallback =
42     base::Callback<void(interface_control::RunResponseMessageParamsPtr)>;
43 
44 class RunResponseForwardToCallback : public MessageReceiver {
45  public:
RunResponseForwardToCallback(const RunCallback & callback)46   explicit RunResponseForwardToCallback(const RunCallback& callback)
47       : callback_(callback) {}
48   bool Accept(Message* message) override;
49 
50  private:
51   RunCallback callback_;
52   DISALLOW_COPY_AND_ASSIGN(RunResponseForwardToCallback);
53 };
54 
Accept(Message * message)55 bool RunResponseForwardToCallback::Accept(Message* message) {
56   if (!ValidateControlResponse(message))
57     return false;
58 
59   interface_control::internal::RunResponseMessageParams_Data* params =
60       reinterpret_cast<
61           interface_control::internal::RunResponseMessageParams_Data*>(
62           message->mutable_payload());
63   interface_control::RunResponseMessageParamsPtr params_ptr;
64   SerializationContext context;
65   Deserialize<interface_control::RunResponseMessageParamsDataView>(
66       params, &params_ptr, &context);
67 
68   callback_.Run(std::move(params_ptr));
69   return true;
70 }
71 
SendRunMessage(MessageReceiverWithResponder * receiver,interface_control::RunInputPtr input_ptr,const RunCallback & callback)72 void SendRunMessage(MessageReceiverWithResponder* receiver,
73                     interface_control::RunInputPtr input_ptr,
74                     const RunCallback& callback) {
75   auto params_ptr = interface_control::RunMessageParams::New();
76   params_ptr->input = std::move(input_ptr);
77   Message message(interface_control::kRunMessageId,
78                   Message::kFlagExpectsResponse, 0, 0, nullptr);
79   SerializationContext context;
80   interface_control::internal::RunMessageParams_Data::BufferWriter params;
81   Serialize<interface_control::RunMessageParamsDataView>(
82       params_ptr, message.payload_buffer(), &params, &context);
83   std::unique_ptr<MessageReceiver> responder =
84       std::make_unique<RunResponseForwardToCallback>(callback);
85   ignore_result(receiver->AcceptWithResponder(&message, std::move(responder)));
86 }
87 
ConstructRunOrClosePipeMessage(interface_control::RunOrClosePipeInputPtr input_ptr)88 Message ConstructRunOrClosePipeMessage(
89     interface_control::RunOrClosePipeInputPtr input_ptr) {
90   auto params_ptr = interface_control::RunOrClosePipeMessageParams::New();
91   params_ptr->input = std::move(input_ptr);
92   Message message(interface_control::kRunOrClosePipeMessageId, 0, 0, 0,
93                   nullptr);
94   SerializationContext context;
95   interface_control::internal::RunOrClosePipeMessageParams_Data::BufferWriter
96       params;
97   Serialize<interface_control::RunOrClosePipeMessageParamsDataView>(
98       params_ptr, message.payload_buffer(), &params, &context);
99   return message;
100 }
101 
SendRunOrClosePipeMessage(MessageReceiverWithResponder * receiver,interface_control::RunOrClosePipeInputPtr input_ptr)102 void SendRunOrClosePipeMessage(
103     MessageReceiverWithResponder* receiver,
104     interface_control::RunOrClosePipeInputPtr input_ptr) {
105   Message message(ConstructRunOrClosePipeMessage(std::move(input_ptr)));
106   ignore_result(receiver->Accept(&message));
107 }
108 
RunVersionCallback(const base::Callback<void (uint32_t)> & callback,interface_control::RunResponseMessageParamsPtr run_response)109 void RunVersionCallback(
110     const base::Callback<void(uint32_t)>& callback,
111     interface_control::RunResponseMessageParamsPtr run_response) {
112   uint32_t version = 0u;
113   if (run_response->output && run_response->output->is_query_version_result())
114     version = run_response->output->get_query_version_result()->version;
115   callback.Run(version);
116 }
117 
RunClosure(const base::Closure & callback,interface_control::RunResponseMessageParamsPtr run_response)118 void RunClosure(const base::Closure& callback,
119                 interface_control::RunResponseMessageParamsPtr run_response) {
120   callback.Run();
121 }
122 
123 }  // namespace
124 
ControlMessageProxy(MessageReceiverWithResponder * receiver)125 ControlMessageProxy::ControlMessageProxy(MessageReceiverWithResponder* receiver)
126     : receiver_(receiver) {
127 }
128 
129 ControlMessageProxy::~ControlMessageProxy() = default;
130 
QueryVersion(const base::Callback<void (uint32_t)> & callback)131 void ControlMessageProxy::QueryVersion(
132     const base::Callback<void(uint32_t)>& callback) {
133   auto input_ptr = interface_control::RunInput::New();
134   input_ptr->set_query_version(interface_control::QueryVersion::New());
135   SendRunMessage(receiver_, std::move(input_ptr),
136                  base::Bind(&RunVersionCallback, callback));
137 }
138 
RequireVersion(uint32_t version)139 void ControlMessageProxy::RequireVersion(uint32_t version) {
140   auto require_version = interface_control::RequireVersion::New();
141   require_version->version = version;
142   auto input_ptr = interface_control::RunOrClosePipeInput::New();
143   input_ptr->set_require_version(std::move(require_version));
144   SendRunOrClosePipeMessage(receiver_, std::move(input_ptr));
145 }
146 
FlushForTesting()147 void ControlMessageProxy::FlushForTesting() {
148   if (encountered_error_)
149     return;
150 
151   auto input_ptr = interface_control::RunInput::New();
152   input_ptr->set_flush_for_testing(interface_control::FlushForTesting::New());
153   base::RunLoop run_loop(base::RunLoop::Type::kNestableTasksAllowed);
154   run_loop_quit_closure_ = run_loop.QuitClosure();
155   SendRunMessage(
156       receiver_, std::move(input_ptr),
157       base::Bind(&RunClosure,
158                  base::Bind(&ControlMessageProxy::RunFlushForTestingClosure,
159                             base::Unretained(this))));
160   run_loop.Run();
161 }
162 
RunFlushForTestingClosure()163 void ControlMessageProxy::RunFlushForTestingClosure() {
164   DCHECK(!run_loop_quit_closure_.is_null());
165   base::ResetAndReturn(&run_loop_quit_closure_).Run();
166 }
167 
OnConnectionError()168 void ControlMessageProxy::OnConnectionError() {
169   encountered_error_ = true;
170   if (!run_loop_quit_closure_.is_null())
171     RunFlushForTestingClosure();
172 }
173 
174 }  // namespace internal
175 }  // namespace mojo
176