1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Class to transform an subgraph into another.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from copy import deepcopy
23from functools import partial
24from six import iteritems
25from six import string_types
26from six import StringIO
27from tensorflow.contrib.graph_editor import reroute
28from tensorflow.contrib.graph_editor import select
29from tensorflow.contrib.graph_editor import subgraph
30from tensorflow.contrib.graph_editor import util
31from tensorflow.python.framework import ops as tf_ops
32from tensorflow.python.platform import tf_logging as logging
33
34
35__all__ = [
36    "replace_t_with_placeholder_handler",
37    "keep_t_if_possible_handler",
38    "assign_renamed_collections_handler",
39    "transform_op_if_inside_handler",
40    "copy_op_handler",
41    "Transformer",
42    "TransformerInfo",
43    "copy",
44    "copy_with_input_replacements",
45    "graph_replace",
46]
47
48
49def replace_t_with_placeholder_handler(info, t):
50  """Transform a tensor into a placeholder tensor.
51
52  This handler is typically used to transform a subgraph input tensor into a
53  placeholder.
54
55  Args:
56    info: Transform._TmpInfo instance.
57    t: tensor whose input must be transformed into a place holder.
58  Returns:
59    The tensor generated by the newly created place holder.
60  """
61  with info.graph_.as_default():
62    t_ = util.make_placeholder_from_tensor(t, scope=info.scope_)
63  return t_
64
65
66def keep_t_if_possible_handler(info, t):
67  """Transform a tensor into itself (identity) if possible.
68
69  This handler transform a tensor into itself if the source and destination
70  graph are the same. Otherwise it will create a placeholder.
71  This handler is typically used to transform a hidden input tensors.
72
73  Args:
74    info: Transform._TmpInfo instance.
75    t: tensor whose input must be transformed into a place holder.
76  Returns:
77    The tensor generated by the newly created place holder.
78  """
79  if info.graph is info.graph_:
80    return t
81  else:
82    return replace_t_with_placeholder_handler(info, t)
83
84
85def assign_renamed_collections_handler(info, elem, elem_):
86  """Add the transformed elem to the (renamed) collections of elem.
87
88  A collection is renamed only if is not a known key, as described in
89  `tf.GraphKeys`.
90
91  Args:
92    info: Transform._TmpInfo instance.
93    elem: the original element (`tf.Tensor` or `tf.Operation`)
94    elem_: the transformed element
95  """
96  known_collection_names = util.get_predefined_collection_names()
97  for name, collection in iteritems(info.collections):
98    if elem not in collection:
99      continue
100
101    if name in known_collection_names:
102      transformed_name = name
103    else:
104      transformed_name = info.new_name(name)
105    info.graph_.add_to_collection(transformed_name, elem_)
106
107
108def transform_op_if_inside_handler(info, op, keep_if_possible=True):
109  """Transform an optional op only if it is inside the subgraph.
110
111  This handler is typically use to handle original op: it is fine to keep them
112  if they are inside the subgraph, otherwise they are just ignored.
113
114  Args:
115    info: Transform._TmpInfo instance.
116    op: the optional op to transform (or ignore).
117    keep_if_possible: re-attach to the original op if possible, that is,
118      if the source graph and the destination graph are the same.
119  Returns:
120    The transformed op or None.
121  """
122  if op in info.sgv.ops:
123    return info.transformed_ops[op]
124  else:
125    if keep_if_possible and info.graph is info.graph_:
126      return op
127    else:
128      return None
129
130
131def copy_op_handler(info, op, new_inputs, copy_shape=False, nodedef_fn=None):
132  """Copy a `tf.Operation`.
133
134  Args:
135    info: Transform._TmpInfo instance.
136    op: the `tf.Operation` to be copied.
137    new_inputs: The new inputs for this op.
138    copy_shape: also copy the shape of the tensor
139    nodedef_fn: If provided, a function that will be run on the NodeDef
140      and should return a mutated NodeDef before a new Operation is created.
141      This is useful as certain features cannot be set on the Operation and
142      must be modified in NodeDef.
143
144  Returns:
145    A `(op, op_outputs)` tuple containing the transformed op and its outputs.
146  """
147  # The `new_inputs` was added to this function. For compatibility reason,
148  # let's raise an error if `new_inputs` is a boolean.
149  if isinstance(new_inputs, bool):
150    raise TypeError("the `new_inputs` argument must be an iterable.")
151
152  # pylint: disable=protected-access
153
154  # Clone the node def:
155  node_def_ = deepcopy(op.node_def)
156
157  # Transform name:
158  name_ = info.new_name(op.name)
159  name_ = info.graph_.unique_name(name_)
160  node_def_.name = name_
161
162  # Mutate NodeDef if requested:
163  if nodedef_fn is not None:
164    node_def_ = nodedef_fn(node_def_)
165
166  # Copy the other inputs needed for initialization
167  output_types_ = op._output_types[:]
168  input_types_ = op._input_types[:]
169
170  # Make a copy of the op_def too.
171  # Its unique to every _type_ of Operation.
172  op_def_ = deepcopy(op.op_def)
173
174  # Initialize a new Operation instance
175  op_ = tf_ops.Operation(node_def_, info.graph_, new_inputs, output_types_,
176                         [], input_types_, None, op_def_)
177
178  # copy the shape over
179  if copy_shape:
180    for t, t_ in zip(op.outputs, op_.outputs):
181      t_.set_shape(t.get_shape())
182
183  # Original op cannot be finalised here yet. Because some ops require this
184  # attribute to exist, we will create a dummy original_op first and then
185  # later finalise it with the actual original_op when all the ops have
186  # been copied.
187  # TODO(fkp): Stop worrying about _original_op and remove this code?
188  if op._original_op:
189    op_._original_op = op._original_op
190
191  return op_, op_.outputs
192
193
194class TransformerInfo(object):
195  """"Contains information about the result of a transform operation."""
196
197  def __init__(self, info):
198    """Constructor.
199
200    Args:
201      info: an instance of Transformer._TmpInfo containing various internal
202        information about the transform operation.
203    """
204    self._graph = info.graph
205    self._scope = info.scope
206    self._graph_ = info.graph_
207    self._scope_ = info.scope_
208    self._transformed_ops = info.transformed_ops
209    self._transformed_ts = info.transformed_ts
210
211  def _get_transformed_map(self, top):
212    """Return the correct container depending on the type of `top`."""
213    if isinstance(top, tf_ops.Operation):
214      return self._transformed_ops
215    elif isinstance(top, tf_ops.Tensor):
216      return self._transformed_ts
217    else:
218      raise TypeError(
219          "Expected a tf.Tensor or a tf.Operation, got a {}".format(
220              type(top)))
221
222  def _transformed_elem(self, original_top, missing_fn=None):
223    """Return the transformed op/tensor corresponding to the original one.
224
225    Args:
226      original_top: the original tensor/operation.
227      missing_fn: function handling the case where the counterpart
228        cannot be found. By default, None is returned.
229    Returns:
230      the transformed tensor/operation (or None if no match is found).
231    """
232    transformed_map = self._get_transformed_map(original_top)
233    if isinstance(original_top, string_types):
234      for original, transformed in iteritems(transformed_map):
235        if original.name == original_top:
236          return transformed
237      return None if missing_fn is None else missing_fn(original_top)
238    else:
239      if original_top not in transformed_map:
240        return None if missing_fn is None else missing_fn(original_top)
241      return transformed_map[original_top]
242
243  def _original_elem(self, transformed_top, missing_fn=None):
244    """Return the original op/tensor corresponding to the transformed one.
245
246    Args:
247      transformed_top: the transformed tensor/operation.
248      missing_fn: function handling the case where the counterpart
249        cannot be found. By default, None is returned.
250    Returns:
251      the original tensor/operation (or None if no match is found).
252    """
253    transformed_map = self._get_transformed_map(transformed_top)
254    if isinstance(transformed_top, string_types):
255      finder = lambda transformed: transformed.name == transformed_top
256    else:
257      finder = lambda transformed: transformed == transformed_top
258    for original, transformed in iteritems(transformed_map):
259      if finder(transformed):
260        return original
261    return None if missing_fn is None else missing_fn(transformed_top)
262
263  def transformed(self, original, missing_fn=None):
264    """Return the transformed op/tensor corresponding to the original one.
265
266    Note that the output of this function mimics the hierarchy
267    of its input argument `original`.
268    Given an iterable, it returns a list. Given an operation or a tensor,
269    it will return an operation or a tensor.
270
271    Args:
272      original: the original tensor/operation.
273      missing_fn: function handling the case where the counterpart
274        cannot be found. By default, None is returned.
275    Returns:
276      the transformed tensor/operation (or None if no match is found).
277    """
278    transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
279    return util.transform_tree(original, transformed_elem)
280
281  def original(self, transformed, missing_fn=None):
282    """Return the original op/tensor corresponding to the transformed one.
283
284    Note that the output of this function mimics the hierarchy
285    of its input argument `transformed`.
286    Given an iterable, it returns a list. Given an operation or a tensor,
287    it will return an operation or a tensor.
288
289    Args:
290      transformed: the transformed tensor/operation.
291      missing_fn: function handling the case where the counterpart
292        cannot be found. By default, None is returned.
293    Returns:
294      the original tensor/operation (or None if no match is found).
295    """
296    original_elem = partial(self._original_elem, missing_fn=missing_fn)
297    return util.transform_tree(transformed, original_elem)
298
299  def __str__(self):
300    res = StringIO()
301    print("Transform result info:", file=res)
302    if self._graph == self._graph_:
303      in_place_str = "" if self._scope_ else " IN-PLACE"
304      print("  Within graph[{}]{}".format(
305          id(self._graph), in_place_str), file=res)
306    else:
307      print("  graph[{}] => graph[{}]".format(
308          id(self._graph), id(self._graph_)), file=res)
309    if self._scope:
310      print("  Relative to source scope: {}".format(self._scope), file=res)
311    if self._scope_:
312      print("  Scope destination: {}".format(self._scope_), file=res)
313    print("Operations mapping:", file=res)
314    for op, op_ in iteritems(self._transformed_ops):
315      print("  {} => {}".format(op.name, op_.name), file=res)
316    return res.getvalue()
317
318
319class _TmpInfo(object):
320  """Transformer temporary data.
321
322  An instance of this class holds all the information relevant to a call
323  to a transformer instance (that is, a call to __call__). An instance
324  is created for the life-time of the __call__ function and is passed as
325  argument to the handlers.
326  """
327
328  def __init__(self, sgv, dst_graph, dst_scope, src_scope):
329    self.sgv = sgv
330    self.sgv_inputs_set = frozenset(sgv.inputs)
331    self.ops = frozenset(sgv.ops)
332    self.control_outputs = util.ControlOutputs(sgv.graph)
333    self.graph = sgv.graph
334    self.scope = src_scope
335    self.graph_ = dst_graph
336    self.scope_ = dst_scope
337    self.transformed_ops = {}
338    self.transformed_ts = {}
339    self.collections = dict((key, self.graph.get_collection(key))
340                            for key in self.graph.get_all_collection_keys())
341    self.cyclic_ops = []
342    self.transform_original_op_handler = transform_op_if_inside_handler
343    # The graph is transformed op by op, in the same order the original ops
344    # were created. However, this is sometimes not possible due to cycles
345    # (i.e. while loops). So when the transformer creates a new op whose
346    # inputs do not exist yet, temporary placeholders are created and stored
347    # in this `tmp_cyclic_ts` container. During a second pass,
348    # those temporary tensors are replaced by the proper transformed tensors
349    # (see the function `_finalize_cycles`).
350    self.tmp_cyclic_ts = []
351
352  def new_name(self, name):
353    """Compute a destination name from a source name.
354
355    Args:
356      name: the name to be "transformed".
357    Returns:
358      The transformed name.
359    Raises:
360      ValueError: if the source scope is used (that is, not an empty string)
361        and the source name does not belong to the source scope.
362    """
363    scope = self.scope
364    if not name.startswith(scope):
365      raise ValueError("{} does not belong to source scope: {}.".format(
366          name, scope))
367    rel_name = name[len(scope):]
368    name_ = self.scope_ + rel_name
369    return name_
370
371
372class Transformer(object):
373  """Transform a subgraph into another one.
374
375  By default, the constructor create a transform which copy a subgraph and
376  replaces inputs with placeholders. This behavior can be modified by changing
377  the handlers.
378  """
379
380  def __init__(self):
381    """Transformer constructor.
382
383    The following members can be modified:
384    transform_op_handler: handle the transformation of a `tf.Operation`.
385      This handler defaults to a simple copy.
386    assign_collections_handler: handle the assignment of collections.
387      This handler defaults to assigning new collections created under the
388      given name-scope.
389    transform_external_input_handler: handle the transform of the inputs to
390      the given subgraph. This handler defaults to creating placeholders
391      instead of the ops just before the input tensors of the subgraph.
392    transform_external_hidden_input_handler: handle the transform of the
393      hidden inputs of the subgraph, that is, the inputs which are not listed
394      in sgv.inputs. This handler defaults to a transform which keep the same
395      input if the source and destination graphs are the same, otherwise
396      use placeholders.
397    transform_original_op_handler: handle the transform of original_op. This
398      handler defaults to transforming original_op only if they are in the
399      subgraph, otherwise they are ignored.
400    """
401
402    # handlers
403    self.transform_op_handler = copy_op_handler
404    self.transform_control_input_handler = transform_op_if_inside_handler
405    self.assign_collections_handler = assign_renamed_collections_handler
406    self.transform_external_input_handler = replace_t_with_placeholder_handler
407    self.transform_external_hidden_input_handler = keep_t_if_possible_handler
408    self.transform_original_op_handler = transform_op_if_inside_handler
409
410  def __call__(self,
411               sgv,
412               dst_graph,
413               dst_scope,
414               src_scope="",
415               reuse_dst_scope=False):
416    """Execute the transformation.
417
418    Args:
419      sgv: the source subgraph-view.
420      dst_graph: the destination graph.
421      dst_scope: the destination scope.
422      src_scope: the source scope, which specify the path from which the
423        relative path of the transformed nodes are computed. For instance, if
424        src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a
425        relative path of x/y and will be transformed into b/x/y.
426      reuse_dst_scope: if True the dst_scope is re-used if it already exists.
427        Otherwise, the scope is given a unique name based on the one given
428        by appending an underscore followed by a digit (default).
429    Returns:
430      A tuple `(sgv, info)` where:
431        `sgv` is the transformed subgraph view;
432        `info` is an instance of TransformerInfo containing
433        information about the transform, including mapping between
434        original and transformed tensors and operations.
435    Raises:
436      ValueError: if the arguments are invalid.
437    """
438    sgv = subgraph.make_view(sgv)
439    if not isinstance(dst_graph, tf_ops.Graph):
440      raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
441
442    src_scope = util.scope_finalize(src_scope)
443    dst_scope = util.scope_finalize(dst_scope)
444
445    # Potentially create new scope if reuse_dst_scope is False
446    if dst_scope and not reuse_dst_scope:
447      dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))
448
449    # Create temporary info used during this transform call
450    info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
451
452    self._copy_ops(info)
453    self._finalize_cycles(info)
454    self._connect_control_inputs(info)
455
456    # Compute information about the transformation
457    res_info = TransformerInfo(info)
458    sgv_ = self._transform_sgv(info, sgv)
459    return sgv_, res_info
460
461  def _copy_ops(self, info):
462    """Copy ops without connecting them."""
463    sorted_ops = sorted(info.sgv.ops, key=lambda op: op._id)  # pylint: disable=protected-access
464    for op in sorted_ops:
465      new_inputs = [self._transformed_t(info, t, op) for t in op.inputs]
466      op_, op_outputs_ = self.transform_op_handler(info, op, new_inputs)
467      if op is op_:
468        raise ValueError("In-place transformation not allowed.")
469
470      # Process op.
471      info.transformed_ops[op] = op_
472      self.assign_collections_handler(info, op, op_)
473
474      # Process output tensors.
475      for op_output, op_output_ in zip(op.outputs, op_outputs_):
476        info.transformed_ts[op_output] = op_output_
477        self.assign_collections_handler(info, op_output, op_output_)
478
479  def _finalize_cycles(self, info):
480    """Reconnects the cyclic tensors."""
481    for t, tmp_t_, consumer_op in info.tmp_cyclic_ts:
482      if t not in info.transformed_ts:
483        raise ValueError("The tensor {} should be transformed by now.".format(
484            t.name))
485      if consumer_op not in info.transformed_ops:
486        raise ValueError("The op {} should be transformed by now.".format(
487            consumer_op.name))
488      t_ = info.transformed_ts[t]
489      consumer_op_ = info.transformed_ops[consumer_op]
490      t_index_ = list(consumer_op_.inputs).index(tmp_t_)
491      consumer_op_._update_input(t_index_, t_)  # pylint: disable=protected-access
492
493  def _connect_control_inputs(self, info):
494    """Connect the previously copied ops."""
495    for op in info.sgv.ops:
496      logging.debug("Connecting control inputs of op: %s", op.name)
497      op_ = info.transformed_ops[op]
498
499      # Finalize original op.
500      # TODO(fkp): Stop worrying about _original_op and remove this code?
501      # pylint: disable=protected-access
502      if op._original_op:
503        original_op = self.transform_original_op_handler(info, op._original_op)
504        if original_op is None:
505          logging.debug("Could not find original op for: %s", op_.name)
506        else:
507          op_._original_op = original_op
508      # pylint: enable=protected-access
509
510      # Finalize control inputs:
511      control_inputs_ = [self.transform_control_input_handler(info, ci)
512                         for ci in op.control_inputs]
513      control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
514      reroute.add_control_inputs(op_, control_inputs_)
515
516  def _transform_sgv(self, info, sgv):
517    """Transform a subgraph view.
518
519    For convenience, a transform operation returns a subgraph view of the
520    transformed graph.
521
522    Args:
523      info: Temporary information for this transorfm call.
524      sgv: the subgraph to be transformed.
525    Returns:
526      The transformed subgraph.
527    """
528    ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)]
529    sgv_ = subgraph.SubGraphView(ops_)
530    sgv_inputs_ = sgv_.inputs
531    sgv_outputs_ = sgv_.outputs
532
533    # re-order inputs
534    input_map_ = []
535    for input_t in sgv.inputs:
536      if input_t not in info.transformed_ts:
537        continue
538      input_t_ = info.transformed_ts[input_t]
539      if input_t_ not in sgv_inputs_:
540        continue
541      input_t_index_ = sgv_.input_index(input_t_)
542      input_map_.append(input_t_index_)
543
544    # re-order outputs
545    output_map_ = []
546    for output_t in sgv.outputs:
547      if output_t not in info.transformed_ts:
548        continue
549      output_t_ = info.transformed_ts[output_t]
550      if output_t_ not in sgv_outputs_:
551        continue
552      output_t_index_ = sgv_.output_index(output_t_)
553      output_map_.append(output_t_index_)
554
555    return sgv_.remap(input_map_, output_map_)
556
557  def _transformed_t(self, info, t, consumer_op):
558    """Return tre transformed tensor of `t`."""
559    if t in info.transformed_ts:
560      # If op is in the subgraph, just return its transformed counterpart.
561      return info.transformed_ts[t]
562
563    if t in info.sgv_inputs_set:
564      # `t` is an input of the subgraph.
565      return self.transform_external_input_handler(info, t)
566    elif t.op in info.ops:
567      # `t` is an internal tensor but is not transformed yet because it
568      # belongs to a graph cycle.
569      logging.debug("Cyclic tensor: t.name = %s", t.name)
570      # Try to find an existing tensor we can use for now,
571      # otherwise create one. We'll rewire this later.
572      if consumer_op.type == "Merge":
573        first_input = consumer_op.inputs[0]
574        tmp_t_ = self._transformed_t(info, first_input, consumer_op)
575      elif t.op.type == "Enter":
576        enter_input = t.op.inputs[0]
577        tmp_t_ = self._transformed_t(info, enter_input, consumer_op)
578      else:
579        with info.graph_.as_default():
580          tmp_t_ = util.make_placeholder_from_tensor(t, scope=info.scope_,
581                                                     prefix="geph_tmp")
582        logging.debug("Created temporary placeholder: %s.", tmp_t_.name)
583      # Register as temporary and return.
584      info.tmp_cyclic_ts.append((t, tmp_t_, consumer_op))
585      return tmp_t_
586    else:
587      # `t` is a hidden input of the subgraph.
588      return self.transform_external_hidden_input_handler(info, t)
589
590
591def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
592         reuse_dst_scope=False):
593  """Copy a subgraph.
594
595  Args:
596    sgv: the source subgraph-view. This argument is converted to a subgraph
597      using the same rules than the function subgraph.make_view.
598    dst_graph: the destination graph.
599    dst_scope: the destination scope.
600    src_scope: the source scope.
601    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
602      Otherwise, the scope is given a unique name based on the one given
603      by appending an underscore followed by a digit (default).
604  Returns:
605    A tuple `(sgv, info)` where:
606      `sgv` is the transformed subgraph view;
607      `info` is an instance of TransformerInfo containing
608      information about the transform, including mapping between
609      original and transformed tensors and operations.
610  Raises:
611    TypeError: if `dst_graph` is not a `tf.Graph`.
612    StandardError: if sgv cannot be converted to a SubGraphView using
613      the same rules than the function subgraph.make_view.
614  """
615  sgv = subgraph.make_view(sgv)
616  if dst_graph is None:
617    dst_graph = sgv.graph
618  if not isinstance(dst_graph, tf_ops.Graph):
619    raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
620
621  copier = Transformer()
622  return copier(
623      sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
624
625
626def copy_with_input_replacements(sgv, replacement_ts,
627                                 dst_graph=None, dst_scope="", src_scope="",
628                                 reuse_dst_scope=False):
629  """Copy a subgraph, replacing some of its inputs.
630
631  Note a replacement only happens if the tensor to be replaced
632  is an input of the given subgraph. The inputs of a subgraph can
633  be queried using sgv.inputs.
634
635  Args:
636    sgv: the source subgraph-view. This argument is converted to a subgraph
637      using the same rules as the function subgraph.make_view.
638    replacement_ts: dictionary mapping from original tensors to the
639      replaced one.
640    dst_graph: the destination graph.
641    dst_scope: the destination scope.
642    src_scope: the source scope.
643    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
644      Otherwise, the scope is given a unique name based on the one given
645      by appending an underscore followed by a digit (default).
646  Returns:
647    A tuple `(sgv, info)` where:
648      `sgv` is the transformed subgraph view;
649      `info` is an instance of TransformerInfo containing
650      information about the transform, including mapping between
651      original and transformed tensors and operations.
652  Raises:
653    TypeError: if dst_graph is not a tf.Graph.
654    StandardError: if sgv cannot be converted to a SubGraphView using
655      the same rules as the function subgraph.make_view.
656  """
657  sgv = subgraph.make_view(sgv)
658  if dst_graph is None:
659    dst_graph = sgv.graph
660  if not isinstance(dst_graph, tf_ops.Graph):
661    raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
662
663  copier = Transformer()
664  # Replace tensor if possible.
665  def replace_t_with_replacement_handler(info, t):
666    if t in replacement_ts:
667      return replacement_ts[t]
668    else:
669      return keep_t_if_possible_handler(info, t)
670  copier.transform_external_input_handler = replace_t_with_replacement_handler
671  return copier(
672      sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)
673
674
675def _add_control_flow_ops(ops, control_ios):
676  """Complete `ops` so that the transformed graph is valid.
677
678  Partially copying a graph can lead to a malformed graph. For instance,
679  copying half of a while construct is likely to result in an invalid graph.
680  This function attempts to add missing ops so that the transformation result
681  in a valid graph.
682
683  Args:
684    ops: list of ops (modifed in-place).
685    control_ios: object created by a call to `util.ControlOutputs`.
686  """
687  # Find while contexts.
688  control_flow_contexts = set()
689  for op in ops:
690    cfc = op._control_flow_context  # pylint: disable=protected-access
691    if cfc:
692      control_flow_contexts.add(cfc)
693  # Find new ops.
694  new_ops = []
695  for cfc in control_flow_contexts:
696    if cfc.IsWhileContext():
697      new_ops += select.get_walks_intersection_ops(
698          [enter_t.op for enter_t in cfc.loop_enters],
699          [exit_t.op for exit_t in cfc.loop_exits],
700          control_ios=control_ios)
701  # Add new ops.
702  new_ops_set = set(new_ops)
703  ops_set = frozenset(ops)
704  for op in new_ops_set:
705    if op not in ops_set:
706      ops.append(op)
707
708
709def graph_replace(target_ts, replacement_ts, dst_scope="",
710                  src_scope="", reuse_dst_scope=False):
711  """Create a new graph which compute the targets from the replaced Tensors.
712
713  Args:
714    target_ts: a single tf.Tensor or an iterable of tf.Tensor.
715    replacement_ts: dictionary mapping from original tensors to replaced tensors
716    dst_scope: the destination scope.
717    src_scope: the source scope.
718    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
719      Otherwise, the scope is given a unique name based on the one given
720      by appending an underscore followed by a digit (default).
721  Returns:
722    A single tf.Tensor or a list of target tf.Tensor, depending on
723    the type of the input argument `target_ts`.
724    The returned tensors are recomputed using the tensors from replacement_ts.
725  Raises:
726    ValueError: if the targets are not connected to replacement_ts.
727  """
728  # Identify operations in the graph that will change.
729  # Start forward walk at Tensors that will be replaced, and
730  # backward walk at the target output Tensors.
731  flatten_target_ts = util.flatten_tree(target_ts)
732  # Construct the forward control dependencies edges so that
733  # the get_walks_intersection_ops can also traverse the
734  # control dependencies.
735  graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
736  control_ios = util.ControlOutputs(graph)
737  ops = select.get_walks_intersection_ops(
738      list(replacement_ts), flatten_target_ts, control_ios=control_ios)
739  if not ops:
740    raise ValueError("Targets and replacements are not connected!")
741
742  # Complete ops to avoid malformed control flow.
743  # TODO(fkp): Consider moving this function deeper (in the transformer?).
744  _add_control_flow_ops(ops, control_ios)
745
746  # Create a copy of the relevant subgraph
747  unused_sgv_, info = copy_with_input_replacements(
748      ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
749  # Return the transformed targets but keep the original if the transformed
750  # counterpart cannot be found
751  missing_fn = lambda original_t: original_t
752  return info.transformed(target_ts, missing_fn)
753