1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FeatureColumn serialization, deserialization logic."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from tensorflow.python.feature_column import feature_column_v2 as fc_lib
24from tensorflow.python.feature_column import sequence_feature_column as sfc_lib
25from tensorflow.python.ops import init_ops
26from tensorflow.python.util import tf_decorator
27from tensorflow.python.util import tf_inspect
28
29
30_FEATURE_COLUMNS = [
31    fc_lib.BucketizedColumn, fc_lib.CrossedColumn, fc_lib.EmbeddingColumn,
32    fc_lib.HashedCategoricalColumn, fc_lib.IdentityCategoricalColumn,
33    fc_lib.IndicatorColumn, fc_lib.NumericColumn,
34    fc_lib.SequenceCategoricalColumn, fc_lib.SequenceDenseColumn,
35    fc_lib.SharedEmbeddingColumn, fc_lib.VocabularyFileCategoricalColumn,
36    fc_lib.VocabularyListCategoricalColumn, fc_lib.WeightedCategoricalColumn,
37    init_ops.TruncatedNormal, sfc_lib.SequenceNumericColumn
38]
39
40
41def serialize_feature_column(fc):
42  """Serializes a FeatureColumn or a raw string key.
43
44  This method should only be used to serialize parent FeatureColumns when
45  implementing FeatureColumn.get_config(), else serialize_feature_columns()
46  is preferable.
47
48  This serialization also keeps information of the FeatureColumn class, so
49  deserialization is possible without knowing the class type. For example:
50
51  a = numeric_column('x')
52  a.get_config() gives:
53  {
54      'key': 'price',
55      'shape': (1,),
56      'default_value': None,
57      'dtype': 'float32',
58      'normalizer_fn': None
59  }
60  While serialize_feature_column(a) gives:
61  {
62      'class_name': 'NumericColumn',
63      'config': {
64          'key': 'price',
65          'shape': (1,),
66          'default_value': None,
67          'dtype': 'float32',
68          'normalizer_fn': None
69      }
70  }
71
72  Args:
73    fc: A FeatureColumn or raw feature key string.
74
75  Returns:
76    Keras serialization for FeatureColumns, leaves string keys unaffected.
77
78  Raises:
79    ValueError if called with input that is not string or FeatureColumn.
80  """
81  if isinstance(fc, six.string_types):
82    return fc
83  elif isinstance(fc, fc_lib.FeatureColumn):
84    return {'class_name': fc.__class__.__name__, 'config': fc.get_config()}
85  else:
86    raise ValueError('Instance: {} is not a FeatureColumn'.format(fc))
87
88
89def deserialize_feature_column(config,
90                               custom_objects=None,
91                               columns_by_name=None):
92  """Deserializes a `config` generated with `serialize_feature_column`.
93
94  This method should only be used to deserialize parent FeatureColumns when
95  implementing FeatureColumn.from_config(), else deserialize_feature_columns()
96  is preferable. Returns a FeatureColumn for this config.
97  TODO(b/118939620): Simplify code if Keras utils support object deduping.
98
99  Args:
100    config: A Dict with the serialization of feature columns acquired by
101      `serialize_feature_column`, or a string representing a raw column.
102    custom_objects: A Dict from custom_object name to the associated keras
103      serializable objects (FeatureColumns, classes or functions).
104    columns_by_name: A Dict[String, FeatureColumn] of existing columns in order
105      to avoid duplication.
106
107  Raises:
108    ValueError if `config` has invalid format (e.g: expected keys missing,
109    or refers to unknown classes).
110
111  Returns:
112    A FeatureColumn corresponding to the input `config`.
113  """
114  if isinstance(config, six.string_types):
115    return config
116  # A dict from class_name to class for all FeatureColumns in this module.
117  # FeatureColumns not part of the module can be passed as custom_objects.
118  module_feature_column_classes = {
119      cls.__name__: cls for cls in _FEATURE_COLUMNS}
120  if columns_by_name is None:
121    columns_by_name = {}
122
123  (cls,
124   cls_config) = _class_and_config_for_serialized_keras_object(
125       config,
126       module_objects=module_feature_column_classes,
127       custom_objects=custom_objects,
128       printable_module_name='feature_column_v2')
129
130  if not issubclass(cls, fc_lib.FeatureColumn):
131    raise ValueError(
132        'Expected FeatureColumn class, instead found: {}'.format(cls))
133
134  # Always deserialize the FeatureColumn, in order to get the name.
135  new_instance = cls.from_config(  # pylint: disable=protected-access
136      cls_config,
137      custom_objects=custom_objects,
138      columns_by_name=columns_by_name)
139
140  # If the name already exists, re-use the column from columns_by_name,
141  # (new_instance remains unused).
142  return columns_by_name.setdefault(
143      _column_name_with_class_name(new_instance), new_instance)
144
145
146def serialize_feature_columns(feature_columns):
147  """Serializes a list of FeatureColumns.
148
149  Returns a list of Keras-style config dicts that represent the input
150  FeatureColumns and can be used with `deserialize_feature_columns` for
151  reconstructing the original columns.
152
153  Args:
154    feature_columns: A list of FeatureColumns.
155
156  Returns:
157    Keras serialization for the list of FeatureColumns.
158
159  Raises:
160    ValueError if called with input that is not a list of FeatureColumns.
161  """
162  return [serialize_feature_column(fc) for fc in feature_columns]
163
164
165def deserialize_feature_columns(configs, custom_objects=None):
166  """Deserializes a list of FeatureColumns configs.
167
168  Returns a list of FeatureColumns given a list of config dicts acquired by
169  `serialize_feature_columns`.
170
171  Args:
172    configs: A list of Dicts with the serialization of feature columns acquired
173      by `serialize_feature_columns`.
174    custom_objects: A Dict from custom_object name to the associated keras
175      serializable objects (FeatureColumns, classes or functions).
176
177  Returns:
178    FeatureColumn objects corresponding to the input configs.
179
180  Raises:
181    ValueError if called with input that is not a list of FeatureColumns.
182  """
183  columns_by_name = {}
184  return [
185      deserialize_feature_column(c, custom_objects, columns_by_name)
186      for c in configs
187  ]
188
189
190def _column_name_with_class_name(fc):
191  """Returns a unique name for the feature column used during deduping.
192
193  Without this two FeatureColumns that have the same name and where
194  one wraps the other, such as an IndicatorColumn wrapping a
195  SequenceCategoricalColumn, will fail to deserialize because they will have the
196  same name in columns_by_name, causing the wrong column to be returned.
197
198  Args:
199    fc: A FeatureColumn.
200
201  Returns:
202    A unique name as a string.
203  """
204  return fc.__class__.__name__ + ':' + fc.name
205
206
207def _serialize_keras_object(instance):
208  """Serialize a Keras object into a JSON-compatible representation."""
209  _, instance = tf_decorator.unwrap(instance)
210  if instance is None:
211    return None
212
213  if hasattr(instance, 'get_config'):
214    name = instance.__class__.__name__
215    config = instance.get_config()
216    serialization_config = {}
217    for key, item in config.items():
218      if isinstance(item, six.string_types):
219        serialization_config[key] = item
220        continue
221
222      # Any object of a different type needs to be converted to string or dict
223      # for serialization (e.g. custom functions, custom classes)
224      try:
225        serialized_item = _serialize_keras_object(item)
226        if isinstance(serialized_item, dict) and not isinstance(item, dict):
227          serialized_item['__passive_serialization__'] = True
228        serialization_config[key] = serialized_item
229      except ValueError:
230        serialization_config[key] = item
231
232    return {'class_name': name, 'config': serialization_config}
233  if hasattr(instance, '__name__'):
234    return instance.__name__
235  raise ValueError('Cannot serialize', instance)
236
237
238def _deserialize_keras_object(identifier,
239                              module_objects=None,
240                              custom_objects=None,
241                              printable_module_name='object'):
242  """Turns the serialized form of a Keras object back into an actual object."""
243  if identifier is None:
244    return None
245
246  if isinstance(identifier, dict):
247    # In this case we are dealing with a Keras config dictionary.
248    config = identifier
249    (cls, cls_config) = _class_and_config_for_serialized_keras_object(
250        config, module_objects, custom_objects, printable_module_name)
251
252    if hasattr(cls, 'from_config'):
253      arg_spec = tf_inspect.getfullargspec(cls.from_config)
254      custom_objects = custom_objects or {}
255
256      if 'custom_objects' in arg_spec.args:
257        return cls.from_config(
258            cls_config,
259            custom_objects=dict(
260                list(custom_objects.items())))
261      return cls.from_config(cls_config)
262    else:
263      # Then `cls` may be a function returning a class.
264      # in this case by convention `config` holds
265      # the kwargs of the function.
266      custom_objects = custom_objects or {}
267      return cls(**cls_config)
268  elif isinstance(identifier, six.string_types):
269    object_name = identifier
270    if custom_objects and object_name in custom_objects:
271      obj = custom_objects.get(object_name)
272    else:
273      obj = module_objects.get(object_name)
274      if obj is None:
275        raise ValueError(
276            'Unknown ' + printable_module_name + ': ' + object_name)
277    # Classes passed by name are instantiated with no args, functions are
278    # returned as-is.
279    if tf_inspect.isclass(obj):
280      return obj()
281    return obj
282  elif tf_inspect.isfunction(identifier):
283    # If a function has already been deserialized, return as is.
284    return identifier
285  else:
286    raise ValueError('Could not interpret serialized %s: %s' %
287                     (printable_module_name, identifier))
288
289
290def _class_and_config_for_serialized_keras_object(
291    config,
292    module_objects=None,
293    custom_objects=None,
294    printable_module_name='object'):
295  """Returns the class name and config for a serialized keras object."""
296  if (not isinstance(config, dict) or 'class_name' not in config or
297      'config' not in config):
298    raise ValueError('Improper config format: ' + str(config))
299
300  class_name = config['class_name']
301  cls = _get_registered_object(class_name, custom_objects=custom_objects,
302                               module_objects=module_objects)
303  if cls is None:
304    raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
305
306  cls_config = config['config']
307
308  deserialized_objects = {}
309  for key, item in cls_config.items():
310    if isinstance(item, dict) and '__passive_serialization__' in item:
311      deserialized_objects[key] = _deserialize_keras_object(
312          item,
313          module_objects=module_objects,
314          custom_objects=custom_objects,
315          printable_module_name='config_item')
316    elif (isinstance(item, six.string_types) and
317          tf_inspect.isfunction(_get_registered_object(item, custom_objects))):
318      # Handle custom functions here. When saving functions, we only save the
319      # function's name as a string. If we find a matching string in the custom
320      # objects during deserialization, we convert the string back to the
321      # original function.
322      # Note that a potential issue is that a string field could have a naming
323      # conflict with a custom function name, but this should be a rare case.
324      # This issue does not occur if a string field has a naming conflict with
325      # a custom object, since the config of an object will always be a dict.
326      deserialized_objects[key] = _get_registered_object(item, custom_objects)
327  for key, item in deserialized_objects.items():
328    cls_config[key] = deserialized_objects[key]
329
330  return (cls, cls_config)
331
332
333def _get_registered_object(name, custom_objects=None, module_objects=None):
334  if custom_objects and name in custom_objects:
335    return custom_objects[name]
336  elif module_objects and name in module_objects:
337    return module_objects[name]
338  return None
339
340