1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities that match patterns in a tf.Graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import abc
22import itertools
23
24import six
25
26
27@six.add_metaclass(abc.ABCMeta)
28class Pattern(object):
29  """The parent class of all patterns (e.g. OpTypePattern and OneofPattern)."""
30
31  @abc.abstractmethod
32  def match(self, op, tensor):
33    """Returns the result of matching op/tensor against this pattern."""
34    raise NotImplementedError('Method "match" not implemented.')
35
36
37class OpTypePattern(Pattern):
38  """A tree pattern that matches TF expressions with certain op types."""
39
40  def __init__(self, op_type, name=None, inputs=None, ordered_inputs=True):
41    """Initializes an OpTypePattern.
42
43    Args:
44      op_type: string that specifies the allowed types of the root. It can be
45        (1) an op type, e.g. 'Conv2D',
46        (2) '*', i.e. wildcard, or
47        (3) multiple op types separated by '|', e.g., 'Relu|Relu6'.
48        We could use regex strings, which might be worthwhile when we have many
49        similar TF op types.
50      name: Optional string. The name of the pattern that can be looked up in
51        MatchResult.
52      inputs: Optional list of `Pattern`s or strings that specify the
53        patterns for the inputs of a matching op. If None, this pattern accepts
54        any inputs of a matching op.
55      ordered_inputs: Defaults to True. If False, will match any op that
56        matches a permutation of the inputs.
57
58    Raises:
59      ValueError: if too many inputs are provided when order_inputs is False.
60    """
61    self._op_type = op_type
62    self._name = name
63    if inputs is None:
64      inputs = []
65    if len(inputs) > 8:
66      raise ValueError(
67          'Only < 8 inputs are allowed when ordered_inputs is False.')
68    self._inputs = [
69        input_pattern
70        if isinstance(input_pattern, Pattern) else OpTypePattern(input_pattern)
71        for input_pattern in inputs
72    ]
73    self._ordered_inputs = ordered_inputs
74
75  @property
76  def name(self):
77    return self._name
78
79  def match(self, op, tensor):
80    if self._op_type != '*':
81      if op.type not in self._op_type.split('|'):
82        return None
83
84    match_result = MatchResult()
85    match_result.add(self, op, tensor)
86
87    if not self._inputs:
88      # If pattern.inputs is empty, skips the rest and accepts all the inputs.
89      return match_result
90
91    if len(op.inputs) != len(self._inputs):
92      return None
93
94    input_patterns_list = [self._inputs]
95    # If order doesn't matter for the inputs, then make sure we match at least
96    # one permutation of the inputs.
97    if not self._ordered_inputs:
98      input_patterns_list = list(itertools.permutations(self._inputs))
99
100    for input_patterns in input_patterns_list:
101      match_failed = False
102      for input_tensor, input_pattern in zip(op.inputs, input_patterns):
103        input_match_result = input_pattern.match(input_tensor.op, input_tensor)
104        if input_match_result is None:
105          match_failed = True
106          break
107        match_result.merge_from(input_match_result)
108      if not match_failed:
109        return match_result
110    return None
111
112
113class OneofPattern(Pattern):
114  """Matches one of the given sub-patterns."""
115
116  def __init__(self, sub_patterns):
117    self._sub_patterns = sub_patterns
118
119  def match(self, op, tensor):
120    for sub_pattern in self._sub_patterns:
121      match_result = sub_pattern.match(op, tensor)
122      if match_result is not None:
123        return match_result
124    return None
125
126
127class MatchResult(object):
128  r"""Encapsulates the result of a match done by GraphMatcher.
129
130  MatchResult contains a map from Pattern to the matching op and tensor.
131  When the matching op has multiple output tensors, the matching tensor is the
132  output tensor used by the matching op of the parent pattern. E.g., when we
133  match graph
134
135      -         +
136     / \y0   y1/ \
137    x    split    z
138          |
139          y         (nodes are ops; edges are going up)
140
141  against add_pattern defined as
142
143    y1_pattern = OpTypePattern('*')
144    z_pattern = OpTypePattern('*')
145    add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern])
146
147  the matching op of `y1_pattern` is `split`, and the matching tensor of
148  `y1_pattern`
149  is `y1` not `y0`.
150  """
151
152  def __init__(self):
153    self._pattern_to_op_tensor = {}
154    self._name_to_pattern = {}
155
156  def add(self, pattern, op, tensor):
157    self._pattern_to_op_tensor[pattern] = op, tensor
158    if pattern.name is not None:
159      if pattern.name in self._name_to_pattern:
160        raise ValueError(
161            'Name %s is already bound to another pattern' % pattern.name)
162      self._name_to_pattern[pattern.name] = pattern
163
164  def _to_pattern(self, pattern_or_name):
165    if isinstance(pattern_or_name, Pattern):
166      return pattern_or_name
167
168    if isinstance(pattern_or_name, str):
169      if pattern_or_name not in self._name_to_pattern:
170        return None
171      return self._name_to_pattern[pattern_or_name]
172
173    raise ValueError('pattern_or_name has type %s. Expect Pattern or str.' %
174                     type(pattern_or_name))
175
176  def _get_op_tensor(self, pattern_or_name):
177    pattern = self._to_pattern(pattern_or_name)
178    if pattern is None:
179      return None
180
181    if pattern not in self._pattern_to_op_tensor:
182      return None
183
184    return self._pattern_to_op_tensor[pattern]
185
186  def get_op(self, pattern_or_name):
187    op_tensor = self._get_op_tensor(pattern_or_name)
188    return op_tensor[0] if op_tensor else None
189
190  def get_tensor(self, pattern_or_name):
191    op_tensor = self._get_op_tensor(pattern_or_name)
192    return op_tensor[1] if op_tensor else None
193
194  def merge_from(self, other_match_result):
195    # pylint: disable=protected-access
196    self._pattern_to_op_tensor.update(other_match_result._pattern_to_op_tensor)
197    self._name_to_pattern.update(other_match_result._name_to_pattern)
198    # pylint: enable=protected-access
199
200
201class GraphMatcher(object):
202  """Checks if a particular subgraph matches a given pattern."""
203
204  def __init__(self, pattern):
205    """Initializes a GraphMatcher.
206
207    Args:
208      pattern: The `Pattern` against which `GraphMatcher` matches
209        subgraphs.
210    """
211    self._pattern = pattern
212
213  def _match_pattern(self, pattern, op, tensor):
214    """Returns whether an TF expression rooted at `op` matches `pattern`.
215
216    If there is a match, adds to `self._match_result` the matching op and tensor
217    with key `pattern`.
218
219    Args:
220      pattern: An `Pattern`.
221      op: A `tf.Operation` to match against the pattern.
222      tensor: the output `tf.Tensor` of `op` that is used by the matching op of
223        `pattern`'s parent. Can be None if `pattern` is already the root of the
224        pattern tree.
225
226    Returns:
227      True if an TF expression rooted at `op` matches `pattern`.
228    """
229    match_result = pattern.match(op, tensor)
230    if match_result is None:
231      return False
232    self._match_result.merge_from(match_result)
233    return True
234
235  def match_op(self, op):
236    """Matches `op` against `self._pattern`.
237
238    Args:
239      op: `tf.Operation` to match against the pattern.
240
241    Returns:
242      Returns a `MatchResult` if `op` matches the pattern; otherwise, returns
243      None.
244    """
245    self._match_result = MatchResult()
246    if not self._match_pattern(self._pattern, op, tensor=None):
247      return None
248    return self._match_result
249
250  def match_ops(self, ops):
251    """Matches each operation in `ops` against `self._pattern`.
252
253    Args:
254      ops: collection of `tf.Operation` to match against the pattern.
255
256    Yields:
257      `MatchResult` for each `tf.Operation` that matches the pattern.
258    """
259    for op in ops:
260      match_result = self.match_op(op)
261      if match_result:
262        yield match_result
263
264  def match_graph(self, graph):
265    """Matches each operation in `graph` against `self._pattern`.
266
267    Args:
268      graph: `tf.Graph` containing operations to match.
269
270    Yields:
271      `MatchResult` for each `tf.Operation` in `graph` that matches the pattern.
272    """
273    # Python 3.3.2+ implements `yield from`, but for now:
274    for match_result in self.match_ops(graph.get_operations()):
275      yield match_result
276