1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Simple graph matching functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from six import string_types 22 23from tensorflow.contrib.graph_editor import select 24from tensorflow.python.framework import ops as tf_ops 25 26__all__ = [ 27 "op_type", 28 "OpMatcher", 29] 30 31 32def _make_graph_match(graph_match): 33 """Convert to a OpMatcher instance.""" 34 if graph_match is None: 35 return None 36 if not isinstance(graph_match, OpMatcher): 37 graph_match = OpMatcher(graph_match) 38 return graph_match 39 40 41def op_type(op_types, op=None): 42 """Check if an op is of the given type. 43 44 Args: 45 op_types: tuple of strings containing the types to check against. 46 For instance: ("Add", "Const") 47 op: the operation to check (or None). 48 Returns: 49 if op is not None, return True if the op is of the correct type. 50 if op is None, return a lambda function which does the type checking. 51 """ 52 if isinstance(op_types, string_types): 53 op_types = (op_types) 54 if op is None: 55 return lambda op: op.node_def.op in op_types 56 else: 57 return op.node_def.op in op_types 58 59 60class OpMatcher(object): 61 """Graph match class.""" 62 63 def __init__(self, positive_filter): 64 """Graph match constructor.""" 65 self.positive_filters = [] 66 self.input_op_matches = None 67 self.control_input_op_matches = None 68 self.output_op_matches = None 69 positive_filter = self._finalize_positive_filter(positive_filter) 70 self.positive_filters.append(positive_filter) 71 72 def _finalize_positive_filter(self, elem): 73 """Convert to a filter function.""" 74 if select.can_be_regex(elem): 75 regex_ = select.make_regex(elem) 76 return lambda op, regex=regex_: regex.search(op.name) is not None 77 elif isinstance(elem, tf_ops.Operation): 78 return lambda op, match_op=elem: op is match_op 79 elif callable(elem): 80 return elem 81 elif elem is True: 82 return lambda op: True 83 else: 84 raise ValueError("Cannot finalize the positive filter: {}".format(elem)) 85 86 def __call__(self, op): 87 """Evaluate if the op matches or not.""" 88 if not isinstance(op, tf_ops.Operation): 89 raise TypeError("Expect tf.Operation, got: {}".format(type(op))) 90 for positive_filter in self.positive_filters: 91 if not positive_filter(op): 92 return False 93 if self.input_op_matches is not None: 94 if len(op.inputs) != len(self.input_op_matches): 95 return False 96 for input_t, input_op_match in zip(op.inputs, self.input_op_matches): 97 if input_op_match is None: 98 continue 99 if not input_op_match(input_t.op): 100 return False 101 if self.control_input_op_matches is not None: 102 if len(op.control_inputs) != len(self.control_input_op_matches): 103 return False 104 for cinput_op, cinput_op_match in zip(op.control_inputs, 105 self.control_input_op_matches): 106 if cinput_op_match is None: 107 continue 108 if not cinput_op_match(cinput_op): 109 return False 110 if self.output_op_matches is not None: 111 if len(op.outputs) != len(self.output_op_matches): 112 return False 113 for output_t, output_op_matches in zip(op.outputs, 114 self.output_op_matches): 115 if output_op_matches is None: 116 continue 117 if len(output_t.consumers()) != len(output_op_matches): 118 return False 119 for consumer_op, consumer_op_match in zip(output_t.consumers(), 120 output_op_matches): 121 if consumer_op_match is None: 122 continue 123 if not consumer_op_match(consumer_op): 124 return False 125 return True 126 127 def input_ops(self, *args): 128 """Add input matches.""" 129 if self.input_op_matches is not None: 130 raise ValueError("input_op_matches is already set.") 131 self.input_op_matches = [] 132 for input_match in args: 133 self.input_op_matches.append(_make_graph_match(input_match)) 134 return self 135 136 def control_input_ops(self, *args): 137 """Add input matches.""" 138 if self.control_input_op_matches is not None: 139 raise ValueError("control_input_op_matches is already set.") 140 self.control_input_op_matches = [] 141 for input_match in args: 142 self.control_input_op_matches.append(_make_graph_match(input_match)) 143 return self 144 145 def output_ops(self, *args): 146 """Add output matches.""" 147 if self.output_op_matches is not None: 148 raise ValueError("output_op_matches is already set.") 149 self.output_op_matches = [] 150 for consumer_op_matches in args: 151 if consumer_op_matches is None: 152 self.output_op_matches.append(None) 153 if not isinstance(consumer_op_matches, list): 154 consumer_op_matches = [consumer_op_matches] 155 consumer_op_matches = [_make_graph_match(consumer_op_match) 156 for consumer_op_match in consumer_op_matches] 157 self.output_op_matches.append(consumer_op_matches) 158 return self 159