1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Locking related utils."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import threading
22
23
24class GroupLock(object):
25  """A lock to allow many members of a group to access a resource exclusively.
26
27  This lock provides a way to allow access to a resource by multiple threads
28  belonging to a logical group at the same time, while restricting access to
29  threads from all other groups. You can think of this as an extension of a
30  reader-writer lock, where you allow multiple writers at the same time. We
31  made it generic to support multiple groups instead of just two - readers and
32  writers.
33
34  Simple usage example with two groups accessing the same resource:
35
36  ```python
37  lock = GroupLock(num_groups=2)
38
39  # In a member of group 0:
40  with lock.group(0):
41    # do stuff, access the resource
42    # ...
43
44  # In a member of group 1:
45  with lock.group(1):
46    # do stuff, access the resource
47    # ...
48  ```
49
50  Using as a context manager with `.group(group_id)` is the easiest way. You
51  can also use the `acquire` and `release` method directly.
52  """
53
54  __slots__ = ["_ready", "_num_groups", "_group_member_counts"]
55
56  def __init__(self, num_groups=2):
57    """Initialize a group lock.
58
59    Args:
60      num_groups: The number of groups that will be accessing the resource under
61        consideration. Should be a positive number.
62
63    Returns:
64      A group lock that can then be used to synchronize code.
65
66    Raises:
67      ValueError: If num_groups is less than 1.
68    """
69    if num_groups < 1:
70      raise ValueError("num_groups must be a positive integer, got {}".format(
71          num_groups))
72    self._ready = threading.Condition(threading.Lock())
73    self._num_groups = num_groups
74    self._group_member_counts = [0] * self._num_groups
75
76  def group(self, group_id):
77    """Enter a context where the lock is with group `group_id`.
78
79    Args:
80      group_id: The group for which to acquire and release the lock.
81
82    Returns:
83      A context manager which will acquire the lock for `group_id`.
84    """
85    self._validate_group_id(group_id)
86    return self._Context(self, group_id)
87
88  def acquire(self, group_id):
89    """Acquire the group lock for a specific group `group_id`."""
90    self._validate_group_id(group_id)
91
92    self._ready.acquire()
93    while self._another_group_active(group_id):
94      self._ready.wait()
95    self._group_member_counts[group_id] += 1
96    self._ready.release()
97
98  def release(self, group_id):
99    """Release the group lock for a specific group `group_id`."""
100    self._validate_group_id(group_id)
101
102    self._ready.acquire()
103    self._group_member_counts[group_id] -= 1
104    if self._group_member_counts[group_id] == 0:
105      self._ready.notifyAll()
106    self._ready.release()
107
108  def _another_group_active(self, group_id):
109    return any(
110        c > 0 for g, c in enumerate(self._group_member_counts) if g != group_id)
111
112  def _validate_group_id(self, group_id):
113    if group_id < 0 or group_id >= self._num_groups:
114      raise ValueError(
115          "group_id={} should be between 0 and num_groups={}".format(
116              group_id, self._num_groups))
117
118  class _Context(object):
119    """Context manager helper for `GroupLock`."""
120
121    __slots__ = ["_lock", "_group_id"]
122
123    def __init__(self, lock, group_id):
124      self._lock = lock
125      self._group_id = group_id
126
127    def __enter__(self):
128      self._lock.acquire(self._group_id)
129
130    def __exit__(self, type_arg, value_arg, traceback_arg):
131      del type_arg, value_arg, traceback_arg
132      self._lock.release(self._group_id)
133