1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A variable which packs a list of variables distributed across devices."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.distribute import device_util
22from tensorflow.python.eager import context
23from tensorflow.python.framework import ops
24from tensorflow.python.ops import math_ops
25from tensorflow.python.ops import resource_variable_ops
26
27
28class PackedDistributedVariable(resource_variable_ops.BaseResourceVariable):
29  """A variable which packs multiple variables distributed across devices.
30
31  It's only supported when eager execution is enabled.
32  For op-by-op execution, use an unpacked handle on the current device; for
33  function execution, use the packed handle to reduce the overhead of function
34  calls.
35  """
36
37  def __init__(self, distributed_variables=None, name=None, **unused_kwargs):
38    """Packs a list of variables which are distributed across devices.
39
40    Args:
41      distributed_variables: A list of distributed Variables to pack.
42      name: Optional name for the variable. Defaults to `'Variable'` and gets
43        uniquified automatically.
44    """
45    if not ops.executing_eagerly_outside_functions():
46      raise ValueError(
47          "PackedDistributedVariable should be created in eager mode.")
48    if not distributed_variables:
49      raise ValueError("Expect a non-empty list of variables to pack.")
50    for i, var in enumerate(distributed_variables):
51      if not resource_variable_ops.is_resource_variable(var):
52        raise ValueError("Expect a list of ResourceVariables to pack, "
53                         "but the %d-th variable is %s" % (i, type(var)))
54
55    self._distributed_variables = distributed_variables
56    self._devices = [v.device for v in distributed_variables]
57    with ops.init_scope():
58      with ops.name_scope(name, "Variable", skip_on_eager=False) as name:
59        handle = ops.pack_eager_tensors(
60            [var.handle for var in distributed_variables])
61        handle_name = ops.name_from_scope_name(name)
62        unique_id = "%s_%d" % (handle_name, ops.uid())
63        super(PackedDistributedVariable, self).__init__(
64            trainable=distributed_variables[0].trainable,
65            shape=distributed_variables[0].shape,
66            dtype=distributed_variables[0].dtype,
67            handle=handle,
68            synchronization=distributed_variables[0].synchronization,
69            constraint=distributed_variables[0].constraint,
70            aggregation=distributed_variables[0].aggregation,
71            distribute_strategy=distributed_variables[0]._distribute_strategy,  # pylint: disable=protected-access
72            name=name,
73            unique_id=unique_id,
74            handle_name=handle_name,
75            graph_element=None,
76            initial_value=None,
77            initializer_op=None,
78            is_initialized_op=None,
79            cached_value=None,
80            caching_device=None,
81            is_distributed_variables=True)
82
83  @property
84  def devices(self):
85    return self._devices
86
87  def on_device(self, device):
88    return PackedVarAndDevice(self, device)
89
90  def get_var_on_device(self, device):
91    for i, d in enumerate(self._devices):
92      if d == device:
93        return self._distributed_variables[i]
94    raise ValueError("Device %s is not found" % device)
95
96  def get_var_on_current_device(self):
97    current_device = device_util.canonicalize(device_util.current())
98    return self.get_var_on_device(current_device)
99
100  def initial_value(self, device):
101    """Returns the Tensor used as the initial value for the variable."""
102    return self.get_var_on_device(device).initial_value
103
104  @property
105  def handle(self):
106    if context.executing_eagerly():
107      return self.get_var_on_current_device().handle
108    else:
109      return self._handle
110
111  @property
112  def packed_handle(self):
113    return self._handle
114
115  def _read_variable_op(self):
116    if context.executing_eagerly():
117      return self.get_var_on_current_device().value()
118    else:
119      return super(PackedDistributedVariable, self)._read_variable_op()
120
121  def value(self):
122    return self._read_variable_op()
123
124  def is_initialized(self, name=None):
125    if context.executing_eagerly():
126      result = self._distributed_variables[0].is_initialized()
127      for v in self._distributed_variables[1:-1]:
128        result = math_ops.logical_and(result, v.is_initialized())
129      result = math_ops.logical_and(
130          result, self._distributed_variables[-1].is_initialized(), name=name)
131    else:
132      with ops.device(self._devices[0]):
133        result = super(PackedDistributedVariable, self).is_initialized(name)
134      for d in self._devices[1:-1]:
135        with ops.device(d):
136          initialized = super(PackedDistributedVariable,
137                              self).is_initialized(name)
138        result = math_ops.logical_and(result, initialized)
139      with ops.device(self._devices[-1]):
140        initialized = super(PackedDistributedVariable,
141                            self).is_initialized(name)
142      result = math_ops.logical_and(result, initialized, name=name)
143    return result
144
145  def _update(self, update_fn, value, **kwargs):
146    if context.executing_eagerly():
147      return update_fn(self.get_var_on_current_device(), value, **kwargs)
148    else:
149      return update_fn(super(PackedDistributedVariable, self), value, **kwargs)
150
151  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
152    assign_sub_fn = lambda var, *a, **kw: var.assign_sub(*a, **kw)
153    return self._update(
154        update_fn=assign_sub_fn,
155        value=delta,
156        use_locking=use_locking,
157        name=name,
158        read_value=read_value)
159
160  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
161    assign_add_fn = lambda var, *a, **kw: var.assign_add(*a, **kw)
162    return self._update(
163        update_fn=assign_add_fn,
164        value=delta,
165        use_locking=use_locking,
166        name=name,
167        read_value=read_value)
168
169  def assign(self, value, use_locking=None, name=None, read_value=True):
170    assign_fn = lambda var, *a, **kw: var.assign(*a, **kw)
171    return self._update(
172        update_fn=assign_fn,
173        value=value,
174        use_locking=use_locking,
175        name=name,
176        read_value=read_value)
177
178  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
179    scatter_sub_fn = lambda var, *a, **kw: var.scatter_sub(*a, **kw)
180    return self._update(
181        update_fn=scatter_sub_fn,
182        value=sparse_delta,
183        use_locking=use_locking,
184        name=name)
185
186  def scatter_add(self, sparse_delta, use_locking=False, name=None):
187    scatter_add_fn = lambda var, *a, **kw: var.scatter_add(*a, **kw)
188    return self._update(
189        update_fn=scatter_add_fn,
190        value=sparse_delta,
191        use_locking=use_locking,
192        name=name)
193
194  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
195    scatter_mul_fn = lambda var, *a, **kw: var.scatter_mul(*a, **kw)
196    return self._update(
197        update_fn=scatter_mul_fn,
198        value=sparse_delta,
199        use_locking=use_locking,
200        name=name)
201
202  def scatter_div(self, sparse_delta, use_locking=False, name=None):
203    scatter_div_fn = lambda var, *a, **kw: var.scatter_div(*a, **kw)
204    return self._update(
205        update_fn=scatter_div_fn,
206        value=sparse_delta,
207        use_locking=use_locking,
208        name=name)
209
210  def scatter_min(self, sparse_delta, use_locking=False, name=None):
211    scatter_min_fn = lambda var, *a, **kw: var.scatter_min(*a, **kw)
212    return self._update(
213        update_fn=scatter_min_fn,
214        value=sparse_delta,
215        use_locking=use_locking,
216        name=name)
217
218  def scatter_max(self, sparse_delta, use_locking=False, name=None):
219    scatter_max_fn = lambda var, *a, **kw: var.scatter_max(*a, **kw)
220    return self._update(
221        update_fn=scatter_max_fn,
222        value=sparse_delta,
223        use_locking=use_locking,
224        name=name)
225
226  def scatter_update(self, sparse_delta, use_locking=False, name=None):
227    scatter_update_fn = lambda var, *a, **kw: var.scatter_update(*a, **kw)
228    return self._update(
229        update_fn=scatter_update_fn,
230        value=sparse_delta,
231        use_locking=use_locking,
232        name=name)
233
234  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
235    if context.executing_eagerly():
236      return self.get_var_on_current_device()._dense_var_to_tensor(  # pylint: disable=protected-access
237          dtype=dtype,
238          name=name,
239          as_ref=as_ref)
240    else:
241      return super(PackedDistributedVariable, self)._dense_var_to_tensor(  # pylint: disable=protected-access
242          dtype=dtype,
243          name=name,
244          as_ref=as_ref)
245
246
247class PackedVarAndDevice(object):
248  """Holds a packed distributed variable and a device."""
249
250  def __init__(self, var, device):
251    self._var = var
252    self._device = device
253
254  def __getattr__(self, name):
255    return getattr(self._var, name)
256
257  def var(self):
258    return self._var
259
260  def value(self):
261    with ops.device(self._device):
262      return self._var.value()
263
264  def read_value(self):
265    with ops.device(self._device):
266      return self._var.read_value()
267
268  @property
269  def initial_value(self):
270    return self._var.initial_value(self._device)
271
272  def initialized_value(self):
273    with ops.device(self._device):
274      return self._var.initialized_value()
275
276  @property
277  def device(self):
278    return self._device
279
280  @property
281  def handle(self):
282    with ops.device(self._device):
283      return self._var.handle
284
285  def on_device_handle(self):
286    with ops.device(self._device):
287      return self._var.get_var_on_current_device().handle
288
289  @property
290  def op(self):
291    with ops.device(self._device):
292      return self._var.op
293
294  def assign_sub(self, delta, use_locking=None, name=None, read_value=True):
295    with ops.device(self._device):
296      return self._var.assign_sub(delta, use_locking, name, read_value)
297
298  def assign_add(self, delta, use_locking=None, name=None, read_value=True):
299    with ops.device(self._device):
300      return self._var.assign_add(delta, use_locking, name, read_value)
301
302  def assign(self, value, use_locking=None, name=None, read_value=True):
303    with ops.device(self._device):
304      return self._var.assign(value, use_locking, name, read_value)
305
306  def scatter_sub(self, sparse_delta, use_locking=False, name=None):
307    with ops.device(self._device):
308      return self._var.scatter_sub(sparse_delta, use_locking, name)
309
310  def scatter_add(self, sparse_delta, use_locking=False, name=None):
311    with ops.device(self._device):
312      return self._var.scatter_add(sparse_delta, use_locking, name)
313
314  def scatter_mul(self, sparse_delta, use_locking=False, name=None):
315    with ops.device(self._device):
316      return self._var.scatter_mul(sparse_delta, use_locking, name)
317
318  def scatter_div(self, sparse_delta, use_locking=False, name=None):
319    with ops.device(self._device):
320      return self._var.scatter_div(sparse_delta, use_locking, name)
321
322  def scatter_min(self, sparse_delta, use_locking=False, name=None):
323    with ops.device(self._device):
324      return self._var.scatter_min(sparse_delta, use_locking, name)
325
326  def scatter_max(self, sparse_delta, use_locking=False, name=None):
327    with ops.device(self._device):
328      return self._var.scatter_max(sparse_delta, use_locking, name)
329
330  def scatter_update(self, sparse_delta, use_locking=False, name=None):
331    with ops.device(self._device):
332      return self._var.scatter_update(sparse_delta, use_locking, name)
333
334  def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
335    with ops.device(self._device):
336      return self._var._dense_var_to_tensor(  # pylint: disable=protected-access
337          dtype=dtype,
338          name=name,
339          as_ref=as_ref)
340
341  def _as_graph_element(self):
342    return self._var._as_graph_element()  # pylint: disable=protected-access
343
344
345def _tensor_conversion_packed_var_and_device(var,
346                                             dtype=None,
347                                             name=None,
348                                             as_ref=False):
349  return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref)  # pylint: disable=protected-access
350
351
352ops.register_tensor_conversion_function(
353    PackedVarAndDevice, _tensor_conversion_packed_var_and_device)
354