1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SubGraphView: a subgraph view on an existing tf.Graph.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import copy
23
24import six
25from six import iteritems
26from six import StringIO
27
28from tensorflow.contrib.graph_editor import select
29from tensorflow.contrib.graph_editor import util
30from tensorflow.python.framework import ops as tf_ops
31
32__all__ = [
33    "SubGraphView",
34    "make_view",
35    "make_view_from_scope",
36]
37
38
39def _finalize_index(index_or_t, ts):
40  """Returns index as is or return index of tensor in `ts`."""
41  if isinstance(index_or_t, six.integer_types):
42    return index_or_t
43  else:
44    return ts.index(index_or_t)
45
46
47def _finalize_indices(list_of_index_or_t, ts):
48  """Returns index in `indices` as is or replace with tensor's index."""
49  return [_finalize_index(index_or_t, ts) for index_or_t in list_of_index_or_t]
50
51
52def _check_within_range(mapping, n, repetition):
53  """Check is the mapping is valid.
54
55  Args:
56    mapping: an iterable of integer.
57    n: define the input domain as [0, n-1]. Note that the mapping can be
58      under-complete, that is, it can only contain a subset of the integers on
59      [0, n-1].
60    repetition: if True repetition are allowed (the function is surjective)
61      otherwise repetition are not allowed (the function is injective).
62  Raises:
63    ValueError: if the mapping is out of range ot if repetition is False and
64      the mapping has some repetition.
65  """
66  for i in mapping:
67    if not 0 <= i < n:
68      raise ValueError("Out of [0, {}[ range: {}".format(n, i))
69  if not repetition and len(set(mapping)) != len(mapping):
70    raise ValueError("Found repetition in mapping: {}".format(mapping))
71
72
73class SubGraphView(object):
74  """A subgraph view on an existing `tf.Graph`.
75
76  An instance of this class is a subgraph view on an existing `tf.Graph`.
77  "subgraph" means that it can represent part of the whole `tf.Graph`.
78  "view" means that it only provides a passive observation and do not to act
79  on the `tf.Graph`. Note that in this documentation, the term "subgraph" is
80  often used as substitute to "subgraph view".
81
82  A subgraph contains:
83
84  * a list of input tensors, accessible via the `inputs` property.
85  * a list of output tensors, accessible via the `outputs` property.
86  * and the operations in between, accessible via the "ops" property.
87
88  An subgraph can be seen as a function F(i0, i1, ...) -> o0, o1, ... It is a
89  function which takes as input some input tensors and returns as output some
90  output tensors. The computation that the function performs is encoded in the
91  operations of the subgraph.
92
93  The tensors (input or output) can be of two kinds:
94
95  - connected: a connected tensor connects to at least one operation contained
96  in the subgraph. One example is a subgraph representing a single operation
97  and its inputs and outputs: all the input and output tensors of the op
98  are "connected".
99  - passthrough: a passthrough tensor does not connect to any operation
100  contained in the subgraph. One example is a subgraph representing a
101  single tensor: this tensor is passthrough. By default a passthrough tensor is
102  present both in the input and output tensors of the subgraph. It can however
103  be remapped to only appear as an input (or output) only.
104
105  The input and output tensors can be remapped. For instance, some input tensor
106  can be omitted. For instance, a subgraph representing an operation with two
107  inputs can be remapped to only take one input. Note that this does not change
108  at all the underlying `tf.Graph` (remember, it is a view). It means that
109  the other input is being ignored, or is being treated as "given".
110  The analogy with functions can be extended like this: F(x,y) is the original
111  function. Remapping the inputs from [x, y] to just [x] means that the subgraph
112  now represent the function F_y(x) (y is "given").
113
114  The output tensors can also be remapped. For instance, some output tensor can
115  be omitted. Other output tensor can be duplicated as well. As mentioned
116  before, this does not change at all the underlying `tf.Graph`.
117  The analogy with functions can be extended like this: F(...)->x,y is the
118  original function. Remapping the outputs from [x, y] to just [y,y] means that
119  the subgraph now represent the function M(F(...)) where M is the function
120  M(a,b)->b,b.
121
122  It is useful to describe three other kind of tensors:
123
124  * internal: an internal tensor is a tensor connecting operations contained
125    in the subgraph. One example in the subgraph representing the two
126    operations A and B connected sequentially: -> A -> B ->. The middle arrow
127    is an internal tensor.
128  * actual input: an input tensor of the subgraph, regardless of whether it is
129    listed in "inputs" or not (masked-out).
130  * actual output: an output tensor of the subgraph, regardless of whether it is
131    listed in "outputs" or not (masked-out).
132  * hidden input: an actual input which has been masked-out using an
133    input remapping. In other word, a hidden input is a non-internal tensor
134    not listed as a input tensor and one of whose consumers belongs to
135    the subgraph.
136  * hidden output: a actual output which has been masked-out using an output
137    remapping. In other word, a hidden output is a non-internal tensor
138    not listed as an output and one of whose generating operations belongs to
139    the subgraph.
140
141  Here are some useful guarantees about an instance of a SubGraphView:
142
143  * the input (or output) tensors are not internal.
144  * the input (or output) tensors are either "connected" or "passthrough".
145  * the passthrough tensors are not connected to any of the operation of
146  the subgraph.
147
148  Note that there is no guarantee that an operation in a subgraph contributes
149  at all to its inputs or outputs. For instance, remapping both the inputs and
150  outputs to empty lists will produce a subgraph which still contains all the
151  original operations. However, the remove_unused_ops function can be used to
152  make a new subgraph view whose operations are connected to at least one of
153  the input or output tensors.
154
155  An instance of this class is meant to be a lightweight object which is not
156  modified in-place by the user. Rather, the user can create new modified
157  instances of a given subgraph. In that sense, the class SubGraphView is meant
158  to be used like an immutable python object.
159
160  A common problem when using views is that they can get out-of-sync with the
161  data they observe (in this case, a `tf.Graph`). This is up to the user to
162  ensure that this doesn't happen. To keep on the safe side, it is recommended
163  that the life time of subgraph views are kept very short. One way to achieve
164  this is to use subgraphs within a "with make_sgv(...) as sgv:" Python context.
165
166  To alleviate the out-of-sync problem, some functions are granted the right to
167  modified subgraph in place. This is typically the case of graph manipulation
168  functions which, given some subgraphs as arguments, can modify the underlying
169  `tf.Graph`. Since this modification is likely to render the subgraph view
170  invalid, those functions can modify the argument in place to reflect the
171  change. For instance, calling the function swap_inputs(svg0, svg1) will modify
172  svg0 and svg1 in place to reflect the fact that their inputs have now being
173  swapped.
174  """
175
176  def __init__(self, inside_ops=(), passthrough_ts=()):
177    """Create a subgraph containing the given ops and the "passthrough" tensors.
178
179    Args:
180      inside_ops: an object convertible to a list of `tf.Operation`. This list
181        defines all the operations in the subgraph.
182      passthrough_ts: an object convertible to a list of `tf.Tensor`. This list
183        define all the "passthrough" tensors. A passthrough tensor is a tensor
184        which goes directly from the input of the subgraph to it output, without
185        any intermediate operations. All the non passthrough tensors are
186        silently ignored.
187    Raises:
188      TypeError: if inside_ops cannot be converted to a list of `tf.Operation`
189        or if `passthrough_ts` cannot be converted to a list of `tf.Tensor`.
190    """
191
192    inside_ops = util.make_list_of_op(inside_ops)
193    passthrough_ts = util.make_list_of_t(passthrough_ts)
194    ops_and_ts = inside_ops + passthrough_ts
195    if ops_and_ts:
196      self._graph = util.get_unique_graph(ops_and_ts)
197      self._ops = inside_ops
198
199      # Compute inside and outside tensor
200      inputs, outputs, insides = select.compute_boundary_ts(inside_ops)
201
202      # Compute passthrough tensors, silently ignoring the non-passthrough ones.
203      all_tensors = frozenset(inputs + outputs + list(insides))
204      self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors]
205
206      # Set inputs and outputs.
207      self._input_ts = inputs + self._passthrough_ts
208      self._output_ts = outputs + self._passthrough_ts
209    else:
210      self._graph = None
211      self._passthrough_ts = []
212      self._input_ts = []
213      self._output_ts = []
214      self._ops = []
215
216  def __copy__(self):
217    """Create a copy of this subgraph.
218
219    Note that this class is a "view", copying it only create another view and
220    does not copy the underlying part of the `tf.Graph`.
221
222    Returns:
223      A new identical instance of the original subgraph view.
224    """
225    cls = self.__class__
226    result = cls.__new__(cls)
227    for k, v in iteritems(self.__dict__):
228      if k == "_graph":
229        setattr(result, k, v)
230      else:
231        setattr(result, k, list(v))  # copy the list
232    return result
233
234  def _assign_from(self, other):
235    """Assign other to itself.
236
237    Args:
238      other: another subgraph-view.
239    Returns:
240      A new instance identical to the original one.
241    Raises:
242      TypeError: if other is not an SubGraphView.
243    """
244    if not isinstance(other, SubGraphView):
245      raise TypeError("Expected SubGraphView, got: {}".format(type(other)))
246    # pylint: disable=protected-access
247    self._graph = other._graph
248    self._ops = list(other._ops)
249    self._passthrough_ts = list(other._passthrough_ts)
250    self._input_ts = list(other._input_ts)
251    self._output_ts = list(other._output_ts)
252    # pylint: enable=protected-access
253
254  def copy(self):
255    """Return a copy of itself.
256
257    Note that this class is a "view", copying it only create another view and
258    does not copy the underlying part of the tf.Graph.
259
260    Returns:
261      A new instance identical to the original one.
262    """
263    return copy.copy(self)
264
265  def _remap_default(self, remove_input_map=True, remove_output_map=True):
266    """Remap in the place the inputs and/or outputs to the default mapping.
267
268    Args:
269      remove_input_map: if True the input map is reset to the default one.
270      remove_output_map: if True the output map is reset to the default one.
271    """
272    if not remove_input_map and not remove_output_map:
273      return
274
275    # Compute inside and outside tensor
276    inputs, outputs, _ = select.compute_boundary_ts(self._ops)
277    if remove_input_map:
278      self._input_ts = list(inputs) + self._passthrough_ts
279    if remove_output_map:
280      self._output_ts = list(outputs) + self._passthrough_ts
281
282  def remap_default(self, remove_input_map=True, remove_output_map=True):
283    """Remap the inputs and/or outputs to the default mapping.
284
285    Args:
286      remove_input_map: if True the input map is reset to the default one.
287      remove_output_map: if True the output map is reset to the default one.
288    Returns:
289      A new modified instance of the original subgraph view with its
290        input and/or output mapping reset to the default one.
291    """
292    res = self.copy()
293    res._remap_default(remove_input_map, remove_output_map)  # pylint: disable=protected-access
294    return res
295
296  def _remap_inputs(self, new_input_indices):
297    """Remap the inputs of the subgraph in-place."""
298    new_input_indices = _finalize_indices(new_input_indices, self._input_ts)
299    _check_within_range(
300        new_input_indices, len(self._input_ts), repetition=False)
301    self._input_ts = [self._input_ts[i] for i in new_input_indices]
302
303  def _remap_outputs(self, new_output_indices):
304    """Remap the outputs of the subgraph in-place."""
305    new_output_indices = _finalize_indices(new_output_indices, self._output_ts)
306    _check_within_range(
307        new_output_indices, len(self._output_ts), repetition=True)
308    self._output_ts = [self._output_ts[i] for i in new_output_indices]
309
310  def _remap_outputs_make_unique(self):
311    """Remap the outputs in place so that all the tensors appears only once."""
312    output_ts = list(self._output_ts)
313    self._output_ts = []
314    util.concatenate_unique(self._output_ts, output_ts)
315
316  def _remap_outputs_to_consumers(self):
317    """Remap the outputs in place to match the number of consumers."""
318    self._remap_outputs_make_unique()
319    output_ts = list(self._output_ts)
320    self._output_ts = []
321    for t in output_ts:
322      self._output_ts += [t] * len(t.consumers())
323
324  def remap_outputs_make_unique(self):
325    """Remap the outputs so that all the tensors appears only once."""
326    res = copy.copy(self)
327    res._remap_outputs_make_unique()  # pylint: disable=protected-access
328    return res
329
330  def remap_outputs_to_consumers(self):
331    """Remap the outputs to match the number of consumers."""
332    res = copy.copy(self)
333    res._remap_outputs_to_consumers()  # pylint: disable=protected-access
334    return res
335
336  def _remove_unused_ops(self, control_inputs=True):
337    """Remove unused ops in place.
338
339    Args:
340      control_inputs: if True, control inputs are used to detect used ops.
341    Returns:
342      A new subgraph view which only contains used operations.
343    """
344    ops = select.get_walks_union_ops(
345        self.connected_inputs,
346        self.connected_outputs,
347        within_ops=self._ops,
348        control_inputs=control_inputs)
349    self._ops = [op for op in self._ops if op in ops]
350
351  def remove_unused_ops(self, control_inputs=True):
352    """Remove unused ops.
353
354    Args:
355      control_inputs: if True, control inputs are used to detect used ops.
356    Returns:
357      A new subgraph view which only contains used operations.
358    """
359    res = copy.copy(self)
360    res._remove_unused_ops(control_inputs)  # pylint: disable=protected-access
361    return res
362
363  def remap_inputs(self, new_input_indices):
364    """Remap the inputs of the subgraph.
365
366    If the inputs of the original subgraph are [t0, t1, t2], remapping to [2,0]
367    will create a new instance whose inputs is [t2, t0].
368
369    Note that this is only modifying the view: the underlying `tf.Graph` is not
370    affected.
371
372    Args:
373      new_input_indices: an iterable of integers or tf.Tensors
374        representing a mapping between the old inputs and the new ones.
375        Integers must be positive and smaller than the number of old inputs.
376        tf.Tensors must belong to the old list of inputs.
377        This mapping can be under-complete and must be without repetitions.
378    Returns:
379      A new modified instance of the original subgraph view with remapped
380        inputs.
381    """
382    res = self.copy()
383    res._remap_inputs(new_input_indices)  # pylint: disable=protected-access
384    return res
385
386  def remap_outputs(self, new_output_indices):
387    """Remap the output of the subgraph.
388
389    If the output of the original subgraph are [t0, t1, t2], remapping to
390    [1,1,0] will create a new instance whose outputs is [t1, t1, t0].
391
392    Note that this is only modifying the view: the underlying tf.Graph is not
393    affected.
394
395    Args:
396      new_output_indices: an iterable of integers or tf.Tensors
397        representing a mapping between the old outputs and the new ones.
398        Integers must be positive and smaller than the number of old outputs.
399        tf.Tensors must belong to the old list of outputs.
400        This mapping can be under-complete and can have repetitions.
401    Returns:
402      A new modified instance of the original subgraph view with remapped
403        outputs.
404    """
405    res = copy.copy(self)
406    res._remap_outputs(new_output_indices)  # pylint: disable=protected-access
407    return res
408
409  def remap(self, new_input_indices=None, new_output_indices=None):
410    """Remap the inputs and outputs of the subgraph.
411
412    Note that this is only modifying the view: the underlying tf.Graph is not
413    affected.
414
415    Args:
416      new_input_indices: an iterable of integers or tf.Tensors
417        representing a mapping between the old inputs and the new ones.
418        Integers must be positive and smaller than the number of old inputs.
419        tf.Tensors must belong to the old list of inputs.
420        This mapping can be under-complete and must be without repetitions.
421      new_output_indices: an iterable of integers or tf.Tensors
422        representing a mapping between the old outputs and the new ones.
423        Integers must be positive and smaller than the number of old outputs.
424        tf.Tensors must belong to the old list of outputs.
425        This mapping can be under-complete and can have repetitions.
426    Returns:
427      A new modified instance of the original subgraph view with remapped
428        inputs and outputs.
429    """
430    res = copy.copy(self)
431    if new_input_indices is not None:
432      res._remap_inputs(new_input_indices)  # pylint: disable=protected-access
433    if new_output_indices is not None:
434      res._remap_outputs(new_output_indices)  # pylint: disable=protected-access
435    return res
436
437  def find_op_by_name(self, op_name):
438    """Return the op named op_name.
439
440    Args:
441      op_name: the name to search for
442    Returns:
443      The op named op_name.
444    Raises:
445      ValueError: if the op_name could not be found.
446      AssertionError: if the name was found multiple time.
447    """
448    res = [op for op in self._ops if op.name == op_name]
449    if not res:
450      raise ValueError("{} not in subgraph.".format(op_name))
451    if len(res) > 1:
452      raise AssertionError("More than 1 op named: {}!".format(op_name))
453    return res[0]
454
455  def __str__(self):
456    if not self:
457      return "SubGraphView: empty"
458
459    def op_name(op):
460      return op.name
461
462    def tensor_name(t):
463      if t in self._passthrough_ts:
464        return "{} *".format(t.name)
465      else:
466        return t.name
467
468    def print_list(name, iterable, get_name):
469      if iterable:
470        print("** {}[{}]:".format(name, len(iterable)), file=res)
471        print("\n".join(["  {}".format(get_name(elem)) for elem in iterable]),
472              file=res)
473      else:
474        print("** {}: empty".format(name), file=res)
475
476    res = StringIO()
477    print("SubGraphView (graphid={}):".format(id(self.graph)), file=res)
478    print_list("ops", self._ops, op_name)
479    print_list("inputs", self._input_ts, tensor_name)
480    print_list("outputs", self._output_ts, tensor_name)
481    return res.getvalue()
482
483  @property
484  def graph(self):
485    """The underlying `tf.Graph`."""
486    return self._graph
487
488  @property
489  def ops(self):
490    """The operations in this subgraph view."""
491    return self._ops
492
493  @property
494  def inputs(self):
495    """The input tensors of this subgraph view."""
496    return util.ListView(self._input_ts)
497
498  @property
499  def connected_inputs(self):
500    """The connected input tensors of this subgraph view."""
501    return [t for t in self._input_ts if t not in self._passthrough_ts]
502
503  @property
504  def outputs(self):
505    """The output tensors of this subgraph view."""
506    return util.ListView(self._output_ts)
507
508  @property
509  def connected_outputs(self):
510    """The connected output tensors of this subgraph view."""
511    return [t for t in self._output_ts if t not in self._passthrough_ts]
512
513  @property
514  def passthroughs(self):
515    """The passthrough tensors, going straight from input to output."""
516    return util.ListView(self._passthrough_ts)
517
518  def __bool__(self):
519    """Allows for implicit boolean conversion."""
520    return self._graph is not None
521
522  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
523  __nonzero__ = __bool__
524
525  def op(self, op_id):
526    """Get an op by its index."""
527    return self._ops[op_id]
528
529  def is_passthrough(self, t):
530    """Check whether a tensor is passthrough."""
531    return t in self._passthrough_ts
532
533  def __enter__(self):
534    """Allow Python context to minimize the life time of a subgraph view.
535
536    A subgraph view is meant to be a lightweight and transient object. A short
537    lifetime will alleviate the "out-of-sync" issue mentioned earlier. For that
538    reason, a SubGraphView instance can be used within a Python context. For
539    example:
540
541    from tensorflow.contrib import graph_editor as ge
542    with ge.make_sgv(...) as sgv:
543      print(sgv)
544
545    Returns:
546      Itself.
547    """
548    return self
549
550  def __exit__(self, exc_type, exc_value, traceback):
551    pass
552
553  def input_index(self, t):
554    """Find the input index corresponding to the given input tensor t.
555
556    Args:
557      t: the input tensor of this subgraph view.
558    Returns:
559      The index in the self.inputs list.
560    Raises:
561      Error: if t in not an input tensor.
562    """
563    try:
564      subgraph_id = self._input_ts.index(t)
565    except:
566      raise ValueError("Can't find {} in inputs of subgraph {}.".format(
567          t.name, self.name))
568    return subgraph_id
569
570  def output_index(self, t):
571    """Find the output index corresponding to given output tensor t.
572
573    Args:
574      t: the output tensor of this subgraph view.
575    Returns:
576      The index in the self.outputs list.
577    Raises:
578      Error: if t in not an output tensor.
579    """
580    try:
581      subgraph_id = self._output_ts.index(t)
582    except:
583      raise ValueError("Can't find {} in outputs of subgraph {}.".format(
584          t.name, self.name))
585    return subgraph_id
586
587  def consumers(self):
588    """Return a Python set of all the consumers of this subgraph view.
589
590    A consumer of a subgraph view is a tf.Operation which is a consumer
591    of one of the output tensors and is not in the subgraph.
592
593    Returns:
594      A list of `tf.Operation` which are the consumers of this subgraph view.
595    """
596    ops_set = frozenset(self._ops)
597    res = []
598    for output in self._output_ts:
599      consumers = [op for op in output.consumers() if op not in ops_set]
600      util.concatenate_unique(res, consumers)
601    return res
602
603
604def _check_graph(sgv, graph):
605  """Check if sgv belongs to the given graph.
606
607  Args:
608    sgv: a SubGraphView.
609    graph: a graph or None.
610  Returns:
611    The SubGraphView sgv.
612  Raises:
613    TypeError: if sgv is not a SubGraphView or if graph is not None and not
614      a tf.Graph.
615    ValueError: if the graph of sgv and the given graph are not None and
616      different.
617  """
618  if not isinstance(sgv, SubGraphView):
619    raise TypeError("Expected a SubGraphView, got: {}".format(type(graph)))
620  if graph is None or not sgv.graph:
621    return sgv
622  if not isinstance(graph, tf_ops.Graph):
623    raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
624  if sgv.graph is not graph:
625    raise ValueError("Graph mismatch.")
626  return sgv
627
628
629def make_view(*args, **kwargs):
630  """Create a SubGraphView from selected operations and passthrough tensors.
631
632  Args:
633    *args: list of 1) regular expressions (compiled or not) or 2) (array of)
634      `tf.Operation` 3) (array of) `tf.Tensor`. Those objects will be converted
635      into a list of operations and a list of candidate for passthrough tensors.
636    **kwargs: keyword graph is used 1) to check that the ops and ts are from
637      the correct graph 2) for regular expression query
638  Returns:
639    A subgraph view.
640  Raises:
641    TypeError: if the optional keyword argument graph is not a `tf.Graph`
642      or if an argument in args is not an (array of) `tf.Tensor`
643      or an (array of) `tf.Operation` or a string or a regular expression.
644    ValueError: if one of the keyword arguments is unexpected.
645  """
646  # get keywords arguments
647  graph = kwargs["graph"] if "graph" in kwargs else None
648
649  # already a view?
650  if len(args) == 1 and isinstance(args[0], SubGraphView):
651    return _check_graph(args[0], graph)
652
653  ops, ts = select.select_ops_and_ts(*args, **kwargs)
654  sgv = SubGraphView(ops, ts)
655  return _check_graph(sgv, graph)
656
657
658def make_view_from_scope(scope, graph):
659  """Make a subgraph from a name scope.
660
661  Args:
662    scope: the name of the scope.
663    graph: the `tf.Graph`.
664  Returns:
665    A subgraph view representing the given scope.
666  """
667  ops = select.get_name_scope_ops(graph, scope)
668  return SubGraphView(ops)
669