1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for model.fit calls with a Dataset object passed as validation_data.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.keras import keras_parameterized 25from tensorflow.python.keras import testing_utils 26from tensorflow.python.keras.layers import core 27from tensorflow.python.platform import test 28 29 30@keras_parameterized.run_with_all_model_types 31@keras_parameterized.run_all_keras_modes 32class ValidationDatasetNoLimitTest(keras_parameterized.TestCase): 33 34 def create_dataset(self, num_samples, batch_size): 35 input_data = np.random.rand(num_samples, 1) 36 expected_data = input_data * 3 37 dataset = dataset_ops.Dataset.from_tensor_slices((input_data, 38 expected_data)) 39 return dataset.shuffle(10 * batch_size).batch(batch_size) 40 41 def test_validation_dataset_with_no_step_arg(self): 42 # Create a model that learns y=Mx. 43 layers = [core.Dense(1)] 44 model = testing_utils.get_model_from_layers(layers, input_shape=(1,)) 45 model.compile(loss="mse", optimizer="adam", metrics=["mean_absolute_error"]) 46 47 train_dataset = self.create_dataset(num_samples=200, batch_size=10) 48 eval_dataset = self.create_dataset(num_samples=50, batch_size=25) 49 50 history = model.fit(x=train_dataset, validation_data=eval_dataset, epochs=2) 51 evaluation = model.evaluate(x=eval_dataset) 52 53 # If the fit call used the entire dataset, then the final val MAE error 54 # from the fit history should be equal to the final element in the output 55 # of evaluating the model on the same eval dataset. 56 self.assertAlmostEqual(history.history["val_mean_absolute_error"][-1], 57 evaluation[-1]) 58 59 60if __name__ == "__main__": 61 test.main() 62