1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tf upgrader."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import tensorflow.compat.v1 as tf
22from tensorflow.python.framework import test_util
23from tensorflow.python.platform import test as test_lib
24
25_TEST_VERSION = 1
26
27
28class TestUpgrade(test_util.TensorFlowTestCase):
29  """Test various APIs that have been changed in 2.0."""
30
31  @classmethod
32  def setUpClass(cls):
33    cls._tf_api_version = 1 if hasattr(tf, 'contrib') else 2
34
35  def setUp(self):
36    tf.compat.v1.enable_v2_behavior()
37
38  def testRenames(self):
39    self.assertAllClose(1.04719755, tf.acos(0.5))
40    self.assertAllClose(0.5, tf.rsqrt(4.0))
41
42  def testSerializeSparseTensor(self):
43    sp_input = tf.SparseTensor(
44        indices=tf.constant([[1]], dtype=tf.int64),
45        values=tf.constant([2], dtype=tf.int64),
46        dense_shape=[2])
47
48    with self.cached_session():
49      serialized_sp = tf.serialize_sparse(sp_input, 'serialize_name', tf.string)
50      self.assertEqual((3,), serialized_sp.shape)
51      self.assertTrue(serialized_sp[0].numpy())  # check non-empty
52
53  def testSerializeManySparse(self):
54    sp_input = tf.SparseTensor(
55        indices=tf.constant([[0, 1]], dtype=tf.int64),
56        values=tf.constant([2], dtype=tf.int64),
57        dense_shape=[1, 2])
58
59    with self.cached_session():
60      serialized_sp = tf.serialize_many_sparse(
61          sp_input, 'serialize_name', tf.string)
62      self.assertEqual((1, 3), serialized_sp.shape)
63
64  def testArgMaxMin(self):
65    self.assertAllClose(
66        [1],
67        tf.argmax([[1, 3, 2]], name='abc', dimension=1))
68    self.assertAllClose(
69        [0, 0, 0],
70        tf.argmax([[1, 3, 2]], dimension=0))
71    self.assertAllClose(
72        [0],
73        tf.argmin([[1, 3, 2]], name='abc', dimension=1))
74
75  def testSoftmaxCrossEntropyWithLogits(self):
76    out = tf.nn.softmax_cross_entropy_with_logits(
77        logits=[0.1, 0.8], labels=[0, 1])
78    self.assertAllClose(out, 0.40318608)
79    out = tf.nn.softmax_cross_entropy_with_logits_v2(
80        logits=[0.1, 0.8], labels=[0, 1])
81    self.assertAllClose(out, 0.40318608)
82
83  def testUniformUnitScalingInitializer(self):
84    init = tf.initializers.uniform_unit_scaling(0.5, seed=1)
85    self.assertArrayNear(
86        [-0.45200047, 0.72815341],
87        init((2,)).numpy(),
88        err=1e-6)
89
90
91if __name__ == "__main__":
92  test_lib.main()
93