1 // Copyright 2016 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include "base/bind.h"
6 #include "base/callback.h"
7 #include "base/macros.h"
8 #include "base/run_loop.h"
9 #include "mojo/core/embedder/embedder.h"
10 #include "mojo/public/cpp/bindings/binding.h"
11 #include "mojo/public/cpp/bindings/message.h"
12 #include "mojo/public/cpp/bindings/tests/bindings_test_base.h"
13 #include "mojo/public/interfaces/bindings/tests/test_bad_messages.mojom.h"
14 #include "testing/gtest/include/gtest/gtest.h"
15 
16 namespace mojo {
17 namespace test {
18 namespace {
19 
20 class TestBadMessagesImpl : public TestBadMessages {
21  public:
TestBadMessagesImpl()22   TestBadMessagesImpl() : binding_(this) {}
~TestBadMessagesImpl()23   ~TestBadMessagesImpl() override {}
24 
BindImpl(TestBadMessagesRequest request)25   void BindImpl(TestBadMessagesRequest request) {
26     binding_.Bind(std::move(request));
27   }
28 
bad_message_callback()29   ReportBadMessageCallback& bad_message_callback() {
30     return bad_message_callback_;
31   }
32 
33  private:
34   // TestBadMessages:
RejectEventually(const RejectEventuallyCallback & callback)35   void RejectEventually(const RejectEventuallyCallback& callback) override {
36     bad_message_callback_ = GetBadMessageCallback();
37     callback.Run();
38   }
39 
RequestResponse(const RequestResponseCallback & callback)40   void RequestResponse(const RequestResponseCallback& callback) override {
41     callback.Run();
42   }
43 
RejectSync(const RejectSyncCallback & callback)44   void RejectSync(const RejectSyncCallback& callback) override {
45     callback.Run();
46     ReportBadMessage("go away");
47   }
48 
RequestResponseSync(const RequestResponseSyncCallback & callback)49   void RequestResponseSync(
50       const RequestResponseSyncCallback& callback) override {
51     callback.Run();
52   }
53 
54   ReportBadMessageCallback bad_message_callback_;
55   mojo::Binding<TestBadMessages> binding_;
56 
57   DISALLOW_COPY_AND_ASSIGN(TestBadMessagesImpl);
58 };
59 
60 class ReportBadMessageTest : public BindingsTestBase {
61  public:
ReportBadMessageTest()62   ReportBadMessageTest() {}
63 
SetUp()64   void SetUp() override {
65     mojo::core::SetDefaultProcessErrorCallback(base::Bind(
66         &ReportBadMessageTest::OnProcessError, base::Unretained(this)));
67 
68     impl_.BindImpl(MakeRequest(&proxy_));
69   }
70 
TearDown()71   void TearDown() override {
72     mojo::core::SetDefaultProcessErrorCallback(
73         mojo::core::ProcessErrorCallback());
74   }
75 
proxy()76   TestBadMessages* proxy() { return proxy_.get(); }
77 
impl()78   TestBadMessagesImpl* impl() { return &impl_; }
79 
SetErrorHandler(const base::Closure & handler)80   void SetErrorHandler(const base::Closure& handler) {
81     error_handler_ = handler;
82   }
83 
84  private:
OnProcessError(const std::string & error)85   void OnProcessError(const std::string& error) {
86     if (!error_handler_.is_null())
87       error_handler_.Run();
88   }
89 
90   TestBadMessagesPtr proxy_;
91   TestBadMessagesImpl impl_;
92   base::Closure error_handler_;
93 };
94 
TEST_P(ReportBadMessageTest,Request)95 TEST_P(ReportBadMessageTest, Request) {
96   // Verify that basic immediate error reporting works.
97   bool error = false;
98   SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
99   EXPECT_TRUE(proxy()->RejectSync());
100   EXPECT_TRUE(error);
101 }
102 
TEST_P(ReportBadMessageTest,RequestAsync)103 TEST_P(ReportBadMessageTest, RequestAsync) {
104   bool error = false;
105   SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
106 
107   // This should capture a bad message reporting callback in the impl.
108   base::RunLoop loop;
109   proxy()->RejectEventually(loop.QuitClosure());
110   loop.Run();
111 
112   EXPECT_FALSE(error);
113 
114   // Now we can run the callback and it should trigger a bad message report.
115   DCHECK(!impl()->bad_message_callback().is_null());
116   std::move(impl()->bad_message_callback()).Run("bad!");
117   EXPECT_TRUE(error);
118 }
119 
TEST_P(ReportBadMessageTest,Response)120 TEST_P(ReportBadMessageTest, Response) {
121   bool error = false;
122   SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
123 
124   base::RunLoop loop;
125   proxy()->RequestResponse(
126       base::Bind([] (const base::Closure& quit) {
127         // Report a bad message inside the response callback. This should
128         // trigger the error handler.
129         ReportBadMessage("no way!");
130         quit.Run();
131       },
132       loop.QuitClosure()));
133   loop.Run();
134 
135   EXPECT_TRUE(error);
136 }
137 
TEST_P(ReportBadMessageTest,ResponseAsync)138 TEST_P(ReportBadMessageTest, ResponseAsync) {
139   bool error = false;
140   SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
141 
142   ReportBadMessageCallback bad_message_callback;
143   base::RunLoop loop;
144   proxy()->RequestResponse(
145       base::Bind([] (const base::Closure& quit,
146                      ReportBadMessageCallback* callback) {
147         // Capture the bad message callback inside the response callback.
148         *callback = GetBadMessageCallback();
149         quit.Run();
150       },
151       loop.QuitClosure(), &bad_message_callback));
152   loop.Run();
153 
154   EXPECT_FALSE(error);
155 
156   // Invoking this callback should report a bad message and trigger the error
157   // handler immediately.
158   std::move(bad_message_callback)
159       .Run("this message is bad and should feel bad");
160   EXPECT_TRUE(error);
161 }
162 
TEST_P(ReportBadMessageTest,ResponseSync)163 TEST_P(ReportBadMessageTest, ResponseSync) {
164   bool error = false;
165   SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
166 
167   SyncMessageResponseContext context;
168   proxy()->RequestResponseSync();
169 
170   EXPECT_FALSE(error);
171   context.ReportBadMessage("i don't like this response");
172   EXPECT_TRUE(error);
173 }
174 
TEST_P(ReportBadMessageTest,ResponseSyncDeferred)175 TEST_P(ReportBadMessageTest, ResponseSyncDeferred) {
176   bool error = false;
177   SetErrorHandler(base::Bind([] (bool* flag) { *flag = true; }, &error));
178 
179   ReportBadMessageCallback bad_message_callback;
180   {
181     SyncMessageResponseContext context;
182     proxy()->RequestResponseSync();
183     bad_message_callback = context.GetBadMessageCallback();
184   }
185 
186   EXPECT_FALSE(error);
187   std::move(bad_message_callback).Run("nope nope nope");
188   EXPECT_TRUE(error);
189 }
190 
191 INSTANTIATE_MOJO_BINDINGS_TEST_CASE_P(ReportBadMessageTest);
192 
193 }  // namespace
194 }  // namespace test
195 }  // namespace mojo
196