1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Keras metrics serialization.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23 24from absl.testing import parameterized 25import numpy as np 26 27from tensorflow.python import keras 28from tensorflow.python.keras import keras_parameterized 29from tensorflow.python.keras import layers 30from tensorflow.python.keras import metrics 31from tensorflow.python.keras import optimizer_v2 32from tensorflow.python.keras import testing_utils 33from tensorflow.python.keras.utils import generic_utils 34from tensorflow.python.ops import math_ops 35from tensorflow.python.platform import test 36from tensorflow.python.util import nest 37 38try: 39 import h5py # pylint:disable=g-import-not-at-top 40except ImportError: 41 h5py = None 42 43 44# Custom metric 45class MyMeanAbsoluteError(metrics.MeanMetricWrapper): 46 47 def __init__(self, name='my_mae', dtype=None): 48 super(MyMeanAbsoluteError, self).__init__(_my_mae, name, dtype=dtype) 49 50 51# Custom metric function 52def _my_mae(y_true, y_pred): 53 return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1) 54 55 56def _get_multi_io_model(): 57 inp_1 = layers.Input(shape=(1,), name='input_1') 58 inp_2 = layers.Input(shape=(1,), name='input_2') 59 d = testing_utils.Bias(name='output') 60 out_1 = d(inp_1) 61 out_2 = d(inp_2) 62 return keras.Model([inp_1, inp_2], [out_1, out_2]) 63 64 65@keras_parameterized.run_all_keras_modes 66@parameterized.named_parameters( 67 dict(testcase_name='string', value=['mae']), 68 dict(testcase_name='built_in_fn', value=[metrics.mae]), 69 dict(testcase_name='built_in_class', value=[metrics.MeanAbsoluteError]), 70 dict(testcase_name='custom_fn', value=[_my_mae]), 71 dict(testcase_name='custom_class', value=[MyMeanAbsoluteError]), 72 dict( 73 testcase_name='list_of_built_in_fn_and_list', 74 value=[metrics.mae, [metrics.mae]]), 75 dict( 76 testcase_name='list_of_built_in_class_and_list', 77 value=[metrics.MeanAbsoluteError, [metrics.MeanAbsoluteError]]), 78 dict( 79 testcase_name='list_of_custom_fn_and_list', value=[_my_mae, [_my_mae]]), 80 dict( 81 testcase_name='list_of_custom_class_and_list', 82 value=[MyMeanAbsoluteError, [MyMeanAbsoluteError]]), 83 dict( 84 testcase_name='list_of_lists_of_custom_fns', 85 value=[[_my_mae], [_my_mae, 'mae']]), 86 dict( 87 testcase_name='list_of_lists_of_custom_classes', 88 value=[[MyMeanAbsoluteError], [MyMeanAbsoluteError, 'mae']]), 89 dict( 90 testcase_name='dict_of_list_of_string', 91 value={ 92 'output': ['mae'], 93 'output_1': ['mae'], 94 }), 95 dict( 96 testcase_name='dict_of_list_of_built_in_fn', 97 value={ 98 'output': [metrics.mae], 99 'output_1': [metrics.mae], 100 }), 101 dict( 102 testcase_name='dict_of_list_of_built_in_class', 103 value={ 104 'output': [metrics.MeanAbsoluteError], 105 'output_1': [metrics.MeanAbsoluteError], 106 }), 107 dict( 108 testcase_name='dict_of_list_of_custom_fn', 109 value={ 110 'output': [_my_mae], 111 'output_1': [_my_mae], 112 }), 113 dict( 114 testcase_name='dict_of_list_of_custom_class', 115 value={ 116 'output': [MyMeanAbsoluteError], 117 'output_1': [MyMeanAbsoluteError], 118 }), 119 dict( 120 testcase_name='dict_of_string', 121 value={ 122 'output': 'mae', 123 'output_1': 'mae', 124 }), 125 dict( 126 testcase_name='dict_of_built_in_fn', 127 value={ 128 'output': metrics.mae, 129 'output_1': metrics.mae, 130 }), 131 dict( 132 testcase_name='dict_of_built_in_class', 133 value={ 134 'output': metrics.MeanAbsoluteError, 135 'output_1': metrics.MeanAbsoluteError, 136 }), 137 dict( 138 testcase_name='dict_of_custom_fn', 139 value={ 140 'output': _my_mae, 141 'output_1': _my_mae 142 }), 143 dict( 144 testcase_name='dict_of_custom_class', 145 value={ 146 'output': MyMeanAbsoluteError, 147 'output_1': MyMeanAbsoluteError, 148 }), 149) 150class MetricsSerialization(keras_parameterized.TestCase): 151 152 def setUp(self): 153 super(MetricsSerialization, self).setUp() 154 tmpdir = self.get_temp_dir() 155 self.addCleanup(shutil.rmtree, tmpdir) 156 self.model_filename = os.path.join(tmpdir, 'tmp_model_metric.h5') 157 self.x = np.array([[0.], [1.], [2.]], dtype='float32') 158 self.y = np.array([[0.5], [2.], [3.5]], dtype='float32') 159 self.w = np.array([1.25, 0.5, 1.25], dtype='float32') 160 161 def test_serializing_model_with_metric_with_custom_object_scope(self, value): 162 163 def get_instance(x): 164 if isinstance(x, str): 165 return x 166 if isinstance(x, type) and issubclass(x, metrics.Metric): 167 return x() 168 return x 169 170 metric_input = nest.map_structure(get_instance, value) 171 weighted_metric_input = nest.map_structure(get_instance, value) 172 173 with generic_utils.custom_object_scope({ 174 'MyMeanAbsoluteError': MyMeanAbsoluteError, 175 '_my_mae': _my_mae, 176 'Bias': testing_utils.Bias, 177 }): 178 model = _get_multi_io_model() 179 model.compile( 180 optimizer_v2.gradient_descent.SGD(0.1), 181 'mae', 182 metrics=metric_input, 183 weighted_metrics=weighted_metric_input, 184 run_eagerly=testing_utils.should_run_eagerly()) 185 history = model.fit([self.x, self.x], [self.y, self.y], 186 batch_size=3, 187 epochs=3, 188 sample_weight=[self.w, self.w]) 189 190 # Assert training. 191 self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3) 192 eval_results = model.evaluate([self.x, self.x], [self.y, self.y], 193 sample_weight=[self.w, self.w]) 194 195 if h5py is None: 196 return 197 model.save(self.model_filename) 198 loaded_model = keras.models.load_model(self.model_filename) 199 loaded_model.predict([self.x, self.x]) 200 loaded_eval_results = loaded_model.evaluate( 201 [self.x, self.x], [self.y, self.y], sample_weight=[self.w, self.w]) 202 203 # Assert all evaluation results are the same. 204 self.assertAllClose(eval_results, loaded_eval_results, 1e-9) 205 206 def test_serializing_model_with_metric_with_custom_objects(self, value): 207 208 def get_instance(x): 209 if isinstance(x, str): 210 return x 211 if isinstance(x, type) and issubclass(x, metrics.Metric): 212 return x() 213 return x 214 215 metric_input = nest.map_structure(get_instance, value) 216 weighted_metric_input = nest.map_structure(get_instance, value) 217 218 model = _get_multi_io_model() 219 model.compile( 220 optimizer_v2.gradient_descent.SGD(0.1), 221 'mae', 222 metrics=metric_input, 223 weighted_metrics=weighted_metric_input, 224 run_eagerly=testing_utils.should_run_eagerly()) 225 history = model.fit([self.x, self.x], [self.y, self.y], 226 batch_size=3, 227 epochs=3, 228 sample_weight=[self.w, self.w]) 229 230 # Assert training. 231 self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3) 232 eval_results = model.evaluate([self.x, self.x], [self.y, self.y], 233 sample_weight=[self.w, self.w]) 234 235 if h5py is None: 236 return 237 model.save(self.model_filename) 238 loaded_model = keras.models.load_model( 239 self.model_filename, 240 custom_objects={ 241 'MyMeanAbsoluteError': MyMeanAbsoluteError, 242 '_my_mae': _my_mae, 243 'Bias': testing_utils.Bias, 244 }) 245 loaded_model.predict([self.x, self.x]) 246 loaded_eval_results = loaded_model.evaluate([self.x, self.x], 247 [self.y, self.y], 248 sample_weight=[self.w, self.w]) 249 250 # Assert all evaluation results are the same. 251 self.assertAllClose(eval_results, loaded_eval_results, 1e-9) 252 253 254if __name__ == '__main__': 255 test.main() 256