1 // Copyright 2014 The Chromium OS Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #include <brillo/dbus/dbus_method_invoker.h>
6 
7 #include <string>
8 
9 #include <brillo/bind_lambda.h>
10 #include <dbus/mock_bus.h>
11 #include <dbus/mock_object_proxy.h>
12 #include <dbus/scoped_dbus_error.h>
13 #include <gmock/gmock.h>
14 #include <gtest/gtest.h>
15 
16 #include "brillo/dbus/test.pb.h"
17 
18 using testing::AnyNumber;
19 using testing::InSequence;
20 using testing::Invoke;
21 using testing::Return;
22 using testing::_;
23 
24 using dbus::MessageReader;
25 using dbus::MessageWriter;
26 using dbus::Response;
27 
28 namespace brillo {
29 namespace dbus_utils {
30 
31 const char kTestPath[] = "/test/path";
32 const char kTestServiceName[] = "org.test.Object";
33 const char kTestInterface[] = "org.test.Object.TestInterface";
34 const char kTestMethod1[] = "TestMethod1";
35 const char kTestMethod2[] = "TestMethod2";
36 const char kTestMethod3[] = "TestMethod3";
37 const char kTestMethod4[] = "TestMethod4";
38 
39 class DBusMethodInvokerTest : public testing::Test {
40  public:
SetUp()41   void SetUp() override {
42     dbus::Bus::Options options;
43     options.bus_type = dbus::Bus::SYSTEM;
44     bus_ = new dbus::MockBus(options);
45     // By default, don't worry about threading assertions.
46     EXPECT_CALL(*bus_, AssertOnOriginThread()).Times(AnyNumber());
47     EXPECT_CALL(*bus_, AssertOnDBusThread()).Times(AnyNumber());
48     // Use a mock exported object.
49     mock_object_proxy_ = new dbus::MockObjectProxy(
50         bus_.get(), kTestServiceName, dbus::ObjectPath(kTestPath));
51     EXPECT_CALL(*bus_,
52                 GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath)))
53         .WillRepeatedly(Return(mock_object_proxy_.get()));
54     int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT;
55     EXPECT_CALL(*mock_object_proxy_,
56                 MockCallMethodAndBlockWithErrorDetails(_, def_timeout_ms, _))
57         .WillRepeatedly(Invoke(this, &DBusMethodInvokerTest::CreateResponse));
58   }
59 
TearDown()60   void TearDown() override { bus_ = nullptr; }
61 
CreateResponse(dbus::MethodCall * method_call,int,dbus::ScopedDBusError * dbus_error)62   Response* CreateResponse(dbus::MethodCall* method_call,
63                            int /* timeout_ms */,
64                            dbus::ScopedDBusError* dbus_error) {
65     if (method_call->GetInterface() == kTestInterface) {
66       if (method_call->GetMember() == kTestMethod1) {
67         MessageReader reader(method_call);
68         int v1, v2;
69         // Input: two ints.
70         // Output: sum of the ints converted to string.
71         if (reader.PopInt32(&v1) && reader.PopInt32(&v2)) {
72           auto response = Response::CreateEmpty();
73           MessageWriter writer(response.get());
74           writer.AppendString(std::to_string(v1 + v2));
75           return response.release();
76         }
77       } else if (method_call->GetMember() == kTestMethod2) {
78         method_call->SetSerial(123);
79         dbus_set_error(dbus_error->get(), "org.MyError", "My error message");
80         return nullptr;
81       } else if (method_call->GetMember() == kTestMethod3) {
82         MessageReader reader(method_call);
83         dbus_utils_test::TestMessage msg;
84         if (PopValueFromReader(&reader, &msg)) {
85           auto response = Response::CreateEmpty();
86           MessageWriter writer(response.get());
87           AppendValueToWriter(&writer, msg);
88           return response.release();
89         }
90       } else if (method_call->GetMember() == kTestMethod4) {
91         method_call->SetSerial(123);
92         MessageReader reader(method_call);
93         dbus::FileDescriptor fd;
94         if (reader.PopFileDescriptor(&fd)) {
95           auto response = Response::CreateEmpty();
96           MessageWriter writer(response.get());
97           fd.CheckValidity();
98           writer.AppendFileDescriptor(fd);
99           return response.release();
100         }
101       }
102     }
103 
104     LOG(ERROR) << "Unexpected method call: " << method_call->ToString();
105     return nullptr;
106   }
107 
CallTestMethod(int v1,int v2)108   std::string CallTestMethod(int v1, int v2) {
109     std::unique_ptr<dbus::Response> response =
110         brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
111                                                kTestInterface, kTestMethod1,
112                                                nullptr, v1, v2);
113     EXPECT_NE(nullptr, response.get());
114     std::string result;
115     using brillo::dbus_utils::ExtractMethodCallResults;
116     EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &result));
117     return result;
118   }
119 
CallProtobufTestMethod(const dbus_utils_test::TestMessage & message)120   dbus_utils_test::TestMessage CallProtobufTestMethod(
121       const dbus_utils_test::TestMessage& message) {
122     std::unique_ptr<dbus::Response> response =
123         brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
124                                                kTestInterface, kTestMethod3,
125                                                nullptr, message);
126     EXPECT_NE(nullptr, response.get());
127     dbus_utils_test::TestMessage result;
128     using brillo::dbus_utils::ExtractMethodCallResults;
129     EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &result));
130     return result;
131   }
132 
133   // Sends a file descriptor received over D-Bus back to the caller.
EchoFD(const dbus::FileDescriptor & fd_in)134   dbus::FileDescriptor EchoFD(const dbus::FileDescriptor& fd_in) {
135     std::unique_ptr<dbus::Response> response =
136         brillo::dbus_utils::CallMethodAndBlock(mock_object_proxy_.get(),
137                                                kTestInterface, kTestMethod4,
138                                                nullptr, fd_in);
139     EXPECT_NE(nullptr, response.get());
140     dbus::FileDescriptor fd_out;
141     using brillo::dbus_utils::ExtractMethodCallResults;
142     EXPECT_TRUE(ExtractMethodCallResults(response.get(), nullptr, &fd_out));
143     return fd_out;
144   }
145 
146   scoped_refptr<dbus::MockBus> bus_;
147   scoped_refptr<dbus::MockObjectProxy> mock_object_proxy_;
148 };
149 
TEST_F(DBusMethodInvokerTest,TestSuccess)150 TEST_F(DBusMethodInvokerTest, TestSuccess) {
151   EXPECT_EQ("4", CallTestMethod(2, 2));
152   EXPECT_EQ("10", CallTestMethod(3, 7));
153   EXPECT_EQ("-4", CallTestMethod(13, -17));
154 }
155 
TEST_F(DBusMethodInvokerTest,TestFailure)156 TEST_F(DBusMethodInvokerTest, TestFailure) {
157   brillo::ErrorPtr error;
158   std::unique_ptr<dbus::Response> response =
159       brillo::dbus_utils::CallMethodAndBlock(
160           mock_object_proxy_.get(), kTestInterface, kTestMethod2, &error);
161   EXPECT_EQ(nullptr, response.get());
162   EXPECT_EQ(brillo::errors::dbus::kDomain, error->GetDomain());
163   EXPECT_EQ("org.MyError", error->GetCode());
164   EXPECT_EQ("My error message", error->GetMessage());
165 }
166 
TEST_F(DBusMethodInvokerTest,TestProtobuf)167 TEST_F(DBusMethodInvokerTest, TestProtobuf) {
168   dbus_utils_test::TestMessage test_message;
169   test_message.set_foo(123);
170   test_message.set_bar("bar");
171 
172   dbus_utils_test::TestMessage resp = CallProtobufTestMethod(test_message);
173 
174   EXPECT_EQ(123, resp.foo());
175   EXPECT_EQ("bar", resp.bar());
176 }
177 
TEST_F(DBusMethodInvokerTest,TestFileDescriptors)178 TEST_F(DBusMethodInvokerTest, TestFileDescriptors) {
179   // Passing a file descriptor over D-Bus would effectively duplicate the fd.
180   // So the resulting file descriptor value would be different but it still
181   // should be valid.
182   dbus::FileDescriptor fd_stdin(0);
183   fd_stdin.CheckValidity();
184   EXPECT_NE(fd_stdin.value(), EchoFD(fd_stdin).value());
185   dbus::FileDescriptor fd_stdout(1);
186   fd_stdout.CheckValidity();
187   EXPECT_NE(fd_stdout.value(), EchoFD(fd_stdout).value());
188   dbus::FileDescriptor fd_stderr(2);
189   fd_stderr.CheckValidity();
190   EXPECT_NE(fd_stderr.value(), EchoFD(fd_stderr).value());
191 }
192 
193 //////////////////////////////////////////////////////////////////////////////
194 // Asynchronous method invocation support
195 
196 class AsyncDBusMethodInvokerTest : public testing::Test {
197  public:
SetUp()198   void SetUp() override {
199     dbus::Bus::Options options;
200     options.bus_type = dbus::Bus::SYSTEM;
201     bus_ = new dbus::MockBus(options);
202     // By default, don't worry about threading assertions.
203     EXPECT_CALL(*bus_, AssertOnOriginThread()).Times(AnyNumber());
204     EXPECT_CALL(*bus_, AssertOnDBusThread()).Times(AnyNumber());
205     // Use a mock exported object.
206     mock_object_proxy_ = new dbus::MockObjectProxy(
207         bus_.get(), kTestServiceName, dbus::ObjectPath(kTestPath));
208     EXPECT_CALL(*bus_,
209                 GetObjectProxy(kTestServiceName, dbus::ObjectPath(kTestPath)))
210         .WillRepeatedly(Return(mock_object_proxy_.get()));
211     int def_timeout_ms = dbus::ObjectProxy::TIMEOUT_USE_DEFAULT;
212     EXPECT_CALL(*mock_object_proxy_,
213                 CallMethodWithErrorCallback(_, def_timeout_ms, _, _))
214         .WillRepeatedly(Invoke(this, &AsyncDBusMethodInvokerTest::HandleCall));
215   }
216 
TearDown()217   void TearDown() override { bus_ = nullptr; }
218 
HandleCall(dbus::MethodCall * method_call,int,dbus::ObjectProxy::ResponseCallback success_callback,dbus::ObjectProxy::ErrorCallback error_callback)219   void HandleCall(dbus::MethodCall* method_call,
220                   int /* timeout_ms */,
221                   dbus::ObjectProxy::ResponseCallback success_callback,
222                   dbus::ObjectProxy::ErrorCallback error_callback) {
223     if (method_call->GetInterface() == kTestInterface) {
224       if (method_call->GetMember() == kTestMethod1) {
225         MessageReader reader(method_call);
226         int v1, v2;
227         // Input: two ints.
228         // Output: sum of the ints converted to string.
229         if (reader.PopInt32(&v1) && reader.PopInt32(&v2)) {
230           auto response = Response::CreateEmpty();
231           MessageWriter writer(response.get());
232           writer.AppendString(std::to_string(v1 + v2));
233           success_callback.Run(response.get());
234         }
235         return;
236       } else if (method_call->GetMember() == kTestMethod2) {
237         method_call->SetSerial(123);
238         auto error_response = dbus::ErrorResponse::FromMethodCall(
239             method_call, "org.MyError", "My error message");
240         error_callback.Run(error_response.get());
241         return;
242       }
243     }
244 
245     LOG(FATAL) << "Unexpected method call: " << method_call->ToString();
246   }
247 
SuccessCallback(const std::string & in_result,int * in_counter)248   base::Callback<void(const std::string&)> SuccessCallback(
249       const std::string& in_result, int* in_counter) {
250     return base::Bind(
251         [](const std::string& result,
252            int* counter,
253            const std::string& actual_result) {
254           (*counter)++;
255           EXPECT_EQ(result, actual_result);
256         },
257         in_result,
258         base::Unretained(in_counter));
259   }
260 
SuccessCallback(int * in_counter)261   base::Callback<void(const std::string&)> SuccessCallback(int* in_counter) {
262     return base::Bind(
263         [](int* counter, const std::string& actual_result) {
264           (*counter)++;
265           EXPECT_EQ("", actual_result);
266         },
267         base::Unretained(in_counter));
268   }
269 
ErrorCallback(int * in_counter)270   AsyncErrorCallback ErrorCallback(int* in_counter) {
271     return base::Bind(
272         [](int* counter, brillo::Error* error) {
273           (*counter)++;
274           EXPECT_NE(nullptr, error);
275           EXPECT_EQ("", error->GetDomain());
276           EXPECT_EQ("", error->GetCode());
277           EXPECT_EQ("", error->GetMessage());
278         },
279         base::Unretained(in_counter));
280   }
281 
ErrorCallback(const std::string & domain,const std::string & code,const std::string & message,int * in_counter)282   AsyncErrorCallback ErrorCallback(const std::string& domain,
283                                    const std::string& code,
284                                    const std::string& message,
285                                    int* in_counter) {
286     return base::Bind(
287         [](const std::string& domain,
288            const std::string& code,
289            const std::string& message,
290            int* counter,
291            brillo::Error* error) {
292           (*counter)++;
293           EXPECT_NE(nullptr, error);
294           EXPECT_EQ(domain, error->GetDomain());
295           EXPECT_EQ(code, error->GetCode());
296           EXPECT_EQ(message, error->GetMessage());
297         },
298         domain,
299         code,
300         message,
301         base::Unretained(in_counter));
302   }
303 
304   scoped_refptr<dbus::MockBus> bus_;
305   scoped_refptr<dbus::MockObjectProxy> mock_object_proxy_;
306 };
307 
TEST_F(AsyncDBusMethodInvokerTest,TestSuccess)308 TEST_F(AsyncDBusMethodInvokerTest, TestSuccess) {
309   int error_count = 0;
310   int success_count = 0;
311   brillo::dbus_utils::CallMethod(
312       mock_object_proxy_.get(),
313       kTestInterface,
314       kTestMethod1,
315       base::Bind(SuccessCallback("4", &success_count)),
316       base::Bind(ErrorCallback(&error_count)),
317       2, 2);
318   brillo::dbus_utils::CallMethod(
319       mock_object_proxy_.get(),
320       kTestInterface,
321       kTestMethod1,
322       base::Bind(SuccessCallback("10", &success_count)),
323       base::Bind(ErrorCallback(&error_count)),
324       3, 7);
325   brillo::dbus_utils::CallMethod(
326       mock_object_proxy_.get(),
327       kTestInterface,
328       kTestMethod1,
329       base::Bind(SuccessCallback("-4", &success_count)),
330       base::Bind(ErrorCallback(&error_count)),
331       13, -17);
332   EXPECT_EQ(0, error_count);
333   EXPECT_EQ(3, success_count);
334 }
335 
TEST_F(AsyncDBusMethodInvokerTest,TestFailure)336 TEST_F(AsyncDBusMethodInvokerTest, TestFailure) {
337   int error_count = 0;
338   int success_count = 0;
339   brillo::dbus_utils::CallMethod(
340       mock_object_proxy_.get(),
341       kTestInterface,
342       kTestMethod2,
343       base::Bind(SuccessCallback(&success_count)),
344       base::Bind(ErrorCallback(brillo::errors::dbus::kDomain,
345                                "org.MyError",
346                                "My error message",
347                                &error_count)),
348       2, 2);
349   EXPECT_EQ(1, error_count);
350   EXPECT_EQ(0, success_count);
351 }
352 
353 }  // namespace dbus_utils
354 }  // namespace brillo
355