• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
4 # you may not use this file except in compliance with the License.
5 # You may obtain a copy of the License at
6 #
7 #     http://www.apache.org/licenses/LICENSE-2.0
8 #
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS,
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 # See the License for the specific language governing permissions and
13 # limitations under the License.
14 # ==============================================================================
15 """Tests for python.util.protobuf.compare."""
16 
17 from __future__ import absolute_import
18 from __future__ import division
19 from __future__ import print_function
20 
21 import copy
22 import re
23 import textwrap
24 
25 import six
26 
27 from google.protobuf import text_format
28 
29 from tensorflow.python.platform import googletest
30 from tensorflow.python.util.protobuf import compare
31 from tensorflow.python.util.protobuf import compare_test_pb2
32 
33 
34 def LargePbs(*args):
35   """Converts ASCII string Large PBs to messages."""
36   pbs = []
37   for arg in args:
38     pb = compare_test_pb2.Large()
39     text_format.Merge(arg, pb)
40     pbs.append(pb)
41 
42   return pbs
43 
44 
45 class ProtoEqTest(googletest.TestCase):
46 
47   def assertNotEquals(self, a, b):
48     """Asserts that ProtoEq says a != b."""
49     a, b = LargePbs(a, b)
50     googletest.TestCase.assertEqual(self, compare.ProtoEq(a, b), False)
51 
52   def assertEqual(self, a, b):
53     """Asserts that ProtoEq says a == b."""
54     a, b = LargePbs(a, b)
55     googletest.TestCase.assertEqual(self, compare.ProtoEq(a, b), True)
56 
57   def testPrimitives(self):
58     googletest.TestCase.assertEqual(self, True, compare.ProtoEq('a', 'a'))
59     googletest.TestCase.assertEqual(self, False, compare.ProtoEq('b', 'a'))
60 
61   def testEmpty(self):
62     self.assertEqual('', '')
63 
64   def testPrimitiveFields(self):
65     self.assertNotEquals('string_: "a"', '')
66     self.assertEqual('string_: "a"', 'string_: "a"')
67     self.assertNotEquals('string_: "b"', 'string_: "a"')
68     self.assertNotEquals('string_: "ab"', 'string_: "aa"')
69 
70     self.assertNotEquals('int64_: 0', '')
71     self.assertEqual('int64_: 0', 'int64_: 0')
72     self.assertNotEquals('int64_: -1', '')
73     self.assertNotEquals('int64_: 1', 'int64_: 0')
74     self.assertNotEquals('int64_: 0', 'int64_: -1')
75 
76     self.assertNotEquals('float_: 0.0', '')
77     self.assertEqual('float_: 0.0', 'float_: 0.0')
78     self.assertNotEquals('float_: -0.1', '')
79     self.assertNotEquals('float_: 3.14', 'float_: 0')
80     self.assertNotEquals('float_: 0', 'float_: -0.1')
81     self.assertEqual('float_: -0.1', 'float_: -0.1')
82 
83     self.assertNotEquals('bool_: true', '')
84     self.assertNotEquals('bool_: false', '')
85     self.assertNotEquals('bool_: true', 'bool_: false')
86     self.assertEqual('bool_: false', 'bool_: false')
87     self.assertEqual('bool_: true', 'bool_: true')
88 
89     self.assertNotEquals('enum_: A', '')
90     self.assertNotEquals('enum_: B', 'enum_: A')
91     self.assertNotEquals('enum_: C', 'enum_: B')
92     self.assertEqual('enum_: C', 'enum_: C')
93 
94   def testRepeatedPrimitives(self):
95     self.assertNotEquals('int64s: 0', '')
96     self.assertEqual('int64s: 0', 'int64s: 0')
97     self.assertNotEquals('int64s: 1', 'int64s: 0')
98     self.assertNotEquals('int64s: 0 int64s: 0', '')
99     self.assertNotEquals('int64s: 0 int64s: 0', 'int64s: 0')
100     self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0')
101     self.assertNotEquals('int64s: 0 int64s: 1', 'int64s: 0')
102     self.assertNotEquals('int64s: 1', 'int64s: 0 int64s: 2')
103     self.assertNotEquals('int64s: 2 int64s: 0', 'int64s: 1')
104     self.assertEqual('int64s: 0 int64s: 0', 'int64s: 0 int64s: 0')
105     self.assertEqual('int64s: 0 int64s: 1', 'int64s: 0 int64s: 1')
106     self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 0')
107     self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 1')
108     self.assertNotEquals('int64s: 1 int64s: 0', 'int64s: 0 int64s: 2')
109     self.assertNotEquals('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0')
110     self.assertNotEquals('int64s: 1 int64s: 1', 'int64s: 1 int64s: 0 int64s: 2')
111 
112   def testMessage(self):
113     self.assertNotEquals('small <>', '')
114     self.assertEqual('small <>', 'small <>')
115     self.assertNotEquals('small < strings: "a" >', '')
116     self.assertNotEquals('small < strings: "a" >', 'small <>')
117     self.assertEqual('small < strings: "a" >', 'small < strings: "a" >')
118     self.assertNotEquals('small < strings: "b" >', 'small < strings: "a" >')
119     self.assertNotEquals('small < strings: "a" strings: "b" >',
120                          'small < strings: "a" >')
121 
122     self.assertNotEquals('string_: "a"', 'small <>')
123     self.assertNotEquals('string_: "a"', 'small < strings: "b" >')
124     self.assertNotEquals('string_: "a"', 'small < strings: "b" strings: "c" >')
125     self.assertNotEquals('string_: "a" small <>', 'small <>')
126     self.assertNotEquals('string_: "a" small <>', 'small < strings: "b" >')
127     self.assertEqual('string_: "a" small <>', 'string_: "a" small <>')
128     self.assertNotEquals('string_: "a" small < strings: "a" >',
129                          'string_: "a" small <>')
130     self.assertEqual('string_: "a" small < strings: "a" >',
131                      'string_: "a" small < strings: "a" >')
132     self.assertNotEquals('string_: "a" small < strings: "a" >',
133                          'int64_: 1 small < strings: "a" >')
134     self.assertNotEquals('string_: "a" small < strings: "a" >', 'int64_: 1')
135     self.assertNotEquals('string_: "a"', 'int64_: 1 small < strings: "a" >')
136     self.assertNotEquals('string_: "a" int64_: 0 small < strings: "a" >',
137                          'int64_: 1 small < strings: "a" >')
138     self.assertNotEquals('string_: "a" int64_: 1 small < strings: "a" >',
139                          'string_: "a" int64_: 0 small < strings: "a" >')
140     self.assertEqual('string_: "a" int64_: 0 small < strings: "a" >',
141                      'string_: "a" int64_: 0 small < strings: "a" >')
142 
143   def testNestedMessage(self):
144     self.assertNotEquals('medium <>', '')
145     self.assertEqual('medium <>', 'medium <>')
146     self.assertNotEquals('medium < smalls <> >', 'medium <>')
147     self.assertEqual('medium < smalls <> >', 'medium < smalls <> >')
148     self.assertNotEquals('medium < smalls <> smalls <> >',
149                          'medium < smalls <> >')
150     self.assertEqual('medium < smalls <> smalls <> >',
151                      'medium < smalls <> smalls <> >')
152 
153     self.assertNotEquals('medium < int32s: 0 >', 'medium < smalls <> >')
154 
155     self.assertNotEquals('medium < smalls < strings: "a"> >',
156                          'medium < smalls <> >')
157 
158   def testTagOrder(self):
159     """Tests that different fields are ordered by tag number.
160 
161     For reference, here are the relevant tag numbers from compare_test.proto:
162       optional string string_ = 1;
163       optional int64 int64_ = 2;
164       optional float float_ = 3;
165       optional Small small = 8;
166       optional Medium medium = 7;
167       optional Small small = 8;
168     """
169     self.assertNotEquals('string_: "a"                      ',
170                          '             int64_: 1            ')
171     self.assertNotEquals('string_: "a" int64_: 2            ',
172                          '             int64_: 1            ')
173     self.assertNotEquals('string_: "b" int64_: 1            ',
174                          'string_: "a" int64_: 2            ')
175     self.assertEqual('string_: "a" int64_: 1            ',
176                      'string_: "a" int64_: 1            ')
177     self.assertNotEquals('string_: "a" int64_: 1 float_: 0.0',
178                          'string_: "a" int64_: 1            ')
179     self.assertEqual('string_: "a" int64_: 1 float_: 0.0',
180                      'string_: "a" int64_: 1 float_: 0.0')
181     self.assertNotEquals('string_: "a" int64_: 1 float_: 0.1',
182                          'string_: "a" int64_: 1 float_: 0.0')
183     self.assertNotEquals('string_: "a" int64_: 2 float_: 0.0',
184                          'string_: "a" int64_: 1 float_: 0.1')
185     self.assertNotEquals('string_: "a"                      ',
186                          '             int64_: 1 float_: 0.1')
187     self.assertNotEquals('string_: "a"           float_: 0.0',
188                          '             int64_: 1            ')
189     self.assertNotEquals('string_: "b"           float_: 0.0',
190                          'string_: "a" int64_: 1            ')
191 
192     self.assertNotEquals('string_: "a"', 'small < strings: "a" >')
193     self.assertNotEquals('string_: "a" small < strings: "a" >',
194                          'small < strings: "b" >')
195     self.assertNotEquals('string_: "a" small < strings: "b" >',
196                          'string_: "a" small < strings: "a" >')
197     self.assertEqual('string_: "a" small < strings: "a" >',
198                      'string_: "a" small < strings: "a" >')
199 
200     self.assertNotEquals('string_: "a" medium <>',
201                          'string_: "a" small < strings: "a" >')
202     self.assertNotEquals('string_: "a" medium < smalls <> >',
203                          'string_: "a" small < strings: "a" >')
204     self.assertNotEquals('medium <>', 'small < strings: "a" >')
205     self.assertNotEquals('medium <> small <>', 'small < strings: "a" >')
206     self.assertNotEquals('medium < smalls <> >', 'small < strings: "a" >')
207     self.assertNotEquals('medium < smalls < strings: "a" > >',
208                          'small < strings: "b" >')
209 
210 
211 class NormalizeNumbersTest(googletest.TestCase):
212   """Tests for NormalizeNumberFields()."""
213 
214   def testNormalizesInts(self):
215     pb = compare_test_pb2.Large()
216     pb.int64_ = 4
217     compare.NormalizeNumberFields(pb)
218     self.assertTrue(isinstance(pb.int64_, six.integer_types))
219 
220     pb.int64_ = 4
221     compare.NormalizeNumberFields(pb)
222     self.assertTrue(isinstance(pb.int64_, six.integer_types))
223 
224     pb.int64_ = 9999999999999999
225     compare.NormalizeNumberFields(pb)
226     self.assertTrue(isinstance(pb.int64_, six.integer_types))
227 
228   def testNormalizesRepeatedInts(self):
229     pb = compare_test_pb2.Large()
230     pb.int64s.extend([1, 400, 999999999999999])
231     compare.NormalizeNumberFields(pb)
232     self.assertTrue(isinstance(pb.int64s[0], six.integer_types))
233     self.assertTrue(isinstance(pb.int64s[1], six.integer_types))
234     self.assertTrue(isinstance(pb.int64s[2], six.integer_types))
235 
236   def testNormalizesFloats(self):
237     pb1 = compare_test_pb2.Large()
238     pb1.float_ = 1.2314352351231
239     pb2 = compare_test_pb2.Large()
240     pb2.float_ = 1.231435
241     self.assertNotEqual(pb1.float_, pb2.float_)
242     compare.NormalizeNumberFields(pb1)
243     compare.NormalizeNumberFields(pb2)
244     self.assertEqual(pb1.float_, pb2.float_)
245 
246   def testNormalizesRepeatedFloats(self):
247     pb = compare_test_pb2.Large()
248     pb.medium.floats.extend([0.111111111, 0.111111])
249     compare.NormalizeNumberFields(pb)
250     for value in pb.medium.floats:
251       self.assertAlmostEqual(0.111111, value)
252 
253   def testNormalizesDoubles(self):
254     pb1 = compare_test_pb2.Large()
255     pb1.double_ = 1.2314352351231
256     pb2 = compare_test_pb2.Large()
257     pb2.double_ = 1.2314352
258     self.assertNotEqual(pb1.double_, pb2.double_)
259     compare.NormalizeNumberFields(pb1)
260     compare.NormalizeNumberFields(pb2)
261     self.assertEqual(pb1.double_, pb2.double_)
262 
263   def testNormalizesMaps(self):
264     pb = compare_test_pb2.WithMap()
265     pb.value_message[4].strings.extend(['a', 'b', 'c'])
266     pb.value_string['d'] = 'e'
267     compare.NormalizeNumberFields(pb)
268 
269 
270 class AssertTest(googletest.TestCase):
271   """Tests assertProtoEqual()."""
272 
273   def assertProtoEqual(self, a, b, **kwargs):
274     if isinstance(a, six.string_types) and isinstance(b, six.string_types):
275       a, b = LargePbs(a, b)
276     compare.assertProtoEqual(self, a, b, **kwargs)
277 
278   def assertAll(self, a, **kwargs):
279     """Checks that all possible asserts pass."""
280     self.assertProtoEqual(a, a, **kwargs)
281 
282   def assertSameNotEqual(self, a, b):
283     """Checks that assertProtoEqual() fails."""
284     self.assertRaises(AssertionError, self.assertProtoEqual, a, b)
285 
286   def assertNone(self, a, b, message, **kwargs):
287     """Checks that all possible asserts fail with the given message."""
288     message = re.escape(textwrap.dedent(message))
289     self.assertRaisesRegex(AssertionError, message, self.assertProtoEqual, a, b,
290                            **kwargs)
291 
292   def testCheckInitialized(self):
293     # neither is initialized
294     a = compare_test_pb2.Labeled()
295     a.optional = 1
296     self.assertNone(a, a, 'Initialization errors: ', check_initialized=True)
297     self.assertAll(a, check_initialized=False)
298 
299     # a is initialized, b isn't
300     b = copy.deepcopy(a)
301     a.required = 2
302     self.assertNone(a, b, 'Initialization errors: ', check_initialized=True)
303     self.assertNone(
304         a,
305         b,
306         """
307                     - required: 2
308                       optional: 1
309                     """,
310         check_initialized=False)
311 
312     # both are initialized
313     a = compare_test_pb2.Labeled()
314     a.required = 2
315     self.assertAll(a, check_initialized=True)
316     self.assertAll(a, check_initialized=False)
317 
318     b = copy.deepcopy(a)
319     b.required = 3
320     message = """
321               - required: 2
322               ?           ^
323               + required: 3
324               ?           ^
325               """
326     self.assertNone(a, b, message, check_initialized=True)
327     self.assertNone(a, b, message, check_initialized=False)
328 
329   def testAssertEqualWithStringArg(self):
330     pb = compare_test_pb2.Large()
331     pb.string_ = 'abc'
332     pb.float_ = 1.234
333     compare.assertProtoEqual(self, """
334           string_: 'abc'
335           float_: 1.234
336         """, pb)
337 
338   def testNormalizesNumbers(self):
339     pb1 = compare_test_pb2.Large()
340     pb1.int64_ = 4
341     pb2 = compare_test_pb2.Large()
342     pb2.int64_ = 4
343     compare.assertProtoEqual(self, pb1, pb2)
344 
345   def testNormalizesFloat(self):
346     pb1 = compare_test_pb2.Large()
347     pb1.double_ = 4.0
348     pb2 = compare_test_pb2.Large()
349     pb2.double_ = 4
350     compare.assertProtoEqual(self, pb1, pb2, normalize_numbers=True)
351 
352   def testPrimitives(self):
353     self.assertAll('string_: "x"')
354     self.assertNone('string_: "x"', 'string_: "y"', """
355                     - string_: "x"
356                     ?           ^
357                     + string_: "y"
358                     ?           ^
359                     """)
360 
361   def testRepeatedPrimitives(self):
362     self.assertAll('int64s: 0 int64s: 1')
363 
364     self.assertSameNotEqual('int64s: 0 int64s: 1', 'int64s: 1 int64s: 0')
365     self.assertSameNotEqual('int64s: 0 int64s: 1 int64s: 2',
366                             'int64s: 2 int64s: 1 int64s: 0')
367 
368     self.assertSameNotEqual('int64s: 0', 'int64s: 0 int64s: 0')
369     self.assertSameNotEqual('int64s: 0 int64s: 1',
370                             'int64s: 1 int64s: 0 int64s: 1')
371 
372     self.assertNone('int64s: 0', 'int64s: 0 int64s: 2', """
373                       int64s: 0
374                     + int64s: 2
375                     """)
376     self.assertNone('int64s: 0 int64s: 1', 'int64s: 0 int64s: 2', """
377                       int64s: 0
378                     - int64s: 1
379                     ?         ^
380                     + int64s: 2
381                     ?         ^
382                     """)
383 
384   def testMessage(self):
385     self.assertAll('medium: {}')
386     self.assertAll('medium: { smalls: {} }')
387     self.assertAll('medium: { int32s: 1 smalls: {} }')
388     self.assertAll('medium: { smalls: { strings: "x" } }')
389     self.assertAll(
390         'medium: { smalls: { strings: "x" } } small: { strings: "y" }')
391 
392     self.assertSameNotEqual('medium: { smalls: { strings: "x" strings: "y" } }',
393                             'medium: { smalls: { strings: "y" strings: "x" } }')
394     self.assertSameNotEqual(
395         'medium: { smalls: { strings: "x" } smalls: { strings: "y" } }',
396         'medium: { smalls: { strings: "y" } smalls: { strings: "x" } }')
397 
398     self.assertSameNotEqual(
399         'medium: { smalls: { strings: "x" strings: "y" strings: "x" } }',
400         'medium: { smalls: { strings: "y" strings: "x" } }')
401     self.assertSameNotEqual(
402         'medium: { smalls: { strings: "x" } int32s: 0 }',
403         'medium: { int32s: 0 smalls: { strings: "x" } int32s: 0 }')
404 
405     self.assertNone('medium: {}', 'medium: { smalls: { strings: "x" } }', """
406                       medium {
407                     +   smalls {
408                     +     strings: "x"
409                     +   }
410                       }
411                     """)
412     self.assertNone('medium: { smalls: { strings: "x" } }',
413                     'medium: { smalls: {} }', """
414                       medium {
415                         smalls {
416                     -     strings: "x"
417                         }
418                       }
419                     """)
420     self.assertNone('medium: { int32s: 0 }', 'medium: { int32s: 1 }', """
421                       medium {
422                     -   int32s: 0
423                     ?           ^
424                     +   int32s: 1
425                     ?           ^
426                       }
427                     """)
428 
429   def testMsgPassdown(self):
430     self.assertRaisesRegex(
431         AssertionError,
432         'test message passed down',
433         self.assertProtoEqual,
434         'medium: {}',
435         'medium: { smalls: { strings: "x" } }',
436         msg='test message passed down')
437 
438   def testRepeatedMessage(self):
439     self.assertAll('medium: { smalls: {} smalls: {} }')
440     self.assertAll('medium: { smalls: { strings: "x" } } medium: {}')
441     self.assertAll('medium: { smalls: { strings: "x" } } medium: { int32s: 0 }')
442     self.assertAll('medium: { smalls: {} smalls: { strings: "x" } } small: {}')
443 
444     self.assertSameNotEqual('medium: { smalls: { strings: "x" } smalls: {} }',
445                             'medium: { smalls: {} smalls: { strings: "x" } }')
446 
447     self.assertSameNotEqual('medium: { smalls: {} }',
448                             'medium: { smalls: {} smalls: {} }')
449     self.assertSameNotEqual('medium: { smalls: {} smalls: {} } medium: {}',
450                             'medium: {} medium: {} medium: { smalls: {} }')
451     self.assertSameNotEqual(
452         'medium: { smalls: { strings: "x" } smalls: {} }',
453         'medium: { smalls: {} smalls: { strings: "x" } smalls: {} }')
454 
455     self.assertNone('medium: {}', 'medium: {} medium { smalls: {} }', """
456                       medium {
457                     +   smalls {
458                     +   }
459                       }
460                     """)
461     self.assertNone('medium: { smalls: {} smalls: { strings: "x" } }',
462                     'medium: { smalls: {} smalls: { strings: "y" } }', """
463                       medium {
464                         smalls {
465                         }
466                         smalls {
467                     -     strings: "x"
468                     ?               ^
469                     +     strings: "y"
470                     ?               ^
471                         }
472                       }
473                     """)
474 
475 
476 class MixinTests(compare.ProtoAssertions, googletest.TestCase):
477 
478   def testAssertEqualWithStringArg(self):
479     pb = compare_test_pb2.Large()
480     pb.string_ = 'abc'
481     pb.float_ = 1.234
482     self.assertProtoEqual("""
483           string_: 'abc'
484           float_: 1.234
485         """, pb)
486 
487 
488 if __name__ == '__main__':
489   googletest.main()
490