1# Copyright 2014 Google Inc. All rights reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import copy 16import multiprocessing 17import pickle 18import traceback 19 20from typ.host import Host 21 22 23def make_pool(host, jobs, callback, context, pre_fn, post_fn): 24 _validate_args(context, pre_fn, post_fn) 25 if jobs > 1: 26 return _ProcessPool(host, jobs, callback, context, pre_fn, post_fn) 27 else: 28 return _AsyncPool(host, jobs, callback, context, pre_fn, post_fn) 29 30 31class _MessageType(object): 32 Request = 'Request' 33 Response = 'Response' 34 Close = 'Close' 35 Done = 'Done' 36 Error = 'Error' 37 Interrupt = 'Interrupt' 38 39 values = [Request, Response, Close, Done, Error, Interrupt] 40 41 42def _validate_args(context, pre_fn, post_fn): 43 try: 44 _ = pickle.dumps(context) 45 except Exception as e: 46 raise ValueError('context passed to make_pool is not picklable: %s' 47 % str(e)) 48 try: 49 _ = pickle.dumps(pre_fn) 50 except pickle.PickleError: 51 raise ValueError('pre_fn passed to make_pool is not picklable') 52 try: 53 _ = pickle.dumps(post_fn) 54 except pickle.PickleError: 55 raise ValueError('post_fn passed to make_pool is not picklable') 56 57 58class _ProcessPool(object): 59 60 def __init__(self, host, jobs, callback, context, pre_fn, post_fn): 61 self.host = host 62 self.jobs = jobs 63 self.requests = multiprocessing.Queue() 64 self.responses = multiprocessing.Queue() 65 self.workers = [] 66 self.discarded_responses = [] 67 self.closed = False 68 self.erred = False 69 for worker_num in range(1, jobs + 1): 70 w = multiprocessing.Process(target=_loop, 71 args=(self.requests, self.responses, 72 host.for_mp(), worker_num, 73 callback, context, 74 pre_fn, post_fn)) 75 w.start() 76 self.workers.append(w) 77 78 def send(self, msg): 79 self.requests.put((_MessageType.Request, msg)) 80 81 def get(self): 82 msg_type, resp = self.responses.get() 83 if msg_type == _MessageType.Error: 84 self._handle_error(resp) 85 elif msg_type == _MessageType.Interrupt: 86 raise KeyboardInterrupt 87 assert msg_type == _MessageType.Response 88 return resp 89 90 def close(self): 91 for _ in self.workers: 92 self.requests.put((_MessageType.Close, None)) 93 self.closed = True 94 95 def join(self): 96 # TODO: one would think that we could close self.requests in close(), 97 # above, and close self.responses below, but if we do, we get 98 # weird tracebacks in the daemon threads multiprocessing starts up. 99 # Instead, we have to hack the innards of multiprocessing. It 100 # seems likely that there's a bug somewhere, either in this module or 101 # in multiprocessing. 102 # pylint: disable=protected-access 103 if self.host.is_python3: # pragma: python3 104 multiprocessing.queues.is_exiting = lambda: True 105 else: # pragma: python2 106 multiprocessing.util._exiting = True 107 108 if not self.closed: 109 # We must be aborting; terminate the workers rather than 110 # shutting down cleanly. 111 for w in self.workers: 112 w.terminate() 113 w.join() 114 return [] 115 116 final_responses = [] 117 error = None 118 interrupted = None 119 for w in self.workers: 120 while True: 121 msg_type, resp = self.responses.get() 122 if msg_type == _MessageType.Error: 123 error = resp 124 break 125 if msg_type == _MessageType.Interrupt: 126 interrupted = True 127 break 128 if msg_type == _MessageType.Done: 129 final_responses.append(resp[1]) 130 break 131 self.discarded_responses.append(resp) 132 133 for w in self.workers: 134 w.join() 135 136 # TODO: See comment above at the beginning of the function for 137 # why this is commented out. 138 # self.responses.close() 139 140 if error: 141 self._handle_error(error) 142 if interrupted: 143 raise KeyboardInterrupt 144 return final_responses 145 146 def _handle_error(self, msg): 147 worker_num, tb = msg 148 self.erred = True 149 raise Exception("Error from worker %d (traceback follows):\n%s" % 150 (worker_num, tb)) 151 152 153# 'Too many arguments' pylint: disable=R0913 154 155def _loop(requests, responses, host, worker_num, 156 callback, context, pre_fn, post_fn, should_loop=True): 157 host = host or Host() 158 try: 159 context_after_pre = pre_fn(host, worker_num, context) 160 keep_looping = True 161 while keep_looping: 162 message_type, args = requests.get(block=True) 163 if message_type == _MessageType.Close: 164 responses.put((_MessageType.Done, 165 (worker_num, post_fn(context_after_pre)))) 166 break 167 assert message_type == _MessageType.Request 168 resp = callback(context_after_pre, args) 169 responses.put((_MessageType.Response, resp)) 170 keep_looping = should_loop 171 except KeyboardInterrupt as e: 172 responses.put((_MessageType.Interrupt, (worker_num, str(e)))) 173 except Exception as e: 174 responses.put((_MessageType.Error, 175 (worker_num, traceback.format_exc(e)))) 176 177 178class _AsyncPool(object): 179 180 def __init__(self, host, jobs, callback, context, pre_fn, post_fn): 181 self.host = host or Host() 182 self.jobs = jobs 183 self.callback = callback 184 self.context = copy.deepcopy(context) 185 self.msgs = [] 186 self.closed = False 187 self.post_fn = post_fn 188 self.context_after_pre = pre_fn(self.host, 1, self.context) 189 self.final_context = None 190 191 def send(self, msg): 192 self.msgs.append(msg) 193 194 def get(self): 195 return self.callback(self.context_after_pre, self.msgs.pop(0)) 196 197 def close(self): 198 self.closed = True 199 self.final_context = self.post_fn(self.context_after_pre) 200 201 def join(self): 202 if not self.closed: 203 self.close() 204 return [self.final_context] 205