1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <windows.h>
6 
7 #include <limits>
8 #include <utility>
9 
10 #include "base/debug/alias.h"
11 #include "base/memory/platform_shared_memory_region.h"
12 #include "base/numerics/safe_conversions.h"
13 #include "base/strings/string_piece.h"
14 #include "mojo/core/broker.h"
15 #include "mojo/core/broker_messages.h"
16 #include "mojo/core/channel.h"
17 #include "mojo/core/platform_handle_utils.h"
18 #include "mojo/public/cpp/platform/named_platform_channel.h"
19 
20 namespace mojo {
21 namespace core {
22 
23 namespace {
24 
25 // 256 bytes should be enough for anyone!
26 const size_t kMaxBrokerMessageSize = 256;
27 
TakeHandlesFromBrokerMessage(Channel::Message * message,size_t num_handles,PlatformHandle * out_handles)28 bool TakeHandlesFromBrokerMessage(Channel::Message* message,
29                                   size_t num_handles,
30                                   PlatformHandle* out_handles) {
31   if (message->num_handles() != num_handles) {
32     DLOG(ERROR) << "Received unexpected number of handles in broker message";
33     return false;
34   }
35 
36   std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
37   DCHECK_EQ(handles.size(), num_handles);
38   DCHECK(out_handles);
39 
40   for (size_t i = 0; i < num_handles; ++i)
41     out_handles[i] = handles[i].TakeHandle();
42   return true;
43 }
44 
WaitForBrokerMessage(HANDLE pipe_handle,BrokerMessageType expected_type)45 Channel::MessagePtr WaitForBrokerMessage(HANDLE pipe_handle,
46                                          BrokerMessageType expected_type) {
47   char buffer[kMaxBrokerMessageSize];
48   DWORD bytes_read = 0;
49   BOOL result = ::ReadFile(pipe_handle, buffer, kMaxBrokerMessageSize,
50                            &bytes_read, nullptr);
51   if (!result) {
52     // The pipe may be broken if the browser side has been closed, e.g. during
53     // browser shutdown. In that case the ReadFile call will fail and we
54     // shouldn't continue waiting.
55     PLOG(ERROR) << "Error reading broker pipe";
56     return nullptr;
57   }
58 
59   Channel::MessagePtr message =
60       Channel::Message::Deserialize(buffer, static_cast<size_t>(bytes_read));
61   if (!message || message->payload_size() < sizeof(BrokerMessageHeader)) {
62     LOG(ERROR) << "Invalid broker message";
63 
64     base::debug::Alias(&buffer[0]);
65     base::debug::Alias(&bytes_read);
66     CHECK(false);
67     return nullptr;
68   }
69 
70   const BrokerMessageHeader* header =
71       reinterpret_cast<const BrokerMessageHeader*>(message->payload());
72   if (header->type != expected_type) {
73     LOG(ERROR) << "Unexpected broker message type";
74 
75     base::debug::Alias(&buffer[0]);
76     base::debug::Alias(&bytes_read);
77     CHECK(false);
78     return nullptr;
79   }
80 
81   return message;
82 }
83 
84 }  // namespace
85 
Broker(PlatformHandle handle)86 Broker::Broker(PlatformHandle handle) : sync_channel_(std::move(handle)) {
87   CHECK(sync_channel_.is_valid());
88   Channel::MessagePtr message = WaitForBrokerMessage(
89       sync_channel_.GetHandle().Get(), BrokerMessageType::INIT);
90 
91   // If we fail to read a message (broken pipe), just return early. The inviter
92   // handle will be null and callers must handle this gracefully.
93   if (!message)
94     return;
95 
96   PlatformHandle endpoint_handle;
97   if (TakeHandlesFromBrokerMessage(message.get(), 1, &endpoint_handle)) {
98     inviter_endpoint_ = PlatformChannelEndpoint(std::move(endpoint_handle));
99   } else {
100     // If the message has no handles, we expect it to carry pipe name instead.
101     const BrokerMessageHeader* header =
102         static_cast<const BrokerMessageHeader*>(message->payload());
103     CHECK_GE(message->payload_size(),
104              sizeof(BrokerMessageHeader) + sizeof(InitData));
105     const InitData* data = reinterpret_cast<const InitData*>(header + 1);
106     CHECK_EQ(message->payload_size(),
107              sizeof(BrokerMessageHeader) + sizeof(InitData) +
108                  data->pipe_name_length * sizeof(base::char16));
109     const base::char16* name_data =
110         reinterpret_cast<const base::char16*>(data + 1);
111     CHECK(data->pipe_name_length);
112     inviter_endpoint_ = NamedPlatformChannel::ConnectToServer(
113         base::StringPiece16(name_data, data->pipe_name_length).as_string());
114   }
115 }
116 
~Broker()117 Broker::~Broker() {}
118 
GetInviterEndpoint()119 PlatformChannelEndpoint Broker::GetInviterEndpoint() {
120   return std::move(inviter_endpoint_);
121 }
122 
GetWritableSharedMemoryRegion(size_t num_bytes)123 base::WritableSharedMemoryRegion Broker::GetWritableSharedMemoryRegion(
124     size_t num_bytes) {
125   base::AutoLock lock(lock_);
126   BufferRequestData* buffer_request;
127   Channel::MessagePtr out_message = CreateBrokerMessage(
128       BrokerMessageType::BUFFER_REQUEST, 0, 0, &buffer_request);
129   buffer_request->size = base::checked_cast<uint32_t>(num_bytes);
130   DWORD bytes_written = 0;
131   BOOL result =
132       ::WriteFile(sync_channel_.GetHandle().Get(), out_message->data(),
133                   static_cast<DWORD>(out_message->data_num_bytes()),
134                   &bytes_written, nullptr);
135   if (!result ||
136       static_cast<size_t>(bytes_written) != out_message->data_num_bytes()) {
137     PLOG(ERROR) << "Error sending sync broker message";
138     return base::WritableSharedMemoryRegion();
139   }
140 
141   PlatformHandle handle;
142   Channel::MessagePtr response = WaitForBrokerMessage(
143       sync_channel_.GetHandle().Get(), BrokerMessageType::BUFFER_RESPONSE);
144   if (response && TakeHandlesFromBrokerMessage(response.get(), 1, &handle)) {
145     BufferResponseData* data;
146     if (!GetBrokerMessageData(response.get(), &data))
147       return base::WritableSharedMemoryRegion();
148     return base::WritableSharedMemoryRegion::Deserialize(
149         base::subtle::PlatformSharedMemoryRegion::Take(
150             CreateSharedMemoryRegionHandleFromPlatformHandles(std::move(handle),
151                                                               PlatformHandle()),
152             base::subtle::PlatformSharedMemoryRegion::Mode::kWritable,
153             num_bytes,
154             base::UnguessableToken::Deserialize(data->guid_high,
155                                                 data->guid_low)));
156   }
157 
158   return base::WritableSharedMemoryRegion();
159 }
160 
161 }  // namespace core
162 }  // namespace mojo
163