1 /*
2  * A type which wraps a socket
3  *
4  * socket_connection.c
5  *
6  * Copyright (c) 2006-2008, R Oudkerk --- see COPYING.txt
7  */
8 
9 #include "multiprocessing.h"
10 
11 #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
12 #  include "poll.h"
13 #endif
14 
15 #ifdef MS_WINDOWS
16 #  define WRITE(h, buffer, length) send((SOCKET)h, buffer, length, 0)
17 #  define READ(h, buffer, length) recv((SOCKET)h, buffer, length, 0)
18 #  define CLOSE(h) closesocket((SOCKET)h)
19 #else
20 #  define WRITE(h, buffer, length) write(h, buffer, length)
21 #  define READ(h, buffer, length) read(h, buffer, length)
22 #  define CLOSE(h) close(h)
23 #endif
24 
25 /*
26  * Wrapper for PyErr_CheckSignals() which can be called without the GIL
27  */
28 
29 static int
check_signals(void)30 check_signals(void)
31 {
32     PyGILState_STATE state;
33     int res;
34     state = PyGILState_Ensure();
35     res = PyErr_CheckSignals();
36     PyGILState_Release(state);
37     return res;
38 }
39 
40 /*
41  * Send string to file descriptor
42  */
43 
44 static Py_ssize_t
_conn_sendall(HANDLE h,char * string,size_t length)45 _conn_sendall(HANDLE h, char *string, size_t length)
46 {
47     char *p = string;
48     Py_ssize_t res;
49 
50     while (length > 0) {
51         res = WRITE(h, p, length);
52         if (res < 0) {
53             if (errno == EINTR) {
54                 if (check_signals() < 0)
55                     return MP_EXCEPTION_HAS_BEEN_SET;
56                 continue;
57             }
58             return MP_SOCKET_ERROR;
59         }
60         length -= res;
61         p += res;
62     }
63 
64     return MP_SUCCESS;
65 }
66 
67 /*
68  * Receive string of exact length from file descriptor
69  */
70 
71 static Py_ssize_t
_conn_recvall(HANDLE h,char * buffer,size_t length)72 _conn_recvall(HANDLE h, char *buffer, size_t length)
73 {
74     size_t remaining = length;
75     Py_ssize_t temp;
76     char *p = buffer;
77 
78     while (remaining > 0) {
79         temp = READ(h, p, remaining);
80         if (temp < 0) {
81             if (errno == EINTR) {
82                 if (check_signals() < 0)
83                     return MP_EXCEPTION_HAS_BEEN_SET;
84                 continue;
85             }
86             return temp;
87         }
88         else if (temp == 0) {
89             return remaining == length ? MP_END_OF_FILE : MP_EARLY_END_OF_FILE;
90         }
91         remaining -= temp;
92         p += temp;
93     }
94 
95     return MP_SUCCESS;
96 }
97 
98 /*
99  * Send a string prepended by the string length in network byte order
100  */
101 
102 static Py_ssize_t
conn_send_string(ConnectionObject * conn,char * string,size_t length)103 conn_send_string(ConnectionObject *conn, char *string, size_t length)
104 {
105     Py_ssize_t res;
106     /* The "header" of the message is a 32 bit unsigned number (in
107        network order) which specifies the length of the "body".  If
108        the message is shorter than about 16kb then it is quicker to
109        combine the "header" and the "body" of the message and send
110        them at once. */
111     if (length < (16*1024)) {
112         char *message;
113 
114         message = PyMem_Malloc(length+4);
115         if (message == NULL)
116             return MP_MEMORY_ERROR;
117 
118         *(UINT32*)message = htonl((UINT32)length);
119         memcpy(message+4, string, length);
120         Py_BEGIN_ALLOW_THREADS
121         res = _conn_sendall(conn->handle, message, length+4);
122         Py_END_ALLOW_THREADS
123         PyMem_Free(message);
124     } else {
125         UINT32 lenbuff;
126 
127         if (length > MAX_MESSAGE_LENGTH)
128             return MP_BAD_MESSAGE_LENGTH;
129 
130         lenbuff = htonl((UINT32)length);
131         Py_BEGIN_ALLOW_THREADS
132         res = _conn_sendall(conn->handle, (char*)&lenbuff, 4) ||
133             _conn_sendall(conn->handle, string, length);
134         Py_END_ALLOW_THREADS
135     }
136     return res;
137 }
138 
139 /*
140  * Attempts to read into buffer, or failing that into *newbuffer
141  *
142  * Returns number of bytes read.
143  */
144 
145 static Py_ssize_t
conn_recv_string(ConnectionObject * conn,char * buffer,size_t buflength,char ** newbuffer,size_t maxlength)146 conn_recv_string(ConnectionObject *conn, char *buffer,
147                  size_t buflength, char **newbuffer, size_t maxlength)
148 {
149     Py_ssize_t res;
150     UINT32 ulength;
151 
152     *newbuffer = NULL;
153 
154     Py_BEGIN_ALLOW_THREADS
155     res = _conn_recvall(conn->handle, (char*)&ulength, 4);
156     Py_END_ALLOW_THREADS
157     if (res < 0)
158         return res;
159 
160     ulength = ntohl(ulength);
161     if (ulength > maxlength)
162         return MP_BAD_MESSAGE_LENGTH;
163 
164     if (ulength > buflength) {
165         *newbuffer = buffer = PyMem_Malloc((size_t)ulength);
166         if (buffer == NULL)
167             return MP_MEMORY_ERROR;
168     }
169 
170     Py_BEGIN_ALLOW_THREADS
171     res = _conn_recvall(conn->handle, buffer, (size_t)ulength);
172     Py_END_ALLOW_THREADS
173 
174     if (res >= 0) {
175         res = (Py_ssize_t)ulength;
176     } else if (*newbuffer != NULL) {
177         PyMem_Free(*newbuffer);
178         *newbuffer = NULL;
179     }
180     return res;
181 }
182 
183 /*
184  * Check whether any data is available for reading -- neg timeout blocks
185  */
186 
187 static int
conn_poll(ConnectionObject * conn,double timeout,PyThreadState * _save)188 conn_poll(ConnectionObject *conn, double timeout, PyThreadState *_save)
189 {
190 #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
191     int res;
192     struct pollfd p;
193 
194     p.fd = (int)conn->handle;
195     p.events = POLLIN | POLLPRI;
196     p.revents = 0;
197 
198     if (timeout < 0) {
199         do {
200             res = poll(&p, 1, -1);
201         } while (res < 0 && errno == EINTR);
202     } else {
203         res = poll(&p, 1, (int)(timeout * 1000 + 0.5));
204         if (res < 0 && errno == EINTR) {
205             /* We were interrupted by a signal.  Just indicate a
206                timeout even though we are early. */
207             return FALSE;
208         }
209     }
210 
211     if (res < 0) {
212         return MP_SOCKET_ERROR;
213     } else if (p.revents & (POLLNVAL|POLLERR)) {
214         Py_BLOCK_THREADS
215         PyErr_SetString(PyExc_IOError, "poll() gave POLLNVAL or POLLERR");
216         Py_UNBLOCK_THREADS
217         return MP_EXCEPTION_HAS_BEEN_SET;
218     } else if (p.revents != 0) {
219         return TRUE;
220     } else {
221         assert(res == 0);
222         return FALSE;
223     }
224 #else
225     int res;
226     fd_set rfds;
227 
228     /*
229      * Verify the handle, issue 3321. Not required for windows.
230      */
231     #ifndef MS_WINDOWS
232         if (((int)conn->handle) < 0 || ((int)conn->handle) >= FD_SETSIZE) {
233             Py_BLOCK_THREADS
234             PyErr_SetString(PyExc_IOError, "handle out of range in select()");
235             Py_UNBLOCK_THREADS
236             return MP_EXCEPTION_HAS_BEEN_SET;
237         }
238     #endif
239 
240     FD_ZERO(&rfds);
241     FD_SET((SOCKET)conn->handle, &rfds);
242 
243     if (timeout < 0.0) {
244         do {
245             res = select((int)conn->handle+1, &rfds, NULL, NULL, NULL);
246         } while (res < 0 && errno == EINTR);
247     } else {
248         struct timeval tv;
249         tv.tv_sec = (long)timeout;
250         tv.tv_usec = (long)((timeout - tv.tv_sec) * 1e6 + 0.5);
251         res = select((int)conn->handle+1, &rfds, NULL, NULL, &tv);
252         if (res < 0 && errno == EINTR) {
253             /* We were interrupted by a signal.  Just indicate a
254                timeout even though we are early. */
255             return FALSE;
256         }
257     }
258 
259     if (res < 0) {
260         return MP_SOCKET_ERROR;
261     } else if (FD_ISSET(conn->handle, &rfds)) {
262         return TRUE;
263     } else {
264         assert(res == 0);
265         return FALSE;
266     }
267 #endif
268 }
269 
270 /*
271  * "connection.h" defines the Connection type using defs above
272  */
273 
274 #define CONNECTION_NAME "Connection"
275 #define CONNECTION_TYPE ConnectionType
276 
277 #include "connection.h"
278