1 // Copyright 2017 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "mojo/core/channel.h"
6 
7 #include "base/bind.h"
8 #include "base/memory/ptr_util.h"
9 #include "base/message_loop/message_loop.h"
10 #include "base/threading/thread.h"
11 #include "mojo/core/platform_handle_utils.h"
12 #include "mojo/public/cpp/platform/platform_channel.h"
13 #include "testing/gmock/include/gmock/gmock.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15 
16 namespace mojo {
17 namespace core {
18 namespace {
19 
20 class TestChannel : public Channel {
21  public:
TestChannel(Channel::Delegate * delegate)22   TestChannel(Channel::Delegate* delegate) : Channel(delegate) {}
23 
GetReadBufferTest(size_t * buffer_capacity)24   char* GetReadBufferTest(size_t* buffer_capacity) {
25     return GetReadBuffer(buffer_capacity);
26   }
27 
OnReadCompleteTest(size_t bytes_read,size_t * next_read_size_hint)28   bool OnReadCompleteTest(size_t bytes_read, size_t* next_read_size_hint) {
29     return OnReadComplete(bytes_read, next_read_size_hint);
30   }
31 
32   MOCK_METHOD7(GetReadPlatformHandles,
33                bool(const void* payload,
34                     size_t payload_size,
35                     size_t num_handles,
36                     const void* extra_header,
37                     size_t extra_header_size,
38                     std::vector<PlatformHandle>* handles,
39                     bool* deferred));
40   MOCK_METHOD0(Start, void());
41   MOCK_METHOD0(ShutDownImpl, void());
42   MOCK_METHOD0(LeakHandle, void());
43 
Write(MessagePtr message)44   void Write(MessagePtr message) override {}
45 
46  protected:
~TestChannel()47   ~TestChannel() override {}
48 };
49 
50 // Not using GMock as I don't think it supports movable types.
51 class MockChannelDelegate : public Channel::Delegate {
52  public:
MockChannelDelegate()53   MockChannelDelegate() {}
54 
GetReceivedPayloadSize() const55   size_t GetReceivedPayloadSize() const { return payload_size_; }
56 
GetReceivedPayload() const57   const void* GetReceivedPayload() const { return payload_.get(); }
58 
59  protected:
OnChannelMessage(const void * payload,size_t payload_size,std::vector<PlatformHandle> handles)60   void OnChannelMessage(const void* payload,
61                         size_t payload_size,
62                         std::vector<PlatformHandle> handles) override {
63     payload_.reset(new char[payload_size]);
64     memcpy(payload_.get(), payload, payload_size);
65     payload_size_ = payload_size;
66   }
67 
68   // Notify that an error has occured and the Channel will cease operation.
OnChannelError(Channel::Error error)69   void OnChannelError(Channel::Error error) override {}
70 
71  private:
72   size_t payload_size_ = 0;
73   std::unique_ptr<char[]> payload_;
74 };
75 
CreateDefaultMessage(bool legacy_message)76 Channel::MessagePtr CreateDefaultMessage(bool legacy_message) {
77   const size_t payload_size = 100;
78   Channel::MessagePtr message = std::make_unique<Channel::Message>(
79       payload_size, 0,
80       legacy_message ? Channel::Message::MessageType::NORMAL_LEGACY
81                      : Channel::Message::MessageType::NORMAL);
82   char* payload = static_cast<char*>(message->mutable_payload());
83   for (size_t i = 0; i < payload_size; i++) {
84     payload[i] = static_cast<char>(i);
85   }
86   return message;
87 }
88 
TestMemoryEqual(const void * data1,size_t data1_size,const void * data2,size_t data2_size)89 void TestMemoryEqual(const void* data1,
90                      size_t data1_size,
91                      const void* data2,
92                      size_t data2_size) {
93   ASSERT_EQ(data1_size, data2_size);
94   const unsigned char* data1_char = static_cast<const unsigned char*>(data1);
95   const unsigned char* data2_char = static_cast<const unsigned char*>(data2);
96   for (size_t i = 0; i < data1_size; i++) {
97     // ASSERT so we don't log tons of errors if the data is different.
98     ASSERT_EQ(data1_char[i], data2_char[i]);
99   }
100 }
101 
TestMessagesAreEqual(Channel::Message * message1,Channel::Message * message2,bool legacy_messages)102 void TestMessagesAreEqual(Channel::Message* message1,
103                           Channel::Message* message2,
104                           bool legacy_messages) {
105   // If any of the message is null, this is probably not what you wanted to
106   // test.
107   ASSERT_NE(nullptr, message1);
108   ASSERT_NE(nullptr, message2);
109 
110   ASSERT_EQ(message1->payload_size(), message2->payload_size());
111   EXPECT_EQ(message1->has_handles(), message2->has_handles());
112 
113   TestMemoryEqual(message1->payload(), message1->payload_size(),
114                   message2->payload(), message2->payload_size());
115 
116   if (legacy_messages)
117     return;
118 
119   ASSERT_EQ(message1->extra_header_size(), message2->extra_header_size());
120   TestMemoryEqual(message1->extra_header(), message1->extra_header_size(),
121                   message2->extra_header(), message2->extra_header_size());
122 }
123 
TEST(ChannelTest,LegacyMessageDeserialization)124 TEST(ChannelTest, LegacyMessageDeserialization) {
125   Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */);
126   Channel::MessagePtr deserialized_message =
127       Channel::Message::Deserialize(message->data(), message->data_num_bytes());
128   TestMessagesAreEqual(message.get(), deserialized_message.get(),
129                        true /* legacy_message */);
130 }
131 
TEST(ChannelTest,NonLegacyMessageDeserialization)132 TEST(ChannelTest, NonLegacyMessageDeserialization) {
133   Channel::MessagePtr message =
134       CreateDefaultMessage(false /* legacy_message */);
135   Channel::MessagePtr deserialized_message =
136       Channel::Message::Deserialize(message->data(), message->data_num_bytes());
137   TestMessagesAreEqual(message.get(), deserialized_message.get(),
138                        false /* legacy_message */);
139 }
140 
TEST(ChannelTest,OnReadLegacyMessage)141 TEST(ChannelTest, OnReadLegacyMessage) {
142   size_t buffer_size = 100 * 1024;
143   Channel::MessagePtr message = CreateDefaultMessage(true /* legacy_message */);
144 
145   MockChannelDelegate channel_delegate;
146   scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate);
147   char* read_buffer = channel->GetReadBufferTest(&buffer_size);
148   ASSERT_LT(message->data_num_bytes(),
149             buffer_size);  // Bad test. Increase buffer
150                            // size.
151   memcpy(read_buffer, message->data(), message->data_num_bytes());
152 
153   size_t next_read_size_hint = 0;
154   EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(),
155                                           &next_read_size_hint));
156 
157   TestMemoryEqual(message->payload(), message->payload_size(),
158                   channel_delegate.GetReceivedPayload(),
159                   channel_delegate.GetReceivedPayloadSize());
160 }
161 
TEST(ChannelTest,OnReadNonLegacyMessage)162 TEST(ChannelTest, OnReadNonLegacyMessage) {
163   size_t buffer_size = 100 * 1024;
164   Channel::MessagePtr message =
165       CreateDefaultMessage(false /* legacy_message */);
166 
167   MockChannelDelegate channel_delegate;
168   scoped_refptr<TestChannel> channel = new TestChannel(&channel_delegate);
169   char* read_buffer = channel->GetReadBufferTest(&buffer_size);
170   ASSERT_LT(message->data_num_bytes(),
171             buffer_size);  // Bad test. Increase buffer
172                            // size.
173   memcpy(read_buffer, message->data(), message->data_num_bytes());
174 
175   size_t next_read_size_hint = 0;
176   EXPECT_TRUE(channel->OnReadCompleteTest(message->data_num_bytes(),
177                                           &next_read_size_hint));
178 
179   TestMemoryEqual(message->payload(), message->payload_size(),
180                   channel_delegate.GetReceivedPayload(),
181                   channel_delegate.GetReceivedPayloadSize());
182 }
183 
184 class ChannelTestShutdownAndWriteDelegate : public Channel::Delegate {
185  public:
ChannelTestShutdownAndWriteDelegate(PlatformChannelEndpoint endpoint,scoped_refptr<base::TaskRunner> task_runner,scoped_refptr<Channel> client_channel,std::unique_ptr<base::Thread> client_thread,base::RepeatingClosure quit_closure)186   ChannelTestShutdownAndWriteDelegate(
187       PlatformChannelEndpoint endpoint,
188       scoped_refptr<base::TaskRunner> task_runner,
189       scoped_refptr<Channel> client_channel,
190       std::unique_ptr<base::Thread> client_thread,
191       base::RepeatingClosure quit_closure)
192       : quit_closure_(std::move(quit_closure)),
193         client_channel_(std::move(client_channel)),
194         client_thread_(std::move(client_thread)) {
195     channel_ = Channel::Create(this, ConnectionParams(std::move(endpoint)),
196                                std::move(task_runner));
197     channel_->Start();
198   }
~ChannelTestShutdownAndWriteDelegate()199   ~ChannelTestShutdownAndWriteDelegate() override { channel_->ShutDown(); }
200 
201   // Channel::Delegate implementation
OnChannelMessage(const void * payload,size_t payload_size,std::vector<PlatformHandle> handles)202   void OnChannelMessage(const void* payload,
203                         size_t payload_size,
204                         std::vector<PlatformHandle> handles) override {
205     ++message_count_;
206 
207     // If |client_channel_| exists then close it and its thread.
208     if (client_channel_) {
209       // Write a fresh message, making our channel readable again.
210       Channel::MessagePtr message = CreateDefaultMessage(false);
211       client_thread_->task_runner()->PostTask(
212           FROM_HERE, base::BindOnce(&Channel::Write, client_channel_,
213                                     base::Passed(&message)));
214 
215       // Close the channel and wait for it to shutdown.
216       client_channel_->ShutDown();
217       client_channel_ = nullptr;
218 
219       client_thread_->Stop();
220       client_thread_ = nullptr;
221     }
222 
223     // Write a message to the channel, to verify whether this triggers an
224     // OnChannelError callback before all messages were read.
225     Channel::MessagePtr message = CreateDefaultMessage(false);
226     channel_->Write(std::move(message));
227   }
228 
OnChannelError(Channel::Error error)229   void OnChannelError(Channel::Error error) override {
230     EXPECT_EQ(2, message_count_);
231     quit_closure_.Run();
232   }
233 
234   base::RepeatingClosure quit_closure_;
235   int message_count_ = 0;
236   scoped_refptr<Channel> channel_;
237 
238   scoped_refptr<Channel> client_channel_;
239   std::unique_ptr<base::Thread> client_thread_;
240 };
241 
TEST(ChannelTest,PeerShutdownDuringRead)242 TEST(ChannelTest, PeerShutdownDuringRead) {
243   base::MessageLoop message_loop(base::MessageLoop::TYPE_IO);
244   PlatformChannel channel;
245 
246   // Create a "client" Channel with one end of the pipe, and Start() it.
247   std::unique_ptr<base::Thread> client_thread =
248       std::make_unique<base::Thread>("clientio_thread");
249   client_thread->StartWithOptions(
250       base::Thread::Options(base::MessageLoop::TYPE_IO, 0));
251 
252   scoped_refptr<Channel> client_channel =
253       Channel::Create(nullptr, ConnectionParams(channel.TakeRemoteEndpoint()),
254                       client_thread->task_runner());
255   client_channel->Start();
256 
257   // On the "client" IO thread, create and write a message.
258   Channel::MessagePtr message = CreateDefaultMessage(false);
259   client_thread->task_runner()->PostTask(
260       FROM_HERE,
261       base::BindOnce(&Channel::Write, client_channel, base::Passed(&message)));
262 
263   // Create a "server" Channel with the other end of the pipe, and process the
264   // messages from it. The |server_delegate| will ShutDown the client end of
265   // the pipe after the first message, and quit the RunLoop when OnChannelError
266   // is received.
267   base::RunLoop run_loop;
268   ChannelTestShutdownAndWriteDelegate server_delegate(
269       channel.TakeLocalEndpoint(), message_loop.task_runner(),
270       std::move(client_channel), std::move(client_thread),
271       run_loop.QuitClosure());
272 
273   run_loop.Run();
274 }
275 
276 }  // namespace
277 }  // namespace core
278 }  // namespace mojo
279