1# Copyright 2014 Google Inc. All rights reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#    http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import copy
16import multiprocessing
17import pickle
18import traceback
19
20from typ.host import Host
21
22
23def make_pool(host, jobs, callback, context, pre_fn, post_fn):
24    _validate_args(context, pre_fn, post_fn)
25    if jobs > 1:
26        return _ProcessPool(host, jobs, callback, context, pre_fn, post_fn)
27    else:
28        return _AsyncPool(host, jobs, callback, context, pre_fn, post_fn)
29
30
31class _MessageType(object):
32    Request = 'Request'
33    Response = 'Response'
34    Close = 'Close'
35    Done = 'Done'
36    Error = 'Error'
37    Interrupt = 'Interrupt'
38
39    values = [Request, Response, Close, Done, Error, Interrupt]
40
41
42def _validate_args(context, pre_fn, post_fn):
43    try:
44        _ = pickle.dumps(context)
45    except Exception as e:
46        raise ValueError('context passed to make_pool is not picklable: %s'
47                         % str(e))
48    try:
49        _ = pickle.dumps(pre_fn)
50    except pickle.PickleError:
51        raise ValueError('pre_fn passed to make_pool is not picklable')
52    try:
53        _ = pickle.dumps(post_fn)
54    except pickle.PickleError:
55        raise ValueError('post_fn passed to make_pool is not picklable')
56
57
58class _ProcessPool(object):
59
60    def __init__(self, host, jobs, callback, context, pre_fn, post_fn):
61        self.host = host
62        self.jobs = jobs
63        self.requests = multiprocessing.Queue()
64        self.responses = multiprocessing.Queue()
65        self.workers = []
66        self.discarded_responses = []
67        self.closed = False
68        self.erred = False
69        for worker_num in range(1, jobs + 1):
70            w = multiprocessing.Process(target=_loop,
71                                        args=(self.requests, self.responses,
72                                              host.for_mp(), worker_num,
73                                              callback, context,
74                                              pre_fn, post_fn))
75            w.start()
76            self.workers.append(w)
77
78    def send(self, msg):
79        self.requests.put((_MessageType.Request, msg))
80
81    def get(self):
82        msg_type, resp = self.responses.get()
83        if msg_type == _MessageType.Error:
84            self._handle_error(resp)
85        elif msg_type == _MessageType.Interrupt:
86            raise KeyboardInterrupt
87        assert msg_type == _MessageType.Response
88        return resp
89
90    def close(self):
91        for _ in self.workers:
92            self.requests.put((_MessageType.Close, None))
93        self.closed = True
94
95    def join(self):
96        # TODO: one would think that we could close self.requests in close(),
97        # above, and close self.responses below, but if we do, we get
98        # weird tracebacks in the daemon threads multiprocessing starts up.
99        # Instead, we have to hack the innards of multiprocessing. It
100        # seems likely that there's a bug somewhere, either in this module or
101        # in multiprocessing.
102        # pylint: disable=protected-access
103        if self.host.is_python3:  # pragma: python3
104            multiprocessing.queues.is_exiting = lambda: True
105        else:  # pragma: python2
106            multiprocessing.util._exiting = True
107
108        if not self.closed:
109            # We must be aborting; terminate the workers rather than
110            # shutting down cleanly.
111            for w in self.workers:
112                w.terminate()
113                w.join()
114            return []
115
116        final_responses = []
117        error = None
118        interrupted = None
119        for w in self.workers:
120            while True:
121                msg_type, resp = self.responses.get()
122                if msg_type == _MessageType.Error:
123                    error = resp
124                    break
125                if msg_type == _MessageType.Interrupt:
126                    interrupted = True
127                    break
128                if msg_type == _MessageType.Done:
129                    final_responses.append(resp[1])
130                    break
131                self.discarded_responses.append(resp)
132
133        for w in self.workers:
134            w.join()
135
136        # TODO: See comment above at the beginning of the function for
137        # why this is commented out.
138        # self.responses.close()
139
140        if error:
141            self._handle_error(error)
142        if interrupted:
143            raise KeyboardInterrupt
144        return final_responses
145
146    def _handle_error(self, msg):
147        worker_num, tb = msg
148        self.erred = True
149        raise Exception("Error from worker %d (traceback follows):\n%s" %
150                        (worker_num, tb))
151
152
153# 'Too many arguments' pylint: disable=R0913
154
155def _loop(requests, responses, host, worker_num,
156          callback, context, pre_fn, post_fn, should_loop=True):
157    host = host or Host()
158    try:
159        context_after_pre = pre_fn(host, worker_num, context)
160        keep_looping = True
161        while keep_looping:
162            message_type, args = requests.get(block=True)
163            if message_type == _MessageType.Close:
164                responses.put((_MessageType.Done,
165                               (worker_num, post_fn(context_after_pre))))
166                break
167            assert message_type == _MessageType.Request
168            resp = callback(context_after_pre, args)
169            responses.put((_MessageType.Response, resp))
170            keep_looping = should_loop
171    except KeyboardInterrupt as e:
172        responses.put((_MessageType.Interrupt, (worker_num, str(e))))
173    except Exception as e:
174        responses.put((_MessageType.Error,
175                       (worker_num, traceback.format_exc(e))))
176
177
178class _AsyncPool(object):
179
180    def __init__(self, host, jobs, callback, context, pre_fn, post_fn):
181        self.host = host or Host()
182        self.jobs = jobs
183        self.callback = callback
184        self.context = copy.deepcopy(context)
185        self.msgs = []
186        self.closed = False
187        self.post_fn = post_fn
188        self.context_after_pre = pre_fn(self.host, 1, self.context)
189        self.final_context = None
190
191    def send(self, msg):
192        self.msgs.append(msg)
193
194    def get(self):
195        return self.callback(self.context_after_pre, self.msgs.pop(0))
196
197    def close(self):
198        self.closed = True
199        self.final_context = self.post_fn(self.context_after_pre)
200
201    def join(self):
202        if not self.closed:
203            self.close()
204        return [self.final_context]
205