1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "mojo/core/channel.h"
6
7 #include <stdint.h>
8 #include <windows.h>
9
10 #include <algorithm>
11 #include <limits>
12 #include <memory>
13
14 #include "base/bind.h"
15 #include "base/containers/queue.h"
16 #include "base/location.h"
17 #include "base/macros.h"
18 #include "base/memory/ref_counted.h"
19 #include "base/message_loop/message_loop_current.h"
20 #include "base/message_loop/message_pump_for_io.h"
21 #include "base/process/process_handle.h"
22 #include "base/synchronization/lock.h"
23 #include "base/task_runner.h"
24 #include "base/win/scoped_handle.h"
25 #include "base/win/win_util.h"
26
27 namespace mojo {
28 namespace core {
29
30 namespace {
31
32 class ChannelWin : public Channel,
33 public base::MessageLoopCurrent::DestructionObserver,
34 public base::MessagePumpForIO::IOHandler {
35 public:
ChannelWin(Delegate * delegate,ConnectionParams connection_params,scoped_refptr<base::TaskRunner> io_task_runner)36 ChannelWin(Delegate* delegate,
37 ConnectionParams connection_params,
38 scoped_refptr<base::TaskRunner> io_task_runner)
39 : Channel(delegate), self_(this), io_task_runner_(io_task_runner) {
40 if (connection_params.server_endpoint().is_valid()) {
41 handle_ = connection_params.TakeServerEndpoint()
42 .TakePlatformHandle()
43 .TakeHandle();
44 needs_connection_ = true;
45 } else {
46 handle_ =
47 connection_params.TakeEndpoint().TakePlatformHandle().TakeHandle();
48 }
49
50 CHECK(handle_.IsValid());
51 }
52
Start()53 void Start() override {
54 io_task_runner_->PostTask(
55 FROM_HERE, base::BindOnce(&ChannelWin::StartOnIOThread, this));
56 }
57
ShutDownImpl()58 void ShutDownImpl() override {
59 // Always shut down asynchronously when called through the public interface.
60 io_task_runner_->PostTask(
61 FROM_HERE, base::BindOnce(&ChannelWin::ShutDownOnIOThread, this));
62 }
63
Write(MessagePtr message)64 void Write(MessagePtr message) override {
65 if (remote_process().is_valid()) {
66 // If we know the remote process handle, we transfer all outgoing handles
67 // to the process now rewriting them in the message.
68 std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
69 for (auto& handle : handles) {
70 if (handle.handle().is_valid())
71 handle.TransferToProcess(remote_process().Clone());
72 }
73 message->SetHandles(std::move(handles));
74 }
75
76 bool write_error = false;
77 {
78 base::AutoLock lock(write_lock_);
79 if (reject_writes_)
80 return;
81
82 bool write_now = !delay_writes_ && outgoing_messages_.empty();
83 outgoing_messages_.emplace_back(std::move(message));
84 if (write_now && !WriteNoLock(outgoing_messages_.front()))
85 reject_writes_ = write_error = true;
86 }
87 if (write_error) {
88 // Do not synchronously invoke OnWriteError(). Write() may have been
89 // called by the delegate and we don't want to re-enter it.
90 io_task_runner_->PostTask(FROM_HERE,
91 base::BindOnce(&ChannelWin::OnWriteError, this,
92 Error::kDisconnected));
93 }
94 }
95
LeakHandle()96 void LeakHandle() override {
97 DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
98 leak_handle_ = true;
99 }
100
GetReadPlatformHandles(const void * payload,size_t payload_size,size_t num_handles,const void * extra_header,size_t extra_header_size,std::vector<PlatformHandle> * handles,bool * deferred)101 bool GetReadPlatformHandles(const void* payload,
102 size_t payload_size,
103 size_t num_handles,
104 const void* extra_header,
105 size_t extra_header_size,
106 std::vector<PlatformHandle>* handles,
107 bool* deferred) override {
108 DCHECK(extra_header);
109 if (num_handles > std::numeric_limits<uint16_t>::max())
110 return false;
111 using HandleEntry = Channel::Message::HandleEntry;
112 size_t handles_size = sizeof(HandleEntry) * num_handles;
113 if (handles_size > extra_header_size)
114 return false;
115 handles->reserve(num_handles);
116 const HandleEntry* extra_header_handles =
117 reinterpret_cast<const HandleEntry*>(extra_header);
118 for (size_t i = 0; i < num_handles; i++) {
119 HANDLE handle_value =
120 base::win::Uint32ToHandle(extra_header_handles[i].handle);
121 if (remote_process().is_valid()) {
122 // If we know the remote process's handle, we assume it doesn't know
123 // ours; that means any handle values still belong to that process, and
124 // we need to transfer them to this process.
125 handle_value = PlatformHandleInTransit::TakeIncomingRemoteHandle(
126 handle_value, remote_process().get())
127 .ReleaseHandle();
128 }
129 handles->emplace_back(base::win::ScopedHandle(std::move(handle_value)));
130 }
131 return true;
132 }
133
134 private:
135 // May run on any thread.
~ChannelWin()136 ~ChannelWin() override {}
137
StartOnIOThread()138 void StartOnIOThread() {
139 base::MessageLoopCurrent::Get()->AddDestructionObserver(this);
140 base::MessageLoopCurrentForIO::Get()->RegisterIOHandler(handle_.Get(),
141 this);
142
143 if (needs_connection_) {
144 BOOL ok = ::ConnectNamedPipe(handle_.Get(), &connect_context_.overlapped);
145 if (ok) {
146 PLOG(ERROR) << "Unexpected success while waiting for pipe connection";
147 OnError(Error::kConnectionFailed);
148 return;
149 }
150
151 const DWORD err = GetLastError();
152 switch (err) {
153 case ERROR_PIPE_CONNECTED:
154 break;
155 case ERROR_IO_PENDING:
156 is_connect_pending_ = true;
157 AddRef();
158 return;
159 case ERROR_NO_DATA:
160 default:
161 OnError(Error::kConnectionFailed);
162 return;
163 }
164 }
165
166 // Now that we have registered our IOHandler, we can start writing.
167 {
168 base::AutoLock lock(write_lock_);
169 if (delay_writes_) {
170 delay_writes_ = false;
171 WriteNextNoLock();
172 }
173 }
174
175 // Keep this alive in case we synchronously run shutdown, via OnError(),
176 // as a result of a ReadFile() failure on the channel.
177 scoped_refptr<ChannelWin> keep_alive(this);
178 ReadMore(0);
179 }
180
ShutDownOnIOThread()181 void ShutDownOnIOThread() {
182 base::MessageLoopCurrent::Get()->RemoveDestructionObserver(this);
183
184 // TODO(https://crbug.com/583525): This function is expected to be called
185 // once, and |handle_| should be valid at this point.
186 CHECK(handle_.IsValid());
187 CancelIo(handle_.Get());
188 if (leak_handle_)
189 ignore_result(handle_.Take());
190 else
191 handle_.Close();
192
193 // Allow |this| to be destroyed as soon as no IO is pending.
194 self_ = nullptr;
195 }
196
197 // base::MessageLoopCurrent::DestructionObserver:
WillDestroyCurrentMessageLoop()198 void WillDestroyCurrentMessageLoop() override {
199 DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
200 if (self_)
201 ShutDownOnIOThread();
202 }
203
204 // base::MessageLoop::IOHandler:
OnIOCompleted(base::MessagePumpForIO::IOContext * context,DWORD bytes_transfered,DWORD error)205 void OnIOCompleted(base::MessagePumpForIO::IOContext* context,
206 DWORD bytes_transfered,
207 DWORD error) override {
208 if (error != ERROR_SUCCESS) {
209 if (context == &write_context_) {
210 {
211 base::AutoLock lock(write_lock_);
212 reject_writes_ = true;
213 }
214 OnWriteError(Error::kDisconnected);
215 } else {
216 OnError(Error::kDisconnected);
217 }
218 } else if (context == &connect_context_) {
219 DCHECK(is_connect_pending_);
220 is_connect_pending_ = false;
221 ReadMore(0);
222
223 base::AutoLock lock(write_lock_);
224 if (delay_writes_) {
225 delay_writes_ = false;
226 WriteNextNoLock();
227 }
228 } else if (context == &read_context_) {
229 OnReadDone(static_cast<size_t>(bytes_transfered));
230 } else {
231 CHECK(context == &write_context_);
232 OnWriteDone(static_cast<size_t>(bytes_transfered));
233 }
234 Release();
235 }
236
OnReadDone(size_t bytes_read)237 void OnReadDone(size_t bytes_read) {
238 DCHECK(is_read_pending_);
239 is_read_pending_ = false;
240
241 if (bytes_read > 0) {
242 size_t next_read_size = 0;
243 if (OnReadComplete(bytes_read, &next_read_size)) {
244 ReadMore(next_read_size);
245 } else {
246 OnError(Error::kReceivedMalformedData);
247 }
248 } else if (bytes_read == 0) {
249 OnError(Error::kDisconnected);
250 }
251 }
252
OnWriteDone(size_t bytes_written)253 void OnWriteDone(size_t bytes_written) {
254 if (bytes_written == 0)
255 return;
256
257 bool write_error = false;
258 {
259 base::AutoLock lock(write_lock_);
260
261 DCHECK(is_write_pending_);
262 is_write_pending_ = false;
263 DCHECK(!outgoing_messages_.empty());
264
265 Channel::MessagePtr message = std::move(outgoing_messages_.front());
266 outgoing_messages_.pop_front();
267
268 // Invalidate all the scoped handles so we don't attempt to close them.
269 std::vector<PlatformHandleInTransit> handles = message->TakeHandles();
270 for (auto& handle : handles)
271 handle.CompleteTransit();
272
273 // Overlapped WriteFile() to a pipe should always fully complete.
274 if (message->data_num_bytes() != bytes_written)
275 reject_writes_ = write_error = true;
276 else if (!WriteNextNoLock())
277 reject_writes_ = write_error = true;
278 }
279 if (write_error)
280 OnWriteError(Error::kDisconnected);
281 }
282
ReadMore(size_t next_read_size_hint)283 void ReadMore(size_t next_read_size_hint) {
284 DCHECK(!is_read_pending_);
285
286 size_t buffer_capacity = next_read_size_hint;
287 char* buffer = GetReadBuffer(&buffer_capacity);
288 DCHECK_GT(buffer_capacity, 0u);
289
290 BOOL ok =
291 ::ReadFile(handle_.Get(), buffer, static_cast<DWORD>(buffer_capacity),
292 NULL, &read_context_.overlapped);
293 if (ok || GetLastError() == ERROR_IO_PENDING) {
294 is_read_pending_ = true;
295 AddRef();
296 } else {
297 OnError(Error::kDisconnected);
298 }
299 }
300
301 // Attempts to write a message directly to the channel. If the full message
302 // cannot be written, it's queued and a wait is initiated to write the message
303 // ASAP on the I/O thread.
WriteNoLock(const Channel::MessagePtr & message)304 bool WriteNoLock(const Channel::MessagePtr& message) {
305 BOOL ok = WriteFile(handle_.Get(), message->data(),
306 static_cast<DWORD>(message->data_num_bytes()), NULL,
307 &write_context_.overlapped);
308 if (ok || GetLastError() == ERROR_IO_PENDING) {
309 is_write_pending_ = true;
310 AddRef();
311 return true;
312 }
313 return false;
314 }
315
WriteNextNoLock()316 bool WriteNextNoLock() {
317 if (outgoing_messages_.empty())
318 return true;
319 return WriteNoLock(outgoing_messages_.front());
320 }
321
OnWriteError(Error error)322 void OnWriteError(Error error) {
323 DCHECK(io_task_runner_->RunsTasksInCurrentSequence());
324 DCHECK(reject_writes_);
325
326 if (error == Error::kDisconnected) {
327 // If we can't write because the pipe is disconnected then continue
328 // reading to fetch any in-flight messages, relying on end-of-stream to
329 // signal the actual disconnection.
330 if (is_read_pending_ || is_connect_pending_)
331 return;
332 }
333
334 OnError(error);
335 }
336
337 // Keeps the Channel alive at least until explicit shutdown on the IO thread.
338 scoped_refptr<Channel> self_;
339
340 // The pipe handle this Channel uses for communication.
341 base::win::ScopedHandle handle_;
342
343 // Indicates whether |handle_| must wait for a connection.
344 bool needs_connection_ = false;
345
346 const scoped_refptr<base::TaskRunner> io_task_runner_;
347
348 base::MessagePumpForIO::IOContext connect_context_;
349 base::MessagePumpForIO::IOContext read_context_;
350 bool is_connect_pending_ = false;
351 bool is_read_pending_ = false;
352
353 // Protects all fields potentially accessed on multiple threads via Write().
354 base::Lock write_lock_;
355 base::MessagePumpForIO::IOContext write_context_;
356 base::circular_deque<Channel::MessagePtr> outgoing_messages_;
357 bool delay_writes_ = true;
358 bool reject_writes_ = false;
359 bool is_write_pending_ = false;
360
361 bool leak_handle_ = false;
362
363 DISALLOW_COPY_AND_ASSIGN(ChannelWin);
364 };
365
366 } // namespace
367
368 // static
Create(Delegate * delegate,ConnectionParams connection_params,scoped_refptr<base::TaskRunner> io_task_runner)369 scoped_refptr<Channel> Channel::Create(
370 Delegate* delegate,
371 ConnectionParams connection_params,
372 scoped_refptr<base::TaskRunner> io_task_runner) {
373 return new ChannelWin(delegate, std::move(connection_params), io_task_runner);
374 }
375
376 } // namespace core
377 } // namespace mojo
378