1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Simple graph matching functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from six import string_types
22
23from tensorflow.contrib.graph_editor import select
24from tensorflow.python.framework import ops as tf_ops
25
26__all__ = [
27    "op_type",
28    "OpMatcher",
29]
30
31
32def _make_graph_match(graph_match):
33  """Convert to a OpMatcher instance."""
34  if graph_match is None:
35    return None
36  if not isinstance(graph_match, OpMatcher):
37    graph_match = OpMatcher(graph_match)
38  return graph_match
39
40
41def op_type(op_types, op=None):
42  """Check if an op is of the given type.
43
44  Args:
45    op_types: tuple of strings containing the types to check against.
46      For instance: ("Add", "Const")
47    op: the operation to check (or None).
48  Returns:
49    if op is not None, return True if the op is of the correct type.
50    if op is None, return a lambda function which does the type checking.
51  """
52  if isinstance(op_types, string_types):
53    op_types = (op_types)
54  if op is None:
55    return lambda op: op.node_def.op in op_types
56  else:
57    return op.node_def.op in op_types
58
59
60class OpMatcher(object):
61  """Graph match class."""
62
63  def __init__(self, positive_filter):
64    """Graph match constructor."""
65    self.positive_filters = []
66    self.input_op_matches = None
67    self.control_input_op_matches = None
68    self.output_op_matches = None
69    positive_filter = self._finalize_positive_filter(positive_filter)
70    self.positive_filters.append(positive_filter)
71
72  def _finalize_positive_filter(self, elem):
73    """Convert to a filter function."""
74    if select.can_be_regex(elem):
75      regex_ = select.make_regex(elem)
76      return lambda op, regex=regex_: regex.search(op.name) is not None
77    elif isinstance(elem, tf_ops.Operation):
78      return lambda op, match_op=elem: op is match_op
79    elif callable(elem):
80      return elem
81    elif elem is True:
82      return lambda op: True
83    else:
84      raise ValueError("Cannot finalize the positive filter: {}".format(elem))
85
86  def __call__(self, op):
87    """Evaluate if the op matches or not."""
88    if not isinstance(op, tf_ops.Operation):
89      raise TypeError("Expect tf.Operation, got: {}".format(type(op)))
90    for positive_filter in self.positive_filters:
91      if not positive_filter(op):
92        return False
93    if self.input_op_matches is not None:
94      if len(op.inputs) != len(self.input_op_matches):
95        return False
96      for input_t, input_op_match in zip(op.inputs, self.input_op_matches):
97        if input_op_match is None:
98          continue
99        if not input_op_match(input_t.op):
100          return False
101    if self.control_input_op_matches is not None:
102      if len(op.control_inputs) != len(self.control_input_op_matches):
103        return False
104      for cinput_op, cinput_op_match in zip(op.control_inputs,
105                                            self.control_input_op_matches):
106        if cinput_op_match is None:
107          continue
108        if not cinput_op_match(cinput_op):
109          return False
110    if self.output_op_matches is not None:
111      if len(op.outputs) != len(self.output_op_matches):
112        return False
113      for output_t, output_op_matches in zip(op.outputs,
114                                             self.output_op_matches):
115        if output_op_matches is None:
116          continue
117        if len(output_t.consumers()) != len(output_op_matches):
118          return False
119        for consumer_op, consumer_op_match in zip(output_t.consumers(),
120                                                  output_op_matches):
121          if consumer_op_match is None:
122            continue
123          if not consumer_op_match(consumer_op):
124            return False
125    return True
126
127  def input_ops(self, *args):
128    """Add input matches."""
129    if self.input_op_matches is not None:
130      raise ValueError("input_op_matches is already set.")
131    self.input_op_matches = []
132    for input_match in args:
133      self.input_op_matches.append(_make_graph_match(input_match))
134    return self
135
136  def control_input_ops(self, *args):
137    """Add input matches."""
138    if self.control_input_op_matches is not None:
139      raise ValueError("control_input_op_matches is already set.")
140    self.control_input_op_matches = []
141    for input_match in args:
142      self.control_input_op_matches.append(_make_graph_match(input_match))
143    return self
144
145  def output_ops(self, *args):
146    """Add output matches."""
147    if self.output_op_matches is not None:
148      raise ValueError("output_op_matches is already set.")
149    self.output_op_matches = []
150    for consumer_op_matches in args:
151      if consumer_op_matches is None:
152        self.output_op_matches.append(None)
153      if not isinstance(consumer_op_matches, list):
154        consumer_op_matches = [consumer_op_matches]
155      consumer_op_matches = [_make_graph_match(consumer_op_match)
156                             for consumer_op_match in consumer_op_matches]
157      self.output_op_matches.append(consumer_op_matches)
158    return self
159