1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for special_functions module."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.autograph.lang import special_functions
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import tensor_util
27from tensorflow.python.ops import list_ops
28from tensorflow.python.platform import test
29
30
31class SpecialFunctionsTest(test.TestCase):
32
33  def test_match_staging_level(self):
34    some_tensor = constant_op.constant(0)
35    tensor_one = special_functions.match_staging_level(1, some_tensor)
36    python_one = special_functions.match_staging_level(1, 1)
37    with self.cached_session() as sess:
38      self.assertTrue(tensor_util.is_tf_type(tensor_one))
39      self.assertAllEqual(self.evaluate(tensor_one), 1)
40      self.assertEqual(python_one, 1)
41
42  def test_tensor_list_empty_list(self):
43    l = special_functions.tensor_list([],
44                                      element_dtype=dtypes.int32,
45                                      element_shape=())
46    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
47    with self.cached_session() as sess:
48      self.assertAllEqual(self.evaluate(sl), [])
49
50    l = special_functions.tensor_list((),
51                                      element_dtype=dtypes.int32,
52                                      element_shape=())
53    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
54    with self.cached_session() as sess:
55      self.assertAllEqual(self.evaluate(sl), [])
56
57  def test_tensor_list_tensor(self):
58    l = special_functions.tensor_list(
59        constant_op.constant([], dtype=dtypes.int32))
60    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
61    with self.cached_session() as sess:
62      self.assertAllEqual(self.evaluate(sl), [])
63
64  def test_tensor_list_unsupported_initializer(self):
65    with self.assertRaisesRegex(ValueError, 'unknown type'):
66      special_functions.tensor_list(np.array([1, 2, 3]))
67
68  def test_tensor_list_empty_list_no_type(self):
69    with self.assertRaisesRegex(ValueError,
70                                'element_dtype and element_shape are required'):
71      special_functions.tensor_list([])
72
73  def test_tensor_list_from_elements(self):
74    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
75
76    l = special_functions.tensor_list(elements)
77    sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
78    with self.cached_session() as sess:
79      self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
80
81  def test_tensor_list_array_from_elements(self):
82    elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
83
84    l = special_functions.tensor_list(elements, use_tensor_array=True)
85    sl = l.stack()
86    with self.cached_session() as sess:
87      self.assertAllEqual(self.evaluate(sl), [[1, 2], [3, 4]])
88
89  def test_stack(self):
90    self.assertEqual(special_functions.stack(1, strict=False), 1)
91    self.assertListEqual(
92        special_functions.stack([1, 2, 3], strict=False), [1, 2, 3])
93    # TODO(mdan): This should probably forward to tf.stack.
94    self.assertTrue(
95        isinstance(
96            special_functions.stack(
97                [constant_op.constant(1),
98                 constant_op.constant(2)], strict=False), list))
99
100    with self.assertRaises(ValueError):
101      special_functions.stack([1, 2, 3])
102
103    t = constant_op.constant([1.0, 2.0])
104    l = list_ops.tensor_list_from_tensor(
105        t, element_shape=constant_op.constant([], dtype=dtypes.int32))
106    self.assertTrue(
107        tensor_util.is_tf_type(
108            special_functions.stack(l, element_dtype=dtypes.float32)))
109
110
111if __name__ == '__main__':
112  test.main()
113