1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils for Estimator (deprecated). 16 17This module and all its submodules are deprecated. See 18[contrib/learn/README.md](https://www.tensorflow.org/code/tensorflow/contrib/learn/README.md) 19for migration instructions. 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26from tensorflow.python.util import tf_inspect 27 28 29def assert_estimator_contract(tester, estimator_class): 30 """Asserts whether given estimator satisfies the expected contract. 31 32 This doesn't check every details of contract. This test is used for that a 33 function is not forgotten to implement in a precanned Estimator. 34 35 Args: 36 tester: A tf.test.TestCase. 37 estimator_class: 'type' object of pre-canned estimator. 38 """ 39 attributes = tf_inspect.getmembers(estimator_class) 40 attribute_names = [a[0] for a in attributes] 41 42 tester.assertTrue('config' in attribute_names) 43 tester.assertTrue('evaluate' in attribute_names) 44 tester.assertTrue('export' in attribute_names) 45 tester.assertTrue('fit' in attribute_names) 46 tester.assertTrue('get_variable_names' in attribute_names) 47 tester.assertTrue('get_variable_value' in attribute_names) 48 tester.assertTrue('model_dir' in attribute_names) 49 tester.assertTrue('predict' in attribute_names) 50 51 52def assert_in_range(min_value, max_value, key, metrics): 53 actual_value = metrics[key] 54 if actual_value < min_value: 55 raise ValueError('%s: %s < %s.' % (key, actual_value, min_value)) 56 if actual_value > max_value: 57 raise ValueError('%s: %s > %s.' % (key, actual_value, max_value)) 58