1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
17 #define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
18 
19 #include "tensorflow/core/framework/lookup_interface.h"
20 #include "tensorflow/core/platform/macros.h"
21 
22 namespace tensorflow {
23 namespace lookup {
24 
25 // Base class for lookup tables that require initialization.
26 class InitializableLookupTable : public LookupInterface {
27  public:
28   class InitTableIterator;
29 
30   // Performs batch lookups, for every element in the key tensor, Find returns
31   // the corresponding value into the values tensor.
32   // If an element is not present in the table, the given default value is used.
33   //
34   // For tables that require initialization, `Find` is available once the table
35   // is marked as initialized.
36   //
37   // Returns the following statuses:
38   // - OK: when the find finishes successfully.
39   // - FailedPrecondition: if the table is not initialized.
40   // - InvalidArgument: if any of the preconditions on the lookup key or value
41   //   fails.
42   // - In addition, other implementations may provide another non-OK status
43   //   specific to their failure modes.
44   Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
45               const Tensor& default_value) final;
46 
47   // Returns errors::Unimplemented.
Insert(OpKernelContext * ctx,const Tensor & keys,const Tensor & values)48   Status Insert(OpKernelContext* ctx, const Tensor& keys,
49                 const Tensor& values) final {
50     return errors::Unimplemented(
51         "Insert not supported by InitializableLookupTable implementations");
52   }
53 
54   // Returns errors::Unimplemented.
Remove(OpKernelContext * ctx,const Tensor & keys)55   Status Remove(OpKernelContext* ctx, const Tensor& keys) final {
56     return errors::Unimplemented(
57         "Remove not supported by InitializableLookupTable implementations");
58   }
59 
ExportValues(OpKernelContext * context)60   Status ExportValues(OpKernelContext* context) override {
61     return errors::Unimplemented(
62         "ExportValues not supported by InitializableLookupTable "
63         "implementations");
64   }
65 
66   Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
67                       const Tensor& values) final;
68 
key_shape()69   TensorShape key_shape() const final { return TensorShape(); }
70 
value_shape()71   TensorShape value_shape() const final { return TensorShape(); }
72 
73   // Returns whether the table was initialized and is ready to serve lookups.
is_initialized()74   bool is_initialized() const { return is_initialized_; }
75 
76   // Initializes the table from the given init table iterator.
77   //
78   // Atomically, this operation prepares the table, populates it with the given
79   // iterator, and mark the table as initialized.
80   //
81   // Returns the following statuses:
82   // - OK: when the initialization was successful.
83   // - InvalidArgument: if any of the preconditions on the lookup key or value
84   //   fails.
85   // - FailedPrecondition: if the table is already initialized and
86   //   fail_if_initialized is set to true.
87   // - In addition, other implementations may provide another non-OK status
88   //   specific to their failure modes.
89   Status Initialize(InitTableIterator& iter);
90 
91   // Basic iterator to initialize lookup tables.
92   // It yields a sequence of pairs of `keys()` and `values()` Tensors, so that
93   // the consumer may insert key-value pairs in batches.
94   //
95   // Then the iterator is exhausted, valid returns false and status returns
96   // Status::OutOfRange.
97   //
98   // This class is Thread-unsafe.
99   class InitTableIterator {
100    public:
InitTableIterator()101     InitTableIterator() {}
102 
~InitTableIterator()103     virtual ~InitTableIterator() {}
104 
105     // Prepares the next batch of key and value tensors.
106     virtual void Next() = 0;
107 
108     // Returns true if keys and values point to valid tensors.
109     virtual bool Valid() const = 0;
110 
111     // Returns a tensor that contains the current batch of 'key' values.
112     virtual const Tensor& keys() const = 0;
113 
114     // Returns a tensor that contains the current batch of 'value' values.
115     virtual const Tensor& values() const = 0;
116 
117     // Returns an error if one has occurred, otherwise returns Status::OK.
118     virtual Status status() const = 0;
119 
120     // Returns the total number of elements that the iterator will produce.
121     // It might return -1 in case of error.
122     virtual int64 total_size() const = 0;
123 
124    private:
125     TF_DISALLOW_COPY_AND_ASSIGN(InitTableIterator);
126   };
127 
GetInitializableLookupTable()128   InitializableLookupTable* GetInitializableLookupTable() override {
129     return this;
130   }
131 
132  protected:
133   // Prepares and allocates the underlying data structure to store the given
134   // number of expected elements.
135   virtual Status DoPrepare(size_t expected_num_elements) = 0;
136 
137   // Same as DoPrepare() but derived implementations might choose to skip
138   // calling get_expected_num_elements if size is not needed for DoPrepare.
DoLazyPrepare(std::function<int64 (void)> get_expected_num_elements)139   virtual Status DoLazyPrepare(
140       std::function<int64(void)> get_expected_num_elements) {
141     int64 expected_num_elements = get_expected_num_elements();
142     if (expected_num_elements < 0) {
143       return errors::FailedPrecondition("Got negative expected_num_elements.");
144     }
145     return DoPrepare(expected_num_elements);
146   }
147 
148   // Populates the table in batches given keys and values as tensors into the
149   // underlying data structure.
150   virtual Status DoInsert(const Tensor& keys, const Tensor& values) = 0;
151 
152   // Performs the batch find operation on the underlying data structure.
153   virtual Status DoFind(const Tensor& keys, Tensor* values,
154                         const Tensor& default_value) = 0;
155 
156   mutex mu_;
157   bool is_initialized_ = false;
158 };
159 
160 // Iterator to initialize tables given 'keys' and 'values' tensors.
161 //
162 // The two tensors are returned in the first iteration. It doesn't loop
163 // over each element of the tensor since insertions in the lookup table can
164 // process batches.
165 class KeyValueTensorIterator
166     : public InitializableLookupTable::InitTableIterator {
167  public:
168   // keys and values are not owned by the iterator.
KeyValueTensorIterator(const Tensor * keys,const Tensor * values)169   explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
170       : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
171     TensorShape key_shape = keys_->shape();
172     if (!key_shape.IsSameSize(values_->shape())) {
173       valid_ = false;
174       status_ = errors::InvalidArgument(
175           "keys and values should have the same dimension.",
176           key_shape.DebugString(), " vs ", values_->shape().DebugString());
177     }
178     if (key_shape.num_elements() == 0) {
179       valid_ = false;
180       status_ =
181           errors::InvalidArgument("keys and values cannot be empty tensors.");
182     }
183   }
184 
Valid()185   bool Valid() const override { return valid_; }
186 
Next()187   void Next() override {
188     valid_ = false;
189     status_ = errors::OutOfRange("No more data.");
190   }
191 
keys()192   const Tensor& keys() const override { return *keys_; }
193 
values()194   const Tensor& values() const override { return *values_; }
195 
status()196   Status status() const override { return status_; }
197 
total_size()198   int64 total_size() const override {
199     return keys_ == nullptr ? -1 : keys_->NumElements();
200   }
201 
202  private:
203   TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
204 
205   const Tensor* keys_;    // Doesn't own it.
206   const Tensor* values_;  // Doesn't own it.
207   bool valid_;            // true if the iterator points to an existing range.
208   Status status_;
209 };
210 
211 }  // namespace lookup
212 }  // namespace tensorflow
213 
214 #endif  // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
215