1# Copyright 2018 gRPC authors.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import logging
17import os
18import threading
19
20_LOGGER = logging.getLogger(__name__)
21
22_AWAIT_THREADS_TIMEOUT_SECONDS = 5
23
24_TRUE_VALUES = ['yes',  'Yes',  'YES', 'true', 'True', 'TRUE', '1']
25
26# This flag enables experimental support within gRPC Python for applications
27# that will fork() without exec(). When enabled, gRPC Python will attempt to
28# pause all of its internally created threads before the fork syscall proceeds.
29#
30# For this to be successful, the application must not have multiple threads of
31# its own calling into gRPC when fork is invoked. Any callbacks from gRPC
32# Python-spawned threads into user code (e.g., callbacks for asynchronous RPCs)
33# must  not block and should execute quickly.
34#
35# This flag is not supported on Windows.
36_GRPC_ENABLE_FORK_SUPPORT = (
37    os.environ.get('GRPC_ENABLE_FORK_SUPPORT', '0')
38        .lower() in _TRUE_VALUES)
39
40cdef void __prefork() nogil:
41    with gil:
42        with _fork_state.fork_in_progress_condition:
43            _fork_state.fork_in_progress = True
44        if not _fork_state.active_thread_count.await_zero_threads(
45                _AWAIT_THREADS_TIMEOUT_SECONDS):
46            _LOGGER.error(
47                'Failed to shutdown gRPC Python threads prior to fork. '
48                'Behavior after fork will be undefined.')
49
50
51cdef void __postfork_parent() nogil:
52    with gil:
53        with _fork_state.fork_in_progress_condition:
54            _fork_state.fork_in_progress = False
55            _fork_state.fork_in_progress_condition.notify_all()
56
57
58cdef void __postfork_child() nogil:
59    with gil:
60        # Thread could be holding the fork_in_progress_condition inside of
61        # block_if_fork_in_progress() when fork occurs. Reset the lock here.
62        _fork_state.fork_in_progress_condition = threading.Condition()
63        # A thread in return_from_user_request_generator() may hold this lock
64        # when fork occurs.
65        _fork_state.active_thread_count = _ActiveThreadCount()
66        for state_to_reset in _fork_state.postfork_states_to_reset:
67            state_to_reset.reset_postfork_child()
68        _fork_state.fork_epoch += 1
69        for channel in _fork_state.channels:
70            channel._close_on_fork()
71        # TODO(ericgribkoff) Check and abort if core is not shutdown
72        with _fork_state.fork_in_progress_condition:
73            _fork_state.fork_in_progress = False
74    if grpc_is_initialized() > 0:
75        with gil:
76            _LOGGER.error('Failed to shutdown gRPC Core after fork()')
77            os._exit(os.EX_USAGE)
78
79
80def fork_handlers_and_grpc_init():
81    grpc_init()
82    if _GRPC_ENABLE_FORK_SUPPORT:
83        with _fork_state.fork_handler_registered_lock:
84            if not _fork_state.fork_handler_registered:
85                pthread_atfork(&__prefork, &__postfork_parent, &__postfork_child)
86                _fork_state.fork_handler_registered = True
87
88
89class ForkManagedThread(object):
90    def __init__(self, target, args=()):
91        if _GRPC_ENABLE_FORK_SUPPORT:
92            def managed_target(*args):
93                try:
94                    target(*args)
95                finally:
96                    _fork_state.active_thread_count.decrement()
97            self._thread = threading.Thread(target=managed_target, args=args)
98        else:
99            self._thread = threading.Thread(target=target, args=args)
100
101    def setDaemon(self, daemonic):
102        self._thread.daemon = daemonic
103
104    def start(self):
105        if _GRPC_ENABLE_FORK_SUPPORT:
106            _fork_state.active_thread_count.increment()
107        self._thread.start()
108
109    def join(self):
110        self._thread.join()
111
112
113def block_if_fork_in_progress(postfork_state_to_reset=None):
114    if _GRPC_ENABLE_FORK_SUPPORT:
115        with _fork_state.fork_in_progress_condition:
116            if not _fork_state.fork_in_progress:
117                return
118            if postfork_state_to_reset is not None:
119                _fork_state.postfork_states_to_reset.append(postfork_state_to_reset)
120            _fork_state.active_thread_count.decrement()
121            _fork_state.fork_in_progress_condition.wait()
122            _fork_state.active_thread_count.increment()
123
124
125def enter_user_request_generator():
126    if _GRPC_ENABLE_FORK_SUPPORT:
127        _fork_state.active_thread_count.decrement()
128
129
130def return_from_user_request_generator():
131    if _GRPC_ENABLE_FORK_SUPPORT:
132        _fork_state.active_thread_count.increment()
133        block_if_fork_in_progress()
134
135
136def get_fork_epoch():
137    return _fork_state.fork_epoch
138
139
140def is_fork_support_enabled():
141    return _GRPC_ENABLE_FORK_SUPPORT
142
143
144def fork_register_channel(channel):
145    if _GRPC_ENABLE_FORK_SUPPORT:
146        _fork_state.channels.add(channel)
147
148
149def fork_unregister_channel(channel):
150    if _GRPC_ENABLE_FORK_SUPPORT:
151        _fork_state.channels.remove(channel)
152
153
154class _ActiveThreadCount(object):
155    def __init__(self):
156        self._num_active_threads = 0
157        self._condition = threading.Condition()
158
159    def increment(self):
160        with self._condition:
161            self._num_active_threads += 1
162
163    def decrement(self):
164        with self._condition:
165            self._num_active_threads -= 1
166            if self._num_active_threads == 0:
167                self._condition.notify_all()
168
169    def await_zero_threads(self, timeout_secs):
170        end_time = time.time() + timeout_secs
171        wait_time = timeout_secs
172        with self._condition:
173            while True:
174                if self._num_active_threads > 0:
175                    self._condition.wait(wait_time)
176                if self._num_active_threads == 0:
177                    return True
178                # Thread count may have increased before this re-obtains the
179                # lock after a notify(). Wait again until timeout_secs has
180                # elapsed.
181                wait_time = end_time - time.time()
182                if wait_time <= 0:
183                    return False
184
185
186class _ForkState(object):
187    def __init__(self):
188        self.fork_in_progress_condition = threading.Condition()
189        self.fork_in_progress = False
190        self.postfork_states_to_reset = []
191        self.fork_handler_registered_lock = threading.Lock()
192        self.fork_handler_registered = False
193        self.active_thread_count = _ActiveThreadCount()
194        self.fork_epoch = 0
195        self.channels = set()
196
197
198_fork_state = _ForkState()
199