1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""A `Predictor` constructed from an `learn.python.estimator.Estimator`.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.contrib.predictor import predictor 23from tensorflow.python.estimator import model_fn 24from tensorflow.python.framework import ops 25from tensorflow.python.saved_model import signature_constants 26from tensorflow.python.training import monitored_session 27 28 29def _get_signature_def( 30 serving_input_receiver, estimator, output_key=None): 31 """Construct a `SignatureDef` proto.""" 32 if output_key is None: 33 output_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 34 # pylint: disable=protected-access 35 estimator_spec = estimator.model_fn( 36 serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT, 37 estimator.config) 38 # pylint: enable=protected-access 39 export_outputs = estimator_spec.export_outputs 40 export_output = export_outputs.get(output_key) 41 if export_output is None: 42 raise KeyError('output_key must be one of {}; got {}'.format( 43 export_outputs.keys(), output_key)) 44 return export_output.as_signature_def(serving_input_receiver.receiver_tensors) 45 46 47class CoreEstimatorPredictor(predictor.Predictor): 48 """A `Predictor` constructed from an `learn.python.estimator.Estimator`.""" 49 50 def __init__(self, 51 estimator, 52 serving_input_receiver_fn, 53 output_key=None, 54 graph=None, 55 config=None): 56 """Initialize a `CoreEstimatorPredictor`. 57 58 Args: 59 estimator: an instance of `learn.python.estimator.Estimator`. 60 serving_input_receiver_fn: a function that takes no arguments and returns 61 an instance of `ServingInputReceiver` compatible with `estimator`. 62 output_key: Optional string specifying the export output to use. If 63 `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. 64 graph: Optional. The Tensorflow `graph` in which prediction should be 65 done. 66 config: `ConfigProto` proto used to configure the session. 67 """ 68 self._graph = graph or ops.Graph() 69 with self._graph.as_default(): 70 serving_input_receiver = serving_input_receiver_fn() 71 signature_def = _get_signature_def( 72 serving_input_receiver, estimator, output_key) 73 checkpoint_dir = estimator.model_dir 74 self._session = monitored_session.MonitoredSession( 75 session_creator=monitored_session.ChiefSessionCreator( 76 config=config, 77 checkpoint_dir=checkpoint_dir)) 78 79 feed_tensor_info = signature_def.inputs 80 self._feed_tensors = {k: self._graph.get_tensor_by_name(v.name) 81 for k, v in feed_tensor_info.items()} 82 fetch_tensor_info = signature_def.outputs 83 self._fetch_tensors = {k: self._graph.get_tensor_by_name(v.name) 84 for k, v in fetch_tensor_info.items()} 85