1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for model.fit calls with a Dataset object passed as validation_data."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.python.data.ops import dataset_ops
24from tensorflow.python.keras import keras_parameterized
25from tensorflow.python.keras import testing_utils
26from tensorflow.python.keras.layers import core
27from tensorflow.python.platform import test
28
29
30@keras_parameterized.run_with_all_model_types
31@keras_parameterized.run_all_keras_modes
32class ValidationDatasetNoLimitTest(keras_parameterized.TestCase):
33
34  def create_dataset(self, num_samples, batch_size):
35    input_data = np.random.rand(num_samples, 1)
36    expected_data = input_data * 3
37    dataset = dataset_ops.Dataset.from_tensor_slices((input_data,
38                                                      expected_data))
39    return dataset.shuffle(10 * batch_size).batch(batch_size)
40
41  def test_validation_dataset_with_no_step_arg(self):
42    # Create a model that learns y=Mx.
43    layers = [core.Dense(1)]
44    model = testing_utils.get_model_from_layers(layers, input_shape=(1,))
45    model.compile(loss="mse", optimizer="adam", metrics=["mean_absolute_error"])
46
47    train_dataset = self.create_dataset(num_samples=200, batch_size=10)
48    eval_dataset = self.create_dataset(num_samples=50, batch_size=25)
49
50    history = model.fit(x=train_dataset, validation_data=eval_dataset, epochs=2)
51    evaluation = model.evaluate(x=eval_dataset)
52
53    # If the fit call used the entire dataset, then the final val MAE error
54    # from the fit history should be equal to the final element in the output
55    # of evaluating the model on the same eval dataset.
56    self.assertAlmostEqual(history.history["val_mean_absolute_error"][-1],
57                           evaluation[-1])
58
59
60if __name__ == "__main__":
61  test.main()
62