1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extract parse_example op configuration to a proto."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.core.example import example_parser_configuration_pb2
22from tensorflow.python.framework import tensor_shape
23from tensorflow.python.framework import tensor_util
24
25
26def extract_example_parser_configuration(parse_example_op, sess):
27  """Returns an ExampleParserConfig proto.
28
29  Args:
30    parse_example_op: A ParseExample `Operation`
31    sess: A tf.Session needed to obtain some configuration values.
32  Returns:
33    A ExampleParserConfig proto.
34
35  Raises:
36    ValueError: If attributes are inconsistent.
37  """
38  config = example_parser_configuration_pb2.ExampleParserConfiguration()
39
40  num_sparse = parse_example_op.get_attr("Nsparse")
41  num_dense = parse_example_op.get_attr("Ndense")
42  total_features = num_dense + num_sparse
43
44  sparse_types = parse_example_op.get_attr("sparse_types")
45  dense_types = parse_example_op.get_attr("Tdense")
46  dense_shapes = parse_example_op.get_attr("dense_shapes")
47
48  if len(sparse_types) != num_sparse:
49    raise ValueError("len(sparse_types) attribute does not match "
50                     "Nsparse attribute (%d vs %d)" %
51                     (len(sparse_types), num_sparse))
52
53  if len(dense_types) != num_dense:
54    raise ValueError("len(dense_types) attribute does not match "
55                     "Ndense attribute (%d vs %d)" %
56                     (len(dense_types), num_dense))
57
58  if len(dense_shapes) != num_dense:
59    raise ValueError("len(dense_shapes) attribute does not match "
60                     "Ndense attribute (%d vs %d)" %
61                     (len(dense_shapes), num_dense))
62
63  # Skip over the serialized input, and the names input.
64  fetch_list = parse_example_op.inputs[2:]
65
66  # Fetch total_features key names and num_dense default values.
67  if len(fetch_list) != (total_features + num_dense):
68    raise ValueError("len(fetch_list) does not match total features + "
69                     "num_dense (%d vs %d)" %
70                     (len(fetch_list), (total_features + num_dense)))
71
72  fetched = sess.run(fetch_list)
73
74  if len(fetched) != len(fetch_list):
75    raise ValueError("len(fetched) does not match len(fetch_list) "
76                     "(%d vs %d)" % (len(fetched), len(fetch_list)))
77
78  # Fetch indices.
79  sparse_keys_start = 0
80  dense_keys_start = sparse_keys_start + num_sparse
81  dense_def_start = dense_keys_start + num_dense
82
83  # Output tensor indices.
84  sparse_indices_start = 0
85  sparse_values_start = num_sparse
86  sparse_shapes_start = sparse_values_start + num_sparse
87  dense_values_start = sparse_shapes_start + num_sparse
88
89  # Dense features.
90  for i in range(num_dense):
91    key = fetched[dense_keys_start + i]
92    feature_config = config.feature_map[key]
93    # Convert the default value numpy array fetched from the session run
94    # into a TensorProto.
95    fixed_config = feature_config.fixed_len_feature
96
97    fixed_config.default_value.CopyFrom(
98        tensor_util.make_tensor_proto(fetched[dense_def_start + i]))
99    # Convert the shape from the attributes
100    # into a TensorShapeProto.
101    fixed_config.shape.CopyFrom(
102        tensor_shape.TensorShape(dense_shapes[i]).as_proto())
103
104    fixed_config.dtype = dense_types[i].as_datatype_enum
105    # Get the output tensor name.
106    fixed_config.values_output_tensor_name = parse_example_op.outputs[
107        dense_values_start + i].name
108
109  # Sparse features.
110  for i in range(num_sparse):
111    key = fetched[sparse_keys_start + i]
112    feature_config = config.feature_map[key]
113    var_len_feature = feature_config.var_len_feature
114    var_len_feature.dtype = sparse_types[i].as_datatype_enum
115    var_len_feature.indices_output_tensor_name = parse_example_op.outputs[
116        sparse_indices_start + i].name
117    var_len_feature.values_output_tensor_name = parse_example_op.outputs[
118        sparse_values_start + i].name
119    var_len_feature.shapes_output_tensor_name = parse_example_op.outputs[
120        sparse_shapes_start + i].name
121
122  return config
123