1 /*
2  * Copyright (C) 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 /**
18  * @file
19  *   This file includes the implementation for the Socket interface to radio
20  * (RCP).
21  */
22 
23 #include "socket_interface.hpp"
24 
25 #include <errno.h>
26 #include <linux/limits.h>
27 #include <openthread/logging.h>
28 #include <sys/inotify.h>
29 #include <sys/socket.h>
30 #include <sys/stat.h>
31 #include <sys/un.h>
32 #include <sys/wait.h>
33 #include <unistd.h>
34 
35 #include <string>
36 
37 #include "common/code_utils.hpp"
38 #include "openthread/openthread-system.h"
39 #include "platform-posix.h"
40 
41 namespace aidl {
42 namespace android {
43 namespace hardware {
44 namespace threadnetwork {
45 
46 const char SocketInterface::kLogModuleName[] = "SocketIntface";
47 
SocketInterface(const ot::Url::Url & aRadioUrl)48 SocketInterface::SocketInterface(const ot::Url::Url& aRadioUrl)
49     : mReceiveFrameCallback(nullptr),
50       mReceiveFrameContext(nullptr),
51       mReceiveFrameBuffer(nullptr),
52       mSockFd(-1),
53       mRadioUrl(aRadioUrl) {
54     memset(&mInterfaceMetrics, 0, sizeof(mInterfaceMetrics));
55     mInterfaceMetrics.mRcpInterfaceType = kSpinelInterfaceTypeVendor;
56 }
57 
Init(ReceiveFrameCallback aCallback,void * aCallbackContext,RxFrameBuffer & aFrameBuffer)58 otError SocketInterface::Init(ReceiveFrameCallback aCallback, void* aCallbackContext,
59                               RxFrameBuffer& aFrameBuffer) {
60     otError error = OT_ERROR_NONE;
61 
62     VerifyOrExit(mSockFd == -1, error = OT_ERROR_ALREADY);
63 
64     WaitForSocketFileCreated(mRadioUrl.GetPath());
65 
66     mSockFd = OpenFile(mRadioUrl);
67     VerifyOrExit(mSockFd != -1, error = OT_ERROR_FAILED);
68 
69     mReceiveFrameCallback = aCallback;
70     mReceiveFrameContext = aCallbackContext;
71     mReceiveFrameBuffer = &aFrameBuffer;
72 
73 exit:
74     return error;
75 }
76 
~SocketInterface(void)77 SocketInterface::~SocketInterface(void) {
78     Deinit();
79 }
80 
Deinit(void)81 void SocketInterface::Deinit(void) {
82     CloseFile();
83 
84     mReceiveFrameCallback = nullptr;
85     mReceiveFrameContext = nullptr;
86     mReceiveFrameBuffer = nullptr;
87 }
88 
SendFrame(const uint8_t * aFrame,uint16_t aLength)89 otError SocketInterface::SendFrame(const uint8_t* aFrame, uint16_t aLength) {
90     Write(aFrame, aLength);
91 
92     return OT_ERROR_NONE;
93 }
94 
WaitForFrame(uint64_t aTimeoutUs)95 otError SocketInterface::WaitForFrame(uint64_t aTimeoutUs) {
96     otError error = OT_ERROR_NONE;
97     struct timeval timeout;
98     timeout.tv_sec = static_cast<time_t>(aTimeoutUs / US_PER_S);
99     timeout.tv_usec = static_cast<suseconds_t>(aTimeoutUs % US_PER_S);
100 
101     fd_set readFds;
102     fd_set errorFds;
103     int rval;
104 
105     FD_ZERO(&readFds);
106     FD_ZERO(&errorFds);
107     FD_SET(mSockFd, &readFds);
108     FD_SET(mSockFd, &errorFds);
109 
110     rval = TEMP_FAILURE_RETRY(select(mSockFd + 1, &readFds, nullptr, &errorFds, &timeout));
111 
112     if (rval > 0) {
113         if (FD_ISSET(mSockFd, &readFds)) {
114             Read();
115         } else if (FD_ISSET(mSockFd, &errorFds)) {
116             DieNowWithMessage("RCP error", OT_EXIT_FAILURE);
117         } else {
118             DieNow(OT_EXIT_FAILURE);
119         }
120     } else if (rval == 0) {
121         ExitNow(error = OT_ERROR_RESPONSE_TIMEOUT);
122     } else {
123         DieNowWithMessage("wait response", OT_EXIT_FAILURE);
124     }
125 
126 exit:
127     return error;
128 }
129 
UpdateFdSet(void * aMainloopContext)130 void SocketInterface::UpdateFdSet(void* aMainloopContext) {
131     otSysMainloopContext* context = reinterpret_cast<otSysMainloopContext*>(aMainloopContext);
132 
133     assert(context != nullptr);
134 
135     FD_SET(mSockFd, &context->mReadFdSet);
136 
137     if (context->mMaxFd < mSockFd) {
138         context->mMaxFd = mSockFd;
139     }
140 }
141 
Process(const void * aMainloopContext)142 void SocketInterface::Process(const void* aMainloopContext) {
143     const otSysMainloopContext* context =
144             reinterpret_cast<const otSysMainloopContext*>(aMainloopContext);
145 
146     assert(context != nullptr);
147 
148     if (FD_ISSET(mSockFd, &context->mReadFdSet)) {
149         Read();
150     }
151 }
152 
Read(void)153 void SocketInterface::Read(void) {
154     uint8_t buffer[kMaxFrameSize];
155 
156     ssize_t rval = TEMP_FAILURE_RETRY(read(mSockFd, buffer, sizeof(buffer)));
157 
158     if (rval > 0) {
159         ProcessReceivedData(buffer, static_cast<uint16_t>(rval));
160     } else if (rval < 0) {
161         DieNow(OT_EXIT_ERROR_ERRNO);
162     } else {
163         LogCrit("Socket connection is closed by remote.");
164         exit(OT_EXIT_FAILURE);
165     }
166 }
167 
Write(const uint8_t * aFrame,uint16_t aLength)168 void SocketInterface::Write(const uint8_t* aFrame, uint16_t aLength) {
169     ssize_t rval = TEMP_FAILURE_RETRY(write(mSockFd, aFrame, aLength));
170     VerifyOrDie(rval >= 0, OT_EXIT_ERROR_ERRNO);
171     VerifyOrDie(rval > 0, OT_EXIT_FAILURE);
172 }
173 
ProcessReceivedData(const uint8_t * aBuffer,uint16_t aLength)174 void SocketInterface::ProcessReceivedData(const uint8_t* aBuffer, uint16_t aLength) {
175     while (aLength--) {
176         uint8_t byte = *aBuffer++;
177         if (mReceiveFrameBuffer->CanWrite(sizeof(uint8_t))) {
178             IgnoreError(mReceiveFrameBuffer->WriteByte(byte));
179         } else {
180             HandleSocketFrame(this, OT_ERROR_NO_BUFS);
181             return;
182         }
183     }
184     HandleSocketFrame(this, OT_ERROR_NONE);
185 }
186 
HandleSocketFrame(void * aContext,otError aError)187 void SocketInterface::HandleSocketFrame(void* aContext, otError aError) {
188     static_cast<SocketInterface*>(aContext)->HandleSocketFrame(aError);
189 }
190 
HandleSocketFrame(otError aError)191 void SocketInterface::HandleSocketFrame(otError aError) {
192     VerifyOrExit((mReceiveFrameCallback != nullptr) && (mReceiveFrameBuffer != nullptr));
193 
194     if (aError == OT_ERROR_NONE) {
195         mReceiveFrameCallback(mReceiveFrameContext);
196     } else {
197         mReceiveFrameBuffer->DiscardFrame();
198         LogWarn("Process socket frame failed: %s", otThreadErrorToString(aError));
199     }
200 
201 exit:
202     return;
203 }
204 
OpenFile(const ot::Url::Url & aRadioUrl)205 int SocketInterface::OpenFile(const ot::Url::Url& aRadioUrl) {
206     int fd = -1;
207     sockaddr_un serverAddress;
208 
209     VerifyOrExit(sizeof(serverAddress.sun_path) > strlen(aRadioUrl.GetPath()),
210                  LogCrit("Invalid file path length"));
211     strncpy(serverAddress.sun_path, aRadioUrl.GetPath(), sizeof(serverAddress.sun_path));
212     serverAddress.sun_family = AF_UNIX;
213 
214     fd = socket(AF_UNIX, SOCK_SEQPACKET, 0);
215     VerifyOrExit(fd != -1, LogCrit("open(): errno=%s", strerror(errno)));
216 
217     if (connect(fd, reinterpret_cast<struct sockaddr*>(&serverAddress), sizeof(serverAddress)) ==
218         -1) {
219         LogCrit("connect(): errno=%s", strerror(errno));
220         close(fd);
221         fd = -1;
222     }
223 
224 exit:
225     return fd;
226 }
227 
CloseFile(void)228 void SocketInterface::CloseFile(void) {
229     VerifyOrExit(mSockFd != -1);
230 
231     VerifyOrExit(0 == close(mSockFd), LogCrit("close(): errno=%s", strerror(errno)));
232     VerifyOrExit(wait(nullptr) != -1 || errno == ECHILD,
233                  LogCrit("wait(): errno=%s", strerror(errno)));
234 
235     mSockFd = -1;
236 
237 exit:
238     return;
239 }
240 
WaitForSocketFileCreated(const char * aPath)241 void SocketInterface::WaitForSocketFileCreated(const char* aPath) {
242     int inotifyFd;
243     int wd;
244     int lastSlashIdx;
245     std::string folderPath;
246     std::string socketPath(aPath);
247 
248     VerifyOrExit(!IsSocketFileExisted(aPath));
249 
250     inotifyFd = inotify_init();
251     VerifyOrDie(inotifyFd != -1, OT_EXIT_ERROR_ERRNO);
252 
253     lastSlashIdx = socketPath.find_last_of('/');
254     VerifyOrDie(lastSlashIdx != std::string::npos, OT_EXIT_ERROR_ERRNO);
255 
256     folderPath = socketPath.substr(0, lastSlashIdx);
257     wd = inotify_add_watch(inotifyFd, folderPath.c_str(), IN_CREATE);
258     VerifyOrDie(wd != -1, OT_EXIT_ERROR_ERRNO);
259 
260     LogInfo("Waiting for socket file %s be created...", aPath);
261 
262     while (true) {
263         fd_set fds;
264         FD_ZERO(&fds);
265         FD_SET(inotifyFd, &fds);
266         struct timeval timeout = {kMaxSelectTimeMs / MS_PER_S,
267                                   (kMaxSelectTimeMs % MS_PER_S) * MS_PER_S};
268 
269         int rval = select(inotifyFd + 1, &fds, nullptr, nullptr, &timeout);
270         VerifyOrDie(rval >= 0, OT_EXIT_ERROR_ERRNO);
271 
272         if (rval == 0 && IsSocketFileExisted(aPath)) {
273             break;
274         }
275 
276         if (FD_ISSET(inotifyFd, &fds)) {
277             char buffer[sizeof(struct inotify_event) + NAME_MAX + 1];
278             ssize_t bytesRead = read(inotifyFd, buffer, sizeof(buffer));
279 
280             VerifyOrDie(bytesRead >= 0, OT_EXIT_ERROR_ERRNO);
281 
282             struct inotify_event* event = reinterpret_cast<struct inotify_event*>(buffer);
283             if ((event->mask & IN_CREATE) && IsSocketFileExisted(aPath)) {
284                 break;
285             }
286         }
287     }
288 
289     close(inotifyFd);
290 
291 exit:
292     LogInfo("Socket file: %s is created", aPath);
293     return;
294 }
295 
IsSocketFileExisted(const char * aPath)296 bool SocketInterface::IsSocketFileExisted(const char* aPath) {
297     struct stat st;
298     return stat(aPath, &st) == 0 && S_ISSOCK(st.st_mode);
299 }
300 
301 }  // namespace threadnetwork
302 }  // namespace hardware
303 }  // namespace android
304 }  // namespace aidl
305