1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utility functions for comparing proto2 messages in Python. 17 18ProtoEq() compares two proto2 messages for equality. 19 20ClearDefaultValuedFields() recursively clears the fields that are set to their 21default values. This is useful for comparing protocol buffers where the 22semantics of unset fields and default valued fields are the same. 23 24assertProtoEqual() is useful for unit tests. It produces much more helpful 25output than assertEqual() for proto2 messages, e.g. this: 26 27 outer { 28 inner { 29- strings: "x" 30? ^ 31+ strings: "y" 32? ^ 33 } 34 } 35 36...compared to the default output from assertEqual() that looks like this: 37 38AssertionError: <my.Msg object at 0x9fb353c> != <my.Msg object at 0x9fb35cc> 39 40Call it inside your unit test's googletest.TestCase subclasses like this: 41 42 from tensorflow.python.util.protobuf import compare 43 44 class MyTest(googletest.TestCase): 45 ... 46 def testXXX(self): 47 ... 48 compare.assertProtoEqual(self, a, b) 49 50Alternatively: 51 52 from tensorflow.python.util.protobuf import compare 53 54 class MyTest(compare.ProtoAssertions, googletest.TestCase): 55 ... 56 def testXXX(self): 57 ... 58 self.assertProtoEqual(a, b) 59""" 60 61from __future__ import absolute_import 62from __future__ import division 63from __future__ import print_function 64 65import difflib 66 67import six 68 69from google.protobuf import descriptor 70from google.protobuf import descriptor_pool 71from google.protobuf import message 72from google.protobuf import text_format 73 74from ..compat import collections_abc 75 76 77def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=invalid-name 78 normalize_numbers=False, msg=None): 79 """Fails with a useful error if a and b aren't equal. 80 81 Comparison of repeated fields matches the semantics of 82 unittest.TestCase.assertEqual(), ie order and extra duplicates fields matter. 83 84 Args: 85 self: googletest.TestCase 86 a: proto2 PB instance, or text string representing one. 87 b: proto2 PB instance -- message.Message or subclass thereof. 88 check_initialized: boolean, whether to fail if either a or b isn't 89 initialized. 90 normalize_numbers: boolean, whether to normalize types and precision of 91 numbers before comparison. 92 msg: if specified, is used as the error message on failure. 93 """ 94 pool = descriptor_pool.Default() 95 if isinstance(a, six.string_types): 96 a = text_format.Merge(a, b.__class__(), descriptor_pool=pool) 97 98 for pb in a, b: 99 if check_initialized: 100 errors = pb.FindInitializationErrors() 101 if errors: 102 self.fail('Initialization errors: %s\n%s' % (errors, pb)) 103 if normalize_numbers: 104 NormalizeNumberFields(pb) 105 106 a_str = text_format.MessageToString(a, descriptor_pool=pool) 107 b_str = text_format.MessageToString(b, descriptor_pool=pool) 108 109 # Some Python versions would perform regular diff instead of multi-line 110 # diff if string is longer than 2**16. We substitute this behavior 111 # with a call to unified_diff instead to have easier-to-read diffs. 112 # For context, see: https://bugs.python.org/issue11763. 113 if len(a_str) < 2**16 and len(b_str) < 2**16: 114 self.assertMultiLineEqual(a_str, b_str, msg=msg) 115 else: 116 diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True), 117 b_str.splitlines(True))) 118 self.fail('%s : %s' % (msg, diff)) 119 120 121def NormalizeNumberFields(pb): 122 """Normalizes types and precisions of number fields in a protocol buffer. 123 124 Due to subtleties in the python protocol buffer implementation, it is possible 125 for values to have different types and precision depending on whether they 126 were set and retrieved directly or deserialized from a protobuf. This function 127 normalizes integer values to ints and longs based on width, 32-bit floats to 128 five digits of precision to account for python always storing them as 64-bit, 129 and ensures doubles are floating point for when they're set to integers. 130 131 Modifies pb in place. Recurses into nested objects. 132 133 Args: 134 pb: proto2 message. 135 136 Returns: 137 the given pb, modified in place. 138 """ 139 for desc, values in pb.ListFields(): 140 is_repeated = True 141 if desc.label is not descriptor.FieldDescriptor.LABEL_REPEATED: 142 is_repeated = False 143 values = [values] 144 145 normalized_values = None 146 147 # We force 32-bit values to int and 64-bit values to long to make 148 # alternate implementations where the distinction is more significant 149 # (e.g. the C++ implementation) simpler. 150 if desc.type in (descriptor.FieldDescriptor.TYPE_INT64, 151 descriptor.FieldDescriptor.TYPE_UINT64, 152 descriptor.FieldDescriptor.TYPE_SINT64): 153 normalized_values = [int(x) for x in values] 154 elif desc.type in (descriptor.FieldDescriptor.TYPE_INT32, 155 descriptor.FieldDescriptor.TYPE_UINT32, 156 descriptor.FieldDescriptor.TYPE_SINT32, 157 descriptor.FieldDescriptor.TYPE_ENUM): 158 normalized_values = [int(x) for x in values] 159 elif desc.type == descriptor.FieldDescriptor.TYPE_FLOAT: 160 normalized_values = [round(x, 6) for x in values] 161 elif desc.type == descriptor.FieldDescriptor.TYPE_DOUBLE: 162 normalized_values = [round(float(x), 7) for x in values] 163 164 if normalized_values is not None: 165 if is_repeated: 166 pb.ClearField(desc.name) 167 getattr(pb, desc.name).extend(normalized_values) 168 else: 169 setattr(pb, desc.name, normalized_values[0]) 170 171 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE or 172 desc.type == descriptor.FieldDescriptor.TYPE_GROUP): 173 if (desc.type == descriptor.FieldDescriptor.TYPE_MESSAGE and 174 desc.message_type.has_options and 175 desc.message_type.GetOptions().map_entry): 176 # This is a map, only recurse if the values have a message type. 177 if (desc.message_type.fields_by_number[2].type == 178 descriptor.FieldDescriptor.TYPE_MESSAGE): 179 for v in six.itervalues(values): 180 NormalizeNumberFields(v) 181 else: 182 for v in values: 183 # recursive step 184 NormalizeNumberFields(v) 185 186 return pb 187 188 189def _IsMap(value): 190 return isinstance(value, collections_abc.Mapping) 191 192 193def _IsRepeatedContainer(value): 194 if isinstance(value, six.string_types): 195 return False 196 try: 197 iter(value) 198 return True 199 except TypeError: 200 return False 201 202 203def ProtoEq(a, b): 204 """Compares two proto2 objects for equality. 205 206 Recurses into nested messages. Uses list (not set) semantics for comparing 207 repeated fields, ie duplicates and order matter. 208 209 Args: 210 a: A proto2 message or a primitive. 211 b: A proto2 message or a primitive. 212 213 Returns: 214 `True` if the messages are equal. 215 """ 216 def Format(pb): 217 """Returns a dictionary or unchanged pb bases on its type. 218 219 Specifically, this function returns a dictionary that maps tag 220 number (for messages) or element index (for repeated fields) to 221 value, or just pb unchanged if it's neither. 222 223 Args: 224 pb: A proto2 message or a primitive. 225 Returns: 226 A dict or unchanged pb. 227 """ 228 if isinstance(pb, message.Message): 229 return dict((desc.number, value) for desc, value in pb.ListFields()) 230 elif _IsMap(pb): 231 return dict(pb.items()) 232 elif _IsRepeatedContainer(pb): 233 return dict(enumerate(list(pb))) 234 else: 235 return pb 236 237 a, b = Format(a), Format(b) 238 239 # Base case 240 if not isinstance(a, dict) or not isinstance(b, dict): 241 return a == b 242 243 # This list performs double duty: it compares two messages by tag value *or* 244 # two repeated fields by element, in order. the magic is in the format() 245 # function, which converts them both to the same easily comparable format. 246 for tag in sorted(set(a.keys()) | set(b.keys())): 247 if tag not in a or tag not in b: 248 return False 249 else: 250 # Recursive step 251 if not ProtoEq(a[tag], b[tag]): 252 return False 253 254 # Didn't find any values that differed, so they're equal! 255 return True 256 257 258class ProtoAssertions(object): 259 """Mix this into a googletest.TestCase class to get proto2 assertions. 260 261 Usage: 262 263 class SomeTestCase(compare.ProtoAssertions, googletest.TestCase): 264 ... 265 def testSomething(self): 266 ... 267 self.assertProtoEqual(a, b) 268 269 See module-level definitions for method documentation. 270 """ 271 272 # pylint: disable=invalid-name 273 def assertProtoEqual(self, *args, **kwargs): 274 return assertProtoEqual(self, *args, **kwargs) 275