1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various function for graph rerouting."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.contrib.graph_editor import subgraph as _subgraph
22from tensorflow.contrib.graph_editor import util as _util
23from tensorflow.python.framework import ops as _tf_ops
24
25from tensorflow.python.util.all_util import remove_undocumented
26
27_allowed_symbols = [
28    "swap_ts",
29    "reroute_ts",
30    "swap_inputs",
31    "reroute_inputs",
32    "swap_outputs",
33    "reroute_outputs",
34    "swap_ios",
35    "reroute_ios",
36    "remove_control_inputs",
37    "add_control_inputs",
38]
39
40
41def _check_ts_compatibility(ts0, ts1):
42  """Make sure the shape and dtype of the two tensor's lists are compatible.
43
44  Args:
45    ts0: an object convertible to a list of `tf.Tensor`.
46    ts1: an object convertible to a list of `tf.Tensor`.
47  Raises:
48    ValueError: if any pair of tensors (same index in ts0 and ts1) have
49      a dtype or a shape which is not compatible.
50  """
51  ts0 = _util.make_list_of_t(ts0)
52  ts1 = _util.make_list_of_t(ts1)
53  if len(ts0) != len(ts1):
54    raise ValueError("ts0 and ts1 have different sizes: {} != {}".format(
55        len(ts0), len(ts1)))
56  for t0, t1 in zip(ts0, ts1):
57    # check dtype
58    dtype0, dtype1 = t0.dtype, t1.dtype
59    if not dtype0.is_compatible_with(dtype1):
60      raise ValueError("Dtypes {} and {} are not compatible.".format(dtype0,
61                                                                     dtype1))
62    # check shape
63    shape0, shape1 = t0.get_shape(), t1.get_shape()
64    if not shape0.is_compatible_with(shape1):
65      raise ValueError("Shapes {} and {} are not compatible.".format(shape0,
66                                                                     shape1))
67
68
69class _RerouteMode(object):
70  """Enums for reroute's mode.
71
72  swap: the end of tensors a and b are swapped.
73  a2b:  the end of the tensor a are also rerouted to the end of the tensor b
74    (the end of b is left dangling).
75  b2a:  the end of the tensor b are also rerouted to the end of the tensor a
76    (the end of a is left dangling).
77  """
78  swap, a2b, b2a = range(3)
79
80  @classmethod
81  def check(cls, mode):
82    """Check swap mode.
83
84    Args:
85      mode: an integer representing one of the modes.
86    Returns:
87      A tuple `(a2b, b2a)` boolean indicating what rerouting needs doing.
88    Raises:
89      ValueError: if mode is outside the enum range.
90    """
91    if mode == cls.swap:
92      return True, True
93    elif mode == cls.b2a:
94      return False, True
95    elif mode == cls.a2b:
96      return True, False
97    else:
98      raise ValueError("Unknown _RerouteMode: {}".format(mode))
99
100
101def _reroute_t(t0, t1, consumers1, can_modify=None, cannot_modify=None):
102  """Reroute the end of the tensors (t0,t1).
103
104  Warning: this function is directly manipulating the internals of the
105  `tf.Graph`.
106
107  Args:
108    t0: a tf.Tensor.
109    t1: a tf.Tensor.
110    consumers1: The consumers of t1 which needs to be rerouted.
111    can_modify: iterable of operations which can be modified. Any operation
112      outside within_ops will be left untouched by this function.
113    cannot_modify: iterable of operations which cannot be modified.
114      Any operation within cannot_modify will be left untouched by this
115      function.
116  Returns:
117    The number of individual modifications made by the function.
118  """
119  nb_update_inputs = 0
120  if can_modify is not None:
121    consumers1 &= can_modify
122  if cannot_modify is not None:
123    consumers1 -= cannot_modify
124  consumers1_indices = {}
125  for consumer1 in consumers1:
126    consumers1_indices[consumer1] = [i for i, t in enumerate(consumer1.inputs)
127                                     if t is t1]
128  for consumer1 in consumers1:
129    for i in consumers1_indices[consumer1]:
130      consumer1._update_input(i, t0)  # pylint: disable=protected-access
131      nb_update_inputs += 1
132  return nb_update_inputs
133
134
135def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None):
136  """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1.
137
138  This function is the back-bone of the Graph-Editor. It is essentially a thin
139  wrapper on top of the tf.Operation._update_input.
140
141  Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end
142  of t0 and t1 in three possible ways:
143  1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After
144  this operation, the previous consumers of t0 are now consumers of t1 and
145  vice-versa.
146  2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the
147  tensors's end of t1 (which are left dangling). After this operation, the
148  previous consumers of t0 are still consuming t0 but the previous consumers of
149  t1 are not also consuming t0. The tensor t1 has no consumer.
150  3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode.
151
152  Note that this function is re-routing the end of two tensors, not the start.
153  Re-routing the start of two tensors is not supported by this library. The
154  reason for that is the following: TensorFlow, by design, creates a strong bond
155  between an op and its output tensor. This Graph editor follows this design and
156  treats an operation A and its generating tensors {t_i} as an entity which
157  cannot be broken. In other words, an op cannot be detached from any of its
158  output tensors, ever. But it is possible to detach an op from its input
159  tensors, which is what this function concerns itself with.
160
161  Warning: this function is directly manipulating the internals of the tf.Graph.
162
163  Args:
164    ts0: an object convertible to a list of `tf.Tensor`.
165    ts1: an object convertible to a list of `tf.Tensor`.
166    mode: what to do with those tensors: "a->b" or "b<->a" for swaping and
167      "a->b" or "b->a" for one direction re-routing.
168    can_modify: iterable of operations which can be modified. Any operation
169      outside within_ops will be left untouched by this function.
170    cannot_modify: iterable of operations which cannot be modified.
171      Any operation within cannot_modify will be left untouched by this
172      function.
173  Returns:
174    The number of individual modifications made by the function.
175  Raises:
176    TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`.
177    TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be
178      converted to a list of `tf.Operation`.
179  """
180  a2b, b2a = _RerouteMode.check(mode)
181  ts0 = _util.make_list_of_t(ts0)
182  ts1 = _util.make_list_of_t(ts1)
183  _check_ts_compatibility(ts0, ts1)
184  if cannot_modify is not None:
185    cannot_modify = frozenset(_util.make_list_of_op(cannot_modify))
186  if can_modify is not None:
187    can_modify = frozenset(_util.make_list_of_op(can_modify))
188  nb_update_inputs = 0
189  precomputed_consumers = []
190  # precompute consumers to avoid issue with repeated tensors:
191  for t0, t1 in zip(ts0, ts1):
192    consumers0 = set(t0.consumers())
193    consumers1 = set(t1.consumers())
194    precomputed_consumers.append((consumers0, consumers1))
195  for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers):
196    if t0 is t1:
197      continue  # Silently ignore identical tensors.
198    consumers0, consumers1 = consumers
199    if a2b:
200      nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify,
201                                     cannot_modify)
202    if b2a:
203      nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify,
204                                     cannot_modify)
205  return nb_update_inputs
206
207
208def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None):
209  """For each tensor's pair, swap the end of (t0,t1).
210
211      B0 B1     B0 B1
212      |  |    =>  X
213      A0 A1     A0 A1
214
215  Args:
216    ts0: an object convertible to a list of `tf.Tensor`.
217    ts1: an object convertible to a list of `tf.Tensor`.
218    can_modify: iterable of operations which can be modified. Any operation
219      outside within_ops will be left untouched by this function.
220    cannot_modify: iterable of operations which cannot be modified.
221      Any operation within cannot_modify will be left untouched by this
222      function.
223  Returns:
224    The number of individual modifications made by the function.
225  Raises:
226    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
227    TypeError: if can_modify or cannot_modify is not None and cannot be
228      converted to a list of tf.Operation.
229  """
230  return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify)
231
232
233def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None):
234  """For each tensor's pair, replace the end of t1 by the end of t0.
235
236      B0 B1     B0 B1
237      |  |    => |/
238      A0 A1     A0 A1
239
240  The end of the tensors in ts1 are left dangling.
241
242  Args:
243    ts0: an object convertible to a list of `tf.Tensor`.
244    ts1: an object convertible to a list of `tf.Tensor`.
245    can_modify: iterable of operations which can be modified. Any operation
246      outside within_ops will be left untouched by this function.
247    cannot_modify: iterable of operations which cannot be modified. Any
248      operation within cannot_modify will be left untouched by this function.
249  Returns:
250    The number of individual modifications made by the function.
251  Raises:
252    TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor.
253    TypeError: if can_modify or cannot_modify is not None and cannot be
254      converted to a list of tf.Operation.
255  """
256  return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify)
257
258
259def _reroute_sgv_remap(sgv0, sgv1, mode):
260  """Remap in place the inputs of two subgraph views to mimic the reroute.
261
262  This function is meant to used by reroute_inputs only.
263
264  Args:
265    sgv0: the first subgraph to have its inputs remapped.
266    sgv1: the second subgraph to have its inputs remapped.
267    mode: reroute mode, see _reroute_ts(...).
268  Raises:
269    TypeError: if svg0 or svg1 are not SubGraphView.
270    ValueError: if sgv0 and sgv1 do not belong to the same graph.
271  """
272  a2b, b2a = _RerouteMode.check(mode)
273  if not isinstance(sgv0, _subgraph.SubGraphView):
274    raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0)))
275  if not isinstance(sgv1, _subgraph.SubGraphView):
276    raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1)))
277  _util.check_graphs(sgv0, sgv1)
278  sgv0_ = sgv0.copy()
279  sgv1_ = sgv1.copy()
280  # pylint: disable=protected-access
281  if a2b and b2a:
282    (sgv0_._input_ts, sgv1_._input_ts) = (sgv1_._input_ts, sgv0_._input_ts)
283    (sgv0_._passthrough_ts, sgv1_._passthrough_ts) = (sgv1_._passthrough_ts,
284                                                      sgv0_._passthrough_ts)
285  elif a2b:
286    sgv1_._input_ts = sgv0_._input_ts[:]
287    sgv1_._passthrough_ts = sgv0_._passthrough_ts[:]
288  elif b2a:
289    sgv0_._input_ts = sgv1_._input_ts[:]
290    sgv0_._passthrough_ts = sgv1_._passthrough_ts[:]
291  # pylint: enable=protected-access
292
293  # Update the passthrough outputs as well.
294  def update_passthrough_outputs(a, b):
295    # pylint: disable=protected-access
296    for i, t in enumerate(b._output_ts):
297      if t in a._passthrough_ts:
298        ii = a._input_ts.index(t)
299        b._output_ts[i] = b._input_ts[ii]
300    # pylint: enable=protected-access
301
302  if a2b:
303    update_passthrough_outputs(sgv0_, sgv1_)
304  if b2a:
305    update_passthrough_outputs(sgv1_, sgv0_)
306
307  # in-place
308  # pylint: disable=protected-access
309  sgv0._assign_from(sgv0_)
310  sgv1._assign_from(sgv1_)
311  # pylint: enable=protected-access
312
313
314def _reroute_sgv_inputs(sgv0, sgv1, mode):
315  """Re-route all the inputs of two subgraphs.
316
317  Args:
318    sgv0: the first subgraph to have its inputs swapped. This argument is
319      converted to a subgraph using the same rules than the function
320      subgraph.make_view.
321    sgv1: the second subgraph to have its inputs swapped. This argument is
322      converted to a subgraph using the same rules than the function
323      subgraph.make_view.
324    mode: reroute mode, see _reroute_ts(...).
325  Returns:
326    A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
327      Note that the function argument sgv0 and sgv1 are also modified in place.
328  Raises:
329    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
330      the same rules than the function subgraph.make_view.
331  """
332  sgv0 = _subgraph.make_view(sgv0)
333  sgv1 = _subgraph.make_view(sgv1)
334  _util.check_graphs(sgv0, sgv1)
335  can_modify = sgv0.ops + sgv1.ops
336  # also allow consumers of passthrough to be modified:
337  can_modify += _util.get_consuming_ops(sgv0.passthroughs)
338  can_modify += _util.get_consuming_ops(sgv1.passthroughs)
339  _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify)
340  _reroute_sgv_remap(sgv0, sgv1, mode)
341  return sgv0, sgv1
342
343
344def _reroute_sgv_outputs(sgv0, sgv1, mode):
345  """Re-route all the outputs of two operations.
346
347  Args:
348    sgv0: the first subgraph to have its outputs swapped. This argument is
349      converted to a subgraph using the same rules than the function
350      subgraph.make_view.
351    sgv1: the second subgraph to have its outputs swapped. This argument is
352      converted to a subgraph using the same rules than the function
353      subgraph.make_view.
354    mode: reroute mode, see _reroute_ts(...).
355  Returns:
356    A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped.
357      Note that the function argument sgv0 and sgv1 are also modified in place.
358  Raises:
359    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
360      the same rules than the function subgraph.make_view.
361  """
362  sgv0 = _subgraph.make_view(sgv0)
363  sgv1 = _subgraph.make_view(sgv1)
364  _util.check_graphs(sgv0, sgv1)
365  cannot_modify = sgv0.ops + sgv1.ops
366  _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify)
367  return sgv0, sgv1
368
369
370def _reroute_sgv(sgv0, sgv1, mode):
371  """Re-route both the inputs and the outputs of the two subgraph views.
372
373  This involves swapping all the inputs/outputs of the two subgraph views.
374
375  Args:
376    sgv0: the first subgraph to be swapped. This argument is converted to a
377      subgraph using the same rules than the function subgraph.make_view.
378    sgv1: the second subgraph to be swapped. This argument is converted to a
379      subgraph using the same rules than the function subgraph.make_view.
380    mode: reroute mode, see _reroute_ts(...).
381  Returns:
382    A tuple `(sgv0, sgv1)` of subgraph views with their outputs and inputs
383      swapped.
384      Note that the function argument sgv0 and sgv1 are also modified in place.
385  Raises:
386    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
387      the same rules than the function subgraph.make_view.
388  """
389  _reroute_sgv_outputs(sgv0, sgv1, mode)
390  _reroute_sgv_inputs(sgv0, sgv1, mode)
391  return sgv0, sgv1
392
393
394def swap_inputs(sgv0, sgv1):
395  """Swap all the inputs of sgv0 and sgv1 (see reroute_inputs)."""
396  return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap)
397
398
399def reroute_inputs(sgv0, sgv1):
400  """Re-route all the inputs of two subgraphs.
401
402  Args:
403    sgv0: the first subgraph to have its inputs swapped. This argument is
404      converted to a subgraph using the same rules than the function
405      subgraph.make_view.
406    sgv1: the second subgraph to have its inputs swapped. This argument is
407      converted to a subgraph using the same rules than the function
408      subgraph.make_view.
409  Returns:
410    A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped.
411      Note that the function argument sgv0 and sgv1 are also modified in place.
412  Raises:
413    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
414      the same rules than the function subgraph.make_view.
415  """
416  return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b)
417
418
419def swap_outputs(sgv0, sgv1):
420  """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs)."""
421  return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap)
422
423
424def reroute_outputs(sgv0, sgv1):
425  """Re-route all the outputs of two operations.
426
427  Args:
428    sgv0: the first subgraph to have its outputs swapped. This argument is
429      converted to a subgraph using the same rules than the function
430      subgraph.make_view.
431    sgv1: the second subgraph to have its outputs swapped. This argument is
432      converted to a subgraph using the same rules than the function
433      subgraph.make_view.
434  Returns:
435    A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped.
436      Note that the function argument sgv0 and sgv1 are also modified in place.
437  Raises:
438    StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
439      the same rules than the function subgraph.make_view.
440  """
441  return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b)
442
443
444def swap_ios(sgv0, sgv1):
445  """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv)."""
446  return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap)
447
448
449def reroute_ios(sgv0, sgv1):
450  """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv)."""
451  return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b)
452
453
454def remove_control_inputs(op, cops):
455  """Remove the control inputs cops from co.
456
457  Warning: this function is directly manipulating the internals of the
458  `tf.Graph`.
459
460  Args:
461    op: a `tf.Operation` from which to remove the control inputs.
462    cops: an object convertible to a list of `tf.Operation`.
463  Raises:
464    TypeError: if op is not a `tf.Operation`.
465    ValueError: if any cop in cops is not a control input of op.
466  """
467  if not isinstance(op, _tf_ops.Operation):
468    raise TypeError("Expected a tf.Operation, got: {}", type(op))
469  cops = _util.make_list_of_op(cops, allow_graph=False)
470  for cop in cops:
471    if cop not in op.control_inputs:
472      raise ValueError("{} is not a control_input of {}".format(op.name,
473                                                                cop.name))
474  control_inputs = [cop for cop in op.control_inputs if cop not in cops]
475  # pylint: disable=protected-access
476  op._remove_all_control_inputs()
477  op._add_control_inputs(control_inputs)
478  # pylint: enable=protected-access
479
480
481def add_control_inputs(op, cops):
482  """Add the control inputs cops to op.
483
484  Warning: this function is directly manipulating the internals of the tf.Graph.
485
486  Args:
487    op: a tf.Operation to which the control inputs are added.
488    cops: an object convertible to a list of `tf.Operation`.
489  Raises:
490    TypeError: if op is not a tf.Operation
491    ValueError: if any cop in cops is already a control input of op.
492  """
493  if not isinstance(op, _tf_ops.Operation):
494    raise TypeError("Expected a tf.Operation, got: {}", type(op))
495  cops = _util.make_list_of_op(cops, allow_graph=False)
496  for cop in cops:
497    if cop in op.control_inputs:
498      raise ValueError("{} is already a control_input of {}".format(cop.name,
499                                                                    op.name))
500  op._add_control_inputs(cops)  # pylint: disable=protected-access
501
502remove_undocumented(__name__, _allowed_symbols)
503