1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for Keras metrics serialization."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23
24from absl.testing import parameterized
25import numpy as np
26
27from tensorflow.python import keras
28from tensorflow.python.keras import keras_parameterized
29from tensorflow.python.keras import layers
30from tensorflow.python.keras import metrics
31from tensorflow.python.keras import optimizer_v2
32from tensorflow.python.keras import testing_utils
33from tensorflow.python.keras.utils import generic_utils
34from tensorflow.python.ops import math_ops
35from tensorflow.python.platform import test
36from tensorflow.python.util import nest
37
38try:
39  import h5py  # pylint:disable=g-import-not-at-top
40except ImportError:
41  h5py = None
42
43
44# Custom metric
45class MyMeanAbsoluteError(metrics.MeanMetricWrapper):
46
47  def __init__(self, name='my_mae', dtype=None):
48    super(MyMeanAbsoluteError, self).__init__(_my_mae, name, dtype=dtype)
49
50
51# Custom metric function
52def _my_mae(y_true, y_pred):
53  return keras.backend.mean(math_ops.abs(y_pred - y_true), axis=-1)
54
55
56def _get_multi_io_model():
57  inp_1 = layers.Input(shape=(1,), name='input_1')
58  inp_2 = layers.Input(shape=(1,), name='input_2')
59  d = testing_utils.Bias(name='output')
60  out_1 = d(inp_1)
61  out_2 = d(inp_2)
62  return keras.Model([inp_1, inp_2], [out_1, out_2])
63
64
65@keras_parameterized.run_all_keras_modes
66@parameterized.named_parameters(
67    dict(testcase_name='string', value=['mae']),
68    dict(testcase_name='built_in_fn', value=[metrics.mae]),
69    dict(testcase_name='built_in_class', value=[metrics.MeanAbsoluteError]),
70    dict(testcase_name='custom_fn', value=[_my_mae]),
71    dict(testcase_name='custom_class', value=[MyMeanAbsoluteError]),
72    dict(
73        testcase_name='list_of_built_in_fn_and_list',
74        value=[metrics.mae, [metrics.mae]]),
75    dict(
76        testcase_name='list_of_built_in_class_and_list',
77        value=[metrics.MeanAbsoluteError, [metrics.MeanAbsoluteError]]),
78    dict(
79        testcase_name='list_of_custom_fn_and_list', value=[_my_mae, [_my_mae]]),
80    dict(
81        testcase_name='list_of_custom_class_and_list',
82        value=[MyMeanAbsoluteError, [MyMeanAbsoluteError]]),
83    dict(
84        testcase_name='list_of_lists_of_custom_fns',
85        value=[[_my_mae], [_my_mae, 'mae']]),
86    dict(
87        testcase_name='list_of_lists_of_custom_classes',
88        value=[[MyMeanAbsoluteError], [MyMeanAbsoluteError, 'mae']]),
89    dict(
90        testcase_name='dict_of_list_of_string',
91        value={
92            'output': ['mae'],
93            'output_1': ['mae'],
94        }),
95    dict(
96        testcase_name='dict_of_list_of_built_in_fn',
97        value={
98            'output': [metrics.mae],
99            'output_1': [metrics.mae],
100        }),
101    dict(
102        testcase_name='dict_of_list_of_built_in_class',
103        value={
104            'output': [metrics.MeanAbsoluteError],
105            'output_1': [metrics.MeanAbsoluteError],
106        }),
107    dict(
108        testcase_name='dict_of_list_of_custom_fn',
109        value={
110            'output': [_my_mae],
111            'output_1': [_my_mae],
112        }),
113    dict(
114        testcase_name='dict_of_list_of_custom_class',
115        value={
116            'output': [MyMeanAbsoluteError],
117            'output_1': [MyMeanAbsoluteError],
118        }),
119    dict(
120        testcase_name='dict_of_string',
121        value={
122            'output': 'mae',
123            'output_1': 'mae',
124        }),
125    dict(
126        testcase_name='dict_of_built_in_fn',
127        value={
128            'output': metrics.mae,
129            'output_1': metrics.mae,
130        }),
131    dict(
132        testcase_name='dict_of_built_in_class',
133        value={
134            'output': metrics.MeanAbsoluteError,
135            'output_1': metrics.MeanAbsoluteError,
136        }),
137    dict(
138        testcase_name='dict_of_custom_fn',
139        value={
140            'output': _my_mae,
141            'output_1': _my_mae
142        }),
143    dict(
144        testcase_name='dict_of_custom_class',
145        value={
146            'output': MyMeanAbsoluteError,
147            'output_1': MyMeanAbsoluteError,
148        }),
149)
150class MetricsSerialization(keras_parameterized.TestCase):
151
152  def setUp(self):
153    super(MetricsSerialization, self).setUp()
154    tmpdir = self.get_temp_dir()
155    self.addCleanup(shutil.rmtree, tmpdir)
156    self.model_filename = os.path.join(tmpdir, 'tmp_model_metric.h5')
157    self.x = np.array([[0.], [1.], [2.]], dtype='float32')
158    self.y = np.array([[0.5], [2.], [3.5]], dtype='float32')
159    self.w = np.array([1.25, 0.5, 1.25], dtype='float32')
160
161  def test_serializing_model_with_metric_with_custom_object_scope(self, value):
162
163    def get_instance(x):
164      if isinstance(x, str):
165        return x
166      if isinstance(x, type) and issubclass(x, metrics.Metric):
167        return x()
168      return x
169
170    metric_input = nest.map_structure(get_instance, value)
171    weighted_metric_input = nest.map_structure(get_instance, value)
172
173    with generic_utils.custom_object_scope({
174        'MyMeanAbsoluteError': MyMeanAbsoluteError,
175        '_my_mae': _my_mae,
176        'Bias': testing_utils.Bias,
177    }):
178      model = _get_multi_io_model()
179      model.compile(
180          optimizer_v2.gradient_descent.SGD(0.1),
181          'mae',
182          metrics=metric_input,
183          weighted_metrics=weighted_metric_input,
184          run_eagerly=testing_utils.should_run_eagerly())
185      history = model.fit([self.x, self.x], [self.y, self.y],
186                          batch_size=3,
187                          epochs=3,
188                          sample_weight=[self.w, self.w])
189
190      # Assert training.
191      self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
192      eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
193                                    sample_weight=[self.w, self.w])
194
195      if h5py is None:
196        return
197      model.save(self.model_filename)
198      loaded_model = keras.models.load_model(self.model_filename)
199      loaded_model.predict([self.x, self.x])
200      loaded_eval_results = loaded_model.evaluate(
201          [self.x, self.x], [self.y, self.y], sample_weight=[self.w, self.w])
202
203      # Assert all evaluation results are the same.
204      self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
205
206  def test_serializing_model_with_metric_with_custom_objects(self, value):
207
208    def get_instance(x):
209      if isinstance(x, str):
210        return x
211      if isinstance(x, type) and issubclass(x, metrics.Metric):
212        return x()
213      return x
214
215    metric_input = nest.map_structure(get_instance, value)
216    weighted_metric_input = nest.map_structure(get_instance, value)
217
218    model = _get_multi_io_model()
219    model.compile(
220        optimizer_v2.gradient_descent.SGD(0.1),
221        'mae',
222        metrics=metric_input,
223        weighted_metrics=weighted_metric_input,
224        run_eagerly=testing_utils.should_run_eagerly())
225    history = model.fit([self.x, self.x], [self.y, self.y],
226                        batch_size=3,
227                        epochs=3,
228                        sample_weight=[self.w, self.w])
229
230    # Assert training.
231    self.assertAllClose(history.history['loss'], [2., 1.6, 1.2], 1e-3)
232    eval_results = model.evaluate([self.x, self.x], [self.y, self.y],
233                                  sample_weight=[self.w, self.w])
234
235    if h5py is None:
236      return
237    model.save(self.model_filename)
238    loaded_model = keras.models.load_model(
239        self.model_filename,
240        custom_objects={
241            'MyMeanAbsoluteError': MyMeanAbsoluteError,
242            '_my_mae': _my_mae,
243            'Bias': testing_utils.Bias,
244        })
245    loaded_model.predict([self.x, self.x])
246    loaded_eval_results = loaded_model.evaluate([self.x, self.x],
247                                                [self.y, self.y],
248                                                sample_weight=[self.w, self.w])
249
250    # Assert all evaluation results are the same.
251    self.assertAllClose(eval_results, loaded_eval_results, 1e-9)
252
253
254if __name__ == '__main__':
255  test.main()
256