1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Utility classes for testing checkpointing."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import ops as ops_lib
24from tensorflow.python.ops import gen_lookup_ops
25from tensorflow.python.training import saver as saver_module
26
27
28class CheckpointedOp(object):
29  """Op with a custom checkpointing implementation.
30
31  Defined as part of the test because the MutableHashTable Python code is
32  currently in contrib.
33  """
34
35  # pylint: disable=protected-access
36  def __init__(self, name, table_ref=None):
37    if table_ref is None:
38      self.table_ref = gen_lookup_ops.mutable_hash_table_v2(
39          key_dtype=dtypes.string, value_dtype=dtypes.float32, name=name)
40    else:
41      self.table_ref = table_ref
42    self._name = name
43    if not context.executing_eagerly():
44      self._saveable = CheckpointedOp.CustomSaveable(self, name)
45      ops_lib.add_to_collection(ops_lib.GraphKeys.SAVEABLE_OBJECTS,
46                                self._saveable)
47
48  @property
49  def name(self):
50    return self._name
51
52  @property
53  def saveable(self):
54    if context.executing_eagerly():
55      return CheckpointedOp.CustomSaveable(self, self.name)
56    else:
57      return self._saveable
58
59  def insert(self, keys, values):
60    return gen_lookup_ops.lookup_table_insert_v2(self.table_ref, keys, values)
61
62  def lookup(self, keys, default):
63    return gen_lookup_ops.lookup_table_find_v2(self.table_ref, keys, default)
64
65  def keys(self):
66    return self._export()[0]
67
68  def values(self):
69    return self._export()[1]
70
71  def _export(self):
72    return gen_lookup_ops.lookup_table_export_v2(self.table_ref, dtypes.string,
73                                                 dtypes.float32)
74
75  class CustomSaveable(saver_module.BaseSaverBuilder.SaveableObject):
76    """A custom saveable for CheckpointedOp."""
77
78    def __init__(self, table, name):
79      tensors = table._export()
80      specs = [
81          saver_module.BaseSaverBuilder.SaveSpec(tensors[0], "",
82                                                 name + "-keys"),
83          saver_module.BaseSaverBuilder.SaveSpec(tensors[1], "",
84                                                 name + "-values")
85      ]
86      super(CheckpointedOp.CustomSaveable, self).__init__(table, specs, name)
87
88    def restore(self, restore_tensors, shapes):
89      return gen_lookup_ops.lookup_table_import_v2(
90          self.op.table_ref, restore_tensors[0], restore_tensors[1])
91  # pylint: enable=protected-access
92