1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Utility to lift subgraphs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23
24from tensorflow.python.framework import func_graph
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import op_selector
28from tensorflow.python.ops import resource_variable_ops
29from tensorflow.python.util import compat
30from tensorflow.python.util import object_identity
31from tensorflow.python.util.tf_export import tf_export
32
33
34UnliftableError = op_selector.UnliftableError
35
36
37def _as_operation(op_or_tensor):
38  if isinstance(op_or_tensor, ops.Tensor):
39    return op_or_tensor.op
40  return op_or_tensor
41
42
43def _constant_inputs(op_or_tensor):
44  return all(_as_operation(i).type == u"Const"
45             and not _as_operation(i).control_inputs
46             for i in op_selector.graph_inputs(_as_operation(op_or_tensor)))
47
48
49# Represents an input to `copied_op` which must be updated once
50# `old_graph_tensor` has been copied.
51_InputMutation = collections.namedtuple(
52    "_InputMutation",
53    ["copied_op", "input_index", "old_graph_tensor"])
54
55
56# Represents a control input to `copied_op` which must be added once
57# `old_graph_op` has been copied.
58_ControlMutation = collections.namedtuple(
59    "_ControlMutation",
60    ["copied_op", "old_graph_op"])
61
62
63def _copy_non_source(op, graph, op_map, base_graph):
64  """Copy an op directly to a given graph.
65
66  Generally `op`'s inputs should already have been copied. If this is not the
67  case, for example with v1 while_loops, then `_copy_non_source` inserts
68  placeholders for the unavailable Tensors and returns a list of required
69  mutations.
70
71  Args:
72    op: The op to be copied.
73    graph: The destination graph.
74    op_map: A dict mapping ops and tensors in the old graph to the new one.
75    base_graph: The graph we're copying from, for any necessary functions.
76  Returns:
77    A tuple of (required_inputs, required_control_inputs):
78      required_inputs:
79        A list of `_InputMutation` tuples containing inputs to `copied_op` which
80        must be updated once `old_graph_tensor` has been copied.
81      required_control_inputs:
82        A list of `_ControlMutation` tuples containing control inputs to
83        `copied_op` which must be added once `old_graph_op` has been copied.
84  """
85  input_mutations = []
86  control_mutations = []
87  copied_inputs = []
88  for input_index, original_input in enumerate(op.inputs):
89    copied_input = op_map.get(original_input, None)
90    if copied_input is None:
91      # An input for this op is missing due to a loop in the graph. We'll insert
92      # a placeholder for now and return information about the required post-hoc
93      # mutation.
94      copied_input = array_ops.placeholder(
95          name="unused_control_flow_input",
96          shape=original_input.shape,
97          dtype=original_input.dtype)
98      input_mutations.append(
99          # `copied_op` is filled in below, after we've created it.
100          _InputMutation(copied_op=None,
101                         input_index=input_index,
102                         old_graph_tensor=original_input))
103    copied_inputs.append(copied_input)
104
105  copied_control_inputs = []
106  for original_control_input in op.control_inputs:
107    copied_control_input = op_map.get(original_control_input, None)
108    if copied_control_input is None:
109      control_mutations.append(
110          _ControlMutation(copied_op=None,
111                           old_graph_op=original_control_input))
112    else:
113      copied_control_inputs.append(copied_control_input)
114
115  # Don't copy over nodes with _tpu_replicate attribute. This attributed is used
116  # to signal that the op was built inside a tpu_replicate context; if we're
117  # lifting it to another graph we're similarly lifting it into another context.
118  with ops.control_dependencies(copied_control_inputs), ops.device(op.device):
119    # pylint: disable=protected-access
120    f = base_graph._functions.get(op.type, None)
121    if f is not None and compat.as_str(f.name) not in graph._functions:
122      f.add_to_graph(graph)
123    # pylint: enable=protected-access
124
125    # Create a new op in the destination graph if it doesn't exist before.
126    copied_op = graph.create_op(
127        op_type=op.type,
128        inputs=copied_inputs,
129        dtypes=[x.dtype for x in op.outputs],
130        attrs={
131            key: value for key, value in op.node_def.attr.items()
132            if not key.startswith("_class") and
133            not key.startswith("_tpu_replicate")
134        },  # b/128981532.
135        name=op.name)
136  op_map[op] = copied_op
137  for i, o in enumerate(op.outputs):
138    op_map[o] = copied_op.outputs[i]
139
140  return ([mutation._replace(copied_op=copied_op)
141           for mutation in input_mutations],
142          [mutation._replace(copied_op=copied_op)
143           for mutation in control_mutations])
144
145
146def _copy_source(s, graph, op_map, handle_captures, inverse_captures,
147                 base_graph):
148  """Create a source in a graph based on a Tensor from a different graph.
149
150  This function creates a placeholder analog of `s` in a graph with the
151  following behavior:
152
153  1) If s is a captured Tensor or Variable and handle_captures is set to True,
154     simply capture it in the new graph as well.
155
156  2) If s is a PlaceholderWithDefault whose default is a constant, preserve
157     said default in the new graph.
158
159  3) When applicable, copy resource variable metadata from `s` to the newly
160     created placeholder.
161
162  Args:
163    s: The source of interest.
164    graph: The destination graph.
165    op_map: A dict mapping ops and tensors in the old graph to the new one.
166    handle_captures: A boolean indicating whether to re-capture s in the new
167      graph or simply create a vanilla placeholder.
168    inverse_captures: A dict mapping s back to the Tensor or Variable that it
169      captures.
170    base_graph: The graph being copied from.
171  """
172  if handle_captures and s in inverse_captures:
173    copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name)
174  elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s):
175    # Copy the default value to the graph.
176    default_value = s.op.inputs[0]
177    unavailable_inputs, unavailable_control_inputs = _copy_non_source(
178        op=default_value.op, graph=graph, op_map=op_map,
179        base_graph=base_graph)
180    if unavailable_inputs or unavailable_control_inputs:
181      raise AssertionError(
182          "Could not copy source node {} because it has inputs."
183          .format(default_value))
184
185    with ops.device(s.op.device):
186      copied_placeholder = array_ops.placeholder_with_default(
187          input=op_map[default_value], shape=s.shape, name=s.op.name)
188  else:
189    with ops.device(s.op.device):
190      copied_placeholder = array_ops.placeholder(
191          dtype=s.dtype, shape=s.shape, name=s.op.name)
192
193  base_handle = resource_variable_ops.get_resource_handle_data(s)
194  if base_handle.shape_and_type:
195    resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
196        copied_placeholder,
197        base_handle,
198        graph_mode=True)
199
200  op_map[s] = copied_placeholder
201  # Add an entry for the op of the source tensor so that if there are any nodes
202  # depending on that op via control dependencies it can work correctly.
203  op_map[s.op] = copied_placeholder.op
204
205
206@tf_export("__internal__.lift_to_graph", v1=[])
207def lift_to_graph(tensors,
208                  graph,
209                  sources=None,
210                  disallowed_placeholders=None,
211                  add_sources=False,
212                  handle_captures=False,
213                  base_graph=None,
214                  op_map=None):
215  """Copies the tensor and all its inputs recursively to the outer graph.
216
217  Args:
218    tensors: The Tensors to lift.
219    graph: The graph to lift to.
220    sources: Optional sequence of nodes to start from. If omitted the whole
221      subgraph which feeds into `init_tensor` is lifted.
222    disallowed_placeholders: An optional set of ops which may not appear in the
223      lifted graph. Defaults to all placeholders.
224    add_sources: A boolean indicating whether placeholders which are not in
225      sources should be allowed.
226    handle_captures: A boolean indicating whether to re-capture s in the new
227      graph or simply create a vanilla placeholder.
228    base_graph: The graph from which to lift ops. This will be inferred if not
229      specified.
230    op_map: A map contains all the existing nodes that have been lifted to the
231      destination graph, so they won't be lifted and copied again.
232
233  Returns:
234    A mapping from ops in the current default graph to ops in `graph`.
235
236  Raises:
237    UnliftableError: If a placeholder blocks lifting.
238  """
239  variable_init_tensors = []
240  init_tensors = []
241  for tensor in tensors:
242    if isinstance(tensor, resource_variable_ops.ResourceVariable):
243      variable_init_tensors.append(tensor)
244    else:
245      init_tensors.append(tensor)
246  base_graph = base_graph or init_tensors[0].graph
247  op_map = op_map or object_identity.ObjectIdentityDictionary()
248
249  # Check that the initializer does not depend on any placeholders.
250  sources = object_identity.ObjectIdentitySet(sources or [])
251  visited_ops = set(x.op for x in sources)
252  op_outputs = collections.defaultdict(set)
253
254  # First we extract the subgraph between init_tensors and sources.
255  for init_tensor in init_tensors:
256    sources.update(op_selector.map_subgraph(
257        init_tensor=init_tensor,
258        sources=sources,
259        disallowed_placeholders=disallowed_placeholders,
260        visited_ops=visited_ops,
261        op_outputs=op_outputs,
262        add_sources=add_sources))
263
264  # Try to topologically sort the nodes we've extracted. Now we know how many of
265  # their outputs are part of this subgraph.
266  ops_to_copy = []
267  marked_ops = set([])
268  ops_to_visit = [_as_operation(t) for t in init_tensors
269                  if not op_outputs[_as_operation(t)]]
270  unvisited_ops = set(ops_to_visit)
271  while unvisited_ops:
272    while ops_to_visit:
273      op = ops_to_visit.pop()
274      if op in marked_ops:
275        continue
276      marked_ops.add(op)
277      ops_to_copy.append(op)
278      for inp in op_selector.graph_inputs(op):
279        # Don't lift the TPUReplicateMetadata nodes out of the function, because
280        # it has no registered kernels.
281        if inp.type == "TPUReplicateMetadata":
282          continue
283        unvisited_ops.add(inp)
284        if (all(x in marked_ops for x in op_outputs[inp]) and
285            inp not in sources):
286          ops_to_visit.append(inp)
287    unvisited_ops.difference_update(marked_ops)
288    if unvisited_ops:
289      # `unvisited_ops` should only have elements if the graph has a loop. In
290      # this case we want to keep copying and there's no topological ordering;
291      # we'll do ugly post-hoc mutations instead.
292      ops_to_visit.append(next(iter(unvisited_ops)))
293
294  # When lifting from one FuncGraph to another, we will need to capture the
295  # relevant tensors as well.
296  captures = []
297  inverse_captures = object_identity.ObjectIdentityDictionary()
298  internal_captures = []
299  if (isinstance(base_graph, func_graph.FuncGraph) and
300      isinstance(graph, func_graph.FuncGraph)):
301    captures = base_graph.captures
302    for external_capture, internal_capture in captures:
303      inverse_captures[internal_capture] = external_capture
304    internal_captures = base_graph.internal_captures
305
306  # ops_to_copy now holds a reverse topologically sorted list of ops which
307  # ends in the initializer. We copy those to the outermost graph and
308  # build the initialization op there.
309  with graph.as_default():
310    for i in variable_init_tensors:
311      op_map[i] = i
312    source_ops = set()
313    # Add the sources in the same order as the original graph.
314    for s in internal_captures:
315      if s in sources:
316        sources.remove(s)
317        source_ops.add(s.op)
318        _copy_source(
319            s=s,
320            graph=graph,
321            op_map=op_map,
322            handle_captures=handle_captures,
323            inverse_captures=inverse_captures,
324            base_graph=base_graph)
325    for s in sources:
326      source_ops.add(s.op)
327      _copy_source(
328          s=s,
329          graph=graph,
330          op_map=op_map,
331          handle_captures=handle_captures,
332          inverse_captures=inverse_captures,
333          base_graph=base_graph)
334
335    input_mutations = []
336    control_mutations = []
337    for op in reversed(ops_to_copy):
338      if op in source_ops or op in op_map:
339        continue
340      new_input_mutations, new_control_mutations = _copy_non_source(
341          op=op, graph=graph, op_map=op_map, base_graph=base_graph)
342      input_mutations.extend(new_input_mutations)
343      control_mutations.extend(new_control_mutations)
344
345    # Mutate the new graph to insert any loops which existed in the source
346    # graph due to v1 while_loops.
347    #
348    # pylint: disable=protected-access
349    with graph._mutation_lock():
350      for mutation in input_mutations:
351        mutation.copied_op._update_input(
352            mutation.input_index, op_map[mutation.old_graph_tensor])
353      for mutation in control_mutations:
354        # Don't lift the TPUReplicateMetadata nodes out of the function, because
355        # it has no registered kernels.
356        if mutation.old_graph_op.type == "TPUReplicateMetadata":
357          continue
358        mutation.copied_op._add_control_input(op_map[mutation.old_graph_op])
359    # pylint: enable=protected-access
360
361    return op_map
362