1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "base/bind.h"
6 #include "base/callback.h"
7 #include "base/macros.h"
8 #include "base/run_loop.h"
9 #include "mojo/core/embedder/embedder.h"
10 #include "mojo/public/cpp/bindings/binding.h"
11 #include "mojo/public/cpp/bindings/message.h"
12 #include "mojo/public/cpp/bindings/tests/bindings_test_base.h"
13 #include "mojo/public/interfaces/bindings/tests/test_bad_messages.mojom.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15
16 namespace mojo {
17 namespace test {
18 namespace {
19
20 class TestBadMessagesImpl : public TestBadMessages {
21 public:
TestBadMessagesImpl()22 TestBadMessagesImpl() : binding_(this) {}
~TestBadMessagesImpl()23 ~TestBadMessagesImpl() override {}
24
BindImpl(TestBadMessagesRequest request)25 void BindImpl(TestBadMessagesRequest request) {
26 binding_.Bind(std::move(request));
27 }
28
bad_message_callback()29 ReportBadMessageCallback& bad_message_callback() {
30 return bad_message_callback_;
31 }
32
33 private:
34 // TestBadMessages:
RejectEventually(const RejectEventuallyCallback & callback)35 void RejectEventually(const RejectEventuallyCallback& callback) override {
36 bad_message_callback_ = GetBadMessageCallback();
37 callback.Run();
38 }
39
RequestResponse(const RequestResponseCallback & callback)40 void RequestResponse(const RequestResponseCallback& callback) override {
41 callback.Run();
42 }
43
RejectSync(const RejectSyncCallback & callback)44 void RejectSync(const RejectSyncCallback& callback) override {
45 callback.Run();
46 ReportBadMessage("go away");
47 }
48
RequestResponseSync(const RequestResponseSyncCallback & callback)49 void RequestResponseSync(
50 const RequestResponseSyncCallback& callback) override {
51 callback.Run();
52 }
53
54 ReportBadMessageCallback bad_message_callback_;
55 mojo::Binding<TestBadMessages> binding_;
56
57 DISALLOW_COPY_AND_ASSIGN(TestBadMessagesImpl);
58 };
59
60 class ReportBadMessageTest : public BindingsTestBase {
61 public:
ReportBadMessageTest()62 ReportBadMessageTest() {}
63
SetUp()64 void SetUp() override {
65 mojo::core::SetDefaultProcessErrorCallback(base::Bind(
66 &ReportBadMessageTest::OnProcessError, base::Unretained(this)));
67
68 impl_.BindImpl(MakeRequest(&proxy_));
69 }
70
TearDown()71 void TearDown() override {
72 mojo::core::SetDefaultProcessErrorCallback(
73 mojo::core::ProcessErrorCallback());
74 }
75
proxy()76 TestBadMessages* proxy() { return proxy_.get(); }
77
impl()78 TestBadMessagesImpl* impl() { return &impl_; }
79
SetErrorHandler(const base::Closure & handler)80 void SetErrorHandler(const base::Closure& handler) {
81 error_handler_ = handler;
82 }
83
84 private:
OnProcessError(const std::string & error)85 void OnProcessError(const std::string& error) {
86 if (!error_handler_.is_null())
87 error_handler_.Run();
88 }
89
90 TestBadMessagesPtr proxy_;
91 TestBadMessagesImpl impl_;
92 base::Closure error_handler_;
93 };
94
TEST_P(ReportBadMessageTest,Request)95 TEST_P(ReportBadMessageTest, Request) {
96 // Verify that basic immediate error reporting works.
97 bool error = false;
98 SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
99 EXPECT_TRUE(proxy()->RejectSync());
100 EXPECT_TRUE(error);
101 }
102
TEST_P(ReportBadMessageTest,RequestAsync)103 TEST_P(ReportBadMessageTest, RequestAsync) {
104 bool error = false;
105 SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
106
107 // This should capture a bad message reporting callback in the impl.
108 base::RunLoop loop;
109 proxy()->RejectEventually(loop.QuitClosure());
110 loop.Run();
111
112 EXPECT_FALSE(error);
113
114 // Now we can run the callback and it should trigger a bad message report.
115 DCHECK(!impl()->bad_message_callback().is_null());
116 std::move(impl()->bad_message_callback()).Run("bad!");
117 EXPECT_TRUE(error);
118 }
119
TEST_P(ReportBadMessageTest,Response)120 TEST_P(ReportBadMessageTest, Response) {
121 bool error = false;
122 SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
123
124 base::RunLoop loop;
125 proxy()->RequestResponse(
126 base::Bind([] (const base::Closure& quit) {
127 // Report a bad message inside the response callback. This should
128 // trigger the error handler.
129 ReportBadMessage("no way!");
130 quit.Run();
131 },
132 loop.QuitClosure()));
133 loop.Run();
134
135 EXPECT_TRUE(error);
136 }
137
TEST_P(ReportBadMessageTest,ResponseAsync)138 TEST_P(ReportBadMessageTest, ResponseAsync) {
139 bool error = false;
140 SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
141
142 ReportBadMessageCallback bad_message_callback;
143 base::RunLoop loop;
144 proxy()->RequestResponse(
145 base::Bind([] (const base::Closure& quit,
146 ReportBadMessageCallback* callback) {
147 // Capture the bad message callback inside the response callback.
148 *callback = GetBadMessageCallback();
149 quit.Run();
150 },
151 loop.QuitClosure(), &bad_message_callback));
152 loop.Run();
153
154 EXPECT_FALSE(error);
155
156 // Invoking this callback should report a bad message and trigger the error
157 // handler immediately.
158 std::move(bad_message_callback)
159 .Run("this message is bad and should feel bad");
160 EXPECT_TRUE(error);
161 }
162
TEST_P(ReportBadMessageTest,ResponseSync)163 TEST_P(ReportBadMessageTest, ResponseSync) {
164 bool error = false;
165 SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
166
167 SyncMessageResponseContext context;
168 proxy()->RequestResponseSync();
169
170 EXPECT_FALSE(error);
171 context.ReportBadMessage("i don't like this response");
172 EXPECT_TRUE(error);
173 }
174
TEST_P(ReportBadMessageTest,ResponseSyncDeferred)175 TEST_P(ReportBadMessageTest, ResponseSyncDeferred) {
176 bool error = false;
177 SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
178
179 ReportBadMessageCallback bad_message_callback;
180 {
181 SyncMessageResponseContext context;
182 proxy()->RequestResponseSync();
183 bad_message_callback = context.GetBadMessageCallback();
184 }
185
186 EXPECT_FALSE(error);
187 std::move(bad_message_callback).Run("nope nope nope");
188 EXPECT_TRUE(error);
189 }
190
191 INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(ReportBadMessageTest);
192
193 } // namespace
194 } // namespace test
195 } // namespace mojo
196