1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Various ways of selecting operations and tensors in a graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import re
22
23from six import iteritems
24from six import string_types
25
26from tensorflow.contrib.graph_editor import util
27from tensorflow.python.framework import ops as tf_ops
28
29__all__ = [
30    "can_be_regex",
31    "make_regex",
32    "filter_ts",
33    "filter_ts_from_regex",
34    "filter_ops",
35    "filter_ops_from_regex",
36    "get_name_scope_ops",
37    "check_cios",
38    "get_ops_ios",
39    "compute_boundary_ts",
40    "get_within_boundary_ops",
41    "get_forward_walk_ops",
42    "get_backward_walk_ops",
43    "get_walks_intersection_ops",
44    "get_walks_union_ops",
45    "select_ops",
46    "select_ts",
47    "select_ops_and_ts",
48]
49
50_RE_TYPE = type(re.compile(""))
51
52
53def can_be_regex(obj):
54  """Return True if obj can be turned into a regular expression."""
55  return isinstance(obj, string_types + (_RE_TYPE,))
56
57
58def make_regex(obj):
59  """Return a compiled regular expression.
60
61  Args:
62    obj: a string or a regular expression.
63  Returns:
64    A compiled regular expression.
65  Raises:
66    ValueError: if obj could not be converted to a regular expression.
67  """
68  if not can_be_regex(obj):
69    raise ValueError("Expected a string or a regex, got: {}".format(type(obj)))
70
71  if isinstance(obj, string_types):
72    return re.compile(obj)
73  else:
74    return obj
75
76
77def _get_input_ts(ops):
78  """Compute the list of unique input tensors of all the op in ops.
79
80  Args:
81    ops: an object convertible to a list of `tf.Operation`.
82  Returns:
83    The list of unique input tensors of all the op in ops.
84  Raises:
85    TypeError: if ops cannot be converted to a list of `tf.Operation`.
86  """
87  ops = util.make_list_of_op(ops)
88  ts = []
89  ts_set = set()
90  for op in ops:
91    for t in op.inputs:
92      if t not in ts_set:
93        ts.append(t)
94        ts_set.add(t)
95  return ts
96
97
98def _get_output_ts(ops):
99  """Compute the list of unique output tensors of all the op in ops.
100
101  Args:
102    ops: an object convertible to a list of tf.Operation.
103  Returns:
104    The list of unique output tensors of all the op in ops.
105  Raises:
106    TypeError: if ops cannot be converted to a list of tf.Operation.
107  """
108  ops = util.make_list_of_op(ops)
109  ts = []
110  for op in ops:
111    ts += op.outputs
112  return ts
113
114
115def filter_ts(ops, positive_filter):
116  """Get all the tensors which are input or output of an op in ops.
117
118  Args:
119    ops: an object convertible to a list of `tf.Operation`.
120    positive_filter: a function deciding whether to keep a tensor or not.
121      If `True`, all the tensors are returned.
122  Returns:
123    A list of `tf.Tensor`.
124  Raises:
125    TypeError: if ops cannot be converted to a list of `tf.Operation`.
126  """
127  ops = util.make_list_of_op(ops)
128  ts = _get_input_ts(ops)
129  util.concatenate_unique(ts, _get_output_ts(ops))
130  if positive_filter is not True:
131    ts = [t for t in ts if positive_filter(t)]
132  return ts
133
134
135def filter_ts_from_regex(ops, regex):
136  r"""Get all the tensors linked to ops that match the given regex.
137
138  Args:
139    ops: an object convertible to a list of tf.Operation.
140    regex: a regular expression matching the tensors' name.
141      For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo"
142      scope.
143  Returns:
144    A list of tf.Tensor.
145  Raises:
146    TypeError: if ops cannot be converted to a list of tf.Operation.
147  """
148  ops = util.make_list_of_op(ops)
149  regex_obj = make_regex(regex)
150  return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name))
151
152
153def filter_ops(ops, positive_filter):
154  """Get the ops passing the given filter.
155
156  Args:
157    ops: an object convertible to a list of tf.Operation.
158    positive_filter: a function deciding where to keep an operation or not.
159      If True, all the operations are returned.
160  Returns:
161    A list of selected tf.Operation.
162  Raises:
163    TypeError: if ops cannot be converted to a list of tf.Operation.
164  """
165  ops = util.make_list_of_op(ops)
166  if positive_filter is not True:  # pylint: disable=g-explicit-bool-comparison
167    ops = [op for op in ops if positive_filter(op)]
168  return ops
169
170
171def filter_ops_from_regex(ops, regex):
172  """Get all the operations that match the given regex.
173
174  Args:
175    ops: an object convertible to a list of `tf.Operation`.
176    regex: a regular expression matching the operation's name.
177      For example, `"^foo(/.*)?$"` will match all the operations in the "foo"
178      scope.
179  Returns:
180    A list of `tf.Operation`.
181  Raises:
182    TypeError: if ops cannot be converted to a list of `tf.Operation`.
183  """
184  ops = util.make_list_of_op(ops)
185  regex_obj = make_regex(regex)
186  return filter_ops(ops, lambda op: regex_obj.search(op.name))
187
188
189def get_name_scope_ops(ops, scope):
190  """Get all the operations under the given scope path.
191
192  Args:
193    ops: an object convertible to a list of tf.Operation.
194    scope: a scope path.
195  Returns:
196    A list of tf.Operation.
197  Raises:
198    TypeError: if ops cannot be converted to a list of tf.Operation.
199  """
200  if scope and scope[-1] == "/":
201    scope = scope[:-1]
202  return filter_ops_from_regex(ops, "^{}(/.*)?$".format(scope))
203
204
205def check_cios(control_inputs=False, control_outputs=None, control_ios=None):
206  """Do various check on control_inputs and control_outputs.
207
208  Args:
209    control_inputs: A boolean indicating whether control inputs are enabled.
210    control_outputs: An instance of util.ControlOutputs or None. If not None,
211      control outputs are enabled.
212    control_ios:  An instance of util.ControlOutputs or None. If not None, both
213      control inputs and control outputs are enabled. This is equivalent to set
214      control_inputs to True and control_outputs to the util.ControlOutputs
215      instance.
216  Returns:
217    A tuple `(control_inputs, control_outputs)` where:
218      `control_inputs` is a boolean indicating whether to use control inputs.
219      `control_outputs` is an instance of util.ControlOutputs or None
220  Raises:
221    ValueError: if control_inputs is an instance of util.ControlOutputs but
222      control_outputs is not None
223    TypeError: if control_outputs is not None and is not a util.ControlOutputs.
224  """
225  if control_ios is not None:
226    if not isinstance(control_ios, util.ControlOutputs):
227      raise TypeError("Expected a util.ControlOutputs, got: {}".format(
228          type(control_ios)))
229    if control_outputs is not None:
230      raise ValueError("control_outputs should be None when using control_ios.")
231    control_inputs = True
232    control_outputs = control_ios
233  elif control_outputs is not None:
234    if not isinstance(control_outputs, util.ControlOutputs):
235      raise TypeError("Expected a util.ControlOutputs, got: {}".format(
236          type(control_outputs)))
237
238  if control_outputs is not None:
239    control_outputs.update()
240  return control_inputs, control_outputs
241
242
243def get_ops_ios(ops, control_inputs=False, control_outputs=None,
244                control_ios=None):
245  """Return all the `tf.Operation` which are connected to an op in ops.
246
247  Args:
248    ops: an object convertible to a list of `tf.Operation`.
249    control_inputs: A boolean indicating whether control inputs are enabled.
250    control_outputs: An instance of `util.ControlOutputs` or `None`. If not
251      `None`, control outputs are enabled.
252    control_ios:  An instance of `util.ControlOutputs` or `None`. If not `None`,
253      both control inputs and control outputs are enabled. This is equivalent to
254      set `control_inputs` to `True` and `control_outputs` to the
255      `util.ControlOutputs` instance.
256  Returns:
257    All the `tf.Operation` surrounding the given ops.
258  Raises:
259    TypeError: if `ops` cannot be converted to a list of `tf.Operation`.
260  """
261  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
262                                               control_ios)
263  ops = util.make_list_of_op(ops)
264  res = []
265  for op in ops:
266    util.concatenate_unique(res, [t.op for t in op.inputs])
267    for t in op.outputs:
268      util.concatenate_unique(res, t.consumers())
269    if control_outputs is not None:
270      util.concatenate_unique(res, control_outputs.get(op))
271    if control_inputs:
272      util.concatenate_unique(res, op.control_inputs)
273  return res
274
275
276def compute_boundary_ts(ops):
277  """Compute the tensors at the boundary of a set of ops.
278
279  This function looks at all the tensors connected to the given ops (in/out)
280  and classify them into three categories:
281  1) input tensors: tensors whose generating operation is not in ops.
282  2) output tensors: tensors whose consumer operations are not in ops
283  3) inside tensors: tensors which are neither input nor output tensors.
284
285  Note that a tensor can be both an inside tensor and an output tensor if it is
286  consumed by operations both outside and inside of `ops`.
287
288  Args:
289    ops: an object convertible to a list of tf.Operation.
290  Returns:
291    A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where:
292      `outside_input_ts` is a Python list of input tensors;
293      `outside_output_ts` is a python list of output tensors;
294      `inside_ts` is a python list of inside tensors.
295    Since a tensor can be both an inside tensor and an output tensor,
296    `outside_output_ts` and `inside_ts` might intersect.
297  Raises:
298    TypeError: if ops cannot be converted to a list of tf.Operation.
299  """
300  ops = util.make_list_of_op(ops)
301  input_ts = _get_input_ts(ops)
302  output_ts = _get_output_ts(ops)
303  output_ts_set = frozenset(output_ts)
304  ops_set = frozenset(ops)
305
306  # Compute inside tensors.
307  inside_ts = []
308  only_inside_ts = []
309  for t in input_ts:
310    # Skip if the input tensor is not also an output tensor.
311    if t not in output_ts_set:
312      continue
313    # Mark as "inside".
314    inside_ts.append(t)
315    # Mark as "only inside" if the tensor is not both inside and output.
316    consumers = frozenset(t.consumers())
317    if consumers - ops_set:
318      continue
319    only_inside_ts.append(t)
320
321  inside_ts_set = frozenset(inside_ts)
322  only_inside_ts_set = frozenset(only_inside_ts)
323  outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set]
324  outside_input_ts = [t for t in input_ts if t not in inside_ts_set]
325  return outside_input_ts, outside_output_ts, inside_ts
326
327
328def get_within_boundary_ops(ops,
329                            seed_ops,
330                            boundary_ops=(),
331                            inclusive=True,
332                            control_inputs=False,
333                            control_outputs=None,
334                            control_ios=None):
335  """Return all the `tf.Operation` within the given boundary.
336
337  Args:
338    ops: an object convertible to a list of `tf.Operation`. those ops define the
339      set in which to perform the operation (if a `tf.Graph` is given, it
340      will be converted to the list of all its operations).
341    seed_ops: the operations from which to start expanding.
342    boundary_ops: the ops forming the boundary.
343    inclusive: if `True`, the result will also include the boundary ops.
344    control_inputs: A boolean indicating whether control inputs are enabled.
345    control_outputs: An instance of `util.ControlOutputs` or `None`. If not
346      `None`, control outputs are enabled.
347    control_ios:  An instance of `util.ControlOutputs` or `None`. If not
348      `None`, both control inputs and control outputs are enabled. This is
349      equivalent to set control_inputs to True and control_outputs to
350      the `util.ControlOutputs` instance.
351  Returns:
352    All the `tf.Operation` surrounding the given ops.
353  Raises:
354    TypeError: if `ops` or `seed_ops` cannot be converted to a list of
355      `tf.Operation`.
356    ValueError: if the boundary is intersecting with the seeds.
357  """
358  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
359                                               control_ios)
360  ops = util.make_list_of_op(ops)
361  seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
362  boundary_ops = set(util.make_list_of_op(boundary_ops))
363  res = set(seed_ops)
364  if boundary_ops & res:
365    raise ValueError("Boundary is intersecting with the seeds.")
366  wave = set(seed_ops)
367  while wave:
368    new_wave = set()
369    ops_io = get_ops_ios(wave, control_inputs, control_outputs)
370    for op in ops_io:
371      if op in res:
372        continue
373      if op in boundary_ops:
374        if inclusive:
375          res.add(op)
376      else:
377        new_wave.add(op)
378    res.update(new_wave)
379    wave = new_wave
380  return [op for op in ops if op in res]
381
382
383def get_forward_walk_ops(seed_ops,
384                         inclusive=True,
385                         within_ops=None,
386                         within_ops_fn=None,
387                         stop_at_ts=(),
388                         control_outputs=None):
389  """Do a forward graph walk and return all the visited ops.
390
391  Args:
392    seed_ops: an iterable of operations from which the forward graph
393      walk starts. If a list of tensors is given instead, the seed_ops are set
394      to be the consumers of those tensors.
395    inclusive: if True the given seed_ops are also part of the resulting set.
396    within_ops: an iterable of `tf.Operation` within which the search is
397      restricted. If `within_ops` is `None`, the search is performed within
398      the whole graph.
399    within_ops_fn: if provided, a function on ops that should return True iff
400      the op is within the graph traversal. This can be used along within_ops,
401      in which case an op is within if it is also in within_ops.
402    stop_at_ts: an iterable of tensors at which the graph walk stops.
403    control_outputs: a `util.ControlOutputs` instance or None.
404      If not `None`, it will be used while walking the graph forward.
405  Returns:
406    A Python set of all the `tf.Operation` ahead of `seed_ops`.
407  Raises:
408    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
409      `tf.Operation`.
410  """
411  _, control_outputs = check_cios(False, control_outputs)
412  if not util.is_iterable(seed_ops):
413    seed_ops = [seed_ops]
414  if not seed_ops:
415    return []
416  if isinstance(seed_ops[0], tf_ops.Tensor):
417    ts = util.make_list_of_t(seed_ops, allow_graph=False)
418    seed_ops = util.get_consuming_ops(ts)
419  else:
420    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
421
422  seed_ops = frozenset(seed_ops)
423  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
424  if within_ops:
425    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
426    within_ops = frozenset(within_ops)
427    seed_ops &= within_ops
428
429  def is_within(op):
430    return (within_ops is None or op in within_ops) and (
431        within_ops_fn is None or within_ops_fn(op))
432
433  result = list(seed_ops)
434  wave = set(seed_ops)
435  while wave:
436    new_wave = set()
437    for op in wave:
438      for new_t in op.outputs:
439        if new_t in stop_at_ts:
440          continue
441        for new_op in new_t.consumers():
442          if new_op not in result and is_within(new_op):
443            new_wave.add(new_op)
444      if control_outputs is not None:
445        for new_op in control_outputs.get(op):
446          if new_op not in result and is_within(new_op):
447            new_wave.add(new_op)
448    util.concatenate_unique(result, new_wave)
449    wave = new_wave
450  if not inclusive:
451    result = [op for op in result if op not in seed_ops]
452  return result
453
454
455def get_backward_walk_ops(seed_ops,
456                          inclusive=True,
457                          within_ops=None,
458                          within_ops_fn=None,
459                          stop_at_ts=(),
460                          control_inputs=False):
461  """Do a backward graph walk and return all the visited ops.
462
463  Args:
464    seed_ops: an iterable of operations from which the backward graph
465      walk starts. If a list of tensors is given instead, the seed_ops are set
466      to be the generators of those tensors.
467    inclusive: if True the given seed_ops are also part of the resulting set.
468    within_ops: an iterable of `tf.Operation` within which the search is
469      restricted. If `within_ops` is `None`, the search is performed within
470      the whole graph.
471    within_ops_fn: if provided, a function on ops that should return True iff
472      the op is within the graph traversal. This can be used along within_ops,
473      in which case an op is within if it is also in within_ops.
474    stop_at_ts: an iterable of tensors at which the graph walk stops.
475    control_inputs: if True, control inputs will be used while moving backward.
476  Returns:
477    A Python set of all the `tf.Operation` behind `seed_ops`.
478  Raises:
479    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
480      `tf.Operation`.
481  """
482  if not util.is_iterable(seed_ops):
483    seed_ops = [seed_ops]
484  if not seed_ops:
485    return []
486  if isinstance(seed_ops[0], tf_ops.Tensor):
487    ts = util.make_list_of_t(seed_ops, allow_graph=False)
488    seed_ops = util.get_generating_ops(ts)
489  else:
490    seed_ops = util.make_list_of_op(seed_ops, allow_graph=False)
491
492  stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts))
493  seed_ops = frozenset(util.make_list_of_op(seed_ops))
494  if within_ops:
495    within_ops = util.make_list_of_op(within_ops, allow_graph=False)
496    within_ops = frozenset(within_ops)
497    seed_ops &= within_ops
498
499  def is_within(op):
500    return (within_ops is None or op in within_ops) and (
501        within_ops_fn is None or within_ops_fn(op))
502
503  result = list(seed_ops)
504  wave = set(seed_ops)
505  while wave:
506    new_wave = set()
507    for op in wave:
508      for new_t in op.inputs:
509        if new_t in stop_at_ts:
510          continue
511        if new_t.op not in result and is_within(new_t.op):
512          new_wave.add(new_t.op)
513      if control_inputs:
514        for new_op in op.control_inputs:
515          if new_op not in result and is_within(new_op):
516            new_wave.add(new_op)
517    util.concatenate_unique(result, new_wave)
518    wave = new_wave
519  if not inclusive:
520    result = [op for op in result if op not in seed_ops]
521  return result
522
523
524def get_walks_intersection_ops(forward_seed_ops,
525                               backward_seed_ops,
526                               forward_inclusive=True,
527                               backward_inclusive=True,
528                               within_ops=None,
529                               within_ops_fn=None,
530                               control_inputs=False,
531                               control_outputs=None,
532                               control_ios=None):
533  """Return the intersection of a forward and a backward walk.
534
535  Args:
536    forward_seed_ops: an iterable of operations from which the forward graph
537      walk starts. If a list of tensors is given instead, the seed_ops are set
538      to be the consumers of those tensors.
539    backward_seed_ops: an iterable of operations from which the backward graph
540      walk starts. If a list of tensors is given instead, the seed_ops are set
541      to be the generators of those tensors.
542    forward_inclusive: if True the given forward_seed_ops are also part of the
543      resulting set.
544    backward_inclusive: if True the given backward_seed_ops are also part of the
545      resulting set.
546    within_ops: an iterable of tf.Operation within which the search is
547      restricted. If within_ops is None, the search is performed within
548      the whole graph.
549    within_ops_fn: if provided, a function on ops that should return True iff
550      the op is within the graph traversal. This can be used along within_ops,
551      in which case an op is within if it is also in within_ops.
552    control_inputs: A boolean indicating whether control inputs are enabled.
553    control_outputs: An instance of util.ControlOutputs or None. If not None,
554      control outputs are enabled.
555    control_ios:  An instance of util.ControlOutputs or None. If not None, both
556      control inputs and control outputs are enabled. This is equivalent to set
557      control_inputs to True and control_outputs to the util.ControlOutputs
558      instance.
559  Returns:
560    A Python set of all the tf.Operation in the intersection of a forward and a
561      backward walk.
562  Raises:
563    TypeError: if `forward_seed_ops` or `backward_seed_ops` or `within_ops`
564      cannot be converted to a list of `tf.Operation`.
565  """
566  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
567                                               control_ios)
568  forward_ops = get_forward_walk_ops(
569      forward_seed_ops,
570      inclusive=forward_inclusive,
571      within_ops=within_ops,
572      within_ops_fn=within_ops_fn,
573      control_outputs=control_outputs)
574  backward_ops = get_backward_walk_ops(
575      backward_seed_ops,
576      inclusive=backward_inclusive,
577      within_ops=within_ops,
578      within_ops_fn=within_ops_fn,
579      control_inputs=control_inputs)
580  return [op for op in forward_ops if op in backward_ops]
581
582
583def get_walks_union_ops(forward_seed_ops,
584                        backward_seed_ops,
585                        forward_inclusive=True,
586                        backward_inclusive=True,
587                        within_ops=None,
588                        within_ops_fn=None,
589                        control_inputs=False,
590                        control_outputs=None,
591                        control_ios=None):
592  """Return the union of a forward and a backward walk.
593
594  Args:
595    forward_seed_ops: an iterable of operations from which the forward graph
596      walk starts. If a list of tensors is given instead, the seed_ops are set
597      to be the consumers of those tensors.
598    backward_seed_ops: an iterable of operations from which the backward graph
599      walk starts. If a list of tensors is given instead, the seed_ops are set
600      to be the generators of those tensors.
601    forward_inclusive: if True the given forward_seed_ops are also part of the
602      resulting set.
603    backward_inclusive: if True the given backward_seed_ops are also part of the
604      resulting set.
605    within_ops: restrict the search within those operations. If within_ops is
606      None, the search is done within the whole graph.
607    within_ops_fn: if provided, a function on ops that should return True iff
608      the op is within the graph traversal. This can be used along within_ops,
609      in which case an op is within if it is also in within_ops.
610    control_inputs: A boolean indicating whether control inputs are enabled.
611    control_outputs: An instance of util.ControlOutputs or None. If not None,
612      control outputs are enabled.
613    control_ios:  An instance of util.ControlOutputs or None. If not None, both
614      control inputs and control outputs are enabled. This is equivalent to set
615      control_inputs to True and control_outputs to the util.ControlOutputs
616      instance.
617  Returns:
618    A Python set of all the tf.Operation in the union of a forward and a
619      backward walk.
620  Raises:
621    TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be
622      converted to a list of tf.Operation.
623  """
624  control_inputs, control_outputs = check_cios(control_inputs, control_outputs,
625                                               control_ios)
626  forward_ops = get_forward_walk_ops(
627      forward_seed_ops,
628      inclusive=forward_inclusive,
629      within_ops=within_ops,
630      within_ops_fn=within_ops_fn,
631      control_outputs=control_outputs)
632  backward_ops = get_backward_walk_ops(
633      backward_seed_ops,
634      inclusive=backward_inclusive,
635      within_ops=within_ops,
636      within_ops_fn=within_ops_fn,
637      control_inputs=control_inputs)
638  return util.concatenate_unique(forward_ops, backward_ops)
639
640
641def select_ops(*args, **kwargs):
642  """Helper to select operations.
643
644  Args:
645    *args: list of 1) regular expressions (compiled or not) or 2) (array of)
646      `tf.Operation`. `tf.Tensor` instances are silently ignored.
647    **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
648      required when using regex.
649      'positive_filter': an elem if selected only if `positive_filter(elem)` is
650        `True`. This is optional.
651      'restrict_ops_regex': a regular expression is ignored if it doesn't start
652        with the substring "(?#ops)".
653  Returns:
654    A list of `tf.Operation`.
655  Raises:
656    TypeError: if the optional keyword argument graph is not a `tf.Graph`
657      or if an argument in args is not an (array of) `tf.Operation`
658      or an (array of) `tf.Tensor` (silently ignored) or a string
659      or a regular expression.
660    ValueError: if one of the keyword arguments is unexpected or if a regular
661      expression is used without passing a graph as a keyword argument.
662  """
663  # get keywords arguments
664  graph = None
665  positive_filter = None
666  restrict_ops_regex = False
667  for k, v in iteritems(kwargs):
668    if k == "graph":
669      graph = v
670      if graph is not None and not isinstance(graph, tf_ops.Graph):
671        raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
672    elif k == "positive_filter":
673      positive_filter = v
674    elif k == "restrict_ops_regex":
675      restrict_ops_regex = v
676    elif k == "restrict_ts_regex":
677      pass
678    else:
679      raise ValueError("Wrong keywords argument: {}.".format(k))
680
681  ops = []
682
683  for arg in args:
684    if can_be_regex(arg):
685      if graph is None:
686        raise ValueError("Use the keyword argument 'graph' to use regex.")
687      regex = make_regex(arg)
688      if regex.pattern.startswith("(?#ts)"):
689        continue
690      if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"):
691        continue
692      ops_ = filter_ops_from_regex(graph, regex)
693      for op_ in ops_:
694        if op_ not in ops:
695          if positive_filter is None or positive_filter(op_):
696            ops.append(op_)
697    else:
698      ops_aux = util.make_list_of_op(arg, ignore_ts=True)
699      if positive_filter is not None:
700        ops_aux = [op for op in ops_aux if positive_filter(op)]
701      ops_aux = [op for op in ops_aux if op not in ops]
702      ops += ops_aux
703
704  return ops
705
706
707def select_ts(*args, **kwargs):
708  """Helper to select tensors.
709
710  Args:
711    *args: list of 1) regular expressions (compiled or not) or 2) (array of)
712      `tf.Tensor`. `tf.Operation` instances are silently ignored.
713    **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
714      required when using regex.
715      'positive_filter': an elem if selected only if `positive_filter(elem)` is
716        `True`. This is optional.
717      'restrict_ts_regex': a regular expression is ignored if it doesn't start
718        with the substring "(?#ts)".
719  Returns:
720    A list of `tf.Tensor`.
721  Raises:
722    TypeError: if the optional keyword argument graph is not a `tf.Graph`
723      or if an argument in args is not an (array of) `tf.Tensor`
724      or an (array of) `tf.Operation` (silently ignored) or a string
725      or a regular expression.
726    ValueError: if one of the keyword arguments is unexpected or if a regular
727      expression is used without passing a graph as a keyword argument.
728  """
729  # get keywords arguments
730  graph = None
731  positive_filter = None
732  restrict_ts_regex = False
733  for k, v in iteritems(kwargs):
734    if k == "graph":
735      graph = v
736      if graph is not None and not isinstance(graph, tf_ops.Graph):
737        raise TypeError("Expected a tf.Graph, got {}".format(type(graph)))
738    elif k == "positive_filter":
739      positive_filter = v
740    elif k == "restrict_ts_regex":
741      restrict_ts_regex = v
742    elif k == "restrict_ops_regex":
743      pass
744    else:
745      raise ValueError("Wrong keywords argument: {}.".format(k))
746
747  ts = []
748
749  for arg in args:
750    if can_be_regex(arg):
751      if graph is None:
752        raise ValueError("Use the keyword argument 'graph' to use regex.")
753      regex = make_regex(arg)
754      if regex.pattern.startswith("(?#ops)"):
755        continue
756      if restrict_ts_regex and not regex.pattern.startswith("(?#ts)"):
757        continue
758      ts_ = filter_ts_from_regex(graph, regex)
759      for t_ in ts_:
760        if t_ not in ts:
761          if positive_filter is None or positive_filter(t_):
762            ts.append(t_)
763    else:
764      ts_aux = util.make_list_of_t(arg, ignore_ops=True)
765      if positive_filter is not None:
766        ts_aux = [t for t in ts_aux if positive_filter(t)]
767      ts_aux = [t for t in ts_aux if t not in ts]
768      ts += ts_aux
769
770  return ts
771
772
773def select_ops_and_ts(*args, **kwargs):
774  """Helper to select operations and tensors.
775
776  Args:
777    *args: list of 1) regular expressions (compiled or not) or 2) (array of)
778      `tf.Operation` 3) (array of) tf.Tensor. Regular expressions matching
779      tensors must start with the comment `"(?#ts)"`, for instance:
780      `"(?#ts)^foo/.*"`.
781    **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is
782      required when using regex.
783      'positive_filter': an elem if selected only if `positive_filter(elem)` is
784        `True`. This is optional.
785  Returns:
786    A tuple `(ops, ts)` where:
787      `ops` is a list of `tf.Operation`, and
788      `ts` is a list of `tf.Tensor`
789  Raises:
790    TypeError: if the optional keyword argument graph is not a `tf.Graph`
791      or if an argument in args is not an (array of) `tf.Tensor`
792      or an (array of) `tf.Operation` or a string or a regular expression.
793    ValueError: if one of the keyword arguments is unexpected or if a regular
794      expression is used without passing a graph as a keyword argument.
795  """
796  ops = select_ops(*args, restrict_ops_regex=False, **kwargs)
797  ts = select_ts(*args, restrict_ts_regex=True, **kwargs)
798  return ops, ts
799