1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tools for selecting ops in a graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.framework import ops 22from tensorflow.python.util import object_identity 23 24 25def is_differentiable(op): 26 try: 27 return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access 28 except LookupError: 29 return False 30 31 32def is_iterable(obj): 33 """Return true if the object is iterable.""" 34 if isinstance(obj, ops.Tensor): 35 return False 36 try: 37 _ = iter(obj) 38 except Exception: # pylint: disable=broad-except 39 return False 40 return True 41 42 43def concatenate_unique(la, lb): 44 """Add all the elements of `lb` to `la` if they are not there already. 45 46 The elements added to `la` maintain ordering with respect to `lb`. 47 48 Args: 49 la: List of Python objects. 50 lb: List of Python objects. 51 Returns: 52 `la`: The list `la` with missing elements from `lb`. 53 """ 54 la_set = set(la) 55 for l in lb: 56 if l not in la_set: 57 la.append(l) 58 la_set.add(l) 59 return la 60 61 62def get_tensors(graph): 63 """get all the tensors which are input or output of an op in the graph. 64 65 Args: 66 graph: a `tf.Graph`. 67 Returns: 68 A list of `tf.Tensor`. 69 Raises: 70 TypeError: if graph is not a `tf.Graph`. 71 """ 72 if not isinstance(graph, ops.Graph): 73 raise TypeError("Expected a graph, got: {}".format(type(graph))) 74 ts = [] 75 for op in graph.get_operations(): 76 ts += op.outputs 77 return ts 78 79 80def get_unique_graph(tops, check_types=None, none_if_empty=False): 81 """Return the unique graph used by the all the elements in tops. 82 83 Args: 84 tops: list of elements to check (usually a list of tf.Operation and/or 85 tf.Tensor). Or a tf.Graph. 86 check_types: check that the element in tops are of given type(s). If None, 87 the types (tf.Operation, tf.Tensor) are used. 88 none_if_empty: don't raise an error if tops is an empty list, just return 89 None. 90 Returns: 91 The unique graph used by all the tops. 92 Raises: 93 TypeError: if tops is not a iterable of tf.Operation. 94 ValueError: if the graph is not unique. 95 """ 96 if isinstance(tops, ops.Graph): 97 return tops 98 if not is_iterable(tops): 99 raise TypeError("{} is not iterable".format(type(tops))) 100 if check_types is None: 101 check_types = (ops.Operation, ops.Tensor) 102 elif not is_iterable(check_types): 103 check_types = (check_types,) 104 g = None 105 for op in tops: 106 if not isinstance(op, check_types): 107 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( 108 t) for t in check_types]), type(op))) 109 if g is None: 110 g = op.graph 111 elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access 112 raise ValueError("Operation {} does not belong to given graph".format(op)) 113 if g is None and not none_if_empty: 114 raise ValueError("Can't find the unique graph of an empty list") 115 return g 116 117 118def check_graphs(*args): 119 """Check that all the element in args belong to the same graph. 120 121 Args: 122 *args: a list of object with a obj.graph property. 123 Raises: 124 ValueError: if all the elements do not belong to the same graph. 125 """ 126 graph = None 127 for i, sgv in enumerate(args): 128 if graph is None and sgv.graph is not None: 129 graph = sgv.graph 130 elif sgv.graph is not None and sgv.graph is not graph: 131 raise ValueError("Argument[{}]: Wrong graph!".format(i)) 132 133 134def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): 135 """Convert ts to a list of `tf.Tensor`. 136 137 Args: 138 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. 139 check_graph: if `True` check if all the tensors belong to the same graph. 140 allow_graph: if `False` a `tf.Graph` cannot be converted. 141 ignore_ops: if `True`, silently ignore `tf.Operation`. 142 Returns: 143 A newly created list of `tf.Tensor`. 144 Raises: 145 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, 146 if `check_graph` is `True`, if all the ops do not belong to the same graph. 147 """ 148 if isinstance(ts, ops.Graph): 149 if allow_graph: 150 return get_tensors(ts) 151 else: 152 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 153 else: 154 if not is_iterable(ts): 155 ts = [ts] 156 if not ts: 157 return [] 158 if check_graph: 159 check_types = None if ignore_ops else ops.Tensor 160 get_unique_graph(ts, check_types=check_types) 161 return [t for t in ts if isinstance(t, ops.Tensor)] 162 163 164def get_generating_ops(ts): 165 """Return all the generating ops of the tensors in `ts`. 166 167 Args: 168 ts: a list of `tf.Tensor` 169 Returns: 170 A list of all the generating `tf.Operation` of the tensors in `ts`. 171 Raises: 172 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. 173 """ 174 ts = make_list_of_t(ts, allow_graph=False) 175 return [t.op for t in ts] 176 177 178def get_consuming_ops(ts): 179 """Return all the consuming ops of the tensors in ts. 180 181 Args: 182 ts: a list of `tf.Tensor` 183 Returns: 184 A list of all the consuming `tf.Operation` of the tensors in `ts`. 185 Raises: 186 TypeError: if ts cannot be converted to a list of `tf.Tensor`. 187 """ 188 ts = make_list_of_t(ts, allow_graph=False) 189 tops = [] 190 for t in ts: 191 for op in t.consumers(): 192 if op not in tops: 193 tops.append(op) 194 return tops 195 196 197def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False): 198 """Convert ops to a list of `tf.Operation`. 199 200 Args: 201 tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single 202 operation. 203 check_graph: if `True` check if all the operations belong to the same graph. 204 allow_graph: if `False` a `tf.Graph` cannot be converted. 205 ignore_ts: if True, silently ignore `tf.Tensor`. 206 Returns: 207 A newly created list of `tf.Operation`. 208 Raises: 209 TypeError: if tops cannot be converted to a list of `tf.Operation` or, 210 if `check_graph` is `True`, if all the ops do not belong to the 211 same graph. 212 """ 213 if isinstance(tops, ops.Graph): 214 if allow_graph: 215 return tops.get_operations() 216 else: 217 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 218 else: 219 if not is_iterable(tops): 220 tops = [tops] 221 if not tops: 222 return [] 223 if check_graph: 224 check_types = None if ignore_ts else ops.Operation 225 get_unique_graph(tops, check_types=check_types) 226 return [op for op in tops if isinstance(op, ops.Operation)] 227 228 229def _get_inputs(op, only_differentiable): 230 op_inputs = op.inputs 231 if only_differentiable: 232 return op_inputs if is_differentiable(op) else [] 233 else: 234 return op_inputs 235 236 237def get_backward_walk_ops(seed_ops, 238 inclusive=True, 239 within_ops=None, 240 within_ops_fn=None, 241 stop_at_ts=(), 242 control_inputs=False, 243 only_differentiable=False): 244 """Do a backward graph walk and return all the visited ops. 245 246 Args: 247 seed_ops: an iterable of operations from which the backward graph 248 walk starts. If a list of tensors is given instead, the seed_ops are set 249 to be the generators of those tensors. 250 inclusive: if True the given seed_ops are also part of the resulting set. 251 within_ops: an iterable of `tf.Operation` within which the search is 252 restricted. If `within_ops` is `None`, the search is performed within 253 the whole graph. 254 within_ops_fn: if provided, a function on ops that should return True iff 255 the op is within the graph traversal. This can be used along within_ops, 256 in which case an op is within if it is also in within_ops. 257 stop_at_ts: an iterable of tensors at which the graph walk stops. 258 control_inputs: if True, control inputs will be used while moving backward. 259 only_differentiable: if True, only traverse ops which are differentiable. 260 This includes natively differentiable ops, or ops with custom gradients. 261 Returns: 262 A Python set of all the `tf.Operation` behind `seed_ops`. 263 Raises: 264 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of 265 `tf.Operation`. 266 """ 267 control_inputs = control_inputs and (not only_differentiable) 268 269 if not is_iterable(seed_ops): 270 seed_ops = [seed_ops] 271 if not seed_ops: 272 return [] 273 if isinstance(seed_ops[0], ops.Tensor): 274 ts = make_list_of_t(seed_ops, allow_graph=False) 275 seed_ops = get_generating_ops(ts) 276 else: 277 seed_ops = make_list_of_op(seed_ops, allow_graph=False) 278 279 stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts)) 280 seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops)) 281 if within_ops: 282 within_ops = make_list_of_op(within_ops, allow_graph=False) 283 within_ops = object_identity.ObjectIdentitySet(within_ops) 284 seed_ops &= within_ops 285 286 def is_within(op): 287 return (within_ops is None or op in within_ops) and ( 288 within_ops_fn is None or within_ops_fn(op)) 289 290 result = list(seed_ops) 291 wave = set(seed_ops) 292 while wave: 293 new_wave = set() 294 for op in wave: 295 for new_t in _get_inputs(op, only_differentiable=only_differentiable): 296 if new_t in stop_at_ts: 297 continue 298 if new_t.op not in result and is_within(new_t.op): 299 new_wave.add(new_t.op) 300 if control_inputs: 301 for new_op in op.control_inputs: 302 if new_op not in result and is_within(new_op): 303 new_wave.add(new_op) 304 concatenate_unique(result, new_wave) 305 wave = new_wave 306 if not inclusive: 307 result = [op for op in result if op not in seed_ops] 308 return result 309 310 311class UnliftableError(Exception): 312 """Raised if a Tensor cannot be lifted from the graph.""" 313 314 # Prevent autograph from rewriting this error. 315 ag_pass_through = True 316 317 318def _as_operation(op_or_tensor): 319 if isinstance(op_or_tensor, ops.Tensor): 320 return op_or_tensor.op 321 return op_or_tensor 322 323 324def graph_inputs(op): 325 return [x.op for x in op.inputs] + list(op.control_inputs) 326 327 328def _path_from(from_op, tensor, sources): 329 """Find one path from `from_op` to `tensor`, ignoring `sources`. 330 331 Args: 332 from_op: A `tf.Operation`. 333 tensor: A `tf.Operation` or `tf.Tensor`. 334 sources: A list of `tf.Tensor`. 335 336 Returns: 337 A python string containing the path, or "??" if none is found. 338 """ 339 if isinstance(from_op, ops.Tensor): 340 from_op = from_op.op 341 342 visited_ops = set(x.op for x in sources) 343 ops_to_visit = [_as_operation(tensor)] 344 some_op_output = {} 345 while ops_to_visit: 346 op = ops_to_visit.pop() 347 if op in visited_ops: 348 continue 349 visited_ops.add(op) 350 if op == from_op: 351 path_op = op 352 path = [path_op] 353 final_op = _as_operation(tensor) 354 while path_op != final_op: 355 path_op = some_op_output[path_op] 356 path.append(path_op) 357 return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path)) 358 else: 359 for inp in graph_inputs(op): 360 if inp not in visited_ops and inp not in sources: 361 some_op_output[inp] = op 362 ops_to_visit.append(inp) 363 return "??" 364 365 366# TODO(jmenick) - there is considerable duplication of functionality between 367# this function and get_backward_walk_ops(). Need to deduplicate. 368def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, 369 op_outputs, add_sources): 370 """Walk a Graph and capture the subgraph between init_tensor and sources. 371 372 Note: This function mutates visited_ops and op_outputs. 373 374 Args: 375 init_tensor: A Tensor or Operation where the subgraph terminates. 376 sources: A set of Tensors where subgraph extraction should stop. 377 disallowed_placeholders: An optional set of ops which may not appear in the 378 lifted graph. Defaults to all placeholders. 379 visited_ops: A set of operations which were visited in a prior pass. 380 op_outputs: A defaultdict containing the outputs of an op which are to be 381 copied into the new subgraph. 382 add_sources: A boolean indicating whether placeholders which are not in 383 sources should be allowed. 384 385 Returns: 386 The set of placeholders upon which init_tensor depends and are not in 387 sources. 388 389 Raises: 390 UnliftableError: if init_tensor depends on a placeholder which is not in 391 sources and add_sources is False. 392 """ 393 ops_to_visit = [_as_operation(init_tensor)] 394 extra_sources = object_identity.ObjectIdentitySet() 395 while ops_to_visit: 396 op = ops_to_visit.pop() 397 if op in visited_ops: 398 continue 399 visited_ops.add(op) 400 401 should_raise = False 402 if disallowed_placeholders is not None and op in disallowed_placeholders: 403 should_raise = True 404 elif op.type == "Placeholder": 405 if disallowed_placeholders is None and not add_sources: 406 should_raise = True 407 extra_sources.update(op.outputs) 408 409 if should_raise: 410 raise UnliftableError( 411 "Unable to lift tensor %s because it depends transitively on " 412 "placeholder %s via at least one path, e.g.: %s" 413 % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources))) 414 for inp in graph_inputs(op): 415 op_outputs[inp].add(op) 416 if inp not in visited_ops and inp not in (sources or extra_sources): 417 ops_to_visit.append(inp) 418 419 return extra_sources 420