1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_
18 #define INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_
19 
20 #include <stdint.h>
21 #include <sys/types.h>
22 
23 #include <memory>
24 #include <string>
25 
26 #include "perfetto/base/logging.h"
27 #include "perfetto/base/scoped_file.h"
28 #include "perfetto/base/utils.h"
29 #include "perfetto/base/weak_ptr.h"
30 
31 struct msghdr;
32 
33 namespace perfetto {
34 namespace base {
35 
36 class TaskRunner;
37 
38 // Use arbitrarily high values to avoid that some code accidentally ends up
39 // assuming that these enum values match the sysroot's SOCK_xxx defines rather
40 // than using GetUnixSockType().
41 enum class SockType { kStream = 100, kDgram, kSeqPacket };
42 
43 // UnixSocketRaw is a basic wrapper around UNIX sockets. It exposes wrapper
44 // methods that take care of most common pitfalls (e.g., marking fd as
45 // O_CLOEXEC, avoiding SIGPIPE, properly handling partial writes). It is used as
46 // a building block for the more sophisticated UnixSocket class.
47 class UnixSocketRaw {
48  public:
49   // Creates a new unconnected unix socket.
CreateMayFail(SockType t)50   static UnixSocketRaw CreateMayFail(SockType t) { return UnixSocketRaw(t); }
51 
52   // Crates a pair of connected sockets.
53   static std::pair<UnixSocketRaw, UnixSocketRaw> CreatePair(SockType);
54 
55   // Creates an uninitialized unix socket.
56   UnixSocketRaw();
57 
58   // Creates a unix socket adopting an existing file descriptor. This is
59   // typically used to inherit fds from init via environment variables.
60   UnixSocketRaw(ScopedFile, SockType);
61 
62   ~UnixSocketRaw() = default;
63   UnixSocketRaw(UnixSocketRaw&&) noexcept = default;
64   UnixSocketRaw& operator=(UnixSocketRaw&&) = default;
65 
66   bool Bind(const std::string& socket_name);
67   bool Listen();
68   bool Connect(const std::string& socket_name);
69   bool SetTxTimeout(uint32_t timeout_ms);
70   bool SetRxTimeout(uint32_t timeout_ms);
71   void Shutdown();
72   void SetBlocking(bool);
73   bool IsBlocking() const;
74   void RetainOnExec();
type()75   SockType type() const { return type_; }
fd()76   int fd() const { return *fd_; }
77   explicit operator bool() const { return !!fd_; }
78 
ReleaseFd()79   ScopedFile ReleaseFd() { return std::move(fd_); }
80 
81   ssize_t Send(const void* msg,
82                size_t len,
83                const int* send_fds = nullptr,
84                size_t num_fds = 0);
85 
86   // Re-enter sendmsg until all the data has been sent or an error occurs.
87   // TODO(fmayer): Figure out how to do timeouts here for heapprofd.
88   ssize_t SendMsgAll(struct msghdr* msg);
89 
90   ssize_t Receive(void* msg,
91                   size_t len,
92                   ScopedFile* fd_vec = nullptr,
93                   size_t max_files = 0);
94 
95   // Exposed for testing only.
96   // Update msghdr so subsequent sendmsg will send data that remains after n
97   // bytes have already been sent.
98   static void ShiftMsgHdr(size_t n, struct msghdr* msg);
99 
100  private:
101   explicit UnixSocketRaw(SockType);
102 
103   UnixSocketRaw(const UnixSocketRaw&) = delete;
104   UnixSocketRaw& operator=(const UnixSocketRaw&) = delete;
105 
106   ScopedFile fd_;
107   SockType type_{SockType::kStream};
108 };
109 
110 // A non-blocking UNIX domain socket. Allows also to transfer file descriptors.
111 // None of the methods in this class are blocking.
112 // The main design goal is making strong guarantees on the EventListener
113 // callbacks, in order to avoid ending in some undefined state.
114 // In case of any error it will aggressively just shut down the socket and
115 // notify the failure with OnConnect(false) or OnDisconnect() depending on the
116 // state of the socket (see below).
117 // EventListener callbacks stop happening as soon as the instance is destroyed.
118 //
119 // Lifecycle of a client socket:
120 //
121 //                           Connect()
122 //                               |
123 //            +------------------+------------------+
124 //            | (success)                           | (failure or Shutdown())
125 //            V                                     V
126 //     OnConnect(true)                         OnConnect(false)
127 //            |
128 //            V
129 //    OnDataAvailable()
130 //            |
131 //            V
132 //     OnDisconnect()  (failure or shutdown)
133 //
134 //
135 // Lifecycle of a server socket:
136 //
137 //                          Listen()  --> returns false in case of errors.
138 //                             |
139 //                             V
140 //              OnNewIncomingConnection(new_socket)
141 //
142 //          (|new_socket| inherits the same EventListener)
143 //                             |
144 //                             V
145 //                     OnDataAvailable()
146 //                             | (failure or Shutdown())
147 //                             V
148 //                       OnDisconnect()
149 class UnixSocket {
150  public:
151   class EventListener {
152    public:
153     virtual ~EventListener();
154 
155     // After Listen().
156     virtual void OnNewIncomingConnection(
157         UnixSocket* self,
158         std::unique_ptr<UnixSocket> new_connection);
159 
160     // After Connect(), whether successful or not.
161     virtual void OnConnect(UnixSocket* self, bool connected);
162 
163     // After a successful Connect() or OnNewIncomingConnection(). Either the
164     // other endpoint did disconnect or some other error happened.
165     virtual void OnDisconnect(UnixSocket* self);
166 
167     // Whenever there is data available to Receive(). Note that spurious FD
168     // watch events are possible, so it is possible that Receive() soon after
169     // OnDataAvailable() returns 0 (just ignore those).
170     virtual void OnDataAvailable(UnixSocket* self);
171   };
172 
173   enum class State {
174     kDisconnected = 0,  // Failed connection, peer disconnection or Shutdown().
175     kConnecting,  // Soon after Connect(), before it either succeeds or fails.
176     kConnected,   // After a successful Connect().
177     kListening    // After Listen(), until Shutdown().
178   };
179 
180   enum class BlockingMode { kNonBlocking, kBlocking };
181 
182   // Creates a Unix domain socket and starts listening. If |socket_name|
183   // starts with a '@', an abstract socket will be created (Linux/Android only).
184   // Returns always an instance. In case of failure (e.g., another socket
185   // with the same name is  already listening) the returned socket will have
186   // is_listening() == false and last_error() will contain the failure reason.
187   static std::unique_ptr<UnixSocket> Listen(const std::string& socket_name,
188                                             EventListener*,
189                                             TaskRunner*,
190                                             SockType = SockType::kStream);
191 
192   // Attaches to a pre-existing socket. The socket must have been created in
193   // SOCK_STREAM mode and the caller must have called bind() on it.
194   static std::unique_ptr<UnixSocket> Listen(ScopedFile,
195                                             EventListener*,
196                                             TaskRunner*,
197                                             SockType = SockType::kStream);
198 
199   // Creates a Unix domain socket and connects to the listening endpoint.
200   // Returns always an instance. EventListener::OnConnect(bool success) will
201   // be called always, whether the connection succeeded or not.
202   static std::unique_ptr<UnixSocket> Connect(const std::string& socket_name,
203                                              EventListener*,
204                                              TaskRunner*,
205                                              SockType = SockType::kStream);
206 
207   // Constructs a UnixSocket using the given connected socket.
208   static std::unique_ptr<UnixSocket> AdoptConnected(
209       ScopedFile fd,
210       EventListener* event_listener,
211       TaskRunner* task_runner,
212       SockType sock_type);
213 
214   UnixSocket(const UnixSocket&) = delete;
215   UnixSocket& operator=(const UnixSocket&) = delete;
216   // Cannot be easily moved because of tasks from the FileDescriptorWatch.
217   UnixSocket(UnixSocket&&) = delete;
218   UnixSocket& operator=(UnixSocket&&) = delete;
219 
220   // This class gives the hard guarantee that no callback is called on the
221   // passed EventListener immediately after the object has been destroyed.
222   // Any queued callback will be silently dropped.
223   ~UnixSocket();
224 
225   // Shuts down the current connection, if any. If the socket was Listen()-ing,
226   // stops listening. The socket goes back to kNotInitialized state, so it can
227   // be reused with Listen() or Connect().
228   void Shutdown(bool notify);
229 
230   // Returns true is the message was queued, false if there was no space in the
231   // output buffer, in which case the client should retry or give up.
232   // If any other error happens the socket will be shutdown and
233   // EventListener::OnDisconnect() will be called.
234   // If the socket is not connected, Send() will just return false.
235   // Does not append a null string terminator to msg in any case.
236   //
237   // DO NOT PASS kNonBlocking, it is broken.
238   bool Send(const void* msg,
239             size_t len,
240             const int* send_fds,
241             size_t num_fds,
242             BlockingMode blocking = BlockingMode::kNonBlocking);
243 
244   inline bool Send(const void* msg,
245                    size_t len,
246                    int send_fd = -1,
247                    BlockingMode blocking = BlockingMode::kNonBlocking) {
248     if (send_fd != -1)
249       return Send(msg, len, &send_fd, 1, blocking);
250     return Send(msg, len, nullptr, 0, blocking);
251   }
252 
253   inline bool Send(const std::string& msg,
254                    BlockingMode blocking = BlockingMode::kNonBlocking) {
255     return Send(msg.c_str(), msg.size() + 1, -1, blocking);
256   }
257 
258   // Returns the number of bytes (<= |len|) written in |msg| or 0 if there
259   // is no data in the buffer to read or an error occurs (in which case a
260   // EventListener::OnDisconnect() will follow).
261   // If the ScopedFile pointer is not null and a FD is received, it moves the
262   // received FD into that. If a FD is received but the ScopedFile pointer is
263   // null, the FD will be automatically closed.
264   size_t Receive(void* msg, size_t len, ScopedFile*, size_t max_files = 1);
265 
Receive(void * msg,size_t len)266   inline size_t Receive(void* msg, size_t len) {
267     return Receive(msg, len, nullptr, 0);
268   }
269 
270   // Only for tests. This is slower than Receive() as it requires a heap
271   // allocation and a copy for the std::string. Guarantees that the returned
272   // string is null terminated even if the underlying message sent by the peer
273   // is not.
274   std::string ReceiveString(size_t max_length = 1024);
275 
is_connected()276   bool is_connected() const { return state_ == State::kConnected; }
is_listening()277   bool is_listening() const { return state_ == State::kListening; }
fd()278   int fd() const { return sock_raw_.fd(); }
last_error()279   int last_error() const { return last_error_; }
280 
281   // User ID of the peer, as returned by the kernel. If the client disconnects
282   // and the socket goes into the kDisconnected state, it retains the uid of
283   // the last peer.
peer_uid()284   uid_t peer_uid() const {
285     PERFETTO_DCHECK(!is_listening() && peer_uid_ != kInvalidUid);
286     return peer_uid_;
287   }
288 
289 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
290     PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
291   // Process ID of the peer, as returned by the kernel. If the client
292   // disconnects and the socket goes into the kDisconnected state, it
293   // retains the pid of the last peer.
294   //
295   // This is only available on Linux / Android.
peer_pid()296   pid_t peer_pid() const {
297     PERFETTO_DCHECK(!is_listening() && peer_pid_ != kInvalidPid);
298     return peer_pid_;
299   }
300 #endif
301 
302   // This makes the UnixSocket unusable.
303   UnixSocketRaw ReleaseSocket();
304 
305  private:
306   UnixSocket(EventListener*, TaskRunner*, SockType);
307   UnixSocket(EventListener*, TaskRunner*, ScopedFile, State, SockType);
308 
309   // Called once by the corresponding public static factory methods.
310   void DoConnect(const std::string& socket_name);
311   void ReadPeerCredentials();
312 
313   void OnEvent();
314   void NotifyConnectionState(bool success);
315 
316   UnixSocketRaw sock_raw_;
317   State state_ = State::kDisconnected;
318   int last_error_ = 0;
319   uid_t peer_uid_ = kInvalidUid;
320 #if PERFETTO_BUILDFLAG(PERFETTO_OS_LINUX) || \
321     PERFETTO_BUILDFLAG(PERFETTO_OS_ANDROID)
322   pid_t peer_pid_ = kInvalidPid;
323 #endif
324   EventListener* const event_listener_;
325   TaskRunner* const task_runner_;
326   WeakPtrFactory<UnixSocket> weak_ptr_factory_;  // Keep last.
327 };
328 
329 }  // namespace base
330 }  // namespace perfetto
331 
332 #endif  // INCLUDE_PERFETTO_BASE_UNIX_SOCKET_H_
333