1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras metrics functions.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.eager import context 24from tensorflow.python.keras import backend as K 25from tensorflow.python.keras import metrics 26from tensorflow.python.platform import test 27 28 29class KerasFunctionalMetricsTest(test.TestCase): 30 31 def test_metrics(self): 32 with self.cached_session(): 33 y_a = K.variable(np.random.random((6, 7))) 34 y_b = K.variable(np.random.random((6, 7))) 35 for metric in [metrics.binary_accuracy, metrics.categorical_accuracy]: 36 output = metric(y_a, y_b) 37 self.assertEqual(K.eval(output).shape, (6,)) 38 39 def test_sparse_categorical_accuracy_int(self): 40 with self.cached_session(): 41 metric = metrics.sparse_categorical_accuracy 42 y_true = K.variable(np.random.randint(0, 7, (6,))) 43 y_pred = K.variable(np.random.random((6, 7))) 44 self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,)) 45 46 # Test correctness if the shape of y_true is (num_samples,) 47 y_true = K.variable([1., 0., 0., 0.]) 48 y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]]) 49 print(K.eval(metric(y_true, y_pred))) 50 self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.]) 51 52 # Test correctness if the shape of y_true is (num_samples, 1) 53 y_true = K.variable([[1.], [0.], [0.], [0.]]) 54 y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]]) 55 print(K.eval(metric(y_true, y_pred))) 56 self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.]) 57 58 def test_sparse_categorical_accuracy_float(self): 59 with self.cached_session(): 60 metric = metrics.sparse_categorical_accuracy 61 y_true = K.variable(np.random.random((6,))) 62 y_pred = K.variable(np.random.random((6, 7))) 63 self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,)) 64 65 def test_sparse_categorical_accuracy_eager(self): 66 """Tests that ints passed in via Eager return results. See b/113504761.""" 67 with context.eager_mode(): 68 metric = metrics.sparse_categorical_accuracy 69 y_true = np.arange(6).reshape([6, 1]) 70 y_pred = np.arange(36).reshape([6, 6]) 71 self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.]) 72 73 def test_sparse_categorical_accuracy_float_eager(self): 74 """Tests that floats passed in via Eager return results. See b/113504761.""" 75 with context.eager_mode(): 76 metric = metrics.sparse_categorical_accuracy 77 y_true = np.arange(6, dtype=np.float32).reshape([6, 1]) 78 y_pred = np.arange(36).reshape([6, 6]) 79 self.assertAllEqual(metric(y_true, y_pred), [0., 0., 0., 0., 0., 1.]) 80 81 def test_sparse_top_k_categorical_accuracy(self): 82 with self.cached_session(): 83 # Test correctness if the shape of y_true is (num_samples, 1) 84 y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) 85 y_true = K.variable(np.array([[1], [0]])) 86 result = K.eval( 87 metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) 88 self.assertEqual(result, 1) 89 result = K.eval( 90 metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) 91 self.assertEqual(result, 0.5) 92 result = K.eval( 93 metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) 94 self.assertEqual(result, 0.) 95 96 # Test correctness if the shape of y_true is (num_samples,) 97 y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) 98 y_true = K.variable(np.array([1, 0])) 99 result = K.eval( 100 metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3)) 101 self.assertEqual(result, 1) 102 result = K.eval( 103 metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2)) 104 self.assertEqual(result, 0.5) 105 result = K.eval( 106 metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1)) 107 self.assertEqual(result, 0.) 108 109 def test_top_k_categorical_accuracy(self): 110 with self.cached_session(): 111 y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]])) 112 y_true = K.variable(np.array([[0, 1, 0], [1, 0, 0]])) 113 result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=3)) 114 self.assertEqual(result, 1) 115 result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=2)) 116 self.assertEqual(result, 0.5) 117 result = K.eval(metrics.top_k_categorical_accuracy(y_true, y_pred, k=1)) 118 self.assertEqual(result, 0.) 119 120 121if __name__ == '__main__': 122 test.main() 123