1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/public/cpp/bindings/message.h"
6 
7 #include <stddef.h>
8 #include <stdint.h>
9 #include <stdlib.h>
10 
11 #include <algorithm>
12 #include <utility>
13 
14 #include "base/bind.h"
15 #include "base/lazy_instance.h"
16 #include "base/logging.h"
17 #include "base/numerics/safe_math.h"
18 #include "base/strings/stringprintf.h"
19 #include "base/threading/thread_local.h"
20 #include "mojo/public/cpp/bindings/associated_group_controller.h"
21 #include "mojo/public/cpp/bindings/lib/array_internal.h"
22 #include "mojo/public/cpp/bindings/lib/unserialized_message_context.h"
23 
24 namespace mojo {
25 
26 namespace {
27 
28 base::LazyInstance<base::ThreadLocalPointer<internal::MessageDispatchContext>>::
29     Leaky g_tls_message_dispatch_context = LAZY_INSTANCE_INITIALIZER;
30 
31 base::LazyInstance<base::ThreadLocalPointer<SyncMessageResponseContext>>::Leaky
32     g_tls_sync_response_context = LAZY_INSTANCE_INITIALIZER;
33 
DoNotifyBadMessage(Message message,const std::string & error)34 void DoNotifyBadMessage(Message message, const std::string& error) {
35   message.NotifyBadMessage(error);
36 }
37 
38 template <typename HeaderType>
AllocateHeaderFromBuffer(internal::Buffer * buffer,HeaderType ** header)39 void AllocateHeaderFromBuffer(internal::Buffer* buffer, HeaderType** header) {
40   *header = buffer->AllocateAndGet<HeaderType>();
41   (*header)->num_bytes = sizeof(HeaderType);
42 }
43 
WriteMessageHeader(uint32_t name,uint32_t flags,size_t payload_interface_id_count,internal::Buffer * payload_buffer)44 void WriteMessageHeader(uint32_t name,
45                         uint32_t flags,
46                         size_t payload_interface_id_count,
47                         internal::Buffer* payload_buffer) {
48   if (payload_interface_id_count > 0) {
49     // Version 2
50     internal::MessageHeaderV2* header;
51     AllocateHeaderFromBuffer(payload_buffer, &header);
52     header->version = 2;
53     header->name = name;
54     header->flags = flags;
55     // The payload immediately follows the header.
56     header->payload.Set(header + 1);
57   } else if (flags &
58              (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
59     // Version 1
60     internal::MessageHeaderV1* header;
61     AllocateHeaderFromBuffer(payload_buffer, &header);
62     header->version = 1;
63     header->name = name;
64     header->flags = flags;
65   } else {
66     internal::MessageHeader* header;
67     AllocateHeaderFromBuffer(payload_buffer, &header);
68     header->version = 0;
69     header->name = name;
70     header->flags = flags;
71   }
72 }
73 
CreateSerializedMessageObject(uint32_t name,uint32_t flags,size_t payload_size,size_t payload_interface_id_count,std::vector<ScopedHandle> * handles,ScopedMessageHandle * out_handle,internal::Buffer * out_buffer)74 void CreateSerializedMessageObject(uint32_t name,
75                                    uint32_t flags,
76                                    size_t payload_size,
77                                    size_t payload_interface_id_count,
78                                    std::vector<ScopedHandle>* handles,
79                                    ScopedMessageHandle* out_handle,
80                                    internal::Buffer* out_buffer) {
81   ScopedMessageHandle handle;
82   MojoResult rv = mojo::CreateMessage(&handle);
83   DCHECK_EQ(MOJO_RESULT_OK, rv);
84   DCHECK(handle.is_valid());
85 
86   void* buffer;
87   uint32_t buffer_size;
88   size_t total_size = internal::ComputeSerializedMessageSize(
89       flags, payload_size, payload_interface_id_count);
90   DCHECK(base::IsValueInRangeForNumericType<uint32_t>(total_size));
91   DCHECK(!handles ||
92          base::IsValueInRangeForNumericType<uint32_t>(handles->size()));
93   rv = MojoAppendMessageData(
94       handle->value(), static_cast<uint32_t>(total_size),
95       handles ? reinterpret_cast<MojoHandle*>(handles->data()) : nullptr,
96       handles ? static_cast<uint32_t>(handles->size()) : 0, nullptr, &buffer,
97       &buffer_size);
98   DCHECK_EQ(MOJO_RESULT_OK, rv);
99   if (handles) {
100     // Handle ownership has been taken by MojoAppendMessageData.
101     for (size_t i = 0; i < handles->size(); ++i)
102       ignore_result(handles->at(i).release());
103   }
104 
105   internal::Buffer payload_buffer(handle.get(), total_size, buffer,
106                                   buffer_size);
107 
108   // Make sure we zero the memory first!
109   memset(payload_buffer.data(), 0, total_size);
110   WriteMessageHeader(name, flags, payload_interface_id_count, &payload_buffer);
111 
112   *out_handle = std::move(handle);
113   *out_buffer = std::move(payload_buffer);
114 }
115 
SerializeUnserializedContext(MojoMessageHandle message,uintptr_t context_value)116 void SerializeUnserializedContext(MojoMessageHandle message,
117                                   uintptr_t context_value) {
118   auto* context =
119       reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
120   void* buffer;
121   uint32_t buffer_size;
122   MojoResult attach_result = MojoAppendMessageData(
123       message, 0, nullptr, 0, nullptr, &buffer, &buffer_size);
124   if (attach_result != MOJO_RESULT_OK)
125     return;
126 
127   internal::Buffer payload_buffer(MessageHandle(message), 0, buffer,
128                                   buffer_size);
129   WriteMessageHeader(context->message_name(), context->message_flags(),
130                      0 /* payload_interface_id_count */, &payload_buffer);
131 
132   // We need to copy additional header data which may have been set after
133   // message construction, as this codepath may be reached at some arbitrary
134   // time between message send and message dispatch.
135   static_cast<internal::MessageHeader*>(buffer)->interface_id =
136       context->header()->interface_id;
137   if (context->header()->flags &
138       (Message::kFlagExpectsResponse | Message::kFlagIsResponse)) {
139     DCHECK_GE(context->header()->version, 1u);
140     static_cast<internal::MessageHeaderV1*>(buffer)->request_id =
141         context->header()->request_id;
142   }
143 
144   internal::SerializationContext serialization_context;
145   context->Serialize(&serialization_context, &payload_buffer);
146 
147   // TODO(crbug.com/753433): Support lazy serialization of associated endpoint
148   // handles. See corresponding TODO in the bindings generator for proof that
149   // this DCHECK is indeed valid.
150   DCHECK(serialization_context.associated_endpoint_handles()->empty());
151   if (!serialization_context.handles()->empty())
152     payload_buffer.AttachHandles(serialization_context.mutable_handles());
153   payload_buffer.Seal();
154 }
155 
DestroyUnserializedContext(uintptr_t context)156 void DestroyUnserializedContext(uintptr_t context) {
157   delete reinterpret_cast<internal::UnserializedMessageContext*>(context);
158 }
159 
CreateUnserializedMessageObject(std::unique_ptr<internal::UnserializedMessageContext> context)160 ScopedMessageHandle CreateUnserializedMessageObject(
161     std::unique_ptr<internal::UnserializedMessageContext> context) {
162   ScopedMessageHandle handle;
163   MojoResult rv = mojo::CreateMessage(&handle);
164   DCHECK_EQ(MOJO_RESULT_OK, rv);
165   DCHECK(handle.is_valid());
166 
167   rv = MojoSetMessageContext(
168       handle->value(), reinterpret_cast<uintptr_t>(context.release()),
169       &SerializeUnserializedContext, &DestroyUnserializedContext, nullptr);
170   DCHECK_EQ(MOJO_RESULT_OK, rv);
171   return handle;
172 }
173 
174 }  // namespace
175 
176 Message::Message() = default;
177 
Message(Message && other)178 Message::Message(Message&& other)
179     : handle_(std::move(other.handle_)),
180       payload_buffer_(std::move(other.payload_buffer_)),
181       handles_(std::move(other.handles_)),
182       associated_endpoint_handles_(
183           std::move(other.associated_endpoint_handles_)),
184       transferable_(other.transferable_),
185       serialized_(other.serialized_) {
186   other.transferable_ = false;
187   other.serialized_ = false;
188 #if defined(ENABLE_IPC_FUZZER)
189   interface_name_ = other.interface_name_;
190   method_name_ = other.method_name_;
191 #endif
192 }
193 
Message(std::unique_ptr<internal::UnserializedMessageContext> context)194 Message::Message(std::unique_ptr<internal::UnserializedMessageContext> context)
195     : Message(CreateUnserializedMessageObject(std::move(context))) {}
196 
Message(uint32_t name,uint32_t flags,size_t payload_size,size_t payload_interface_id_count,std::vector<ScopedHandle> * handles)197 Message::Message(uint32_t name,
198                  uint32_t flags,
199                  size_t payload_size,
200                  size_t payload_interface_id_count,
201                  std::vector<ScopedHandle>* handles) {
202   CreateSerializedMessageObject(name, flags, payload_size,
203                                 payload_interface_id_count, handles, &handle_,
204                                 &payload_buffer_);
205   transferable_ = true;
206   serialized_ = true;
207 }
208 
Message(ScopedMessageHandle handle)209 Message::Message(ScopedMessageHandle handle) {
210   DCHECK(handle.is_valid());
211 
212   uintptr_t context_value = 0;
213   MojoResult get_context_result =
214       MojoGetMessageContext(handle->value(), nullptr, &context_value);
215   if (get_context_result == MOJO_RESULT_NOT_FOUND) {
216     // It's a serialized message. Extract handles if possible.
217     uint32_t num_bytes;
218     void* buffer;
219     uint32_t num_handles = 0;
220     MojoResult rv = MojoGetMessageData(handle->value(), nullptr, &buffer,
221                                        &num_bytes, nullptr, &num_handles);
222     if (rv == MOJO_RESULT_RESOURCE_EXHAUSTED) {
223       handles_.resize(num_handles);
224       rv = MojoGetMessageData(handle->value(), nullptr, &buffer, &num_bytes,
225                               reinterpret_cast<MojoHandle*>(handles_.data()),
226                               &num_handles);
227     } else {
228       // No handles, so it's safe to retransmit this message if the caller
229       // really wants to.
230       transferable_ = true;
231     }
232 
233     if (rv != MOJO_RESULT_OK) {
234       // Failed to deserialize handles. Leave the Message uninitialized.
235       return;
236     }
237 
238     payload_buffer_ = internal::Buffer(buffer, num_bytes, num_bytes);
239     serialized_ = true;
240   } else {
241     DCHECK_EQ(MOJO_RESULT_OK, get_context_result);
242     auto* context =
243         reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
244     // Dummy data address so common header accessors still behave properly. The
245     // choice is V1 reflects unserialized message capabilities: we may or may
246     // not need to support request IDs (which require at least V1), but we never
247     // (for now, anyway) need to support associated interface handles (V2).
248     payload_buffer_ =
249         internal::Buffer(context->header(), sizeof(internal::MessageHeaderV1),
250                          sizeof(internal::MessageHeaderV1));
251     transferable_ = true;
252     serialized_ = false;
253   }
254 
255   handle_ = std::move(handle);
256 }
257 
258 Message::~Message() = default;
259 
operator =(Message && other)260 Message& Message::operator=(Message&& other) {
261   handle_ = std::move(other.handle_);
262   payload_buffer_ = std::move(other.payload_buffer_);
263   handles_ = std::move(other.handles_);
264   associated_endpoint_handles_ = std::move(other.associated_endpoint_handles_);
265   transferable_ = other.transferable_;
266   other.transferable_ = false;
267   serialized_ = other.serialized_;
268   other.serialized_ = false;
269 #if defined(ENABLE_IPC_FUZZER)
270   interface_name_ = other.interface_name_;
271   method_name_ = other.method_name_;
272 #endif
273   return *this;
274 }
275 
Reset()276 void Message::Reset() {
277   handle_.reset();
278   payload_buffer_.Reset();
279   handles_.clear();
280   associated_endpoint_handles_.clear();
281   transferable_ = false;
282   serialized_ = false;
283 }
284 
payload() const285 const uint8_t* Message::payload() const {
286   if (version() < 2)
287     return data() + header()->num_bytes;
288 
289   DCHECK(!header_v2()->payload.is_null());
290   return static_cast<const uint8_t*>(header_v2()->payload.Get());
291 }
292 
payload_num_bytes() const293 uint32_t Message::payload_num_bytes() const {
294   DCHECK_GE(data_num_bytes(), header()->num_bytes);
295   size_t num_bytes;
296   if (version() < 2) {
297     num_bytes = data_num_bytes() - header()->num_bytes;
298   } else {
299     auto payload_begin =
300         reinterpret_cast<uintptr_t>(header_v2()->payload.Get());
301     auto payload_end =
302         reinterpret_cast<uintptr_t>(header_v2()->payload_interface_ids.Get());
303     if (!payload_end)
304       payload_end = reinterpret_cast<uintptr_t>(data() + data_num_bytes());
305     DCHECK_GE(payload_end, payload_begin);
306     num_bytes = payload_end - payload_begin;
307   }
308   DCHECK(base::IsValueInRangeForNumericType<uint32_t>(num_bytes));
309   return static_cast<uint32_t>(num_bytes);
310 }
311 
payload_num_interface_ids() const312 uint32_t Message::payload_num_interface_ids() const {
313   auto* array_pointer =
314       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
315   return array_pointer ? static_cast<uint32_t>(array_pointer->size()) : 0;
316 }
317 
payload_interface_ids() const318 const uint32_t* Message::payload_interface_ids() const {
319   auto* array_pointer =
320       version() < 2 ? nullptr : header_v2()->payload_interface_ids.Get();
321   return array_pointer ? array_pointer->storage() : nullptr;
322 }
323 
AttachHandlesFromSerializationContext(internal::SerializationContext * context)324 void Message::AttachHandlesFromSerializationContext(
325     internal::SerializationContext* context) {
326   if (context->handles()->empty() &&
327       context->associated_endpoint_handles()->empty()) {
328     // No handles attached, so no extra serialization work.
329     return;
330   }
331 
332   if (context->associated_endpoint_handles()->empty()) {
333     // Attaching only non-associated handles is easier since we don't have to
334     // modify the message header. Faster path for that.
335     payload_buffer_.AttachHandles(context->mutable_handles());
336     return;
337   }
338 
339   // Allocate a new message with enough space to hold all attached handles. Copy
340   // this message's contents into the new one and use it to replace ourself.
341   //
342   // TODO(rockot): We could avoid the extra full message allocation by instead
343   // growing the buffer and carefully moving its contents around. This errs on
344   // the side of less complexity with probably only marginal performance cost.
345   uint32_t payload_size = payload_num_bytes();
346   mojo::Message new_message(name(), header()->flags, payload_size,
347                             context->associated_endpoint_handles()->size(),
348                             context->mutable_handles());
349   std::swap(*context->mutable_associated_endpoint_handles(),
350             new_message.associated_endpoint_handles_);
351   memcpy(new_message.payload_buffer()->AllocateAndGet(payload_size), payload(),
352          payload_size);
353   *this = std::move(new_message);
354 }
355 
TakeMojoMessage()356 ScopedMessageHandle Message::TakeMojoMessage() {
357   // If there are associated endpoints transferred,
358   // SerializeAssociatedEndpointHandles() must be called before this method.
359   DCHECK(associated_endpoint_handles_.empty());
360   DCHECK(transferable_);
361   payload_buffer_.Seal();
362   auto handle = std::move(handle_);
363   Reset();
364   return handle;
365 }
366 
NotifyBadMessage(const std::string & error)367 void Message::NotifyBadMessage(const std::string& error) {
368   DCHECK(handle_.is_valid());
369   mojo::NotifyBadMessage(handle_.get(), error);
370 }
371 
SerializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)372 void Message::SerializeAssociatedEndpointHandles(
373     AssociatedGroupController* group_controller) {
374   if (associated_endpoint_handles_.empty())
375     return;
376 
377   DCHECK_GE(version(), 2u);
378   DCHECK(header_v2()->payload_interface_ids.is_null());
379   DCHECK(payload_buffer_.is_valid());
380   DCHECK(handle_.is_valid());
381 
382   size_t size = associated_endpoint_handles_.size();
383 
384   internal::Array_Data<uint32_t>::BufferWriter handle_writer;
385   handle_writer.Allocate(size, &payload_buffer_);
386   header_v2()->payload_interface_ids.Set(handle_writer.data());
387 
388   for (size_t i = 0; i < size; ++i) {
389     ScopedInterfaceEndpointHandle& handle = associated_endpoint_handles_[i];
390 
391     DCHECK(handle.pending_association());
392     handle_writer->storage()[i] =
393         group_controller->AssociateInterface(std::move(handle));
394   }
395   associated_endpoint_handles_.clear();
396 }
397 
DeserializeAssociatedEndpointHandles(AssociatedGroupController * group_controller)398 bool Message::DeserializeAssociatedEndpointHandles(
399     AssociatedGroupController* group_controller) {
400   if (!serialized_)
401     return true;
402 
403   associated_endpoint_handles_.clear();
404 
405   uint32_t num_ids = payload_num_interface_ids();
406   if (num_ids == 0)
407     return true;
408 
409   associated_endpoint_handles_.reserve(num_ids);
410   uint32_t* ids = header_v2()->payload_interface_ids.Get()->storage();
411   bool result = true;
412   for (uint32_t i = 0; i < num_ids; ++i) {
413     auto handle = group_controller->CreateLocalEndpointHandle(ids[i]);
414     if (IsValidInterfaceId(ids[i]) && !handle.is_valid()) {
415       // |ids[i]| itself is valid but handle creation failed. In that case, mark
416       // deserialization as failed but continue to deserialize the rest of
417       // handles.
418       result = false;
419     }
420 
421     associated_endpoint_handles_.push_back(std::move(handle));
422     ids[i] = kInvalidInterfaceId;
423   }
424   return result;
425 }
426 
SerializeIfNecessary()427 void Message::SerializeIfNecessary() {
428   MojoResult rv = MojoSerializeMessage(handle_->value(), nullptr);
429   if (rv == MOJO_RESULT_FAILED_PRECONDITION)
430     return;
431 
432   // Reconstruct this Message instance from the serialized message's handle.
433   *this = Message(std::move(handle_));
434 }
435 
436 std::unique_ptr<internal::UnserializedMessageContext>
TakeUnserializedContext(const internal::UnserializedMessageContext::Tag * tag)437 Message::TakeUnserializedContext(
438     const internal::UnserializedMessageContext::Tag* tag) {
439   DCHECK(handle_.is_valid());
440   uintptr_t context_value = 0;
441   MojoResult rv =
442       MojoGetMessageContext(handle_->value(), nullptr, &context_value);
443   if (rv == MOJO_RESULT_NOT_FOUND)
444     return nullptr;
445   DCHECK_EQ(MOJO_RESULT_OK, rv);
446 
447   auto* context =
448       reinterpret_cast<internal::UnserializedMessageContext*>(context_value);
449   if (context->tag() != tag)
450     return nullptr;
451 
452   // Detach the context from the message.
453   rv = MojoSetMessageContext(handle_->value(), 0, nullptr, nullptr, nullptr);
454   DCHECK_EQ(MOJO_RESULT_OK, rv);
455   return base::WrapUnique(context);
456 }
457 
PrefersSerializedMessages()458 bool MessageReceiver::PrefersSerializedMessages() {
459   return false;
460 }
461 
PassThroughFilter()462 PassThroughFilter::PassThroughFilter() {}
463 
~PassThroughFilter()464 PassThroughFilter::~PassThroughFilter() {}
465 
Accept(Message * message)466 bool PassThroughFilter::Accept(Message* message) {
467   return true;
468 }
469 
SyncMessageResponseContext()470 SyncMessageResponseContext::SyncMessageResponseContext()
471     : outer_context_(current()) {
472   g_tls_sync_response_context.Get().Set(this);
473 }
474 
~SyncMessageResponseContext()475 SyncMessageResponseContext::~SyncMessageResponseContext() {
476   DCHECK_EQ(current(), this);
477   g_tls_sync_response_context.Get().Set(outer_context_);
478 }
479 
480 // static
current()481 SyncMessageResponseContext* SyncMessageResponseContext::current() {
482   return g_tls_sync_response_context.Get().Get();
483 }
484 
ReportBadMessage(const std::string & error)485 void SyncMessageResponseContext::ReportBadMessage(const std::string& error) {
486   GetBadMessageCallback().Run(error);
487 }
488 
GetBadMessageCallback()489 ReportBadMessageCallback SyncMessageResponseContext::GetBadMessageCallback() {
490   DCHECK(!response_.IsNull());
491   return base::BindOnce(&DoNotifyBadMessage, std::move(response_));
492 }
493 
ReadMessage(MessagePipeHandle handle,Message * message)494 MojoResult ReadMessage(MessagePipeHandle handle, Message* message) {
495   ScopedMessageHandle message_handle;
496   MojoResult rv =
497       ReadMessageNew(handle, &message_handle, MOJO_READ_MESSAGE_FLAG_NONE);
498   if (rv != MOJO_RESULT_OK)
499     return rv;
500 
501   *message = Message(std::move(message_handle));
502   return MOJO_RESULT_OK;
503 }
504 
ReportBadMessage(const std::string & error)505 void ReportBadMessage(const std::string& error) {
506   internal::MessageDispatchContext* context =
507       internal::MessageDispatchContext::current();
508   DCHECK(context);
509   context->GetBadMessageCallback().Run(error);
510 }
511 
GetBadMessageCallback()512 ReportBadMessageCallback GetBadMessageCallback() {
513   internal::MessageDispatchContext* context =
514       internal::MessageDispatchContext::current();
515   DCHECK(context);
516   return context->GetBadMessageCallback();
517 }
518 
519 namespace internal {
520 
521 MessageHeaderV2::MessageHeaderV2() = default;
522 
MessageDispatchContext(Message * message)523 MessageDispatchContext::MessageDispatchContext(Message* message)
524     : outer_context_(current()), message_(message) {
525   g_tls_message_dispatch_context.Get().Set(this);
526 }
527 
~MessageDispatchContext()528 MessageDispatchContext::~MessageDispatchContext() {
529   DCHECK_EQ(current(), this);
530   g_tls_message_dispatch_context.Get().Set(outer_context_);
531 }
532 
533 // static
current()534 MessageDispatchContext* MessageDispatchContext::current() {
535   return g_tls_message_dispatch_context.Get().Get();
536 }
537 
GetBadMessageCallback()538 ReportBadMessageCallback MessageDispatchContext::GetBadMessageCallback() {
539   DCHECK(!message_->IsNull());
540   return base::BindOnce(&DoNotifyBadMessage, std::move(*message_));
541 }
542 
543 // static
SetCurrentSyncResponseMessage(Message * message)544 void SyncMessageResponseSetup::SetCurrentSyncResponseMessage(Message* message) {
545   SyncMessageResponseContext* context = SyncMessageResponseContext::current();
546   if (context)
547     context->response_ = std::move(*message);
548 }
549 
550 }  // namespace internal
551 
552 }  // namespace mojo
553