1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
18 
19 #include "tensorflow/core/framework/resource_mgr.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/status.h"
22 
23 namespace tensorflow {
24 
25 class OpKernelContext;
26 
27 namespace lookup {
28 
29 // Forward declaration so we can define GetInitializableLookupTable() in
30 // LookupInterface.
31 class InitializableLookupTable;
32 
33 // Lookup interface for batch lookups used by table lookup ops.
34 class LookupInterface : public ResourceBase {
35  public:
36   // Performs batch lookups, for every element in the key tensor, Find returns
37   // the corresponding value into the values tensor.
38   // If an element is not present in the table, the given default value is used.
39 
40   // For tables that require initialization, Find is available once the table
41   // is marked as initialized.
42 
43   // Returns the following statuses:
44   // - OK: when the find finishes successfully.
45   // - FailedPrecondition: if the table is not initialized.
46   // - InvalidArgument: if any of the preconditions on the lookup key or value
47   //   fails.
48   // - In addition, other implementations may provide another non-OK status
49   //   specific to their failure modes.
50   virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values,
51                       const Tensor& default_value) = 0;
52 
53   // Inserts elements into the table. Each element of the key tensor is
54   // associated with the corresponding element in the value tensor.
55   // This method is only implemented in mutable tables that can be updated over
56   // the execution of the graph. It returns Status::NotImplemented for read-only
57   // tables that are initialized once before they can be looked up.
58 
59   // Returns the following statuses:
60   // - OK: when the insert finishes successfully.
61   // - InvalidArgument: if any of the preconditions on the lookup key or value
62   //   fails.
63   // - Unimplemented: if the table does not support insertions.
64   virtual Status Insert(OpKernelContext* ctx, const Tensor& keys,
65                         const Tensor& values) = 0;
66 
67   // Removes elements from the table.
68   // This method is only implemented in mutable tables that can be updated over
69   // the execution of the graph. It returns Status::NotImplemented for read-only
70   // tables that are initialized once before they can be looked up.
71 
72   // Returns the following statuses:
73   // - OK: when the remove finishes successfully.
74   // - InvalidArgument: if any of the preconditions on the lookup key fails.
75   // - Unimplemented: if the table does not support removals.
76   virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0;
77 
78   // Returns the number of elements in the table.
79   virtual size_t size() const = 0;
80 
81   // Exports the values of the table to two tensors named keys and values.
82   // Note that the shape of the tensors is completely up to the implementation
83   // of the table and can be different than the tensors used for the Insert
84   // function above.
85   virtual Status ExportValues(OpKernelContext* ctx) = 0;
86 
87   // Imports previously exported keys and values.
88   // As mentioned above, the shape of the keys and values tensors are determined
89   // by the ExportValues function above and can be different than for the
90   // Insert function.
91   virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
92                               const Tensor& values) = 0;
93 
94   // Returns the data type of the key.
95   virtual DataType key_dtype() const = 0;
96 
97   // Returns the data type of the value.
98   virtual DataType value_dtype() const = 0;
99 
100   // Returns the shape of a key in the table.
101   virtual TensorShape key_shape() const = 0;
102 
103   // Returns the shape of a value in the table.
104   virtual TensorShape value_shape() const = 0;
105 
106   // Check format of the key and value tensors for the Insert function.
107   // Returns OK if all the following requirements are satisfied, otherwise it
108   // returns InvalidArgument:
109   // - DataType of the tensor keys equals to the table key_dtype
110   // - DataType of the tensor values equals to the table value_dtype
111   // - the values tensor has the required shape given keys and the tables's
112   //   value shape.
113   virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys,
114                                                   const Tensor& values);
115 
116   // Similar to the function above but instead checks eligibility for the Import
117   // function.
118   virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys,
119                                                   const Tensor& values);
120 
121   // Check format of the key tensor for the Remove function.
122   // Returns OK if all the following requirements are satisfied, otherwise it
123   // returns InvalidArgument:
124   // - DataType of the tensor keys equals to the table key_dtype
125   virtual Status CheckKeyTensorForRemove(const Tensor& keys);
126 
127   // Check the arguments of a find operation. Returns OK if all the following
128   // requirements are satisfied, otherwise it returns InvalidArgument:
129   // - DataType of the tensor keys equals to the table key_dtype
130   // - DataType of the tensor default_value equals to the table value_dtype
131   // - the default_value tensor shape matches the table's value shape.
132   Status CheckFindArguments(const Tensor& keys, const Tensor& default_value);
133 
DebugString()134   string DebugString() const override {
135     return strings::StrCat("A lookup table of size: ", size());
136   }
137 
138   // Returns an InitializableLookupTable, a subclass of LookupInterface, if the
139   // current object is an InitializableLookupTable. Otherwise, returns nullptr.
GetInitializableLookupTable()140   virtual InitializableLookupTable* GetInitializableLookupTable() {
141     return nullptr;
142   }
143 
144  protected:
145   virtual ~LookupInterface() = default;
146 
147   // Makes sure that the key and value tensor DataType's match the table
148   // key_dtype and value_dtype.
149   Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values);
150 
151   // Makes sure that the provided shape is consistent with the table keys shape.
152   Status CheckKeyShape(const TensorShape& shape);
153 
154  private:
155   Status CheckKeyAndValueTensorsHelper(const Tensor& keys,
156                                        const Tensor& values);
157 };
158 
159 }  // namespace lookup
160 }  // namespace tensorflow
161 
162 #endif  // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
163