1import errno
2import os
3import selectors
4import signal
5import socket
6import struct
7import sys
8import threading
9import warnings
10
11from . import connection
12from . import process
13from .context import reduction
14from . import semaphore_tracker
15from . import spawn
16from . import util
17
18__all__ = ['ensure_running', 'get_inherited_fds', 'connect_to_new_process',
19           'set_forkserver_preload']
20
21#
22#
23#
24
25MAXFDS_TO_SEND = 256
26SIGNED_STRUCT = struct.Struct('q')     # large enough for pid_t
27
28#
29# Forkserver class
30#
31
32class ForkServer(object):
33
34    def __init__(self):
35        self._forkserver_address = None
36        self._forkserver_alive_fd = None
37        self._forkserver_pid = None
38        self._inherited_fds = None
39        self._lock = threading.Lock()
40        self._preload_modules = ['__main__']
41
42    def set_forkserver_preload(self, modules_names):
43        '''Set list of module names to try to load in forkserver process.'''
44        if not all(type(mod) is str for mod in self._preload_modules):
45            raise TypeError('module_names must be a list of strings')
46        self._preload_modules = modules_names
47
48    def get_inherited_fds(self):
49        '''Return list of fds inherited from parent process.
50
51        This returns None if the current process was not started by fork
52        server.
53        '''
54        return self._inherited_fds
55
56    def connect_to_new_process(self, fds):
57        '''Request forkserver to create a child process.
58
59        Returns a pair of fds (status_r, data_w).  The calling process can read
60        the child process's pid and (eventually) its returncode from status_r.
61        The calling process should write to data_w the pickled preparation and
62        process data.
63        '''
64        self.ensure_running()
65        if len(fds) + 4 >= MAXFDS_TO_SEND:
66            raise ValueError('too many fds')
67        with socket.socket(socket.AF_UNIX) as client:
68            client.connect(self._forkserver_address)
69            parent_r, child_w = os.pipe()
70            child_r, parent_w = os.pipe()
71            allfds = [child_r, child_w, self._forkserver_alive_fd,
72                      semaphore_tracker.getfd()]
73            allfds += fds
74            try:
75                reduction.sendfds(client, allfds)
76                return parent_r, parent_w
77            except:
78                os.close(parent_r)
79                os.close(parent_w)
80                raise
81            finally:
82                os.close(child_r)
83                os.close(child_w)
84
85    def ensure_running(self):
86        '''Make sure that a fork server is running.
87
88        This can be called from any process.  Note that usually a child
89        process will just reuse the forkserver started by its parent, so
90        ensure_running() will do nothing.
91        '''
92        with self._lock:
93            semaphore_tracker.ensure_running()
94            if self._forkserver_pid is not None:
95                # forkserver was launched before, is it still running?
96                pid, status = os.waitpid(self._forkserver_pid, os.WNOHANG)
97                if not pid:
98                    # still alive
99                    return
100                # dead, launch it again
101                os.close(self._forkserver_alive_fd)
102                self._forkserver_address = None
103                self._forkserver_alive_fd = None
104                self._forkserver_pid = None
105
106            cmd = ('from multiprocessing.forkserver import main; ' +
107                   'main(%d, %d, %r, **%r)')
108
109            if self._preload_modules:
110                desired_keys = {'main_path', 'sys_path'}
111                data = spawn.get_preparation_data('ignore')
112                data = {x: y for x, y in data.items() if x in desired_keys}
113            else:
114                data = {}
115
116            with socket.socket(socket.AF_UNIX) as listener:
117                address = connection.arbitrary_address('AF_UNIX')
118                listener.bind(address)
119                os.chmod(address, 0o600)
120                listener.listen()
121
122                # all client processes own the write end of the "alive" pipe;
123                # when they all terminate the read end becomes ready.
124                alive_r, alive_w = os.pipe()
125                try:
126                    fds_to_pass = [listener.fileno(), alive_r]
127                    cmd %= (listener.fileno(), alive_r, self._preload_modules,
128                            data)
129                    exe = spawn.get_executable()
130                    args = [exe] + util._args_from_interpreter_flags()
131                    args += ['-c', cmd]
132                    pid = util.spawnv_passfds(exe, args, fds_to_pass)
133                except:
134                    os.close(alive_w)
135                    raise
136                finally:
137                    os.close(alive_r)
138                self._forkserver_address = address
139                self._forkserver_alive_fd = alive_w
140                self._forkserver_pid = pid
141
142#
143#
144#
145
146def main(listener_fd, alive_r, preload, main_path=None, sys_path=None):
147    '''Run forkserver.'''
148    if preload:
149        if '__main__' in preload and main_path is not None:
150            process.current_process()._inheriting = True
151            try:
152                spawn.import_main_path(main_path)
153            finally:
154                del process.current_process()._inheriting
155        for modname in preload:
156            try:
157                __import__(modname)
158            except ImportError:
159                pass
160
161    util._close_stdin()
162
163    sig_r, sig_w = os.pipe()
164    os.set_blocking(sig_r, False)
165    os.set_blocking(sig_w, False)
166
167    def sigchld_handler(*_unused):
168        # Dummy signal handler, doesn't do anything
169        pass
170
171    handlers = {
172        # unblocking SIGCHLD allows the wakeup fd to notify our event loop
173        signal.SIGCHLD: sigchld_handler,
174        # protect the process from ^C
175        signal.SIGINT: signal.SIG_IGN,
176        }
177    old_handlers = {sig: signal.signal(sig, val)
178                    for (sig, val) in handlers.items()}
179
180    # calling os.write() in the Python signal handler is racy
181    signal.set_wakeup_fd(sig_w)
182
183    # map child pids to client fds
184    pid_to_fd = {}
185
186    with socket.socket(socket.AF_UNIX, fileno=listener_fd) as listener, \
187         selectors.DefaultSelector() as selector:
188        _forkserver._forkserver_address = listener.getsockname()
189
190        selector.register(listener, selectors.EVENT_READ)
191        selector.register(alive_r, selectors.EVENT_READ)
192        selector.register(sig_r, selectors.EVENT_READ)
193
194        while True:
195            try:
196                while True:
197                    rfds = [key.fileobj for (key, events) in selector.select()]
198                    if rfds:
199                        break
200
201                if alive_r in rfds:
202                    # EOF because no more client processes left
203                    assert os.read(alive_r, 1) == b'', "Not at EOF?"
204                    raise SystemExit
205
206                if sig_r in rfds:
207                    # Got SIGCHLD
208                    os.read(sig_r, 65536)  # exhaust
209                    while True:
210                        # Scan for child processes
211                        try:
212                            pid, sts = os.waitpid(-1, os.WNOHANG)
213                        except ChildProcessError:
214                            break
215                        if pid == 0:
216                            break
217                        child_w = pid_to_fd.pop(pid, None)
218                        if child_w is not None:
219                            if os.WIFSIGNALED(sts):
220                                returncode = -os.WTERMSIG(sts)
221                            else:
222                                if not os.WIFEXITED(sts):
223                                    raise AssertionError(
224                                        "Child {0:n} status is {1:n}".format(
225                                            pid,sts))
226                                returncode = os.WEXITSTATUS(sts)
227                            # Send exit code to client process
228                            try:
229                                write_signed(child_w, returncode)
230                            except BrokenPipeError:
231                                # client vanished
232                                pass
233                            os.close(child_w)
234                        else:
235                            # This shouldn't happen really
236                            warnings.warn('forkserver: waitpid returned '
237                                          'unexpected pid %d' % pid)
238
239                if listener in rfds:
240                    # Incoming fork request
241                    with listener.accept()[0] as s:
242                        # Receive fds from client
243                        fds = reduction.recvfds(s, MAXFDS_TO_SEND + 1)
244                        if len(fds) > MAXFDS_TO_SEND:
245                            raise RuntimeError(
246                                "Too many ({0:n}) fds to send".format(
247                                    len(fds)))
248                        child_r, child_w, *fds = fds
249                        s.close()
250                        pid = os.fork()
251                        if pid == 0:
252                            # Child
253                            code = 1
254                            try:
255                                listener.close()
256                                selector.close()
257                                unused_fds = [alive_r, child_w, sig_r, sig_w]
258                                unused_fds.extend(pid_to_fd.values())
259                                code = _serve_one(child_r, fds,
260                                                  unused_fds,
261                                                  old_handlers)
262                            except Exception:
263                                sys.excepthook(*sys.exc_info())
264                                sys.stderr.flush()
265                            finally:
266                                os._exit(code)
267                        else:
268                            # Send pid to client process
269                            try:
270                                write_signed(child_w, pid)
271                            except BrokenPipeError:
272                                # client vanished
273                                pass
274                            pid_to_fd[pid] = child_w
275                            os.close(child_r)
276                            for fd in fds:
277                                os.close(fd)
278
279            except OSError as e:
280                if e.errno != errno.ECONNABORTED:
281                    raise
282
283
284def _serve_one(child_r, fds, unused_fds, handlers):
285    # close unnecessary stuff and reset signal handlers
286    signal.set_wakeup_fd(-1)
287    for sig, val in handlers.items():
288        signal.signal(sig, val)
289    for fd in unused_fds:
290        os.close(fd)
291
292    (_forkserver._forkserver_alive_fd,
293     semaphore_tracker._semaphore_tracker._fd,
294     *_forkserver._inherited_fds) = fds
295
296    # Run process object received over pipe
297    code = spawn._main(child_r)
298
299    return code
300
301
302#
303# Read and write signed numbers
304#
305
306def read_signed(fd):
307    data = b''
308    length = SIGNED_STRUCT.size
309    while len(data) < length:
310        s = os.read(fd, length - len(data))
311        if not s:
312            raise EOFError('unexpected EOF')
313        data += s
314    return SIGNED_STRUCT.unpack(data)[0]
315
316def write_signed(fd, n):
317    msg = SIGNED_STRUCT.pack(n)
318    while msg:
319        nbytes = os.write(fd, msg)
320        if nbytes == 0:
321            raise RuntimeError('should not get here')
322        msg = msg[nbytes:]
323
324#
325#
326#
327
328_forkserver = ForkServer()
329ensure_running = _forkserver.ensure_running
330get_inherited_fds = _forkserver.get_inherited_fds
331connect_to_new_process = _forkserver.connect_to_new_process
332set_forkserver_preload = _forkserver.set_forkserver_preload
333