1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests metrics correctness using Keras model."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python import tf2
24from tensorflow.python.keras import keras_parameterized
25from tensorflow.python.keras import layers
26from tensorflow.python.keras import metrics
27from tensorflow.python.keras import testing_utils
28from tensorflow.python.platform import test
29
30
31@keras_parameterized.run_with_all_model_types(exclude_models=['sequential'])
32@keras_parameterized.run_all_keras_modes
33class TestMetricsCorrectnessMultiIO(keras_parameterized.TestCase):
34
35  def _get_multi_io_model(self):
36    inp_1 = layers.Input(shape=(1,), name='input_1')
37    inp_2 = layers.Input(shape=(1,), name='input_2')
38    x = layers.Dense(3, kernel_initializer='ones', trainable=False)
39    out_1 = layers.Dense(
40        1, kernel_initializer='ones', name='output_1', trainable=False)
41    out_2 = layers.Dense(
42        1, kernel_initializer='ones', name='output_2', trainable=False)
43
44    branch_a = [inp_1, x, out_1]
45    branch_b = [inp_2, x, out_2]
46    model = testing_utils.get_multi_io_model(branch_a, branch_b)
47    model.compile(
48        optimizer='rmsprop',
49        loss='mse',
50        metrics=[metrics.MeanSquaredError(name='mean_squared_error')],
51        weighted_metrics=[
52            metrics.MeanSquaredError(name='mean_squared_error_2')
53        ],
54        run_eagerly=testing_utils.should_run_eagerly())
55    return model
56
57  def _custom_generator(self):
58    batch_size = 2
59    num_samples = 4
60    inputs = np.asarray([[1.], [2.], [3.], [4.]])
61    targets = np.asarray([[2.], [4.], [6.], [8.]])
62    w1 = np.asarray([2., 3., 4., 5.])
63    w2 = np.asarray([3.5, 2.5, 1.5, 0.5])
64    i = 0
65    while True:
66      batch_index = i * batch_size % num_samples
67      i += 1
68      start = batch_index
69      end = start + batch_size
70      x = [inputs[start:end], inputs[start:end]]
71      y = [targets[start:end], targets[start:end]]
72      w = [w1[start:end], w2[start:end]]
73      yield x, y, w
74
75  def setUp(self):
76    super(TestMetricsCorrectnessMultiIO, self).setUp()
77    self.x = np.asarray([[1.], [2.], [3.], [4.]])
78    self.y = np.asarray([[2.], [4.], [6.], [8.]])
79    self.weights_1 = np.asarray([2., 3., 4., 5.])
80    self.weights_2 = np.asarray([3.5, 2.5, 1.5, 0.5])
81
82    # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
83
84    # Metric `output_1`, `output_2`:
85    #   Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30,
86    #   Count = 2 + 2
87    #   Result = 7.5
88
89    # Weighted metric `output_1`:
90    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
91    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
92    #         = 130
93    #   Count = (2 + 3) + (4 + 5)
94    #   Result = 9.2857141
95
96    # Weighted metric `output_2`:
97    #   Total = ((3 - 2)^2 * 3.5 + (6 - 4)^2 * 2.5) +
98    #           ((9 - 6)^2 * 1.5 + (12 - 8)^2 * 0.5)
99    #         = 35
100    #   Count = (3.5 + 2.5) + (1.5 + 0.5)
101    #   Result = 4.375
102
103    # Loss `output_1`:
104    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
105    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
106    #         = 130
107    #   Count = 2 + 2
108    #   Result = 32.5
109
110    # Loss `output_2`:
111    #   Total = ((3 - 2)^2 * 3.5 + (6 - 4)^2 * 2.5) +
112    #           ((9 - 6)^2 * 1.5 + (12 - 8)^2 * 0.5)
113    #         = 35
114    #   Count = 2 + 2
115    #   Result = 8.75
116
117    # Total loss = 32.5 + 8.75 = 41.25
118
119    wmse = 'mean_squared_error_2'
120    if not tf2.enabled():
121      wmse = 'weighted_' + wmse
122    self.expected_fit_result = {
123        'output_1_mean_squared_error': [7.5, 7.5],
124        'output_2_mean_squared_error': [7.5, 7.5],
125        'output_1_' + wmse: [9.286, 9.286],
126        'output_2_' + wmse: [4.375, 4.375],
127        'loss': [41.25, 41.25],
128        'output_1_loss': [32.5, 32.5],
129        'output_2_loss': [8.75, 8.75],
130    }
131
132    # In the order: 'loss', 'output_1_loss', 'output_2_loss',
133    # 'output_1_mean_squared_error', 'output_1_mean_squared_error_2',
134    # 'output_2_mean_squared_error', 'output_2_mean_squared_error_2'
135    self.expected_batch_result = [41.25, 32.5, 8.75, 7.5, 9.286, 7.5, 4.375]
136
137  def test_fit(self):
138    model = self._get_multi_io_model()
139    history = model.fit([self.x, self.x], [self.y, self.y],
140                        sample_weight={
141                            'output_1': self.weights_1,
142                            'output_2': self.weights_2,
143                        },
144                        batch_size=2,
145                        epochs=2,
146                        shuffle=False)
147    for key, value in self.expected_fit_result.items():
148      self.assertAllClose(history.history[key], value, 1e-3)
149
150  def test_eval(self):
151    model = self._get_multi_io_model()
152    eval_result = model.evaluate([self.x, self.x], [self.y, self.y],
153                                 batch_size=2,
154                                 sample_weight={
155                                     'output_1': self.weights_1,
156                                     'output_2': self.weights_2,
157                                 })
158    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
159
160    # Verify that metric value is same with arbitrary weights and batch size.
161    x = np.random.random((50, 1))
162    y = np.random.random((50, 1))
163    w = np.random.random((50,))
164    mse1 = model.evaluate([x, x], [y, y], sample_weight=[w, w], batch_size=5)[3]
165    mse2 = model.evaluate([x, x], [y, y], sample_weight=[w, w],
166                          batch_size=10)[3]
167    self.assertAllClose(mse1, mse2, 1e-3)
168
169  def test_train_on_batch(self):
170    model = self._get_multi_io_model()
171    result = model.train_on_batch([self.x, self.x], [self.y, self.y],
172                                  sample_weight={
173                                      'output_1': self.weights_1,
174                                      'output_2': self.weights_2,
175                                  })
176    self.assertAllClose(result, self.expected_batch_result, 1e-3)
177
178  def test_test_on_batch(self):
179    model = self._get_multi_io_model()
180    result = model.test_on_batch([self.x, self.x], [self.y, self.y],
181                                 sample_weight={
182                                     'output_1': self.weights_1,
183                                     'output_2': self.weights_2,
184                                 })
185    self.assertAllClose(result, self.expected_batch_result, 1e-3)
186
187  def test_fit_generator(self):
188    model = self._get_multi_io_model()
189    history = model.fit_generator(
190        self._custom_generator(), steps_per_epoch=2, epochs=2)
191    for key, value in self.expected_fit_result.items():
192      self.assertAllClose(history.history[key], value, 1e-3)
193
194  def test_eval_generator(self):
195    model = self._get_multi_io_model()
196    eval_result = model.evaluate_generator(self._custom_generator(), steps=2)
197    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
198
199
200@keras_parameterized.run_with_all_model_types
201@keras_parameterized.run_all_keras_modes
202class TestMetricsCorrectnessSingleIO(keras_parameterized.TestCase):
203
204  def _get_model(self):
205    x = layers.Dense(3, kernel_initializer='ones', trainable=False)
206    out = layers.Dense(
207        1, kernel_initializer='ones', name='output', trainable=False)
208    model = testing_utils.get_model_from_layers([x, out], input_shape=(1,))
209    model.compile(
210        optimizer='rmsprop',
211        loss='mse',
212        metrics=[metrics.MeanSquaredError(name='mean_squared_error')],
213        weighted_metrics=[
214            metrics.MeanSquaredError(name='mean_squared_error_2')
215        ],
216        run_eagerly=testing_utils.should_run_eagerly())
217    return model
218
219  def _custom_generator(self):
220    batch_size = 2
221    num_samples = 4
222    x = np.asarray([[1.], [2.], [3.], [4.]])
223    y = np.asarray([[2.], [4.], [6.], [8.]])
224    w = np.asarray([2., 3., 4., 5.])
225    i = 0
226    while True:
227      batch_index = i * batch_size % num_samples
228      i += 1
229      start = batch_index
230      end = start + batch_size
231      yield x[start:end], y[start:end], w[start:end]
232
233  def setUp(self):
234    super(TestMetricsCorrectnessSingleIO, self).setUp()
235    self.x = np.asarray([[1.], [2.], [3.], [4.]])
236    self.y = np.asarray([[2.], [4.], [6.], [8.]])
237    self.weights = np.asarray([2., 3., 4., 5.])
238
239    # y_true = [[2.], [4.], [6.], [8.]], y_pred = [[3.], [6.], [9.], [12.]]
240
241    # Metric:
242    #   Total = ((3 - 2)^2 + (6 - 4)^2) + ((9 - 6)^2 + (12 - 8)^2) = 30,
243    #   Count = 2 + 2
244    #   Result = 7.5
245
246    # Weighted metric:
247    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
248    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
249    #         = 130
250    #   Count = (2 + 3) + (4 + 5)
251    #   Result = 9.2857141
252
253    # Total loss:
254    #   Total = ((3 - 2)^2 * 2  + (6 - 4)^2 * 3) +
255    #           ((9 - 6)^2 * 4 + (12 - 8)^2 * 5)
256    #         = 130,
257    #   Count = 2 + 2
258    #   Result = 32.5
259
260    wmse = 'mean_squared_error_2'
261    if not tf2.enabled():
262      wmse = 'weighted_' + wmse
263    self.expected_fit_result = {
264        'mean_squared_error': [7.5, 7.5],
265        wmse: [9.286, 9.286],
266        'loss': [32.5, 32.5]
267    }
268
269    # In the order: 'loss', 'mean_squared_error', 'mean_squared_error_2'
270    self.expected_batch_result = [32.5, 7.5, 9.286]
271
272  def test_fit(self):
273    model = self._get_model()
274    history = model.fit(
275        self.x,
276        self.y,
277        sample_weight=self.weights,
278        batch_size=2,
279        epochs=2,
280        shuffle=False)
281    for key, value in self.expected_fit_result.items():
282      self.assertAllClose(history.history[key], value, 1e-3)
283
284  def test_eval(self):
285    model = self._get_model()
286    eval_result = model.evaluate(
287        self.x, self.y, batch_size=2, sample_weight=self.weights)
288    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
289
290    # Verify that metric value is same with arbitrary weights and batch size.
291    x = np.random.random((50, 1))
292    y = np.random.random((50, 1))
293    w = np.random.random((50,))
294    mse1 = model.evaluate(x, y, sample_weight=w, batch_size=5)[1]
295    mse2 = model.evaluate(x, y, sample_weight=w, batch_size=10)[1]
296    self.assertAllClose(mse1, mse2, 1e-3)
297
298  def test_train_on_batch(self):
299    model = self._get_model()
300    result = model.train_on_batch(self.x, self.y, sample_weight=self.weights)
301    self.assertAllClose(result, self.expected_batch_result, 1e-3)
302
303  def test_test_on_batch(self):
304    model = self._get_model()
305    result = model.test_on_batch(self.x, self.y, sample_weight=self.weights)
306    self.assertAllClose(result, self.expected_batch_result, 1e-3)
307
308  def test_fit_generator(self):
309    model = self._get_model()
310    history = model.fit_generator(
311        self._custom_generator(), steps_per_epoch=2, epochs=2)
312    for key, value in self.expected_fit_result.items():
313      self.assertAllClose(history.history[key], value, 1e-3)
314
315  def test_eval_generator(self):
316    model = self._get_model()
317    eval_result = model.evaluate_generator(self._custom_generator(), steps=2)
318    self.assertAllClose(eval_result, self.expected_batch_result, 1e-3)
319
320
321if __name__ == '__main__':
322  test.main()
323