1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utility functions for the graph_editor.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import re
24from six import iteritems
25from tensorflow.python.framework import ops as tf_ops
26from tensorflow.python.ops import array_ops as tf_array_ops
27
28__all__ = [
29    "make_list_of_op",
30    "get_tensors",
31    "make_list_of_t",
32    "get_generating_ops",
33    "get_consuming_ops",
34    "ControlOutputs",
35    "placeholder_name",
36    "make_placeholder_from_tensor",
37    "make_placeholder_from_dtype_and_shape",
38]
39
40
41# The graph editor sometimes need to create placeholders, they are named
42# "geph_*". "geph" stands for Graph-Editor PlaceHolder.
43_DEFAULT_PLACEHOLDER_PREFIX = "geph"
44
45
46def concatenate_unique(la, lb):
47  """Add all the elements of `lb` to `la` if they are not there already.
48
49  The elements added to `la` maintain ordering with respect to `lb`.
50
51  Args:
52    la: List of Python objects.
53    lb: List of Python objects.
54  Returns:
55    `la`: The list `la` with missing elements from `lb`.
56  """
57  la_set = set(la)
58  for l in lb:
59    if l not in la_set:
60      la.append(l)
61      la_set.add(l)
62  return la
63
64
65# TODO(fkp): very generic code, it should be moved in a more generic place.
66class ListView(object):
67  """Immutable list wrapper.
68
69  This class is strongly inspired by the one in tf.Operation.
70  """
71
72  def __init__(self, list_):
73    if not isinstance(list_, list):
74      raise TypeError("Expected a list, got: {}.".format(type(list_)))
75    self._list = list_
76
77  def __iter__(self):
78    return iter(self._list)
79
80  def __len__(self):
81    return len(self._list)
82
83  def __bool__(self):
84    return bool(self._list)
85
86  # Python 3 wants __bool__, Python 2.7 wants __nonzero__
87  __nonzero__ = __bool__
88
89  def __getitem__(self, i):
90    return self._list[i]
91
92  def __add__(self, other):
93    if not isinstance(other, list):
94      other = list(other)
95    return list(self) + other
96
97
98# TODO(fkp): very generic code, it should be moved in a more generic place.
99def is_iterable(obj):
100  """Return true if the object is iterable."""
101  if isinstance(obj, tf_ops.Tensor):
102    return False
103  try:
104    _ = iter(obj)
105  except Exception:  # pylint: disable=broad-except
106    return False
107  return True
108
109
110def flatten_tree(tree, leaves=None):
111  """Flatten a tree into a list.
112
113  Args:
114    tree: iterable or not. If iterable, its elements (child) can also be
115      iterable or not.
116    leaves: list to which the tree leaves are appended (None by default).
117  Returns:
118    A list of all the leaves in the tree.
119  """
120  if leaves is None:
121    leaves = []
122  if isinstance(tree, dict):
123    for _, child in iteritems(tree):
124      flatten_tree(child, leaves)
125  elif is_iterable(tree):
126    for child in tree:
127      flatten_tree(child, leaves)
128  else:
129    leaves.append(tree)
130  return leaves
131
132
133def transform_tree(tree, fn, iterable_type=tuple):
134  """Transform all the nodes of a tree.
135
136  Args:
137    tree: iterable or not. If iterable, its elements (child) can also be
138      iterable or not.
139    fn: function to apply to each leaves.
140    iterable_type: type use to construct the resulting tree for unknown
141      iterable, typically `list` or `tuple`.
142  Returns:
143    A tree whose leaves has been transformed by `fn`.
144    The hierarchy of the output tree mimics the one of the input tree.
145  """
146  if is_iterable(tree):
147    if isinstance(tree, dict):
148      res = tree.__new__(type(tree))
149      res.__init__(
150          (k, transform_tree(child, fn)) for k, child in iteritems(tree))
151      return res
152    elif isinstance(tree, tuple):
153      # NamedTuple?
154      if hasattr(tree, "_asdict"):
155        res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn))
156      else:
157        res = tree.__new__(type(tree),
158                           (transform_tree(child, fn) for child in tree))
159      return res
160    elif isinstance(tree, collections.Sequence):
161      res = tree.__new__(type(tree))
162      res.__init__(transform_tree(child, fn) for child in tree)
163      return res
164    else:
165      return iterable_type(transform_tree(child, fn) for child in tree)
166  else:
167    return fn(tree)
168
169
170def check_graphs(*args):
171  """Check that all the element in args belong to the same graph.
172
173  Args:
174    *args: a list of object with a obj.graph property.
175  Raises:
176    ValueError: if all the elements do not belong to the same graph.
177  """
178  graph = None
179  for i, sgv in enumerate(args):
180    if graph is None and sgv.graph is not None:
181      graph = sgv.graph
182    elif sgv.graph is not None and sgv.graph is not graph:
183      raise ValueError("Argument[{}]: Wrong graph!".format(i))
184
185
186def get_unique_graph(tops, check_types=None, none_if_empty=False):
187  """Return the unique graph used by the all the elements in tops.
188
189  Args:
190    tops: list of elements to check (usually a list of tf.Operation and/or
191      tf.Tensor). Or a tf.Graph.
192    check_types: check that the element in tops are of given type(s). If None,
193      the types (tf.Operation, tf.Tensor) are used.
194    none_if_empty: don't raise an error if tops is an empty list, just return
195      None.
196  Returns:
197    The unique graph used by all the tops.
198  Raises:
199    TypeError: if tops is not a iterable of tf.Operation.
200    ValueError: if the graph is not unique.
201  """
202  if isinstance(tops, tf_ops.Graph):
203    return tops
204  if not is_iterable(tops):
205    raise TypeError("{} is not iterable".format(type(tops)))
206  if check_types is None:
207    check_types = (tf_ops.Operation, tf_ops.Tensor)
208  elif not is_iterable(check_types):
209    check_types = (check_types,)
210  g = None
211  for op in tops:
212    if not isinstance(op, check_types):
213      raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
214          t) for t in check_types]), type(op)))
215    if g is None:
216      g = op.graph
217    elif g is not op.graph:
218      raise ValueError("Operation {} does not belong to given graph".format(op))
219  if g is None and not none_if_empty:
220    raise ValueError("Can't find the unique graph of an empty list")
221  return g
222
223
224def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False):
225  """Convert ops to a list of `tf.Operation`.
226
227  Args:
228    ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
229      operation.
230    check_graph: if `True` check if all the operations belong to the same graph.
231    allow_graph: if `False` a `tf.Graph` cannot be converted.
232    ignore_ts: if True, silently ignore `tf.Tensor`.
233  Returns:
234    A newly created list of `tf.Operation`.
235  Raises:
236    TypeError: if ops cannot be converted to a list of `tf.Operation` or,
237     if `check_graph` is `True`, if all the ops do not belong to the
238     same graph.
239  """
240  if isinstance(ops, tf_ops.Graph):
241    if allow_graph:
242      return ops.get_operations()
243    else:
244      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
245  else:
246    if not is_iterable(ops):
247      ops = [ops]
248    if not ops:
249      return []
250    if check_graph:
251      check_types = None if ignore_ts else tf_ops.Operation
252      get_unique_graph(ops, check_types=check_types)
253    return [op for op in ops if isinstance(op, tf_ops.Operation)]
254
255
256# TODO(fkp): move this function in tf.Graph?
257def get_tensors(graph):
258  """get all the tensors which are input or output of an op in the graph.
259
260  Args:
261    graph: a `tf.Graph`.
262  Returns:
263    A list of `tf.Tensor`.
264  Raises:
265    TypeError: if graph is not a `tf.Graph`.
266  """
267  if not isinstance(graph, tf_ops.Graph):
268    raise TypeError("Expected a graph, got: {}".format(type(graph)))
269  ts = []
270  for op in graph.get_operations():
271    ts += op.outputs
272  return ts
273
274
275def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
276  """Convert ts to a list of `tf.Tensor`.
277
278  Args:
279    ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
280    check_graph: if `True` check if all the tensors belong to the same graph.
281    allow_graph: if `False` a `tf.Graph` cannot be converted.
282    ignore_ops: if `True`, silently ignore `tf.Operation`.
283  Returns:
284    A newly created list of `tf.Tensor`.
285  Raises:
286    TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
287     if `check_graph` is `True`, if all the ops do not belong to the same graph.
288  """
289  if isinstance(ts, tf_ops.Graph):
290    if allow_graph:
291      return get_tensors(ts)
292    else:
293      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
294  else:
295    if not is_iterable(ts):
296      ts = [ts]
297    if not ts:
298      return []
299    if check_graph:
300      check_types = None if ignore_ops else tf_ops.Tensor
301      get_unique_graph(ts, check_types=check_types)
302    return [t for t in ts if isinstance(t, tf_ops.Tensor)]
303
304
305def get_generating_ops(ts):
306  """Return all the generating ops of the tensors in `ts`.
307
308  Args:
309    ts: a list of `tf.Tensor`
310  Returns:
311    A list of all the generating `tf.Operation` of the tensors in `ts`.
312  Raises:
313    TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
314  """
315  ts = make_list_of_t(ts, allow_graph=False)
316  return [t.op for t in ts]
317
318
319def get_consuming_ops(ts):
320  """Return all the consuming ops of the tensors in ts.
321
322  Args:
323    ts: a list of `tf.Tensor`
324  Returns:
325    A list of all the consuming `tf.Operation` of the tensors in `ts`.
326  Raises:
327    TypeError: if ts cannot be converted to a list of `tf.Tensor`.
328  """
329  ts = make_list_of_t(ts, allow_graph=False)
330  ops = []
331  for t in ts:
332    for op in t.consumers():
333      if op not in ops:
334        ops.append(op)
335  return ops
336
337
338class ControlOutputs(object):
339  """The control outputs topology."""
340
341  def __init__(self, graph):
342    """Create a dictionary of control-output dependencies.
343
344    Args:
345      graph: a `tf.Graph`.
346    Returns:
347      A dictionary where a key is a `tf.Operation` instance and the
348         corresponding value is a list of all the ops which have the key
349         as one of their control-input dependencies.
350    Raises:
351      TypeError: graph is not a `tf.Graph`.
352    """
353    if not isinstance(graph, tf_ops.Graph):
354      raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
355    self._control_outputs = {}
356    self._graph = graph
357    self._version = None
358    self._build()
359
360  def update(self):
361    """Update the control outputs if the graph has changed."""
362    if self._version != self._graph.version:
363      self._build()
364    return self
365
366  def _build(self):
367    """Build the control outputs dictionary."""
368    self._control_outputs.clear()
369    ops = self._graph.get_operations()
370    for op in ops:
371      for control_input in op.control_inputs:
372        if control_input not in self._control_outputs:
373          self._control_outputs[control_input] = []
374        if op not in self._control_outputs[control_input]:
375          self._control_outputs[control_input].append(op)
376    self._version = self._graph.version
377
378  def get_all(self):
379    return self._control_outputs
380
381  def get(self, op):
382    """return the control outputs of op."""
383    if op in self._control_outputs:
384      return self._control_outputs[op]
385    else:
386      return ()
387
388  @property
389  def graph(self):
390    return self._graph
391
392
393def scope_finalize(scope):
394  if scope and scope[-1] != "/":
395    scope += "/"
396  return scope
397
398
399def scope_dirname(scope):
400  slash = scope.rfind("/")
401  if slash == -1:
402    return ""
403  return scope[:slash + 1]
404
405
406def scope_basename(scope):
407  slash = scope.rfind("/")
408  if slash == -1:
409    return scope
410  return scope[slash + 1:]
411
412
413def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX):
414  """Create placeholder name for the graph editor.
415
416  Args:
417    t: optional tensor on which the placeholder operation's name will be based
418      on
419    scope: absolute scope with which to prefix the placeholder's name. None
420      means that the scope of t is preserved. "" means the root scope.
421    prefix: placeholder name prefix.
422  Returns:
423    A new placeholder name prefixed by "geph". Note that "geph" stands for
424      Graph Editor PlaceHolder. This convention allows to quickly identify the
425      placeholder generated by the Graph Editor.
426  Raises:
427    TypeError: if t is not None or a tf.Tensor.
428  """
429  if scope is not None:
430    scope = scope_finalize(scope)
431  if t is not None:
432    if not isinstance(t, tf_ops.Tensor):
433      raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t)))
434    op_dirname = scope_dirname(t.op.name)
435    op_basename = scope_basename(t.op.name)
436    if scope is None:
437      scope = op_dirname
438
439    if op_basename.startswith("{}__".format(prefix)):
440      ph_name = op_basename
441    else:
442      ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index)
443
444    return scope + ph_name
445  else:
446    if scope is None:
447      scope = ""
448    return "{}{}".format(scope, prefix)
449
450
451def make_placeholder_from_tensor(t, scope=None,
452                                 prefix=_DEFAULT_PLACEHOLDER_PREFIX):
453  """Create a `tf.placeholder` for the Graph Editor.
454
455  Note that the correct graph scope must be set by the calling function.
456
457  Args:
458    t: a `tf.Tensor` whose name will be used to create the placeholder
459      (see function placeholder_name).
460    scope: absolute scope within which to create the placeholder. None
461      means that the scope of `t` is preserved. `""` means the root scope.
462    prefix: placeholder name prefix.
463  Returns:
464    A newly created `tf.placeholder`.
465  Raises:
466    TypeError: if `t` is not `None` or a `tf.Tensor`.
467  """
468  return tf_array_ops.placeholder(
469      dtype=t.dtype, shape=t.get_shape(),
470      name=placeholder_name(t, scope=scope, prefix=prefix))
471
472
473def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None,
474                                          prefix=_DEFAULT_PLACEHOLDER_PREFIX):
475  """Create a tf.placeholder for the Graph Editor.
476
477  Note that the correct graph scope must be set by the calling function.
478  The placeholder is named using the function placeholder_name (with no
479  tensor argument).
480
481  Args:
482    dtype: the tensor type.
483    shape: the tensor shape (optional).
484    scope: absolute scope within which to create the placeholder. None
485      means that the scope of t is preserved. "" means the root scope.
486    prefix: placeholder name prefix.
487  Returns:
488    A newly created tf.placeholder.
489  """
490  return tf_array_ops.placeholder(
491      dtype=dtype, shape=shape,
492      name=placeholder_name(scope=scope, prefix=prefix))
493
494
495_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
496
497
498def get_predefined_collection_names():
499  """Return all the predefined collection names."""
500  return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
501          if not _INTERNAL_VARIABLE_RE.match(key)]
502
503
504def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
505  """Find corresponding op/tensor in a different graph.
506
507  Args:
508    target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
509    dst_graph: The graph in which the corresponding graph element must be found.
510    dst_scope: A scope which is prepended to the name to look for.
511    src_scope: A scope which is removed from the original of `target` name.
512
513  Returns:
514    The corresponding tf.Tensor` or a `tf.Operation`.
515
516  Raises:
517    ValueError: if `src_name` does not start with `src_scope`.
518    TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
519    KeyError: If the corresponding graph element cannot be found.
520  """
521  src_name = target.name
522  if src_scope:
523    src_scope = scope_finalize(src_scope)
524    if not src_name.startswidth(src_scope):
525      raise ValueError("{} does not start with {}".format(src_name, src_scope))
526    src_name = src_name[len(src_scope):]
527
528  dst_name = src_name
529  if dst_scope:
530    dst_scope = scope_finalize(dst_scope)
531    dst_name = dst_scope + dst_name
532
533  if isinstance(target, tf_ops.Tensor):
534    return dst_graph.get_tensor_by_name(dst_name)
535  if isinstance(target, tf_ops.Operation):
536    return dst_graph.get_operation_by_name(dst_name)
537  raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
538
539
540def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
541  """Find corresponding ops/tensors in a different graph.
542
543  `targets` is a Python tree, that is, a nested structure of iterable
544  (list, tupple, dictionary) whose leaves are instances of
545  `tf.Tensor` or `tf.Operation`
546
547  Args:
548    targets: A Python tree containing `tf.Tensor` or `tf.Operation`
549      belonging to the original graph.
550    dst_graph: The graph in which the corresponding graph element must be found.
551    dst_scope: A scope which is prepended to the name to look for.
552    src_scope: A scope which is removed from the original of `top` name.
553
554  Returns:
555    A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
556
557  Raises:
558    ValueError: if `src_name` does not start with `src_scope`.
559    TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
560    KeyError: If the corresponding graph element cannot be found.
561  """
562  def func(top):
563    return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
564  return transform_tree(targets, func)
565