1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/core/broker_host.h"
6 
7 #include <utility>
8 
9 #include "base/logging.h"
10 #include "base/memory/platform_shared_memory_region.h"
11 #include "base/memory/ref_counted.h"
12 #include "base/threading/thread_task_runner_handle.h"
13 #include "build/build_config.h"
14 #include "mojo/core/broker_messages.h"
15 #include "mojo/core/platform_handle_utils.h"
16 
17 #if defined(OS_WIN)
18 #include <windows.h>
19 #endif
20 
21 namespace mojo {
22 namespace core {
23 
BrokerHost(base::ProcessHandle client_process,ConnectionParams connection_params,const ProcessErrorCallback & process_error_callback)24 BrokerHost::BrokerHost(base::ProcessHandle client_process,
25                        ConnectionParams connection_params,
26                        const ProcessErrorCallback& process_error_callback)
27     : process_error_callback_(process_error_callback)
28 #if defined(OS_WIN)
29       ,
30       client_process_(ScopedProcessHandle::CloneFrom(client_process))
31 #endif
32 {
33   CHECK(connection_params.endpoint().is_valid() ||
34         connection_params.server_endpoint().is_valid());
35 
36   base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
37 
38   channel_ = Channel::Create(this, std::move(connection_params),
39                              base::ThreadTaskRunnerHandle::Get());
40   channel_->Start();
41 }
42 
~BrokerHost()43 BrokerHost::~BrokerHost() {
44   // We're always destroyed on the creation thread, which is the IO thread.
45   base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
46 
47   if (channel_)
48     channel_->ShutDown();
49 }
50 
PrepareHandlesForClient(std::vector<PlatformHandleInTransit> * handles)51 bool BrokerHost::PrepareHandlesForClient(
52     std::vector<PlatformHandleInTransit>* handles) {
53 #if defined(OS_WIN)
54   bool handles_ok = true;
55   for (auto& handle : *handles) {
56     if (!handle.TransferToProcess(client_process_.Clone()))
57       handles_ok = false;
58   }
59   return handles_ok;
60 #else
61   return true;
62 #endif
63 }
64 
SendChannel(PlatformHandle handle)65 bool BrokerHost::SendChannel(PlatformHandle handle) {
66   CHECK(handle.is_valid());
67   CHECK(channel_);
68 
69 #if defined(OS_WIN)
70   InitData* data;
71   Channel::MessagePtr message =
72       CreateBrokerMessage(BrokerMessageType::INIT, 1, 0, &data);
73   data->pipe_name_length = 0;
74 #else
75   Channel::MessagePtr message =
76       CreateBrokerMessage(BrokerMessageType::INIT, 1, nullptr);
77 #endif
78   std::vector<PlatformHandleInTransit> handles(1);
79   handles[0] = PlatformHandleInTransit(std::move(handle));
80 
81   // This may legitimately fail on Windows if the client process is in another
82   // session, e.g., is an elevated process.
83   if (!PrepareHandlesForClient(&handles))
84     return false;
85 
86   message->SetHandles(std::move(handles));
87   channel_->Write(std::move(message));
88   return true;
89 }
90 
91 #if defined(OS_WIN)
92 
SendNamedChannel(const base::StringPiece16 & pipe_name)93 void BrokerHost::SendNamedChannel(const base::StringPiece16& pipe_name) {
94   InitData* data;
95   base::char16* name_data;
96   Channel::MessagePtr message = CreateBrokerMessage(
97       BrokerMessageType::INIT, 0, sizeof(*name_data) * pipe_name.length(),
98       &data, reinterpret_cast<void**>(&name_data));
99   data->pipe_name_length = static_cast<uint32_t>(pipe_name.length());
100   std::copy(pipe_name.begin(), pipe_name.end(), name_data);
101   channel_->Write(std::move(message));
102 }
103 
104 #endif  // defined(OS_WIN)
105 
OnBufferRequest(uint32_t num_bytes)106 void BrokerHost::OnBufferRequest(uint32_t num_bytes) {
107   base::subtle::PlatformSharedMemoryRegion region =
108       base::subtle::PlatformSharedMemoryRegion::CreateWritable(num_bytes);
109 
110   std::vector<PlatformHandleInTransit> handles(2);
111   if (region.IsValid()) {
112     PlatformHandle h[2];
113     ExtractPlatformHandlesFromSharedMemoryRegionHandle(
114         region.PassPlatformHandle(), &h[0], &h[1]);
115     handles[0] = PlatformHandleInTransit(std::move(h[0]));
116     handles[1] = PlatformHandleInTransit(std::move(h[1]));
117 #if !defined(OS_POSIX) || defined(OS_ANDROID) || defined(OS_FUCHSIA) || \
118     (defined(OS_MACOSX) && !defined(OS_IOS))
119     // Non-POSIX systems, as well as Android, Fuchsia, and non-iOS Mac, only use
120     // a single handle to represent a writable region.
121     DCHECK(!handles[1].handle().is_valid());
122     handles.resize(1);
123 #else
124     DCHECK(handles[1].handle().is_valid());
125 #endif
126   }
127 
128   BufferResponseData* response;
129   Channel::MessagePtr message = CreateBrokerMessage(
130       BrokerMessageType::BUFFER_RESPONSE, handles.size(), 0, &response);
131   if (!handles.empty()) {
132     base::UnguessableToken guid = region.GetGUID();
133     response->guid_high = guid.GetHighForSerialization();
134     response->guid_low = guid.GetLowForSerialization();
135     PrepareHandlesForClient(&handles);
136     message->SetHandles(std::move(handles));
137   }
138 
139   channel_->Write(std::move(message));
140 }
141 
OnChannelMessage(const void * payload,size_t payload_size,std::vector<PlatformHandle> handles)142 void BrokerHost::OnChannelMessage(const void* payload,
143                                   size_t payload_size,
144                                   std::vector<PlatformHandle> handles) {
145   if (payload_size < sizeof(BrokerMessageHeader))
146     return;
147 
148   const BrokerMessageHeader* header =
149       static_cast<const BrokerMessageHeader*>(payload);
150   switch (header->type) {
151     case BrokerMessageType::BUFFER_REQUEST:
152       if (payload_size ==
153           sizeof(BrokerMessageHeader) + sizeof(BufferRequestData)) {
154         const BufferRequestData* request =
155             reinterpret_cast<const BufferRequestData*>(header + 1);
156         OnBufferRequest(request->size);
157       }
158       break;
159 
160     default:
161       DLOG(ERROR) << "Unexpected broker message type: " << header->type;
162       break;
163   }
164 }
165 
OnChannelError(Channel::Error error)166 void BrokerHost::OnChannelError(Channel::Error error) {
167   if (process_error_callback_ &&
168       error == Channel::Error::kReceivedMalformedData) {
169     process_error_callback_.Run("Broker host received malformed message");
170   }
171 
172   delete this;
173 }
174 
WillDestroyCurrentMessageLoop()175 void BrokerHost::WillDestroyCurrentMessageLoop() {
176   delete this;
177 }
178 
179 }  // namespace core
180 }  // namespace mojo
181