1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
16# pylint: disable=g-bad-name
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import functools
23import six
24
25from tensorflow.python.compat import compat as fwd_compat
26from tensorflow.python.eager import context
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import gen_lookup_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.ops import string_ops
38# go/tf-wildcard-import
39# pylint: disable=wildcard-import
40from tensorflow.python.ops.gen_lookup_ops import *
41from tensorflow.python.training.saver import BaseSaverBuilder
42# pylint: enable=wildcard-import
43from tensorflow.python.training.tracking import base as trackable_base
44from tensorflow.python.training.tracking import tracking as trackable
45from tensorflow.python.util import compat
46from tensorflow.python.util.deprecation import deprecated
47from tensorflow.python.util.tf_export import tf_export
48
49
50@tf_export(v1=["initialize_all_tables"])
51@deprecated(None, "Use `tf.tables_initializer` instead.")
52def initialize_all_tables(name="init_all_tables"):
53  """Returns an Op that initializes all tables of the default graph.
54
55  Args:
56    name: Optional name for the initialization op.
57
58  Returns:
59    An Op that initializes all tables.  Note that if there are
60    not tables the returned Op is a NoOp.
61  """
62  return tables_initializer(name)
63
64
65@tf_export(v1=["initializers.tables_initializer", "tables_initializer"])
66def tables_initializer(name="init_all_tables"):
67  """Returns an Op that initializes all tables of the default graph.
68
69  See the [Low Level Intro](https://www.tensorflow.org/guide/low_level_intro#feature_columns)
70  guide, for an example of usage.
71
72  Args:
73    name: Optional name for the initialization op.
74
75  Returns:
76    An Op that initializes all tables.  Note that if there are
77    not tables the returned Op is a NoOp.
78  """
79  initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
80  if initializers:
81    return control_flow_ops.group(*initializers, name=name)
82  return control_flow_ops.no_op(name=name)
83
84
85def _check_table_dtypes(table, key_dtype, value_dtype):
86  """Check that the given key_dtype and value_dtype matches the table dtypes.
87
88  Args:
89    table: The table to check types against to.
90    key_dtype: The key data type to check.
91    value_dtype: The value data type to check.
92
93  Raises:
94    TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
95      types.
96  """
97  if key_dtype.base_dtype != table.key_dtype:
98    raise TypeError("Invalid key dtype, expected %s but got %s." %
99                    (table.key_dtype, key_dtype))
100  if value_dtype.base_dtype != table.value_dtype:
101    raise TypeError("Invalid value dtype, expected %s but got %s." %
102                    (table.value_dtype, value_dtype))
103
104
105class LookupInterface(trackable.TrackableResource):
106  """Represent a lookup table that persists across different steps."""
107
108  def __init__(self, key_dtype, value_dtype):
109    """Construct a lookup table interface.
110
111    Args:
112      key_dtype: The table key type.
113      value_dtype: The table value type.
114    """
115    self._key_dtype = dtypes.as_dtype(key_dtype)
116    self._value_dtype = dtypes.as_dtype(value_dtype)
117    super(LookupInterface, self).__init__()
118
119  def _create_resource(self):
120    raise NotImplementedError
121
122  @property
123  def key_dtype(self):
124    """The table key dtype."""
125    return self._key_dtype
126
127  @property
128  def value_dtype(self):
129    """The table value dtype."""
130    return self._value_dtype
131
132  @property
133  def name(self):
134    """The name of the table."""
135    return NotImplementedError
136
137  def size(self, name=None):
138    """Compute the number of elements in this table."""
139    raise NotImplementedError
140
141  def lookup(self, keys, name=None):
142    """Looks up `keys` in a table, outputs the corresponding values."""
143    raise NotImplementedError
144
145
146class InitializableLookupTableBase(LookupInterface):
147  """Initializable lookup table interface.
148
149  An initializable lookup tables persist across different steps.
150  """
151
152  def __init__(self, default_value, initializer):
153    """Construct a table object from a table reference.
154
155    If requires a table initializer object (subclass of `TableInitializerBase`).
156    It provides the table key and value types, as well as the op to initialize
157    the table. The caller is responsible to execute the initialization op.
158
159    Args:
160      default_value: The value to use if a key is missing in the table.
161      initializer: The table initializer to use.
162    """
163    super(InitializableLookupTableBase, self).__init__(initializer.key_dtype,
164                                                       initializer.value_dtype)
165    self._default_value = ops.convert_to_tensor(
166        default_value, dtype=self._value_dtype)
167    self._default_value.get_shape().merge_with(tensor_shape.scalar())
168    if isinstance(initializer, trackable_base.Trackable):
169      self._initializer = self._track_trackable(
170          initializer, "_initializer")
171    with ops.init_scope():
172      self._resource_handle = self._create_resource()
173      self._init_op = self._initialize()
174
175  def _initialize(self):
176    return self._initializer.initialize(self)
177
178  @property
179  def default_value(self):
180    """The default value of the table."""
181    return self._default_value
182
183  def size(self, name=None):
184    """Compute the number of elements in this table.
185
186    Args:
187      name: A name for the operation (optional).
188
189    Returns:
190      A scalar tensor containing the number of elements in this table.
191    """
192    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
193      return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
194
195  def lookup(self, keys, name=None):
196    """Looks up `keys` in a table, outputs the corresponding values.
197
198    The `default_value` is used for keys not present in the table.
199
200    Args:
201      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
202      name: A name for the operation (optional).
203
204    Returns:
205      A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
206
207    Raises:
208      TypeError: when `keys` or `default_value` doesn't match the table data
209        types.
210    """
211    key_tensor = keys
212    if isinstance(keys, sparse_tensor.SparseTensor):
213      key_tensor = keys.values
214
215    if keys.dtype.base_dtype != self._key_dtype:
216      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
217                      (self._key_dtype, keys.dtype))
218
219    with ops.name_scope(
220        name, "%s_Lookup" % self.name,
221        (self.resource_handle, key_tensor, self._default_value)):
222      values = gen_lookup_ops.lookup_table_find_v2(
223          self.resource_handle, key_tensor, self._default_value)
224
225    values.set_shape(key_tensor.get_shape())
226    if isinstance(keys, sparse_tensor.SparseTensor):
227      return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
228    else:
229      return values
230
231
232class InitializableLookupTableBaseV1(InitializableLookupTableBase):
233
234  @property
235  def initializer(self):
236    return self._init_op
237
238
239@tf_export("lookup.StaticHashTable", v1=[])
240class StaticHashTable(InitializableLookupTableBase):
241  """A generic hash table implementation.
242
243  Example usage:
244
245  ```python
246  table = tf.lookup.StaticHashTable(
247      tf.KeyValueTensorInitializer(keys, values), -1)
248  out = table.lookup(input_tensor)
249  table.init.run()
250  print(out.eval())
251  ```
252  """
253
254  def __init__(self, initializer, default_value, name=None):
255    """Creates a non-initialized `HashTable` object.
256
257    Creates a table, the type of its keys and values are specified by the
258    initializer.
259    Before using the table you will have to initialize it. After initialization
260    the table will be immutable.
261
262    Args:
263      initializer: The table initializer to use. See `HashTable` kernel for
264        supported key and value types.
265      default_value: The value to use if a key is missing in the table.
266      name: A name for the operation (optional).
267
268    Returns:
269      A `HashTable` object.
270    """
271    self._initializer = initializer
272    self._default_value = default_value
273    self._shared_name = self._initializer._shared_name  # pylint: disable=protected-access
274    self._name = name or "hash_table"
275    self._table_name = None
276    super(StaticHashTable, self).__init__(default_value, initializer)
277    self._value_shape = self._default_value.get_shape()
278
279  def _create_resource(self):
280    table_ref = gen_lookup_ops.hash_table_v2(
281        shared_name=self._shared_name,
282        key_dtype=self._initializer.key_dtype,
283        value_dtype=self._initializer.value_dtype,
284        name=self._name)
285    if context.executing_eagerly():
286      self._table_name = None
287    else:
288      self._table_name = table_ref.op.name.split("/")[-1]
289    return table_ref
290
291  @property
292  def name(self):
293    return self._table_name
294
295  def export(self, name=None):
296    """Returns tensors of all keys and values in the table.
297
298    Args:
299      name: A name for the operation (optional).
300
301    Returns:
302      A pair of tensors with the first tensor containing all keys and the
303        second tensors containing all values in the table.
304    """
305    with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]):
306      exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
307          self.resource_handle, self._key_dtype, self._value_dtype)
308
309    exported_values.set_shape(exported_keys.get_shape().concatenate(
310        self._value_shape))
311    return exported_keys, exported_values
312
313
314@tf_export(v1=["lookup.StaticHashTable"])
315class StaticHashTableV1(StaticHashTable):
316
317  @property
318  def initializer(self):
319    return self._init_op
320
321
322# For backwards compatibility. This will be removed in TF 2.0.
323class HashTable(StaticHashTableV1):
324
325  @property
326  def init(self):
327    return self.initializer
328
329
330class TableInitializerBase(trackable_base.Trackable):
331  """Base class for lookup table initializers."""
332
333  def __init__(self, key_dtype, value_dtype):
334    """Construct a table initializer object.
335
336    Args:
337      key_dtype: Type of the table keys.
338      value_dtype: Type of the table values.
339    """
340    self._key_dtype = dtypes.as_dtype(key_dtype)
341    self._value_dtype = dtypes.as_dtype(value_dtype)
342
343  @property
344  def key_dtype(self):
345    """The expected table key dtype."""
346    return self._key_dtype
347
348  @property
349  def value_dtype(self):
350    """The expected table value dtype."""
351    return self._value_dtype
352
353  def initialize(self, table):
354    """Returns the table initialization op."""
355    raise NotImplementedError
356
357  @property
358  def _shared_name(self):
359    """Returns a shared name to be used by the table."""
360    shared_name = ""
361    if context.executing_eagerly():
362      # Ensure a unique name when eager execution is enabled to avoid spurious
363      # sharing issues.
364      # TODO(rohanj): Use context.shared_name() instead.
365      shared_name += str(ops.uid())
366    return shared_name
367
368
369@tf_export("lookup.KeyValueTensorInitializer")
370class KeyValueTensorInitializer(TableInitializerBase):
371  """Table initializers given `keys` and `values` tensors."""
372
373  def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
374    """Constructs a table initializer object based on keys and values tensors.
375
376    Args:
377      keys: The tensor for the keys.
378      values: The tensor for the values.
379      key_dtype: The `keys` data type. Used when `keys` is a python array.
380      value_dtype: The `values` data type. Used when `values` is a python array.
381      name: A name for the operation (optional).
382    """
383    with ops.init_scope():
384      self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
385      self._values = ops.convert_to_tensor(
386          values, dtype=value_dtype, name="values")
387    self._name = name if name is not None else "key_value_init"
388    if context.executing_eagerly():
389      # Ensure a unique name when eager execution is enabled to avoid spurious
390      # sharing issues.
391      # TODO(rohanj): Use context.shared_name() instead.
392      self._name += str(ops.uid())
393
394    super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
395                                                    self._values.dtype)
396
397  def initialize(self, table):
398    """Initializes the given `table` with `keys` and `values` tensors.
399
400    Args:
401      table: The table to initialize.
402
403    Returns:
404      The operation that initializes the table.
405
406    Raises:
407      TypeError: when the keys and values data types do not match the table
408      key and value data types.
409    """
410    _check_table_dtypes(table, self._keys.dtype, self._values.dtype)
411    with ops.name_scope(
412        self._name, values=(table.resource_handle, self._keys, self._values)):
413      if fwd_compat.forward_compatible(2018, 9, 19):
414        init_op = gen_lookup_ops.lookup_table_import_v2(
415            table.resource_handle, self._keys, self._values)
416      else:
417        # To maintain forward compatibiltiy, use the old implementation.
418        init_op = gen_lookup_ops.initialize_table_v2(table.resource_handle,
419                                                     self._keys, self._values)
420    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
421    return init_op
422
423
424class TextFileIndex(object):
425  WHOLE_LINE = -2
426  LINE_NUMBER = -1
427
428
429@tf_export("lookup.TextFileInitializer")
430class TextFileInitializer(TableInitializerBase):
431  """Table initializers from a text file.
432
433  This initializer assigns one entry in the table for each line in the file.
434
435  The key and value type of the table to initialize is given by `key_dtype` and
436  `value_dtype`.
437
438  The key and value content to get from each line is specified by
439  the `key_index` and `value_index`.
440
441  * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
442    expects data type int64.
443  * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
444    type string.
445  * A value `>=0` means use the index (starting at zero) of the split line based
446      on `delimiter`.
447
448  For example if we have a file with the following content:
449
450  ```
451  emerson 10
452  lake 20
453  palmer 30
454  ```
455
456  The following snippet initializes a table with the first column as keys and
457  second column as values:
458
459  * `emerson -> 10`
460  * `lake -> 20`
461  * `palmer -> 30`
462
463  ```python
464  table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
465      "test.txt", tf.string, 0, tf.int64, 1, delimiter=" "), -1)
466  ...
467  table.init.run()
468  ```
469
470  Similarly to initialize the whole line as keys and the line number as values.
471
472  * `emerson 10 -> 0`
473  * `lake 20 -> 1`
474  * `palmer 30 -> 2`
475
476  ```python
477  table = tf.lookup.StaticHashTable(tf.lookup.TextFileInitializer(
478      "test.txt", tf.string, tf.lookup.TextFileIndex.WHOLE_LINE,
479      tf.int64, tf.lookup.TextFileIndex.LINE_NUMBER, delimiter=" "), -1)
480  ...
481  table.init.run()
482  ```
483  """
484
485  def __init__(self,
486               filename,
487               key_dtype,
488               key_index,
489               value_dtype,
490               value_index,
491               vocab_size=None,
492               delimiter="\t",
493               name=None):
494    """Constructs a table initializer object to populate from a text file.
495
496    It generates one key-value pair per line. The type of table key and
497    value are specified by `key_dtype` and `value_dtype`, respectively.
498    Similarly the content of the key and value are specified by the key_index
499    and value_index.
500
501    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
502      expects data type int64.
503    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
504      type string.
505    - A value >=0 means use the index (starting at zero) of the split line based
506      on `delimiter`.
507
508    Args:
509      filename: The filename of the text file to be used for initialization.
510        The path must be accessible from wherever the graph is initialized
511        (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
512      key_dtype: The `key` data type.
513      key_index: the index that represents information of a line to get the
514        table 'key' values from.
515      value_dtype: The `value` data type.
516      value_index: the index that represents information of a line to get the
517        table 'value' values from.'
518      vocab_size: The number of elements in the file, if known.
519      delimiter: The delimiter to separate fields in a line.
520      name: A name for the operation (optional).
521
522    Raises:
523      ValueError: when the filename is empty, or when the table key and value
524      data types do not match the expected data types.
525    """
526    if not isinstance(filename, ops.Tensor) and not filename:
527      raise ValueError("Filename required for %s." % name)
528
529    self._filename_arg = filename
530    key_dtype = dtypes.as_dtype(key_dtype)
531    value_dtype = dtypes.as_dtype(value_dtype)
532
533    if key_index < -2:
534      raise ValueError("Invalid key index %s." % (key_index))
535
536    if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
537      raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
538                       (dtypes.int64, key_dtype))
539    if ((key_index == TextFileIndex.WHOLE_LINE) and
540        (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
541      raise ValueError(
542          "Signature mismatch. Keys must be integer or string, got %s." %
543          key_dtype)
544    if value_index < -2:
545      raise ValueError("Invalid value index %s." % (value_index))
546
547    if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
548      raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
549                       (dtypes.int64, value_dtype))
550    if value_index == TextFileIndex.WHOLE_LINE and value_dtype != dtypes.string:
551      raise ValueError("Signature mismatch. Values must be dtype %s, got %s." %
552                       (dtypes.string, value_dtype))
553
554    if (vocab_size is not None) and (vocab_size <= 0):
555      raise ValueError("Invalid vocab_size %s." % vocab_size)
556
557    self._key_index = key_index
558    self._value_index = value_index
559    self._vocab_size = vocab_size
560    self._delimiter = delimiter
561    self._name = name
562    self._filename = self._track_trackable(
563        trackable.TrackableAsset(filename),
564        "_filename")
565
566    super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
567
568  def initialize(self, table):
569    """Initializes the table from a text file.
570
571    Args:
572      table: The table to be initialized.
573
574    Returns:
575      The operation that initializes the table.
576
577    Raises:
578      TypeError: when the keys and values data types do not match the table
579      key and value data types.
580    """
581    _check_table_dtypes(table, self.key_dtype, self.value_dtype)
582    with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
583      filename = ops.convert_to_tensor(
584          self._filename, dtypes.string, name="asset_filepath")
585      init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
586          table.resource_handle, filename, self._key_index, self._value_index,
587          -1 if self._vocab_size is None else self._vocab_size, self._delimiter)
588    ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
589    # If the filename tensor is anything other than a string constant (e.g.,
590    # if it is a placeholder) then it does not make sense to track it as an
591    # asset.
592    if not context.executing_eagerly() and constant_op.is_constant(filename):
593      ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
594    return init_op
595
596  @property
597  def _shared_name(self):
598    if self._vocab_size:
599      # Keep the shared_name:
600      # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>
601      shared_name = "hash_table_%s_%d_%s_%s" % (
602          self._filename_arg, self._vocab_size, self._key_index,
603          self._value_index)
604    else:
605      # Keep the shared_name
606      # <table_type>_<filename>_<key_index>_<value_index>
607      shared_name = "hash_table_%s_%s_%s" % (self._filename_arg,
608                                             self._key_index, self._value_index)
609    return shared_name
610
611
612class TextFileStringTableInitializer(TextFileInitializer):
613  """Table initializer for `int64` IDs to string tables from a text file."""
614
615  def __init__(self,
616               filename,
617               key_column_index=TextFileIndex.LINE_NUMBER,
618               value_column_index=TextFileIndex.WHOLE_LINE,
619               vocab_size=None,
620               delimiter="\t",
621               name="text_file_string_table_init"):
622    """Constructs an initializer for an id-to-string table from a text file.
623
624    It populates a table that its key and value types are int64 and string,
625    respectively. It generates one key-value pair per line.
626    The content of the key and value are specified by `key_column_index`
627    and `value_column_index`.
628
629    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
630      expects data type int64.
631    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
632      type string.
633    - A value >=0 means use the index (starting at zero) of the split line based
634      on `delimiter`.
635
636    Args:
637      filename: The filename of the text file to be used for initialization.
638        The path must be accessible from wherever the graph is initialized
639        (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
640      key_column_index: The column index from the text file to get the keys
641        from. The default is to use the line number, starting from zero.
642      value_column_index: The column index from the text file to get the
643        values from. The default is to use the whole line content.
644      vocab_size: The number of elements in the file, if known.
645      delimiter: The delimiter to separate fields in a line.
646      name: Optional name for the op.
647
648    Raises:
649      TypeError: when the filename is empty, or when the table key and value
650      data types do not match the expected data types.
651    """
652    super(TextFileStringTableInitializer, self).__init__(
653        filename,
654        dtypes.int64,
655        key_column_index,
656        dtypes.string,
657        value_column_index,
658        vocab_size=vocab_size,
659        delimiter=delimiter,
660        name=name)
661
662
663class TextFileIdTableInitializer(TextFileInitializer):
664  """Table initializer for string to `int64` IDs tables from a text file."""
665
666  def __init__(self,
667               filename,
668               key_column_index=TextFileIndex.WHOLE_LINE,
669               value_column_index=TextFileIndex.LINE_NUMBER,
670               vocab_size=None,
671               delimiter="\t",
672               name="text_file_id_table_init",
673               key_dtype=dtypes.string):
674    """Constructs an initializer for an string-to-id table from a text file.
675
676    It populates a table that its key and value types are string and int64,
677    respectively. It generates one key-value pair per line.
678    The content of the key and value are specified by the key_index
679    and value_index.
680
681    - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
682      expects data type int64.
683    - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
684      type string.
685    - A value >=0 means use the index (starting at zero) of the split line based
686      on `delimiter`.
687
688    Args:
689      filename: The filename of the text file to be used for initialization.
690        The path must be accessible from wherever the graph is initialized
691        (eg. trainer or eval workers). The filename may be a scalar `Tensor`.
692      key_column_index: The column index from the text file to get the `key`
693        values from. The default is to use the whole line content.
694      value_column_index: The column index from the text file to get the `value`
695        values from. The default is to use the line number, starting from zero.
696      vocab_size: The number of elements in the file, if known.
697      delimiter: The delimiter to separate fields in a line.
698      name: Optional name for the op.
699      key_dtype: The `key` data type.
700
701    Raises:
702      TypeError: when the filename is empty, or when the table key and value
703      data types do not match the expected data types.
704    """
705    super(TextFileIdTableInitializer, self).__init__(
706        filename,
707        key_dtype,
708        key_column_index,
709        dtypes.int64,
710        value_column_index,
711        vocab_size=vocab_size,
712        delimiter=delimiter,
713        name=name)
714
715
716class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
717  """A structure for the spec of the hashing function to use for hash buckets.
718
719  `hasher` is the name of the hashing function to use (eg. "fasthash",
720  "stronghash").
721  `key` is optional and specify the key to use for the hash function if
722  supported, currently only used by a strong hash.
723
724  Fields:
725    hasher: The hasher name to use.
726    key: The key to be used by the hashing function, if required.
727  """
728  __slots__ = ()
729
730
731FastHashSpec = HasherSpec("fasthash", None)  # pylint: disable=invalid-name
732
733
734class StrongHashSpec(HasherSpec):
735  """A structure to specify a key of the strong keyed hash spec.
736
737  The strong hash requires a `key`, which is a list of 2 unsigned integer
738  numbers. These should be non-zero; random numbers generated from random.org
739  would be a fine choice.
740
741  Fields:
742    key: The key to be used by the keyed hashing function.
743  """
744  __slots__ = ()
745
746  def __new__(cls, key):
747    if len(key) != 2:
748      raise ValueError("key must have size 2, got %s." % len(key))
749
750    if not isinstance(key[0], compat.integral_types) or not isinstance(
751        key[1], compat.integral_types):
752      raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
753
754    return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
755
756
757def _as_string(tensor):
758  if dtypes.string == tensor.dtype.base_dtype:
759    return tensor
760  return string_ops.as_string(tensor)
761
762
763class IdTableWithHashBuckets(LookupInterface):
764  """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
765
766  For example, if an instance of `IdTableWithHashBuckets` is initialized with a
767  string-to-id table that maps:
768
769  * `emerson -> 0`
770  * `lake -> 1`
771  * `palmer -> 2`
772
773  The `IdTableWithHashBuckets` object will performs the following mapping:
774
775  * `emerson -> 0`
776  * `lake -> 1`
777  * `palmer -> 2`
778  * `<other term> -> bucket_id`, where bucket_id will be between `3` and
779  `3 + num_oov_buckets - 1`, calculated by:
780  `hash(<term>) % num_oov_buckets + vocab_size`
781
782  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
783  the lookup result is `[0, 1, 2, 4, 7]`.
784
785  If `table` is None, only out-of-vocabulary buckets are used.
786
787  Example usage:
788
789  ```python
790  num_oov_buckets = 3
791  input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
792  table = tf.IdTableWithHashBuckets(
793      tf.StaticHashTable(tf.TextFileIdTableInitializer(filename),
794                         default_value),
795      num_oov_buckets)
796  out = table.lookup(input_tensor).
797  table.init.run()
798  print(out.eval())
799  ```
800
801  The hash function used for generating out-of-vocabulary buckets ID is handled
802  by `hasher_spec`.
803  """
804
805  def __init__(self,
806               table,
807               num_oov_buckets,
808               hasher_spec=FastHashSpec,
809               name=None,
810               key_dtype=None):
811    """Construct a `IdTableWithHashBuckets` object.
812
813    Args:
814      table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
815      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
816      hasher_spec: A `HasherSpec` to specify the hash function to use for
817        assignation of out-of-vocabulary buckets  (optional).
818      name: A name for the operation (optional).
819      key_dtype: Data type of keys passed to `lookup`. Defaults to
820        `table.key_dtype` if `table` is specified, otherwise `tf.string`.
821        Must be string or integer, and must be castable to `table.key_dtype`.
822
823    Raises:
824      ValueError: when `table` in None and `num_oov_buckets` is not positive.
825      TypeError: when `hasher_spec` is invalid.
826    """
827    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
828    # characters to use as table name.
829    if name:
830      name = name.rstrip("/")
831    if table:
832      if key_dtype is None:
833        key_dtype = table.key_dtype
834      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
835      if table.key_dtype not in supported_table_key_dtypes:
836        raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
837                        (supported_table_key_dtypes, key_dtype))
838      if table.key_dtype.is_integer != key_dtype.is_integer:
839        raise TypeError("Invalid key dtype, expected %s but got %s." %
840                        ("integer" if key_dtype.is_integer else "non-integer",
841                         table.key_dtype))
842      if table.value_dtype != dtypes.int64:
843        raise TypeError("Invalid value dtype, expected %s but got %s." %
844                        (dtypes.int64, table.value_dtype))
845      self._table = table
846      name = name or self._table.name
847    else:
848      if num_oov_buckets <= 0:
849        raise ValueError("oov_buckets must be > 0 if no table is supplied.")
850      key_dtype = dtypes.string if key_dtype is None else key_dtype
851      self._table = None
852      name = name or "hash_bucket"
853    if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
854      raise TypeError(
855          "Invalid key_dtype, expected integer or string, got %s." % key_dtype)
856    self._num_oov_buckets = num_oov_buckets
857
858    if not isinstance(hasher_spec, HasherSpec):
859      raise TypeError(
860          "hasher_spec must be of type HasherSpec, got %s" % hasher_spec)
861    self._hasher_spec = hasher_spec
862    if name:
863      self._table_name = name.split("/")[-1]
864    else:
865      self._table_name = None
866    super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64)
867
868  def _create_resource(self):
869    if self._table is not None:
870      return self._table._create_resource()  # pylint: disable=protected-access
871    return None
872
873  def _initialize(self):
874    if self._table is not None:
875      return self._table._initialize()  # pylint: disable=protected-access
876    with ops.name_scope(None, "init"):
877      return control_flow_ops.no_op()
878
879  @property
880  def initializer(self):
881    if self._table is not None:
882      return self._table._init_op  # pylint: disable=protected-access
883    with ops.name_scope(None, "init"):
884      return control_flow_ops.no_op()
885
886  @property
887  @deprecated("2018-12-15", "Use `initializer` instead.")
888  def init(self):
889    return self.initializer
890
891  @property
892  def resource_handle(self):
893    if self._table is not None:
894      return self._table.resource_handle
895    return None
896
897  @property
898  def name(self):
899    return self._table_name
900
901  def size(self, name=None):
902    """Compute the number of elements in this table."""
903    with ops.name_scope(name, "%s_Size" % self.name):
904      if self._table:
905        tsize = self._table.size()
906      else:
907        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
908      return tsize + self._num_oov_buckets
909
910  def _get_string_to_hash_bucket_fn(self, hasher_spec):
911    """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
912    if not isinstance(hasher_spec, HasherSpec):
913      raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
914    if hasher_spec.hasher == "fasthash":
915      return string_ops.string_to_hash_bucket_fast
916    if hasher_spec.hasher == "legacy":
917      return string_ops.string_to_hash_bucket
918    if hasher_spec.hasher == "stronghash":
919      return functools.partial(
920          string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
921    raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
922
923  def lookup(self, keys, name=None):
924    """Looks up `keys` in the table, outputs the corresponding values.
925
926    It assigns out-of-vocabulary keys to buckets based in their hashes.
927
928    Args:
929      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
930      name: Optional name for the op.
931
932    Returns:
933      A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
934
935    Raises:
936      TypeError: when `keys` doesn't match the table key data type.
937    """
938    if keys.dtype.base_dtype != self._key_dtype:
939      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
940                      (self._key_dtype, keys.dtype))
941    values = keys
942    if isinstance(keys, sparse_tensor.SparseTensor):
943      values = keys.values
944    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
945      values = math_ops.cast(values, dtypes.int64)
946
947    if self._num_oov_buckets == 0:
948      ids = self._table.lookup(values, name=name)
949    else:
950      # TODO(yleon): Consider moving this functionality to its own kernel.
951      with ops.name_scope(name, "%s_Lookup" % self.name):
952        str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
953            self._hasher_spec)
954        buckets = str_to_hash_bucket(
955            _as_string(values),
956            num_buckets=self._num_oov_buckets,
957            name="hash_bucket")
958        if self._table:
959          ids = self._table.lookup(values)
960          buckets = math_ops.add(buckets, self._table.size())
961          is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
962          ids = array_ops.where(is_id_non_default, ids, buckets)
963        else:
964          ids = buckets
965    if isinstance(keys, sparse_tensor.SparseTensor):
966      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
967    return ids
968
969
970@tf_export("lookup.StaticVocabularyTable", v1=[])
971class StaticVocabularyTable(LookupInterface):
972  """String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
973
974  For example, if an instance of `StaticVocabularyTable` is initialized with a
975  string-to-id initializer that maps:
976
977  * `emerson -> 0`
978  * `lake -> 1`
979  * `palmer -> 2`
980
981  The `Vocabulary` object will performs the following mapping:
982
983  * `emerson -> 0`
984  * `lake -> 1`
985  * `palmer -> 2`
986  * `<other term> -> bucket_id`, where bucket_id will be between `3` and
987  `3 + num_oov_buckets - 1`, calculated by:
988  `hash(<term>) % num_oov_buckets + vocab_size`
989
990  If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
991  the lookup result is `[0, 1, 2, 4, 7]`.
992
993  If `initializer` is None, only out-of-vocabulary buckets are used.
994
995  Example usage:
996
997  ```python
998  num_oov_buckets = 3
999  input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
1000  table = tf.lookup.StaticVocabularyTable(
1001      tf.TextFileIdTableInitializer(filename), num_oov_buckets)
1002  out = table.lookup(input_tensor).
1003  table.init.run()
1004  print(out.eval())
1005  ```
1006
1007  The hash function used for generating out-of-vocabulary buckets ID is
1008  Fingerprint64.
1009  """
1010
1011  def __init__(self,
1012               initializer,
1013               num_oov_buckets,
1014               lookup_key_dtype=None,
1015               name=None):
1016    """Construct a `StaticVocabularyTable` object.
1017
1018    Args:
1019      initializer: A TableInitializerBase object that contains the data used to
1020        initialize the table. If None, then we only use out-of-vocab buckets.
1021      num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must
1022        be greater than zero.
1023      lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to
1024        `initializer.key_dtype` if `initializer` is specified, otherwise
1025        `tf.string`. Must be string or integer, and must be castable to
1026        `initializer.key_dtype`.
1027      name: A name for the operation (optional).
1028
1029    Raises:
1030      ValueError: when `num_oov_buckets` is not positive.
1031      TypeError: when lookup_key_dtype or initializer.key_dtype are not
1032        integer or string. Also when initializer.value_dtype != int64.
1033    """
1034    if num_oov_buckets <= 0:
1035      raise ValueError("oov_buckets must be > 0.")
1036    # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1037    # characters to use as table name.
1038    if name:
1039      name = name.rstrip("/")
1040    if initializer:
1041      if lookup_key_dtype is None:
1042        lookup_key_dtype = initializer.key_dtype
1043      supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1044      if initializer.key_dtype not in supported_table_key_dtypes:
1045        raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
1046                        (supported_table_key_dtypes, initializer.key_dtype))
1047      if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer:
1048        raise TypeError(
1049            "Invalid key dtype, expected %s but got %s." %
1050            ("integer" if lookup_key_dtype.is_integer else "non-integer",
1051             initializer.key_dtype))
1052      if initializer.value_dtype != dtypes.int64:
1053        raise TypeError("Invalid value dtype, expected %s but got %s." %
1054                        (dtypes.int64, initializer.value_dtype))
1055      self._table = HashTable(initializer, default_value=-1)
1056      name = name or self._table.name
1057    else:
1058      lookup_key_dtype = dtypes.string
1059      self._table = None
1060      name = name or "hash_bucket"
1061    if (not lookup_key_dtype.is_integer) and (dtypes.string !=
1062                                              lookup_key_dtype):
1063      raise TypeError("Invalid key_dtype, expected integer or string, got %s." %
1064                      lookup_key_dtype)
1065    self._num_oov_buckets = num_oov_buckets
1066
1067    self._table_name = None
1068    if name is not None:
1069      self._table_name = name.split("/")[-1]
1070    super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64)
1071
1072  def _create_resource(self):
1073    if self._table is not None:
1074      return self._table._create_resource()  # pylint: disable=protected-access
1075    return None
1076
1077  def _initialize(self):
1078    if self._table is not None:
1079      return self._table._initialize()  # pylint: disable=protected-access
1080    with ops.name_scope(None, "init"):
1081      return control_flow_ops.no_op()
1082
1083  @property
1084  def resource_handle(self):
1085    if self._table is not None:
1086      return self._table.resource_handle
1087    return None
1088
1089  @property
1090  def name(self):
1091    return self._table_name
1092
1093  def size(self, name=None):
1094    """Compute the number of elements in this table."""
1095    with ops.name_scope(name, "%s_Size" % self.name):
1096      if self._table:
1097        tsize = self._table.size()
1098      else:
1099        tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1100      return tsize + self._num_oov_buckets
1101
1102  def lookup(self, keys, name=None):
1103    """Looks up `keys` in the table, outputs the corresponding values.
1104
1105    It assigns out-of-vocabulary keys to buckets based in their hashes.
1106
1107    Args:
1108      keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1109      name: Optional name for the op.
1110
1111    Returns:
1112      A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
1113
1114    Raises:
1115      TypeError: when `keys` doesn't match the table key data type.
1116    """
1117    if keys.dtype.base_dtype != self._key_dtype:
1118      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
1119                      (self._key_dtype, keys.dtype))
1120    values = keys
1121    if isinstance(keys, sparse_tensor.SparseTensor):
1122      values = keys.values
1123    if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1124      values = math_ops.cast(values, dtypes.int64)
1125
1126    # TODO(yleon): Consider moving this functionality to its own kernel.
1127    with ops.name_scope(name, "%s_Lookup" % self.name):
1128      buckets = string_ops.string_to_hash_bucket_fast(
1129          _as_string(values),
1130          num_buckets=self._num_oov_buckets,
1131          name="hash_bucket")
1132      if self._table:
1133        ids = self._table.lookup(values)
1134        buckets = math_ops.add(buckets, self._table.size())
1135        is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1136        ids = array_ops.where(is_id_non_default, ids, buckets)
1137      else:
1138        ids = buckets
1139    if isinstance(keys, sparse_tensor.SparseTensor):
1140      return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1141    return ids
1142
1143
1144@tf_export(v1=["lookup.StaticVocabularyTable"])
1145class StaticVocabularyTableV1(StaticVocabularyTable):
1146
1147  @property
1148  def initializer(self):
1149    if self._table is not None:
1150      return self._table._init_op  # pylint: disable=protected-access
1151    with ops.name_scope(None, "init"):
1152      return control_flow_ops.no_op()
1153
1154
1155def index_table_from_file(vocabulary_file=None,
1156                          num_oov_buckets=0,
1157                          vocab_size=None,
1158                          default_value=-1,
1159                          hasher_spec=FastHashSpec,
1160                          key_dtype=dtypes.string,
1161                          name=None,
1162                          key_column_index=TextFileIndex.WHOLE_LINE,
1163                          value_column_index=TextFileIndex.LINE_NUMBER,
1164                          delimiter="\t"):
1165  """Returns a lookup table that converts a string tensor into int64 IDs.
1166
1167  This operation constructs a lookup table to convert tensor of strings into
1168  int64 IDs. The mapping can be initialized from a vocabulary file specified in
1169  `vocabulary_file`, where the whole line is the key and the zero-based line
1170  number is the ID.
1171
1172  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1173  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1174  `default_value`.
1175  The bucket ID range is
1176  `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
1177
1178  The underlying table must be initialized by calling
1179  `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
1180
1181  To specify multi-column vocabulary files, use key_column_index and
1182  value_column_index and delimiter.
1183
1184  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1185    expects data type int64.
1186  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1187    type string.
1188  - A value >=0 means use the index (starting at zero) of the split line based
1189    on `delimiter`.
1190
1191  Sample Usages:
1192
1193  If we have a vocabulary file "test.txt" with the following content:
1194
1195  ```
1196  emerson
1197  lake
1198  palmer
1199  ```
1200
1201  ```python
1202  features = tf.constant(["emerson", "lake", "and", "palmer"])
1203  table = tf.lookup.index_table_from_file(
1204      vocabulary_file="test.txt", num_oov_buckets=1)
1205  ids = table.lookup(features)
1206  ...
1207  tf.tables_initializer().run()
1208
1209  ids.eval()  ==> [0, 1, 3, 2]  # where 3 is the out-of-vocabulary bucket
1210  ```
1211
1212  Args:
1213    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1214    num_oov_buckets: The number of out-of-vocabulary buckets.
1215    vocab_size: Number of the elements in the vocabulary, if known.
1216    default_value: The value to use for out-of-vocabulary feature values.
1217      Defaults to -1.
1218    hasher_spec: A `HasherSpec` to specify the hash function to use for
1219      assignation of out-of-vocabulary buckets.
1220    key_dtype: The `key` data type.
1221    name: A name for this op (optional).
1222    key_column_index: The column index from the text file to get the `key`
1223      values from. The default is to use the whole line content.
1224    value_column_index: The column index from the text file to get the `value`
1225      values from. The default is to use the line number, starting from zero.
1226    delimiter: The delimiter to separate fields in a line.
1227
1228  Returns:
1229    The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
1230
1231  Raises:
1232    ValueError: If `vocabulary_file` is not set.
1233    ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
1234      than zero.
1235  """
1236  if vocabulary_file is None or (
1237      isinstance(vocabulary_file, six.string_types) and not vocabulary_file):
1238    raise ValueError("vocabulary_file must be specified and must not be empty.")
1239  if num_oov_buckets < 0:
1240    raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
1241                     % num_oov_buckets)
1242  if vocab_size is not None and vocab_size < 1:
1243    vocab_file_value = vocabulary_file
1244    if isinstance(vocabulary_file, ops.Tensor):
1245      vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
1246    raise ValueError("vocab_size must be greater than 0, got %d. "
1247                     "vocabulary_file: %s" % (vocab_size, vocab_file_value))
1248  if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
1249    raise TypeError("Only integer and string keys are supported.")
1250
1251  with ops.name_scope(name, "string_to_index"):
1252    table = None
1253    with ops.name_scope(None, "hash_table"):
1254      init = TextFileIdTableInitializer(
1255          vocabulary_file,
1256          vocab_size=vocab_size,
1257          key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
1258          name="table_init",
1259          key_column_index=key_column_index,
1260          value_column_index=value_column_index,
1261          delimiter=delimiter)
1262
1263      table = StaticHashTableV1(init, default_value)
1264    if num_oov_buckets:
1265      table = IdTableWithHashBuckets(
1266          table,
1267          num_oov_buckets=num_oov_buckets,
1268          hasher_spec=hasher_spec,
1269          key_dtype=key_dtype)
1270
1271    return table
1272
1273
1274def index_table_from_tensor(vocabulary_list,
1275                            num_oov_buckets=0,
1276                            default_value=-1,
1277                            hasher_spec=FastHashSpec,
1278                            dtype=dtypes.string,
1279                            name=None):
1280  """Returns a lookup table that converts a string tensor into int64 IDs.
1281
1282  This operation constructs a lookup table to convert tensor of strings into
1283  int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
1284  tensor where each element is a key and corresponding index within the tensor
1285  is the value.
1286
1287  Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1288  hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1289  `default_value`. The bucket ID range is
1290  `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
1291
1292  The underlying table must be initialized by calling
1293  `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
1294
1295  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1296  the table initializer op, it will throw a `FailedPreconditionError`.
1297
1298  Sample Usages:
1299
1300  ```python
1301  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1302  table = tf.lookup.index_table_from_tensor(
1303      vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
1304  features = tf.constant(["emerson", "lake", "and", "palmer"])
1305  ids = table.lookup(features)
1306  ...
1307  tf.tables_initializer().run()
1308
1309  ids.eval()  ==> [0, 1, 4, 2]
1310  ```
1311
1312  Args:
1313    vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
1314      indices. The type of this object must be castable to `dtype`.
1315    num_oov_buckets: The number of out-of-vocabulary buckets.
1316    default_value: The value to use for out-of-vocabulary feature values.
1317      Defaults to -1.
1318    hasher_spec: A `HasherSpec` to specify the hash function to use for
1319      assignment of out-of-vocabulary buckets.
1320    dtype: The type of values passed to `lookup`. Only string and integers are
1321      supported.
1322    name: A name for this op (optional).
1323
1324  Returns:
1325    The lookup table to map an input `Tensor` to index `int64` `Tensor`.
1326
1327  Raises:
1328    ValueError: If `vocabulary_list` is invalid.
1329    ValueError: If `num_oov_buckets` is negative.
1330  """
1331  if vocabulary_list is None:
1332    raise ValueError("vocabulary_list must be specified.")
1333
1334  if num_oov_buckets < 0:
1335    raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
1336                     % num_oov_buckets)
1337
1338  if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
1339    raise TypeError("Only integer and string keys are supported.")
1340
1341  with ops.name_scope(name, "string_to_index"):
1342    keys = ops.convert_to_tensor(vocabulary_list)
1343    if keys.dtype.is_integer != dtype.is_integer:
1344      raise ValueError("Expected %s, got %s." %
1345                       ("integer"
1346                        if dtype.is_integer else "non-integer", keys.dtype))
1347    if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
1348      raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
1349    num_elements = array_ops.size(keys)
1350    values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1351
1352    with ops.name_scope(None, "hash_table"):
1353      table_keys = math_ops.cast(
1354          keys, dtypes.int64) if keys.dtype.is_integer else keys
1355      init = KeyValueTensorInitializer(
1356          table_keys,
1357          values,
1358          table_keys.dtype.base_dtype,
1359          dtypes.int64,
1360          name="table_init")
1361      table = StaticHashTableV1(init, default_value)
1362    if num_oov_buckets:
1363      table = IdTableWithHashBuckets(
1364          table,
1365          num_oov_buckets=num_oov_buckets,
1366          hasher_spec=hasher_spec,
1367          key_dtype=dtype)
1368    return table
1369
1370
1371def index_to_string_table_from_file(vocabulary_file,
1372                                    vocab_size=None,
1373                                    default_value="UNK",
1374                                    name=None,
1375                                    key_column_index=TextFileIndex.LINE_NUMBER,
1376                                    value_column_index=TextFileIndex.WHOLE_LINE,
1377                                    delimiter="\t"):
1378  """Returns a lookup table that maps a `Tensor` of indices into strings.
1379
1380  This operation constructs a lookup table to map int64 indices into string
1381  values. The table is initialized from a vocabulary file specified in
1382  `vocabulary_file`, where the whole line is the value and the
1383  zero-based line number is the index.
1384
1385  Any input which does not have a corresponding index in the vocabulary file
1386  (an out-of-vocabulary entry) is assigned the `default_value`
1387
1388  The underlying table must be initialized by calling
1389  `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
1390
1391  To specify multi-column vocabulary files, use key_column_index and
1392  value_column_index and delimiter.
1393
1394  - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1395    expects data type int64.
1396  - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1397    type string.
1398  - A value >=0 means use the index (starting at zero) of the split line based
1399    on `delimiter`.
1400
1401  Sample Usages:
1402
1403  If we have a vocabulary file "test.txt" with the following content:
1404
1405  ```
1406  emerson
1407  lake
1408  palmer
1409  ```
1410
1411  ```python
1412  indices = tf.constant([1, 5], tf.int64)
1413  table = tf.lookup.index_to_string_table_from_file(
1414      vocabulary_file="test.txt", default_value="UNKNOWN")
1415  values = table.lookup(indices)
1416  ...
1417  tf.tables_initializer().run()
1418
1419  values.eval() ==> ["lake", "UNKNOWN"]
1420  ```
1421
1422  Args:
1423    vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1424    vocab_size: Number of the elements in the vocabulary, if known.
1425    default_value: The value to use for out-of-vocabulary indices.
1426    name: A name for this op (optional).
1427    key_column_index: The column index from the text file to get the `key`
1428      values from. The default is to use the line number, starting from zero.
1429    value_column_index: The column index from the text file to get the `value`
1430      values from. The default is to use the whole line content.
1431    delimiter: The delimiter to separate fields in a line.
1432
1433  Returns:
1434    The lookup table to map a string values associated to a given index `int64`
1435    `Tensors`.
1436
1437  Raises:
1438    ValueError: when `vocabulary_file` is empty.
1439    ValueError: when `vocab_size` is invalid.
1440  """
1441  if vocabulary_file is None or (
1442      isinstance(vocabulary_file, six.string_types) and not vocabulary_file):
1443    raise ValueError("vocabulary_file must be specified and must not be empty.")
1444
1445  if vocab_size is not None and vocab_size < 1:
1446    raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
1447
1448  with ops.name_scope(name, "index_to_string"):
1449    init = TextFileStringTableInitializer(
1450        vocabulary_file,
1451        vocab_size=vocab_size,
1452        name="table_init",
1453        key_column_index=key_column_index,
1454        value_column_index=value_column_index,
1455        delimiter=delimiter)
1456
1457    # TODO(yleon): Use a more effienct structure.
1458    return StaticHashTableV1(init, default_value)
1459
1460
1461def index_to_string_table_from_tensor(vocabulary_list,
1462                                      default_value="UNK",
1463                                      name=None):
1464  """Returns a lookup table that maps a `Tensor` of indices into strings.
1465
1466  This operation constructs a lookup table to map int64 indices into string
1467  values. The mapping is initialized from a string `vocabulary_list` 1-D
1468  `Tensor` where each element is a value and the corresponding index within the
1469  tensor is the key.
1470
1471  Any input which does not have a corresponding index in 'vocabulary_list'
1472  (an out-of-vocabulary entry) is assigned the `default_value`
1473
1474  The underlying table must be initialized by calling
1475  `session.run(tf.tables_initializer)` or `session.run(table.init)` once.
1476
1477  Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1478  the table initializer op, it will throw a `FailedPreconditionError`.
1479
1480  Sample Usages:
1481
1482  ```python
1483  vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1484  indices = tf.constant([1, 5], tf.int64)
1485  table = tf.lookup.index_to_string_table_from_tensor(
1486      vocabulary_list, default_value="UNKNOWN")
1487  values = table.lookup(indices)
1488  ...
1489  tf.tables_initializer().run()
1490
1491  values.eval() ==> ["lake", "UNKNOWN"]
1492  ```
1493
1494  Args:
1495    vocabulary_list: A 1-D string `Tensor` that specifies the strings to map
1496      from indices.
1497    default_value: The value to use for out-of-vocabulary indices.
1498    name: A name for this op (optional).
1499
1500  Returns:
1501    The lookup table to map a string values associated to a given index `int64`
1502    `Tensors`.
1503
1504  Raises:
1505    ValueError: when `vocabulary_list` is not set.
1506  """
1507
1508  if vocabulary_list is None:
1509    raise ValueError("vocabulary_list must be specified.")
1510
1511  with ops.name_scope(name, "index_to_string"):
1512    vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string)
1513    num_elements = array_ops.size(vocabulary_list)
1514    keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1515
1516    init = KeyValueTensorInitializer(
1517        keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
1518    # TODO(yleon): Use a more effienct structure.
1519    return StaticHashTableV1(init, default_value)
1520
1521
1522class MutableHashTable(LookupInterface):
1523  """A generic mutable hash table implementation.
1524
1525  Data can be inserted by calling the insert method and removed by calling the
1526  remove method. It does not support initialization via the init method.
1527
1528  Example usage:
1529
1530  ```python
1531  table = tf.lookup.MutableHashTable(key_dtype=tf.string, value_dtype=tf.int64,
1532                                     default_value=-1)
1533  sess.run(table.insert(keys, values))
1534  out = table.lookup(query_keys)
1535  print(out.eval())
1536  ```
1537  """
1538
1539  def __init__(self,
1540               key_dtype,
1541               value_dtype,
1542               default_value,
1543               name="MutableHashTable",
1544               checkpoint=True):
1545    """Creates an empty `MutableHashTable` object.
1546
1547    Creates a table, the type of its keys and values are specified by key_dtype
1548    and value_dtype, respectively.
1549
1550    Args:
1551      key_dtype: the type of the key tensors.
1552      value_dtype: the type of the value tensors.
1553      default_value: The value to use if a key is missing in the table.
1554      name: A name for the operation (optional).
1555      checkpoint: if True, the contents of the table are saved to and restored
1556        from checkpoints. If `shared_name` is empty for a checkpointed table, it
1557        is shared using the table node name.
1558
1559    Returns:
1560      A `MutableHashTable` object.
1561
1562    Raises:
1563      ValueError: If checkpoint is True and no name was specified.
1564    """
1565    self._default_value = ops.convert_to_tensor(
1566        default_value, dtype=value_dtype)
1567    self._value_shape = self._default_value.get_shape()
1568    self._checkpoint = checkpoint
1569    self._key_dtype = key_dtype
1570    self._value_dtype = value_dtype
1571    self._name = name
1572
1573    self._shared_name = None
1574    if context.executing_eagerly():
1575      # TODO(allenl): This will leak memory due to kernel caching by the
1576      # shared_name attribute value (but is better than the alternative of
1577      # sharing everything by default when executing eagerly; hopefully creating
1578      # tables in a loop is uncommon).
1579      # TODO(rohanj): Use context.shared_name() instead.
1580      self._shared_name = "table_%d" % (ops.uid(),)
1581    super(MutableHashTable, self).__init__(key_dtype, value_dtype)
1582
1583    self._resource_handle = self._create_resource()
1584    if checkpoint:
1585      saveable = MutableHashTable._Saveable(self, name)
1586      if not context.executing_eagerly():
1587        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
1588
1589  def _create_resource(self):
1590    # The table must be shared if checkpointing is requested for multi-worker
1591    # training to work correctly. Use the node name if no shared_name has been
1592    # explicitly specified.
1593    use_node_name_sharing = self._checkpoint and self._shared_name is None
1594    if self._default_value.get_shape().ndims == 0:
1595      table_ref = gen_lookup_ops.mutable_hash_table_v2(
1596          shared_name=self._shared_name,
1597          use_node_name_sharing=use_node_name_sharing,
1598          key_dtype=self._key_dtype,
1599          value_dtype=self._value_dtype,
1600          name=self._name)
1601    else:
1602      table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
1603          shared_name=self._shared_name,
1604          use_node_name_sharing=use_node_name_sharing,
1605          key_dtype=self._key_dtype,
1606          value_dtype=self._value_dtype,
1607          value_shape=self._default_value.get_shape(),
1608          name=self._name)
1609
1610    if context.executing_eagerly():
1611      self._table_name = None
1612    else:
1613      self._table_name = table_ref.op.name.split("/")[-1]
1614    return table_ref
1615
1616  @property
1617  def name(self):
1618    return self._table_name
1619
1620  def size(self, name=None):
1621    """Compute the number of elements in this table.
1622
1623    Args:
1624      name: A name for the operation (optional).
1625
1626    Returns:
1627      A scalar tensor containing the number of elements in this table.
1628    """
1629    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
1630      with ops.colocate_with(self.resource_handle):
1631        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
1632
1633  def remove(self, keys, name=None):
1634    """Removes `keys` and its associated values from the table.
1635
1636    If a key is not present in the table, it is silently ignored.
1637
1638    Args:
1639      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1640        key type.
1641      name: A name for the operation (optional).
1642
1643    Returns:
1644      The created Operation.
1645
1646    Raises:
1647      TypeError: when `keys` do not match the table data types.
1648    """
1649    if keys.dtype != self._key_dtype:
1650      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
1651                      (self._key_dtype, keys.dtype))
1652
1653    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
1654                        (self.resource_handle, keys, self._default_value)):
1655      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
1656
1657    return op
1658
1659  def lookup(self, keys, name=None):
1660    """Looks up `keys` in a table, outputs the corresponding values.
1661
1662    The `default_value` is used for keys not present in the table.
1663
1664    Args:
1665      keys: Keys to look up. Can be a tensor of any shape. Must match the
1666        table's key_dtype.
1667      name: A name for the operation (optional).
1668
1669    Returns:
1670      A tensor containing the values in the same shape as `keys` using the
1671        table's value type.
1672
1673    Raises:
1674      TypeError: when `keys` do not match the table data types.
1675    """
1676    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
1677                        (self.resource_handle, keys, self._default_value)):
1678      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1679      with ops.colocate_with(self.resource_handle):
1680        values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
1681                                                     self._default_value)
1682    return values
1683
1684  def insert(self, keys, values, name=None):
1685    """Associates `keys` with `values`.
1686
1687    Args:
1688      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
1689        key type.
1690      values: Values to be associated with keys. Must be a tensor of the same
1691        shape as `keys` and match the table's value type.
1692      name: A name for the operation (optional).
1693
1694    Returns:
1695      The created Operation.
1696
1697    Raises:
1698      TypeError: when `keys` or `values` doesn't match the table data
1699        types.
1700    """
1701    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
1702                        [self.resource_handle, keys, values]):
1703      keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
1704      values = ops.convert_to_tensor(values, self._value_dtype, name="values")
1705      with ops.colocate_with(self.resource_handle):
1706        # pylint: disable=protected-access
1707        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
1708                                                   values)
1709    return op
1710
1711  def export(self, name=None):
1712    """Returns tensors of all keys and values in the table.
1713
1714    Args:
1715      name: A name for the operation (optional).
1716
1717    Returns:
1718      A pair of tensors with the first tensor containing all keys and the
1719        second tensors containing all values in the table.
1720    """
1721    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
1722                        [self.resource_handle]):
1723      with ops.colocate_with(self.resource_handle):
1724        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
1725            self.resource_handle, self._key_dtype, self._value_dtype)
1726    return exported_keys, exported_values
1727
1728  def _gather_saveables_for_checkpoint(self):
1729    """For object-based checkpointing."""
1730    return {"table": functools.partial(MutableHashTable._Saveable, table=self)}
1731
1732  class _Saveable(BaseSaverBuilder.SaveableObject):
1733    """SaveableObject implementation for MutableHashTable."""
1734
1735    def __init__(self, table, name):
1736      tensors = table.export()
1737      specs = [
1738          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
1739          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
1740      ]
1741      # pylint: disable=protected-access
1742      super(MutableHashTable._Saveable, self).__init__(table, specs, name)
1743
1744    def restore(self, restored_tensors, restored_shapes, name=None):
1745      del restored_shapes  # unused
1746      # pylint: disable=protected-access
1747      with ops.name_scope(name, "%s_table_restore" % self.name):
1748        with ops.colocate_with(self.op.resource_handle):
1749          return gen_lookup_ops.lookup_table_import_v2(
1750              self.op.resource_handle, restored_tensors[0], restored_tensors[1])
1751
1752
1753@tf_export("lookup.experimental.DenseHashTable")
1754class DenseHashTable(LookupInterface):
1755  """A generic mutable hash table implementation using tensors as backing store.
1756
1757  Data can be inserted by calling the insert method and removed by calling the
1758  remove method. It does not support initialization via the init method.
1759
1760  It uses "open addressing" with quadratic reprobing to resolve collisions.
1761  Compared to `MutableHashTable` the insert, remove and lookup operations in a
1762  `DenseHashTable` are typically faster, but memory usage can be higher.
1763  However, `DenseHashTable` does not require additional memory for
1764  temporary tensors created during checkpointing and restore operations.
1765
1766  Example usage:
1767
1768  ```python
1769  table = tf.lookup.DenseHashTable(key_dtype=tf.int64,
1770                                   value_dtype=tf.int64,
1771                                   default_value=-1,
1772                                   empty_key=0,
1773                                   deleted_key=-1)
1774
1775  sess.run(table.insert(keys, values))
1776  out = table.lookup(query_keys)
1777  print(out.eval())
1778  ```
1779  """
1780
1781  # TODO(andreasst): consider extracting common code with MutableHashTable into
1782  # a common superclass.
1783  def __init__(self,
1784               key_dtype,
1785               value_dtype,
1786               default_value,
1787               empty_key,
1788               deleted_key,
1789               initial_num_buckets=None,
1790               name="MutableDenseHashTable",
1791               checkpoint=True):
1792    """Creates an empty `DenseHashTable` object.
1793
1794    Creates a table, the type of its keys and values are specified by key_dtype
1795    and value_dtype, respectively.
1796
1797    Args:
1798      key_dtype: the type of the key tensors.
1799      value_dtype: the type of the value tensors.
1800      default_value: The value to use if a key is missing in the table.
1801      empty_key: the key to use to represent empty buckets internally. Must not
1802        be used in insert, remove or lookup operations.
1803      deleted_key: the key to use to represent deleted buckets internally. Must
1804        not be used in insert, remove or lookup operations and be different from
1805        the empty_key.
1806      initial_num_buckets: the initial number of buckets.
1807      name: A name for the operation (optional).
1808      checkpoint: if True, the contents of the table are saved to and restored
1809        from checkpoints. If `shared_name` is empty for a checkpointed table, it
1810        is shared using the table node name.
1811
1812    Returns:
1813      A `DenseHashTable` object.
1814
1815    Raises:
1816      ValueError: If checkpoint is True and no name was specified.
1817    """
1818    self._default_value = ops.convert_to_tensor(
1819        default_value, dtype=value_dtype, name="default_value")
1820    self._key_dtype = key_dtype
1821    self._value_dtype = value_dtype
1822    self._initial_num_buckets = initial_num_buckets
1823    self._value_shape = self._default_value.get_shape()
1824    self._checkpoint = checkpoint
1825    self._name = name
1826
1827    self._empty_key = ops.convert_to_tensor(
1828        empty_key, dtype=key_dtype, name="empty_key")
1829    self._deleted_key = ops.convert_to_tensor(
1830        deleted_key, dtype=key_dtype, name="deleted_key")
1831    self._shared_name = None
1832    if context.executing_eagerly():
1833      # TODO(allenl): This will leak memory due to kernel caching by the
1834      # shared_name attribute value (but is better than the alternative of
1835      # sharing everything by default when executing eagerly; hopefully creating
1836      # tables in a loop is uncommon).
1837      # TODO(rohanj): Use context.shared_name() instead.
1838      self._shared_name = "table_%d" % (ops.uid(),)
1839    super(DenseHashTable, self).__init__(key_dtype, value_dtype)
1840
1841    self._resource_handle = self._create_resource()
1842    if checkpoint:
1843      saveable = DenseHashTable._Saveable(self, name)
1844      if not context.executing_eagerly():
1845        ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
1846
1847  def _create_resource(self):
1848    # The table must be shared if checkpointing is requested for multi-worker
1849    # training to work correctly. Use the node name if no shared_name has been
1850    # explicitly specified.
1851    use_node_name_sharing = self._checkpoint and self._shared_name is None
1852    table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
1853        empty_key=self._empty_key,
1854        deleted_key=self._deleted_key,
1855        shared_name=self._shared_name,
1856        use_node_name_sharing=use_node_name_sharing,
1857        value_dtype=self._value_dtype,
1858        value_shape=self._value_shape,
1859        initial_num_buckets=self._initial_num_buckets,
1860        name=self._name)
1861    if context.executing_eagerly():
1862      self._table_name = None
1863    else:
1864      self._table_name = table_ref.op.name.split("/")[-1]
1865    return table_ref
1866
1867  @property
1868  def name(self):
1869    return self._table_name
1870
1871  def size(self, name=None):
1872    """Compute the number of elements in this table.
1873
1874    Args:
1875      name: A name for the operation (optional).
1876
1877    Returns:
1878      A scalar tensor containing the number of elements in this table.
1879    """
1880    with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
1881      with ops.colocate_with(self.resource_handle):
1882        return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
1883
1884  def lookup(self, keys, name=None):
1885    """Looks up `keys` in a table, outputs the corresponding values.
1886
1887    The `default_value` is used for keys not present in the table.
1888
1889    Args:
1890      keys: Keys to look up. Can be a tensor of any shape. Must match the
1891        table's key_dtype.
1892      name: A name for the operation (optional).
1893
1894    Returns:
1895      A tensor containing the values in the same shape as `keys` using the
1896        table's value type.
1897
1898    Raises:
1899      TypeError: when `keys` do not match the table data types.
1900    """
1901    with ops.name_scope(name, "%s_lookup_table_find" % self.name,
1902                        [self.resource_handle, keys]):
1903      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1904      with ops.colocate_with(self.resource_handle):
1905        values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
1906                                                     self._default_value)
1907
1908    return values
1909
1910  def insert_or_assign(self, keys, values, name=None):
1911    """Associates `keys` with `values`.
1912
1913    Args:
1914      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
1915        key type.
1916      values: Values to be associated with keys. Must be a tensor of the same
1917        shape as `keys` and match the table's value type.
1918      name: A name for the operation (optional).
1919
1920    Returns:
1921      The created Operation.
1922
1923    Raises:
1924      TypeError: when `keys` or `values` doesn't match the table data
1925        types.
1926    """
1927    with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
1928                        [self.resource_handle, keys, values]):
1929      keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1930      values = ops.convert_to_tensor(
1931          values, dtype=self._value_dtype, name="values")
1932      with ops.colocate_with(self.resource_handle):
1933        op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
1934                                                   values)
1935      return op
1936
1937  def insert(self, keys, values, name=None):
1938    """Associates `keys` with `values`.
1939
1940    Args:
1941      keys: Keys to insert. Can be a tensor of any shape. Must match the table's
1942        key type.
1943      values: Values to be associated with keys. Must be a tensor of the same
1944        shape as `keys` and match the table's value type.
1945      name: A name for the operation (optional).
1946
1947    Returns:
1948      The created Operation.
1949
1950    Raises:
1951      TypeError: when `keys` or `values` doesn't match the table data
1952        types.
1953    """
1954    return self.insert_or_assign(keys, values, name)
1955
1956  def erase(self, keys, name=None):
1957    """Removes `keys` and its associated values from the table.
1958
1959    If a key is not present in the table, it is silently ignored.
1960
1961    Args:
1962      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1963        key type.
1964      name: A name for the operation (optional).
1965
1966    Returns:
1967      The created Operation.
1968
1969    Raises:
1970      TypeError: when `keys` do not match the table data types.
1971    """
1972    if keys.dtype != self._key_dtype:
1973      raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
1974                      (self._key_dtype, keys.dtype))
1975
1976    with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
1977                        (self.resource_handle, keys, self._default_value)):
1978      # pylint: disable=protected-access
1979      op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
1980
1981    return op
1982
1983  def remove(self, keys, name=None):
1984    """Removes `keys` and its associated values from the table.
1985
1986    If a key is not present in the table, it is silently ignored.
1987
1988    Args:
1989      keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1990        key type.
1991      name: A name for the operation (optional).
1992
1993    Returns:
1994      The created Operation.
1995
1996    Raises:
1997      TypeError: when `keys` do not match the table data types.
1998    """
1999    return self.erase(keys, name)
2000
2001  def export(self, name=None):
2002    """Returns tensors of all keys and values in the table.
2003
2004    Args:
2005      name: A name for the operation (optional).
2006
2007    Returns:
2008      A pair of tensors with the first tensor containing all keys and the
2009        second tensors containing all values in the table.
2010    """
2011    with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2012                        [self.resource_handle]):
2013      with ops.colocate_with(self.resource_handle):
2014        exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2015            self.resource_handle, self._key_dtype, self._value_dtype)
2016
2017    return exported_keys, exported_values
2018
2019  def _gather_saveables_for_checkpoint(self):
2020    """For object-based checkpointing."""
2021    return {"table": functools.partial(DenseHashTable._Saveable, table=self)}
2022
2023  class _Saveable(BaseSaverBuilder.SaveableObject):
2024    """SaveableObject implementation for DenseHashTable."""
2025
2026    def __init__(self, table, name):
2027      tensors = table.export()
2028      specs = [
2029          BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2030          BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2031      ]
2032      # pylint: disable=protected-access
2033      super(DenseHashTable._Saveable, self).__init__(table, specs, name)
2034
2035    def restore(self, restored_tensors, restored_shapes, name=None):
2036      del restored_shapes  # unused
2037      # pylint: disable=protected-access
2038      with ops.name_scope(name, "%s_table_restore" % self.name):
2039        with ops.colocate_with(self.op.resource_handle):
2040          return gen_lookup_ops.lookup_table_import_v2(
2041              self.op.resource_handle, restored_tensors[0], restored_tensors[1])
2042
2043
2044ops.NotDifferentiable("LookupTableFind")
2045ops.NotDifferentiable("LookupTableFindV2")
2046ops.NotDifferentiable("LookupTableInsert")
2047ops.NotDifferentiable("LookupTableInsertV2")
2048ops.NotDifferentiable("LookupTableSize")
2049ops.NotDifferentiable("LookupTableSizeV2")
2050ops.NotDifferentiable("HashTable")
2051ops.NotDifferentiable("HashTableV2")
2052ops.NotDifferentiable("InitializeTable")
2053ops.NotDifferentiable("InitializeTableV2")
2054ops.NotDifferentiable("InitializeTableFromTextFile")
2055ops.NotDifferentiable("InitializeTableFromTextFileV2")
2056ops.NotDifferentiable("MutableDenseHashTable")
2057ops.NotDifferentiable("MutableDenseHashTableV2")
2058ops.NotDifferentiable("MutableHashTable")
2059ops.NotDifferentiable("MutableHashTableV2")
2060ops.NotDifferentiable("MutableHashTableOfTensors")
2061ops.NotDifferentiable("MutableHashTableOfTensorsV2")
2062