1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "common/libs/utils/unix_sockets.h"
18 
19 #include <string>
20 #include <utility>
21 
22 #include <android-base/logging.h>
23 #include <android-base/result.h>
24 #include <gtest/gtest.h>
25 
26 #include "common/libs/fs/shared_buf.h"
27 #include "common/libs/fs/shared_fd.h"
28 
29 namespace cuttlefish {
30 
CreateMemFDWithData(const std::string & data)31 SharedFD CreateMemFDWithData(const std::string& data) {
32   auto memfd = SharedFD::MemfdCreate("");
33   CHECK(WriteAll(memfd, data) == data.size()) << memfd->StrError();
34   CHECK(memfd->LSeek(0, SEEK_SET) == 0);
35   return memfd;
36 }
37 
ReadAllFDData(SharedFD fd)38 std::string ReadAllFDData(SharedFD fd) {
39   std::string data;
40   CHECK(ReadAll(fd, &data) > 0) << fd->StrError();
41   return data;
42 }
43 
TEST(UnixSocketMessage,ExtractFileDescriptors)44 TEST(UnixSocketMessage, ExtractFileDescriptors) {
45   auto memfd1 = CreateMemFDWithData("abc");
46   auto memfd2 = CreateMemFDWithData("def");
47 
48   UnixSocketMessage message;
49   auto control1 = ControlMessage::FromFileDescriptors({memfd1});
50   ASSERT_TRUE(control1.ok()) << control1.error().Trace();
51   message.control.emplace_back(std::move(*control1));
52   auto control2 = ControlMessage::FromFileDescriptors({memfd2});
53   ASSERT_TRUE(control2.ok()) << control2.error().Trace();
54   message.control.emplace_back(std::move(*control2));
55 
56   ASSERT_TRUE(message.HasFileDescriptors());
57   auto fds = message.FileDescriptors();
58   ASSERT_TRUE(fds.ok());
59   ASSERT_EQ("abc", ReadAllFDData((*fds)[0]));
60   ASSERT_EQ("def", ReadAllFDData((*fds)[1]));
61 }
62 
UnixMessageSocketPair()63 std::pair<UnixMessageSocket, UnixMessageSocket> UnixMessageSocketPair() {
64   SharedFD sock1, sock2;
65   CHECK(SharedFD::SocketPair(AF_UNIX, SOCK_SEQPACKET, 0, &sock1, &sock2));
66   return {UnixMessageSocket(sock1), UnixMessageSocket(sock2)};
67 }
68 
TEST(UnixMessageSocket,SendPlainMessage)69 TEST(UnixMessageSocket, SendPlainMessage) {
70   auto [writer, reader] = UnixMessageSocketPair();
71   UnixSocketMessage message_in = {{1, 2, 3}, {}};
72   auto write_result = writer.WriteMessage(message_in);
73   ASSERT_TRUE(write_result.ok()) << write_result.error().Trace();
74 
75   auto message_out = reader.ReadMessage();
76   ASSERT_TRUE(message_out.ok()) << message_out.error().Trace();
77   ASSERT_EQ(message_in.data, message_out->data);
78   ASSERT_EQ(0, message_out->control.size());
79 }
80 
TEST(UnixMessageSocket,SendFileDescriptor)81 TEST(UnixMessageSocket, SendFileDescriptor) {
82   auto [writer, reader] = UnixMessageSocketPair();
83 
84   UnixSocketMessage message_in = {{4, 5, 6}, {}};
85   auto control_in =
86       ControlMessage::FromFileDescriptors({CreateMemFDWithData("abc")});
87   ASSERT_TRUE(control_in.ok()) << control_in.error().Trace();
88   message_in.control.emplace_back(std::move(*control_in));
89   auto write_result = writer.WriteMessage(message_in);
90   ASSERT_TRUE(write_result.ok()) << write_result.error().Trace();
91 
92   auto message_out = reader.ReadMessage();
93   ASSERT_TRUE(message_out.ok()) << message_out.error().Trace();
94   ASSERT_EQ(message_in.data, message_out->data);
95 
96   ASSERT_EQ(1, message_out->control.size());
97   auto fds_out = message_out->control[0].AsSharedFDs();
98   ASSERT_TRUE(fds_out.ok()) << fds_out.error().Trace();
99   ASSERT_EQ(1, fds_out->size());
100   ASSERT_EQ("abc", ReadAllFDData((*fds_out)[0]));
101 }
102 
TEST(UnixMessageSocket,SendTwoFileDescriptors)103 TEST(UnixMessageSocket, SendTwoFileDescriptors) {
104   auto memfd1 = CreateMemFDWithData("abc");
105   auto memfd2 = CreateMemFDWithData("def");
106 
107   auto [writer, reader] = UnixMessageSocketPair();
108   UnixSocketMessage message_in = {{7, 8, 9}, {}};
109   auto control_in = ControlMessage::FromFileDescriptors({memfd1, memfd2});
110   ASSERT_TRUE(control_in.ok()) << control_in.error().Trace();
111   message_in.control.emplace_back(std::move(*control_in));
112   auto write_result = writer.WriteMessage(message_in);
113   ASSERT_TRUE(write_result.ok()) << write_result.error().Trace();
114 
115   auto message_out = reader.ReadMessage();
116   ASSERT_TRUE(message_out.ok()) << message_out.error().Trace();
117   ASSERT_EQ(message_in.data, message_out->data);
118 
119   ASSERT_EQ(1, message_out->control.size());
120   auto fds_out = message_out->control[0].AsSharedFDs();
121   ASSERT_TRUE(fds_out.ok()) << fds_out.error().Trace();
122   ASSERT_EQ(2, fds_out->size());
123 
124   ASSERT_EQ("abc", ReadAllFDData((*fds_out)[0]));
125   ASSERT_EQ("def", ReadAllFDData((*fds_out)[1]));
126 }
127 
TEST(UnixMessageSocket,SendCredentials)128 TEST(UnixMessageSocket, SendCredentials) {
129   auto [writer, reader] = UnixMessageSocketPair();
130   auto writer_creds_status = writer.EnableCredentials(true);
131   ASSERT_TRUE(writer_creds_status.ok()) << writer_creds_status.error().Trace();
132   auto reader_creds_status = reader.EnableCredentials(true);
133   ASSERT_TRUE(reader_creds_status.ok()) << reader_creds_status.error().Trace();
134 
135   ucred credentials_in;
136   credentials_in.pid = getpid();
137   credentials_in.uid = getuid();
138   credentials_in.gid = getgid();
139   UnixSocketMessage message_in = {{1, 5, 9}, {}};
140   auto control_in = ControlMessage::FromCredentials(credentials_in);
141   message_in.control.emplace_back(std::move(control_in));
142   auto write_result = writer.WriteMessage(message_in);
143   ASSERT_TRUE(write_result.ok()) << write_result.error().Trace();
144 
145   auto message_out = reader.ReadMessage();
146   ASSERT_TRUE(message_out.ok()) << message_out.error().Trace();
147   ASSERT_EQ(message_in.data, message_out->data);
148 
149   ASSERT_EQ(1, message_out->control.size());
150   auto credentials_out = message_out->control[0].AsCredentials();
151   ASSERT_TRUE(credentials_out.ok()) << credentials_out.error().Trace();
152   ASSERT_EQ(credentials_in.pid, credentials_out->pid);
153   ASSERT_EQ(credentials_in.uid, credentials_out->uid);
154   ASSERT_EQ(credentials_in.gid, credentials_out->gid);
155 }
156 
TEST(UnixMessageSocket,BadCredentialsBlocked)157 TEST(UnixMessageSocket, BadCredentialsBlocked) {
158   auto [writer, reader] = UnixMessageSocketPair();
159   auto writer_creds_status = writer.EnableCredentials(true);
160   ASSERT_TRUE(writer_creds_status.ok()) << writer_creds_status.error().Trace();
161   auto reader_creds_status = reader.EnableCredentials(true);
162   ASSERT_TRUE(reader_creds_status.ok()) << reader_creds_status.error().Trace();
163 
164   ucred credentials_in;
165   // This assumes the test is running without root privileges
166   credentials_in.pid = getpid() + 1;
167   credentials_in.uid = getuid() + 1;
168   credentials_in.gid = getgid() + 1;
169 
170   UnixSocketMessage message_in = {{2, 4, 6}, {}};
171   auto control_in = ControlMessage::FromCredentials(credentials_in);
172   message_in.control.emplace_back(std::move(control_in));
173   auto write_result = writer.WriteMessage(message_in);
174   ASSERT_FALSE(write_result.ok()) << write_result.error().Trace();
175 }
176 
TEST(UnixMessageSocket,AutoCredentials)177 TEST(UnixMessageSocket, AutoCredentials) {
178   auto [writer, reader] = UnixMessageSocketPair();
179   auto writer_creds_status = writer.EnableCredentials(true);
180   ASSERT_TRUE(writer_creds_status.ok()) << writer_creds_status.error().Trace();
181   auto reader_creds_status = reader.EnableCredentials(true);
182   ASSERT_TRUE(reader_creds_status.ok()) << reader_creds_status.error().Trace();
183 
184   UnixSocketMessage message_in = {{3, 6, 9}, {}};
185   auto write_result = writer.WriteMessage(message_in);
186   ASSERT_TRUE(write_result.ok()) << write_result.error().Trace();
187 
188   auto message_out = reader.ReadMessage();
189   ASSERT_TRUE(message_out.ok()) << message_out.error().Trace();
190   ASSERT_EQ(message_in.data, message_out->data);
191 
192   ASSERT_EQ(1, message_out->control.size());
193   auto credentials_out = message_out->control[0].AsCredentials();
194   ASSERT_TRUE(credentials_out.ok()) << credentials_out.error().Trace();
195   ASSERT_EQ(getpid(), credentials_out->pid);
196   ASSERT_EQ(getuid(), credentials_out->uid);
197   ASSERT_EQ(getgid(), credentials_out->gid);
198 }
199 
200 }  // namespace cuttlefish
201