1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#,============================================================================ 15"""Tests for `get_config` backwards compatibility.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.keras import keras_parameterized 22from tensorflow.python.keras.engine import sequential 23from tensorflow.python.keras.engine import training 24from tensorflow.python.keras.tests import get_config_samples 25from tensorflow.python.platform import test 26 27 28@keras_parameterized.run_all_keras_modes 29class TestGetConfigBackwardsCompatible(keras_parameterized.TestCase): 30 31 def test_functional_dnn(self): 32 model = training.Model.from_config(get_config_samples.FUNCTIONAL_DNN) 33 self.assertLen(model.layers, 3) 34 35 def test_functional_cnn(self): 36 model = training.Model.from_config(get_config_samples.FUNCTIONAL_CNN) 37 self.assertLen(model.layers, 4) 38 39 def test_functional_lstm(self): 40 model = training.Model.from_config(get_config_samples.FUNCTIONAL_LSTM) 41 self.assertLen(model.layers, 3) 42 43 def test_sequential_dnn(self): 44 model = sequential.Sequential.from_config(get_config_samples.SEQUENTIAL_DNN) 45 self.assertLen(model.layers, 2) 46 47 def test_sequential_cnn(self): 48 model = sequential.Sequential.from_config(get_config_samples.SEQUENTIAL_CNN) 49 self.assertLen(model.layers, 3) 50 51 def test_sequential_lstm(self): 52 model = sequential.Sequential.from_config( 53 get_config_samples.SEQUENTIAL_LSTM) 54 self.assertLen(model.layers, 2) 55 56 57if __name__ == '__main__': 58 test.main() 59