1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for python.util.protobuf.compare.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import copy 22import re 23import textwrap 24 25import six 26 27from google.protobuf import text_format 28 29from tensorflow.python.platform import googletest 30from tensorflow.python.util.protobuf import compare 31from tensorflow.python.util.protobuf import compare_test_pb2 32 33 34def LargePbs(*args): 35 """Converts ASCII string Large PBs to messages.""" 36 pbs = [] 37 for arg in args: 38 pb = compare_test_pb2.Large() 39 text_format.Merge(arg, pb) 40 pbs.append(pb) 41 42 return pbs 43 44 45class ProtoEqTest(googletest.TestCase): 46 47 def assertNotEquals(self, a, b): 48 """Asserts that ProtoEq says a != b.""" 49 a, b = LargePbs(a, b) 50 googletest.TestCase.assertEqual(self, compare.ProtoEq(a, b), False) 51 52 def assertEqual(self, a, b): 53 """Asserts that ProtoEq says a == b.""" 54 a, b = LargePbs(a, b) 55 googletest.TestCase.assertEqual(self, compare.ProtoEq(a, b), True) 56 57 def testPrimitives(self): 58 googletest.TestCase.assertEqual(self, True, compare.ProtoEq('a', 'a')) 59 googletest.TestCase.assertEqual(self, False, compare.ProtoEq('b', 'a')) 60 61 def testEmpty(self): 62 self.assertEqual('', '') 63 64 def testPrimitiveFields(self): 65 self.assertNotEquals('string_: "a"', '') 66 self.assertEqual('string_: "a"', 'string_: "a"') 67 self.assertNotEquals('string_: "b"', 'string_: "a"') 68 self.assertNotEquals('string_: "ab"', 'string_: "aa"') 69 70 self.assertNotEquals('int64_: 0', '') 71 self.assertEqual('int64_: 0', 'int64_: 0') 72 self.assertNotEquals('int64_: -1', '') 73 self.assertNotEquals('int64_: 1', 'int64_: 0') 74 self.assertNotEquals('int64_: 0', 'int64_: -1') 75 76 self.assertNotEquals('float_: 0.0', '') 77 self.assertEqual('float_: 0.0', 'float_: 0.0') 78 self.assertNotEquals('float_: -0.1', '') 79 self.assertNotEquals('float_: 3.14', 'float_: 0') 80 self.assertNotEquals('float_: 0', 'float_: -0.1') 81 self.assertEqual('float_: -0.1', 'float_: -0.1') 82 83 self.assertNotEquals('bool_: true', '') 84 self.assertNotEquals('bool_: false', '') 85 self.assertNotEquals('bool_: true', 'bool_: false') 86 self.assertEqual('bool_: false', 'bool_: false') 87 self.assertEqual('bool_: true', 'bool_: true') 88 89 self.assertNotEquals('enum_: A', '') 90 self.assertNotEquals('enum_: B', 'enum_: A') 91 self.assertNotEquals('enum_: C', 'enum_: B') 92 self.assertEqual('enum_: C', 'enum_: C') 93 94 def testRepeatedPrimitives(self): 95 self.assertNotEquals('int64s: 0', '') 96 self.assertEqual('int64s: 0', 'int64s: 0') 97 self.assertNotEquals('int64s: 1', 'int64s: 0') 98 self.assertNotEquals('int64s: 0 int64s: 0', '') 99 self.assertNotEquals('int64s: 0 int64s: 0', 'int64s: 0') 100 self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0') 101 self.assertNotEquals('int64s: 0 int64s: 1', 'int64s: 0') 102 self.assertNotEquals('int64s: 1', 'int64s: 0 int64s: 2') 103 self.assertNotEquals('int64s: 2 int64s: 0', 'int64s: 1') 104 self.assertEqual('int64s: 0 int64s: 0', 'int64s: 0 int64s: 0') 105 self.assertEqual('int64s: 0 int64s: 1', 'int64s: 0 int64s: 1') 106 self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 0') 107 self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 1') 108 self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 2') 109 self.assertNotEquals('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0') 110 self.assertNotEquals('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0 int64s: 2') 111 112 def testMessage(self): 113 self.assertNotEquals('small <>', '') 114 self.assertEqual('small <>', 'small <>') 115 self.assertNotEquals('small < strings: "a" >', '') 116 self.assertNotEquals('small < strings: "a" >', 'small <>') 117 self.assertEqual('small < strings: "a" >', 'small < strings: "a" >') 118 self.assertNotEquals('small < strings: "b" >', 'small < strings: "a" >') 119 self.assertNotEquals('small < strings: "a" strings: "b" >', 120 'small < strings: "a" >') 121 122 self.assertNotEquals('string_: "a"', 'small <>') 123 self.assertNotEquals('string_: "a"', 'small < strings: "b" >') 124 self.assertNotEquals('string_: "a"', 'small < strings: "b" strings: "c" >') 125 self.assertNotEquals('string_: "a" small <>', 'small <>') 126 self.assertNotEquals('string_: "a" small <>', 'small < strings: "b" >') 127 self.assertEqual('string_: "a" small <>', 'string_: "a" small <>') 128 self.assertNotEquals('string_: "a" small < strings: "a" >', 129 'string_: "a" small <>') 130 self.assertEqual('string_: "a" small < strings: "a" >', 131 'string_: "a" small < strings: "a" >') 132 self.assertNotEquals('string_: "a" small < strings: "a" >', 133 'int64_: 1 small < strings: "a" >') 134 self.assertNotEquals('string_: "a" small < strings: "a" >', 'int64_: 1') 135 self.assertNotEquals('string_: "a"', 'int64_: 1 small < strings: "a" >') 136 self.assertNotEquals('string_: "a" int64_: 0 small < strings: "a" >', 137 'int64_: 1 small < strings: "a" >') 138 self.assertNotEquals('string_: "a" int64_: 1 small < strings: "a" >', 139 'string_: "a" int64_: 0 small < strings: "a" >') 140 self.assertEqual('string_: "a" int64_: 0 small < strings: "a" >', 141 'string_: "a" int64_: 0 small < strings: "a" >') 142 143 def testNestedMessage(self): 144 self.assertNotEquals('medium <>', '') 145 self.assertEqual('medium <>', 'medium <>') 146 self.assertNotEquals('medium < smalls <> >', 'medium <>') 147 self.assertEqual('medium < smalls <> >', 'medium < smalls <> >') 148 self.assertNotEquals('medium < smalls <> smalls <> >', 149 'medium < smalls <> >') 150 self.assertEqual('medium < smalls <> smalls <> >', 151 'medium < smalls <> smalls <> >') 152 153 self.assertNotEquals('medium < int32s: 0 >', 'medium < smalls <> >') 154 155 self.assertNotEquals('medium < smalls < strings: "a"> >', 156 'medium < smalls <> >') 157 158 def testTagOrder(self): 159 """Tests that different fields are ordered by tag number. 160 161 For reference, here are the relevant tag numbers from compare_test.proto: 162 optional string string_ = 1; 163 optional int64 int64_ = 2; 164 optional float float_ = 3; 165 optional Small small = 8; 166 optional Medium medium = 7; 167 optional Small small = 8; 168 """ 169 self.assertNotEquals('string_: "a" ', 170 ' int64_: 1 ') 171 self.assertNotEquals('string_: "a" int64_: 2 ', 172 ' int64_: 1 ') 173 self.assertNotEquals('string_: "b" int64_: 1 ', 174 'string_: "a" int64_: 2 ') 175 self.assertEqual('string_: "a" int64_: 1 ', 176 'string_: "a" int64_: 1 ') 177 self.assertNotEquals('string_: "a" int64_: 1 float_: 0.0', 178 'string_: "a" int64_: 1 ') 179 self.assertEqual('string_: "a" int64_: 1 float_: 0.0', 180 'string_: "a" int64_: 1 float_: 0.0') 181 self.assertNotEquals('string_: "a" int64_: 1 float_: 0.1', 182 'string_: "a" int64_: 1 float_: 0.0') 183 self.assertNotEquals('string_: "a" int64_: 2 float_: 0.0', 184 'string_: "a" int64_: 1 float_: 0.1') 185 self.assertNotEquals('string_: "a" ', 186 ' int64_: 1 float_: 0.1') 187 self.assertNotEquals('string_: "a" float_: 0.0', 188 ' int64_: 1 ') 189 self.assertNotEquals('string_: "b" float_: 0.0', 190 'string_: "a" int64_: 1 ') 191 192 self.assertNotEquals('string_: "a"', 'small < strings: "a" >') 193 self.assertNotEquals('string_: "a" small < strings: "a" >', 194 'small < strings: "b" >') 195 self.assertNotEquals('string_: "a" small < strings: "b" >', 196 'string_: "a" small < strings: "a" >') 197 self.assertEqual('string_: "a" small < strings: "a" >', 198 'string_: "a" small < strings: "a" >') 199 200 self.assertNotEquals('string_: "a" medium <>', 201 'string_: "a" small < strings: "a" >') 202 self.assertNotEquals('string_: "a" medium < smalls <> >', 203 'string_: "a" small < strings: "a" >') 204 self.assertNotEquals('medium <>', 'small < strings: "a" >') 205 self.assertNotEquals('medium <> small <>', 'small < strings: "a" >') 206 self.assertNotEquals('medium < smalls <> >', 'small < strings: "a" >') 207 self.assertNotEquals('medium < smalls < strings: "a" > >', 208 'small < strings: "b" >') 209 210 211class NormalizeNumbersTest(googletest.TestCase): 212 """Tests for NormalizeNumberFields().""" 213 214 def testNormalizesInts(self): 215 pb = compare_test_pb2.Large() 216 pb.int64_ = 4 217 compare.NormalizeNumberFields(pb) 218 self.assertTrue(isinstance(pb.int64_, six.integer_types)) 219 220 pb.int64_ = 4 221 compare.NormalizeNumberFields(pb) 222 self.assertTrue(isinstance(pb.int64_, six.integer_types)) 223 224 pb.int64_ = 9999999999999999 225 compare.NormalizeNumberFields(pb) 226 self.assertTrue(isinstance(pb.int64_, six.integer_types)) 227 228 def testNormalizesRepeatedInts(self): 229 pb = compare_test_pb2.Large() 230 pb.int64s.extend([1, 400, 999999999999999]) 231 compare.NormalizeNumberFields(pb) 232 self.assertTrue(isinstance(pb.int64s[0], six.integer_types)) 233 self.assertTrue(isinstance(pb.int64s[1], six.integer_types)) 234 self.assertTrue(isinstance(pb.int64s[2], six.integer_types)) 235 236 def testNormalizesFloats(self): 237 pb1 = compare_test_pb2.Large() 238 pb1.float_ = 1.2314352351231 239 pb2 = compare_test_pb2.Large() 240 pb2.float_ = 1.231435 241 self.assertNotEqual(pb1.float_, pb2.float_) 242 compare.NormalizeNumberFields(pb1) 243 compare.NormalizeNumberFields(pb2) 244 self.assertEqual(pb1.float_, pb2.float_) 245 246 def testNormalizesRepeatedFloats(self): 247 pb = compare_test_pb2.Large() 248 pb.medium.floats.extend([0.111111111, 0.111111]) 249 compare.NormalizeNumberFields(pb) 250 for value in pb.medium.floats: 251 self.assertAlmostEqual(0.111111, value) 252 253 def testNormalizesDoubles(self): 254 pb1 = compare_test_pb2.Large() 255 pb1.double_ = 1.2314352351231 256 pb2 = compare_test_pb2.Large() 257 pb2.double_ = 1.2314352 258 self.assertNotEqual(pb1.double_, pb2.double_) 259 compare.NormalizeNumberFields(pb1) 260 compare.NormalizeNumberFields(pb2) 261 self.assertEqual(pb1.double_, pb2.double_) 262 263 def testNormalizesMaps(self): 264 pb = compare_test_pb2.WithMap() 265 pb.value_message[4].strings.extend(['a', 'b', 'c']) 266 pb.value_string['d'] = 'e' 267 compare.NormalizeNumberFields(pb) 268 269 270class AssertTest(googletest.TestCase): 271 """Tests assertProtoEqual().""" 272 273 def assertProtoEqual(self, a, b, **kwargs): 274 if isinstance(a, six.string_types) and isinstance(b, six.string_types): 275 a, b = LargePbs(a, b) 276 compare.assertProtoEqual(self, a, b, **kwargs) 277 278 def assertAll(self, a, **kwargs): 279 """Checks that all possible asserts pass.""" 280 self.assertProtoEqual(a, a, **kwargs) 281 282 def assertSameNotEqual(self, a, b): 283 """Checks that assertProtoEqual() fails.""" 284 self.assertRaises(AssertionError, self.assertProtoEqual, a, b) 285 286 def assertNone(self, a, b, message, **kwargs): 287 """Checks that all possible asserts fail with the given message.""" 288 message = re.escape(textwrap.dedent(message)) 289 self.assertRaisesRegex(AssertionError, message, self.assertProtoEqual, a, b, 290 **kwargs) 291 292 def testCheckInitialized(self): 293 # neither is initialized 294 a = compare_test_pb2.Labeled() 295 a.optional = 1 296 self.assertNone(a, a, 'Initialization errors: ', check_initialized=True) 297 self.assertAll(a, check_initialized=False) 298 299 # a is initialized, b isn't 300 b = copy.deepcopy(a) 301 a.required = 2 302 self.assertNone(a, b, 'Initialization errors: ', check_initialized=True) 303 self.assertNone( 304 a, 305 b, 306 """ 307 - required: 2 308 optional: 1 309 """, 310 check_initialized=False) 311 312 # both are initialized 313 a = compare_test_pb2.Labeled() 314 a.required = 2 315 self.assertAll(a, check_initialized=True) 316 self.assertAll(a, check_initialized=False) 317 318 b = copy.deepcopy(a) 319 b.required = 3 320 message = """ 321 - required: 2 322 ? ^ 323 + required: 3 324 ? ^ 325 """ 326 self.assertNone(a, b, message, check_initialized=True) 327 self.assertNone(a, b, message, check_initialized=False) 328 329 def testAssertEqualWithStringArg(self): 330 pb = compare_test_pb2.Large() 331 pb.string_ = 'abc' 332 pb.float_ = 1.234 333 compare.assertProtoEqual(self, """ 334 string_: 'abc' 335 float_: 1.234 336 """, pb) 337 338 def testNormalizesNumbers(self): 339 pb1 = compare_test_pb2.Large() 340 pb1.int64_ = 4 341 pb2 = compare_test_pb2.Large() 342 pb2.int64_ = 4 343 compare.assertProtoEqual(self, pb1, pb2) 344 345 def testNormalizesFloat(self): 346 pb1 = compare_test_pb2.Large() 347 pb1.double_ = 4.0 348 pb2 = compare_test_pb2.Large() 349 pb2.double_ = 4 350 compare.assertProtoEqual(self, pb1, pb2, normalize_numbers=True) 351 352 def testPrimitives(self): 353 self.assertAll('string_: "x"') 354 self.assertNone('string_: "x"', 'string_: "y"', """ 355 - string_: "x" 356 ? ^ 357 + string_: "y" 358 ? ^ 359 """) 360 361 def testRepeatedPrimitives(self): 362 self.assertAll('int64s: 0 int64s: 1') 363 364 self.assertSameNotEqual('int64s: 0 int64s: 1', 'int64s: 1 int64s: 0') 365 self.assertSameNotEqual('int64s: 0 int64s: 1 int64s: 2', 366 'int64s: 2 int64s: 1 int64s: 0') 367 368 self.assertSameNotEqual('int64s: 0', 'int64s: 0 int64s: 0') 369 self.assertSameNotEqual('int64s: 0 int64s: 1', 370 'int64s: 1 int64s: 0 int64s: 1') 371 372 self.assertNone('int64s: 0', 'int64s: 0 int64s: 2', """ 373 int64s: 0 374 + int64s: 2 375 """) 376 self.assertNone('int64s: 0 int64s: 1', 'int64s: 0 int64s: 2', """ 377 int64s: 0 378 - int64s: 1 379 ? ^ 380 + int64s: 2 381 ? ^ 382 """) 383 384 def testMessage(self): 385 self.assertAll('medium: {}') 386 self.assertAll('medium: { smalls: {} }') 387 self.assertAll('medium: { int32s: 1 smalls: {} }') 388 self.assertAll('medium: { smalls: { strings: "x" } }') 389 self.assertAll( 390 'medium: { smalls: { strings: "x" } } small: { strings: "y" }') 391 392 self.assertSameNotEqual('medium: { smalls: { strings: "x" strings: "y" } }', 393 'medium: { smalls: { strings: "y" strings: "x" } }') 394 self.assertSameNotEqual( 395 'medium: { smalls: { strings: "x" } smalls: { strings: "y" } }', 396 'medium: { smalls: { strings: "y" } smalls: { strings: "x" } }') 397 398 self.assertSameNotEqual( 399 'medium: { smalls: { strings: "x" strings: "y" strings: "x" } }', 400 'medium: { smalls: { strings: "y" strings: "x" } }') 401 self.assertSameNotEqual( 402 'medium: { smalls: { strings: "x" } int32s: 0 }', 403 'medium: { int32s: 0 smalls: { strings: "x" } int32s: 0 }') 404 405 self.assertNone('medium: {}', 'medium: { smalls: { strings: "x" } }', """ 406 medium { 407 + smalls { 408 + strings: "x" 409 + } 410 } 411 """) 412 self.assertNone('medium: { smalls: { strings: "x" } }', 413 'medium: { smalls: {} }', """ 414 medium { 415 smalls { 416 - strings: "x" 417 } 418 } 419 """) 420 self.assertNone('medium: { int32s: 0 }', 'medium: { int32s: 1 }', """ 421 medium { 422 - int32s: 0 423 ? ^ 424 + int32s: 1 425 ? ^ 426 } 427 """) 428 429 def testMsgPassdown(self): 430 self.assertRaisesRegex( 431 AssertionError, 432 'test message passed down', 433 self.assertProtoEqual, 434 'medium: {}', 435 'medium: { smalls: { strings: "x" } }', 436 msg='test message passed down') 437 438 def testRepeatedMessage(self): 439 self.assertAll('medium: { smalls: {} smalls: {} }') 440 self.assertAll('medium: { smalls: { strings: "x" } } medium: {}') 441 self.assertAll('medium: { smalls: { strings: "x" } } medium: { int32s: 0 }') 442 self.assertAll('medium: { smalls: {} smalls: { strings: "x" } } small: {}') 443 444 self.assertSameNotEqual('medium: { smalls: { strings: "x" } smalls: {} }', 445 'medium: { smalls: {} smalls: { strings: "x" } }') 446 447 self.assertSameNotEqual('medium: { smalls: {} }', 448 'medium: { smalls: {} smalls: {} }') 449 self.assertSameNotEqual('medium: { smalls: {} smalls: {} } medium: {}', 450 'medium: {} medium: {} medium: { smalls: {} }') 451 self.assertSameNotEqual( 452 'medium: { smalls: { strings: "x" } smalls: {} }', 453 'medium: { smalls: {} smalls: { strings: "x" } smalls: {} }') 454 455 self.assertNone('medium: {}', 'medium: {} medium { smalls: {} }', """ 456 medium { 457 + smalls { 458 + } 459 } 460 """) 461 self.assertNone('medium: { smalls: {} smalls: { strings: "x" } }', 462 'medium: { smalls: {} smalls: { strings: "y" } }', """ 463 medium { 464 smalls { 465 } 466 smalls { 467 - strings: "x" 468 ? ^ 469 + strings: "y" 470 ? ^ 471 } 472 } 473 """) 474 475 476class MixinTests(compare.ProtoAssertions, googletest.TestCase): 477 478 def testAssertEqualWithStringArg(self): 479 pb = compare_test_pb2.Large() 480 pb.string_ = 'abc' 481 pb.float_ = 1.234 482 self.assertProtoEqual(""" 483 string_: 'abc' 484 float_: 1.234 485 """, pb) 486 487 488if __name__ == '__main__': 489 googletest.main() 490