1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Utility to lift subgraphs."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23import six
24
25from tensorflow.python.framework import func_graph
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import array_ops
28from tensorflow.python.ops import resource_variable_ops
29
30
31def _graph_inputs(op):
32  return [x.op for x in op.inputs] + list(op.control_inputs)
33
34
35def _as_operation(op_or_tensor):
36  if isinstance(op_or_tensor, ops.Tensor):
37    return op_or_tensor.op
38  return op_or_tensor
39
40
41class UnliftableError(Exception):
42  """Raised if a Tensor cannot be lifted from the graph."""
43  pass
44
45
46def _constant_inputs(op_or_tensor):
47  return all(_as_operation(i).type == u"Const"
48             and not _as_operation(i).control_inputs
49             for i in _graph_inputs(_as_operation(op_or_tensor)))
50
51
52def _path_from(from_op, tensor, sources):
53  """Find one path from `from_op` to `tensor`, ignoring `sources`.
54
55  Args:
56    from_op: A `tf.Operation`.
57    tensor: A `tf.Operation` or `tf.Tensor`.
58    sources: A list of `tf.Tensor`.
59
60  Returns:
61    A python string containing the path, or "??" if none is found.
62  """
63  visited_ops = set([x.op for x in sources])
64  ops_to_visit = [_as_operation(tensor)]
65  some_op_output = {}
66  while ops_to_visit:
67    op = ops_to_visit.pop()
68    if op in visited_ops:
69      continue
70    visited_ops.add(op)
71    if op == from_op:
72      path_op = op
73      path = [path_op]
74      final_op = _as_operation(tensor)
75      while path_op != final_op:
76        path_op = some_op_output[path_op]
77        path.append(path_op)
78      return " <- ".join(["%s (%s)" % (x.name, x.type) for x in reversed(path)])
79    else:
80      for inp in _graph_inputs(op):
81        if inp not in visited_ops and inp not in sources:
82          some_op_output[inp] = op
83          ops_to_visit.append(inp)
84  return "??"
85
86
87def _map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
88                  op_outputs, add_sources):
89  """Walk a Graph and capture the subgraph between init_tensor and sources.
90
91  Note: This function mutates visited_ops and op_outputs.
92
93  Arguments:
94    init_tensor:  A Tensor or Operation where the subgraph terminates.
95    sources:  A set of Tensors where subgraph extraction should stop.
96    disallowed_placeholders: An optional set of ops which may not appear in the
97      lifted graph. Defaults to all placeholders.
98    visited_ops: A set of operations which were visited in a prior pass.
99    op_outputs: A defaultdict containing the outputs of an op which are to be
100      copied into the new subgraph.
101    add_sources: A boolean indicating whether placeholders which are not in
102      sources should be allowed.
103
104  Returns:
105    The set of placeholders upon which init_tensor depends and are not in
106    sources.
107
108  Raises:
109    UnliftableError: if init_tensor depends on a placeholder which is not in
110      sources and add_sources is False.
111  """
112  ops_to_visit = [_as_operation(init_tensor)]
113  extra_sources = set()
114  while ops_to_visit:
115    op = ops_to_visit.pop()
116    if op in visited_ops:
117      continue
118    visited_ops.add(op)
119
120    should_raise = False
121    if disallowed_placeholders is not None and op in disallowed_placeholders:
122      should_raise = True
123    elif op.type == "Placeholder":
124      if disallowed_placeholders is None and not add_sources:
125        should_raise = True
126      extra_sources.update(op.outputs)
127
128    if should_raise:
129      raise UnliftableError(
130          "Unable to lift tensor %s because it depends transitively on "
131          "placeholder %s via at least one path, e.g.: %s"
132          % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
133    for inp in _graph_inputs(op):
134      op_outputs[inp].add(op)
135      if inp not in visited_ops and inp not in (sources or extra_sources):
136        ops_to_visit.append(inp)
137
138  return extra_sources
139
140
141def _copy_non_source(op, graph, op_map):
142  """Copy an op directly to a given graph.
143
144  This function assumes that all of the inputs to an op have already been
145  copied.
146
147  Args:
148    op: The op to be copied.
149    graph: The destination graph.
150    op_map: A dict mapping ops and tensors in the old graph to the new one.
151  """
152  copied_inputs = [op_map[x] for x in op.inputs]
153  copied_control_inputs = [op_map[x] for x in op.control_inputs]
154  with ops.control_dependencies(copied_control_inputs), ops.device(op.device):
155    copied_op = graph.create_op(
156        op_type=op.type,
157        inputs=copied_inputs,
158        dtypes=[x.dtype for x in op.outputs],
159        attrs=op.node_def.attr,
160        name=op.name)
161  op_map[op] = copied_op
162  for i, o in enumerate(op.outputs):
163    op_map[o] = copied_op.outputs[i]
164
165
166def _copy_source(s, graph, op_map, handle_captures, inverse_captures):
167  """Create a source in a graph based on a Tensor from a different graph.
168
169  This function creates a placeholder analog of `s` in a graph with the
170  following behavior:
171
172  1) If s is a captured Tensor or Variable and handle_captures is set to True,
173     simply capture it in the new graph as well.
174
175  2) If s is a PlaceholderWithDefault whose default is a constant, preserve
176     said default in the new graph.
177
178  3) When applicable, copy resource variable metadata from `s` to the newly
179     created placeholder.
180
181  Args:
182    s: The source of interest.
183    graph: The destination graph.
184    op_map: A dict mapping ops and tensors in the old graph to the new one.
185    handle_captures: A boolean indicating whether to re-capture s in the new
186      graph or simply create a vanilla placeholder.
187    inverse_captures: A dict mapping s back to the Tensor or Variable that it
188      captures.
189  """
190  if handle_captures and s in inverse_captures:
191    copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name)
192  elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s):
193    # Copy the default value to the graph.
194    default_value = s.op.inputs[0]
195    _copy_non_source(op=default_value.op, graph=graph, op_map=op_map)
196
197    with ops.device(s.op.device):
198      copied_placeholder = array_ops.placeholder_with_default(
199          input=op_map[default_value], shape=s.shape, name=s.op.name)
200  else:
201    with ops.device(s.op.device):
202      copied_placeholder = array_ops.placeholder(
203          dtype=s.dtype, shape=s.shape, name=s.op.name)
204
205  base_handle = resource_variable_ops.get_resource_handle_data(s)
206  if base_handle.shape_and_type:
207    resource_variable_ops._set_handle_shapes_and_types(  # pylint: disable=protected-access
208        copied_placeholder,
209        base_handle,
210        graph_mode=True)
211
212  op_map[s] = copied_placeholder
213
214
215def lift_to_graph(init_tensors, graph, sources=None,
216                  disallowed_placeholders=None, add_sources=False,
217                  handle_captures=False, base_graph=None):
218  """Copies the tensor and all its inputs recursively to the outer graph.
219
220  Args:
221    init_tensors: The Tensor to lift.
222    graph: The graph to lift to.
223    sources: Optional sequence of nodes to start from. If omitted the whole
224      subgraph which feeds into `init_tensor` is lifted.
225    disallowed_placeholders: An optional set of ops which may not appear in the
226      lifted graph. Defaults to all placeholders.
227    add_sources: A boolean indicating whether placeholders which are not in
228      sources should be allowed.
229    handle_captures: A boolean indicating whether to re-capture s in the new
230      graph or simply create a vanilla placeholder.
231    base_graph: The graph from which to lift ops. This will be inferred if not
232      specified.
233
234  Returns:
235    A mapping from ops in the current default graph to ops in `graph`.
236
237  Raises:
238    UnliftableError: If a placeholder blocks lifting.
239  """
240  variable_init_tensors = {i for i in init_tensors if isinstance(
241      i, resource_variable_ops.ResourceVariable)}
242  init_tensors = set(init_tensors).difference(variable_init_tensors)
243  base_graph = base_graph or list(init_tensors)[0].graph
244
245  # Check that the initializer does not depend on any placeholders.
246  sources = set(sources or [])
247  visited_ops = set([x.op for x in sources])
248  op_outputs = collections.defaultdict(set)
249
250  # First we extract the subgraph between init_tensors and sources.
251  for init_tensor in init_tensors:
252    sources.update(_map_subgraph(
253        init_tensor=init_tensor,
254        sources=sources,
255        disallowed_placeholders=disallowed_placeholders,
256        visited_ops=visited_ops,
257        op_outputs=op_outputs,
258        add_sources=add_sources))
259
260  # Topologically sort the nodes we've extracted. Now we know how many of their
261  # outputs are part of this subgraph.
262  ops_to_copy = []
263  marked_ops = set([])
264  ops_to_visit = [_as_operation(t) for t in init_tensors
265                  if not op_outputs[_as_operation(t)]]
266  while ops_to_visit:
267    op = ops_to_visit.pop()
268    if op in marked_ops:
269      continue
270    marked_ops.add(op)
271    ops_to_copy.append(op)
272    for inp in _graph_inputs(op):
273      if (all(x in marked_ops for x in op_outputs[inp]) and
274          inp not in sources):
275        ops_to_visit.append(inp)
276
277  # When lifting from one FuncGraph to another, we will need to capture the
278  # relevant tensors as well.
279  captures = collections.OrderedDict()
280  if (isinstance(base_graph, func_graph.FuncGraph) and
281      isinstance(graph, func_graph.FuncGraph)):
282    captures = base_graph.captures
283  inverse_captures = {v: k for k, v in captures.items()}
284
285  # ops_to_copy now holds a reverse topologically sorted list of ops which
286  # ends in the initializer. We copy those to the outermost graph and
287  # build the initialization op there.
288  with graph.as_default():
289    op_map = {i: i for i in variable_init_tensors}  # Pass through variables.
290    source_ops = set()
291    # Add the sources in the same order as the original graph.
292    for s in six.itervalues(captures):
293      if s in sources:
294        sources.remove(s)
295        source_ops.add(s.op)
296        _copy_source(
297            s=s,
298            graph=graph,
299            op_map=op_map,
300            handle_captures=handle_captures,
301            inverse_captures=inverse_captures)
302    for s in sources:
303      source_ops.add(s.op)
304      _copy_source(
305          s=s,
306          graph=graph,
307          op_map=op_map,
308          handle_captures=handle_captures,
309          inverse_captures=inverse_captures)
310
311    for op in reversed(ops_to_copy):
312      if op in source_ops:
313        continue
314
315      _copy_non_source(op=op, graph=graph, op_map=op_map)
316
317    return op_map
318