1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for logical module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.autograph.operators import logical
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import test_util
24from tensorflow.python.platform import test
25
26
27class LogicalOperatorsTest(test.TestCase):
28
29  def assertNotCalled(self):
30    self.fail('this should not be called')
31
32  def _tf_true(self):
33    return constant_op.constant(True)
34
35  def _tf_false(self):
36    return constant_op.constant(False)
37
38  def test_and_python(self):
39    self.assertTrue(logical.and_(lambda: True, lambda: True))
40    self.assertTrue(logical.and_(lambda: [1], lambda: True))
41    self.assertListEqual(logical.and_(lambda: True, lambda: [1]), [1])
42
43    self.assertFalse(logical.and_(lambda: False, lambda: True))
44    self.assertFalse(logical.and_(lambda: False, self.assertNotCalled))
45
46  @test_util.run_deprecated_v1
47  def test_and_tf(self):
48    with self.cached_session() as sess:
49      t = logical.and_(self._tf_true, self._tf_true)
50      self.assertEqual(self.evaluate(t), True)
51      t = logical.and_(self._tf_true, lambda: True)
52      self.assertEqual(self.evaluate(t), True)
53      t = logical.and_(self._tf_false, lambda: True)
54      self.assertEqual(self.evaluate(t), False)
55      # TODO(mdan): Add a test for ops with side effects.
56
57  def test_or_python(self):
58    self.assertFalse(logical.or_(lambda: False, lambda: False))
59    self.assertFalse(logical.or_(lambda: [], lambda: False))
60    self.assertListEqual(logical.or_(lambda: False, lambda: [1]), [1])
61
62    self.assertTrue(logical.or_(lambda: False, lambda: True))
63    self.assertTrue(logical.or_(lambda: True, self.assertNotCalled))
64
65  @test_util.run_deprecated_v1
66  def test_or_tf(self):
67    with self.cached_session() as sess:
68      t = logical.or_(self._tf_false, self._tf_true)
69      self.assertEqual(self.evaluate(t), True)
70      t = logical.or_(self._tf_false, lambda: True)
71      self.assertEqual(self.evaluate(t), True)
72      t = logical.or_(self._tf_true, lambda: True)
73      self.assertEqual(self.evaluate(t), True)
74      # TODO(mdan): Add a test for ops with side effects.
75
76  def test_not_python(self):
77    self.assertFalse(logical.not_(True))
78    self.assertFalse(logical.not_([1]))
79    self.assertTrue(logical.not_([]))
80
81  def test_not_tf(self):
82    with self.cached_session() as sess:
83      t = logical.not_(self._tf_false())
84      self.assertEqual(self.evaluate(t), True)
85
86
87if __name__ == '__main__':
88  test.main()
89