1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ 18 19 #include "tensorflow/core/framework/resource_mgr.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/lib/core/status.h" 22 23 namespace tensorflow { 24 25 class OpKernelContext; 26 27 namespace lookup { 28 29 // Forward declaration so we can define GetInitializableLookupTable() in 30 // LookupInterface. 31 class InitializableLookupTable; 32 33 // Lookup interface for batch lookups used by table lookup ops. 34 class LookupInterface : public ResourceBase { 35 public: 36 // Performs batch lookups, for every element in the key tensor, Find returns 37 // the corresponding value into the values tensor. 38 // If an element is not present in the table, the given default value is used. 39 40 // For tables that require initialization, Find is available once the table 41 // is marked as initialized. 42 43 // Returns the following statuses: 44 // - OK: when the find finishes successfully. 45 // - FailedPrecondition: if the table is not initialized. 46 // - InvalidArgument: if any of the preconditions on the lookup key or value 47 // fails. 48 // - In addition, other implementations may provide another non-OK status 49 // specific to their failure modes. 50 virtual Status Find(OpKernelContext* ctx, const Tensor& keys, Tensor* values, 51 const Tensor& default_value) = 0; 52 53 // Inserts elements into the table. Each element of the key tensor is 54 // associated with the corresponding element in the value tensor. 55 // This method is only implemented in mutable tables that can be updated over 56 // the execution of the graph. It returns Status::NotImplemented for read-only 57 // tables that are initialized once before they can be looked up. 58 59 // Returns the following statuses: 60 // - OK: when the insert finishes successfully. 61 // - InvalidArgument: if any of the preconditions on the lookup key or value 62 // fails. 63 // - Unimplemented: if the table does not support insertions. 64 virtual Status Insert(OpKernelContext* ctx, const Tensor& keys, 65 const Tensor& values) = 0; 66 67 // Removes elements from the table. 68 // This method is only implemented in mutable tables that can be updated over 69 // the execution of the graph. It returns Status::NotImplemented for read-only 70 // tables that are initialized once before they can be looked up. 71 72 // Returns the following statuses: 73 // - OK: when the remove finishes successfully. 74 // - InvalidArgument: if any of the preconditions on the lookup key fails. 75 // - Unimplemented: if the table does not support removals. 76 virtual Status Remove(OpKernelContext* ctx, const Tensor& keys) = 0; 77 78 // Returns the number of elements in the table. 79 virtual size_t size() const = 0; 80 81 // Exports the values of the table to two tensors named keys and values. 82 // Note that the shape of the tensors is completely up to the implementation 83 // of the table and can be different than the tensors used for the Insert 84 // function above. 85 virtual Status ExportValues(OpKernelContext* ctx) = 0; 86 87 // Imports previously exported keys and values. 88 // As mentioned above, the shape of the keys and values tensors are determined 89 // by the ExportValues function above and can be different than for the 90 // Insert function. 91 virtual Status ImportValues(OpKernelContext* ctx, const Tensor& keys, 92 const Tensor& values) = 0; 93 94 // Returns the data type of the key. 95 virtual DataType key_dtype() const = 0; 96 97 // Returns the data type of the value. 98 virtual DataType value_dtype() const = 0; 99 100 // Returns the shape of a key in the table. 101 virtual TensorShape key_shape() const = 0; 102 103 // Returns the shape of a value in the table. 104 virtual TensorShape value_shape() const = 0; 105 106 // Check format of the key and value tensors for the Insert function. 107 // Returns OK if all the following requirements are satisfied, otherwise it 108 // returns InvalidArgument: 109 // - DataType of the tensor keys equals to the table key_dtype 110 // - DataType of the tensor values equals to the table value_dtype 111 // - the values tensor has the required shape given keys and the tables's 112 // value shape. 113 virtual Status CheckKeyAndValueTensorsForInsert(const Tensor& keys, 114 const Tensor& values); 115 116 // Similar to the function above but instead checks eligibility for the Import 117 // function. 118 virtual Status CheckKeyAndValueTensorsForImport(const Tensor& keys, 119 const Tensor& values); 120 121 // Check format of the key tensor for the Remove function. 122 // Returns OK if all the following requirements are satisfied, otherwise it 123 // returns InvalidArgument: 124 // - DataType of the tensor keys equals to the table key_dtype 125 virtual Status CheckKeyTensorForRemove(const Tensor& keys); 126 127 // Check the arguments of a find operation. Returns OK if all the following 128 // requirements are satisfied, otherwise it returns InvalidArgument: 129 // - DataType of the tensor keys equals to the table key_dtype 130 // - DataType of the tensor default_value equals to the table value_dtype 131 // - the default_value tensor shape matches the table's value shape. 132 Status CheckFindArguments(const Tensor& keys, const Tensor& default_value); 133 DebugString()134 string DebugString() const override { 135 return strings::StrCat("A lookup table of size: ", size()); 136 } 137 138 // Returns an InitializableLookupTable, a subclass of LookupInterface, if the 139 // current object is an InitializableLookupTable. Otherwise, returns nullptr. GetInitializableLookupTable()140 virtual InitializableLookupTable* GetInitializableLookupTable() { 141 return nullptr; 142 } 143 144 protected: 145 virtual ~LookupInterface() = default; 146 147 // Makes sure that the key and value tensor DataType's match the table 148 // key_dtype and value_dtype. 149 Status CheckKeyAndValueTypes(const Tensor& keys, const Tensor& values); 150 151 // Makes sure that the provided shape is consistent with the table keys shape. 152 Status CheckKeyShape(const TensorShape& shape); 153 154 private: 155 Status CheckKeyAndValueTensorsHelper(const Tensor& keys, 156 const Tensor& values); 157 }; 158 159 } // namespace lookup 160 } // namespace tensorflow 161 162 #endif // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_ 163