1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Extract parse_example op configuration to a proto.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.core.example import example_parser_configuration_pb2 22from tensorflow.python.framework import tensor_shape 23from tensorflow.python.framework import tensor_util 24 25 26def extract_example_parser_configuration(parse_example_op, sess): 27 """Returns an ExampleParserConfig proto. 28 29 Args: 30 parse_example_op: A ParseExample `Operation` 31 sess: A tf.Session needed to obtain some configuration values. 32 Returns: 33 A ExampleParserConfig proto. 34 35 Raises: 36 ValueError: If attributes are inconsistent. 37 """ 38 config = example_parser_configuration_pb2.ExampleParserConfiguration() 39 40 num_sparse = parse_example_op.get_attr("Nsparse") 41 num_dense = parse_example_op.get_attr("Ndense") 42 total_features = num_dense + num_sparse 43 44 sparse_types = parse_example_op.get_attr("sparse_types") 45 dense_types = parse_example_op.get_attr("Tdense") 46 dense_shapes = parse_example_op.get_attr("dense_shapes") 47 48 if len(sparse_types) != num_sparse: 49 raise ValueError("len(sparse_types) attribute does not match " 50 "Nsparse attribute (%d vs %d)" % 51 (len(sparse_types), num_sparse)) 52 53 if len(dense_types) != num_dense: 54 raise ValueError("len(dense_types) attribute does not match " 55 "Ndense attribute (%d vs %d)" % 56 (len(dense_types), num_dense)) 57 58 if len(dense_shapes) != num_dense: 59 raise ValueError("len(dense_shapes) attribute does not match " 60 "Ndense attribute (%d vs %d)" % 61 (len(dense_shapes), num_dense)) 62 63 # Skip over the serialized input, and the names input. 64 fetch_list = parse_example_op.inputs[2:] 65 66 # Fetch total_features key names and num_dense default values. 67 if len(fetch_list) != (total_features + num_dense): 68 raise ValueError("len(fetch_list) does not match total features + " 69 "num_dense (%d vs %d)" % 70 (len(fetch_list), (total_features + num_dense))) 71 72 fetched = sess.run(fetch_list) 73 74 if len(fetched) != len(fetch_list): 75 raise ValueError("len(fetched) does not match len(fetch_list) " 76 "(%d vs %d)" % (len(fetched), len(fetch_list))) 77 78 # Fetch indices. 79 sparse_keys_start = 0 80 dense_keys_start = sparse_keys_start + num_sparse 81 dense_def_start = dense_keys_start + num_dense 82 83 # Output tensor indices. 84 sparse_indices_start = 0 85 sparse_values_start = num_sparse 86 sparse_shapes_start = sparse_values_start + num_sparse 87 dense_values_start = sparse_shapes_start + num_sparse 88 89 # Dense features. 90 for i in range(num_dense): 91 key = fetched[dense_keys_start + i] 92 feature_config = config.feature_map[key] 93 # Convert the default value numpy array fetched from the session run 94 # into a TensorProto. 95 fixed_config = feature_config.fixed_len_feature 96 97 fixed_config.default_value.CopyFrom( 98 tensor_util.make_tensor_proto(fetched[dense_def_start + i])) 99 # Convert the shape from the attributes 100 # into a TensorShapeProto. 101 fixed_config.shape.CopyFrom( 102 tensor_shape.TensorShape(dense_shapes[i]).as_proto()) 103 104 fixed_config.dtype = dense_types[i].as_datatype_enum 105 # Get the output tensor name. 106 fixed_config.values_output_tensor_name = parse_example_op.outputs[ 107 dense_values_start + i].name 108 109 # Sparse features. 110 for i in range(num_sparse): 111 key = fetched[sparse_keys_start + i] 112 feature_config = config.feature_map[key] 113 var_len_feature = feature_config.var_len_feature 114 var_len_feature.dtype = sparse_types[i].as_datatype_enum 115 var_len_feature.indices_output_tensor_name = parse_example_op.outputs[ 116 sparse_indices_start + i].name 117 var_len_feature.values_output_tensor_name = parse_example_op.outputs[ 118 sparse_values_start + i].name 119 var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[ 120 sparse_shapes_start + i].name 121 122 return config 123