1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for logical module.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.autograph.operators import logical 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import test_util 24from tensorflow.python.platform import test 25 26 27class LogicalOperatorsTest(test.TestCase): 28 29 def assertNotCalled(self): 30 self.fail('this should not be called') 31 32 def _tf_true(self): 33 return constant_op.constant(True) 34 35 def _tf_false(self): 36 return constant_op.constant(False) 37 38 def test_and_python(self): 39 self.assertTrue(logical.and_(lambda: True, lambda: True)) 40 self.assertTrue(logical.and_(lambda: [1], lambda: True)) 41 self.assertListEqual(logical.and_(lambda: True, lambda: [1]), [1]) 42 43 self.assertFalse(logical.and_(lambda: False, lambda: True)) 44 self.assertFalse(logical.and_(lambda: False, self.assertNotCalled)) 45 46 @test_util.run_deprecated_v1 47 def test_and_tf(self): 48 with self.cached_session() as sess: 49 t = logical.and_(self._tf_true, self._tf_true) 50 self.assertEqual(self.evaluate(t), True) 51 t = logical.and_(self._tf_true, lambda: True) 52 self.assertEqual(self.evaluate(t), True) 53 t = logical.and_(self._tf_false, lambda: True) 54 self.assertEqual(self.evaluate(t), False) 55 # TODO(mdan): Add a test for ops with side effects. 56 57 def test_or_python(self): 58 self.assertFalse(logical.or_(lambda: False, lambda: False)) 59 self.assertFalse(logical.or_(lambda: [], lambda: False)) 60 self.assertListEqual(logical.or_(lambda: False, lambda: [1]), [1]) 61 62 self.assertTrue(logical.or_(lambda: False, lambda: True)) 63 self.assertTrue(logical.or_(lambda: True, self.assertNotCalled)) 64 65 @test_util.run_deprecated_v1 66 def test_or_tf(self): 67 with self.cached_session() as sess: 68 t = logical.or_(self._tf_false, self._tf_true) 69 self.assertEqual(self.evaluate(t), True) 70 t = logical.or_(self._tf_false, lambda: True) 71 self.assertEqual(self.evaluate(t), True) 72 t = logical.or_(self._tf_true, lambda: True) 73 self.assertEqual(self.evaluate(t), True) 74 # TODO(mdan): Add a test for ops with side effects. 75 76 def test_not_python(self): 77 self.assertFalse(logical.not_(True)) 78 self.assertFalse(logical.not_([1])) 79 self.assertTrue(logical.not_([])) 80 81 def test_not_tf(self): 82 with self.cached_session() as sess: 83 t = logical.not_(self._tf_false()) 84 self.assertEqual(self.evaluate(t), True) 85 86 87if __name__ == '__main__': 88 test.main() 89