1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras metrics functions."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.eager import context
24from tensorflow.python.keras import backend as K
25from tensorflow.python.keras import metrics
26from tensorflow.python.platform import test
27
28
29class KerasFunctionalMetricsTest(test.TestCase):
30
31  def test_metrics(self):
32    with self.cached_session():
33      y_a = K.variable(np.random.random((6, 7)))
34      y_b = K.variable(np.random.random((6, 7)))
35      for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]:
36        output = metric(y_a, y_b)
37        self.assertEqual(K.eval(output).shape, (6,))
38
39  def test_sparse_categorical_accuracy_int(self):
40    with self.cached_session():
41      metric = metrics.sparse_categorical_accuracy
42      y_true = K.variable(np.random.randint(0, 7, (6,)))
43      y_pred = K.variable(np.random.random((6, 7)))
44      self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
45
46      # Test correctness if the shape of y_true is (num_samples,)
47      y_true = K.variable([1., 0., 0., 0.])
48      y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
49      print(K.eval(metric(y_true, y_pred)))
50      self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
51
52      # Test correctness if the shape of y_true is (num_samples, 1)
53      y_true = K.variable([[1.], [0.], [0.], [0.]])
54      y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
55      print(K.eval(metric(y_true, y_pred)))
56      self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
57
58  def test_sparse_categorical_accuracy_float(self):
59    with self.cached_session():
60      metric = metrics.sparse_categorical_accuracy
61      y_true = K.variable(np.random.random((6,)))
62      y_pred = K.variable(np.random.random((6, 7)))
63      self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
64
65  def test_sparse_categorical_accuracy_eager(self):
66    """Tests that ints passed in via Eager return results. See b/113504761."""
67    with context.eager_mode():
68      metric = metrics.sparse_categorical_accuracy
69      y_true = np.arange(6).reshape([6, 1])
70      y_pred = np.arange(36).reshape([6, 6])
71      self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
72
73  def test_sparse_categorical_accuracy_float_eager(self):
74    """Tests that floats passed in via Eager return results. See b/113504761."""
75    with context.eager_mode():
76      metric = metrics.sparse_categorical_accuracy
77      y_true = np.arange(6, dtype=np.float32).reshape([6, 1])
78      y_pred = np.arange(36).reshape([6, 6])
79      self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.])
80
81  def test_sparse_top_k_categorical_accuracy(self):
82    with self.cached_session():
83      # Test correctness if the shape of y_true is (num_samples, 1)
84      y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
85      y_true = K.variable(np.array([[1], [0]]))
86      result = K.eval(
87          metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
88      self.assertEqual(result, 1)
89      result = K.eval(
90          metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
91      self.assertEqual(result, 0.5)
92      result = K.eval(
93          metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
94      self.assertEqual(result, 0.)
95
96      # Test correctness if the shape of y_true is (num_samples,)
97      y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
98      y_true = K.variable(np.array([1, 0]))
99      result = K.eval(
100          metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
101      self.assertEqual(result, 1)
102      result = K.eval(
103          metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
104      self.assertEqual(result, 0.5)
105      result = K.eval(
106          metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
107      self.assertEqual(result, 0.)
108
109  def test_top_k_categorical_accuracy(self):
110    with self.cached_session():
111      y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
112      y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]]))
113      result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3))
114      self.assertEqual(result, 1)
115      result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2))
116      self.assertEqual(result, 0.5)
117      result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1))
118      self.assertEqual(result, 0.)
119
120
121if __name__ == '__main__':
122  test.main()
123