1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <string>
17
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/type_index.h"
20 #include "tensorflow/core/framework/variant.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/public/version.h"
24
25 namespace tensorflow {
26
PersistentStringStorage()27 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
28 static std::unordered_set<string>* string_storage =
29 new std::unordered_set<string>();
30 return string_storage;
31 }
32
33 // static
Global()34 UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
35 static UnaryVariantOpRegistry* global_unary_variant_op_registry =
36 new UnaryVariantOpRegistry;
37 return global_unary_variant_op_registry;
38 }
39
GetShapeFn(StringPiece type_name)40 UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
41 StringPiece type_name) {
42 auto found = shape_fns.find(type_name);
43 if (found == shape_fns.end()) return nullptr;
44 return &found->second;
45 }
46
RegisterShapeFn(const string & type_name,const VariantShapeFn & shape_fn)47 void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
48 const VariantShapeFn& shape_fn) {
49 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
50 VariantShapeFn* existing = GetShapeFn(type_name);
51 CHECK_EQ(existing, nullptr)
52 << "Unary VariantShapeFn for type_name: " << type_name
53 << " already registered";
54 shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
55 GetPersistentStringPiece(type_name), shape_fn));
56 }
57
GetUnaryVariantShape(const Tensor & variant_tensor,TensorShape * shape)58 Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
59 CHECK_EQ(variant_tensor.dtype(), DT_VARIANT);
60 CHECK_EQ(variant_tensor.dims(), 0);
61 const Variant& v = variant_tensor.scalar<Variant>()();
62 UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
63 UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
64 if (shape_fn == nullptr) {
65 return errors::Internal(
66 "No unary variant shape function found for Variant type_name: ",
67 v.TypeName());
68 }
69 return (*shape_fn)(v, shape);
70 }
71
72 // Add some basic registrations for use by others, e.g., for testing.
73 namespace {
74 template <typename T>
ScalarShape(const T &,TensorShape * shape)75 Status ScalarShape(const T&, TensorShape* shape) {
76 *shape = TensorShape({});
77 return Status::OK();
78 }
79 } // namespace
80
81 #define REGISTER_VARIANT_SHAPE_TYPE(T) \
82 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
83
84 // No encode/shape registered for std::complex<> and Eigen::half
85 // objects yet.
86 REGISTER_VARIANT_SHAPE_TYPE(int);
87 REGISTER_VARIANT_SHAPE_TYPE(float);
88 REGISTER_VARIANT_SHAPE_TYPE(bool);
89 REGISTER_VARIANT_SHAPE_TYPE(double);
90
91 #undef REGISTER_VARIANT_SHAPE_TYPE
92
GetDecodeFn(StringPiece type_name)93 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
94 StringPiece type_name) {
95 auto found = decode_fns.find(type_name);
96 if (found == decode_fns.end()) return nullptr;
97 return &found->second;
98 }
99
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)100 void UnaryVariantOpRegistry::RegisterDecodeFn(
101 const string& type_name, const VariantDecodeFn& decode_fn) {
102 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
103 VariantDecodeFn* existing = GetDecodeFn(type_name);
104 CHECK_EQ(existing, nullptr)
105 << "Unary VariantDecodeFn for type_name: " << type_name
106 << " already registered";
107 decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
108 GetPersistentStringPiece(type_name), decode_fn));
109 }
110
DecodeUnaryVariant(Variant * variant)111 bool DecodeUnaryVariant(Variant* variant) {
112 UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
113 UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
114 if (decode_fn == nullptr) {
115 return false;
116 }
117 const string type_name = variant->TypeName();
118 bool decoded = (*decode_fn)(variant);
119 if (!decoded) return false;
120 if (variant->TypeName() != type_name) {
121 LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
122 << type_name
123 << " but after decoding was: " << variant->TypeName()
124 << ". Treating this as a failure.";
125 return false;
126 }
127 return true;
128 }
129
130 // Add some basic registrations for use by others, e.g., for testing.
131
132 #define REGISTER_VARIANT_DECODE_TYPE(T) \
133 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
134
135 // No encode/decode registered for std::complex<> and Eigen::half
136 // objects yet.
137 REGISTER_VARIANT_DECODE_TYPE(int);
138 REGISTER_VARIANT_DECODE_TYPE(float);
139 REGISTER_VARIANT_DECODE_TYPE(bool);
140 REGISTER_VARIANT_DECODE_TYPE(double);
141
142 #undef REGISTER_VARIANT_DECODE_TYPE
143
144 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,StringPiece type_name)145 UnaryVariantOpRegistry::GetDeviceCopyFn(
146 const VariantDeviceCopyDirection direction, StringPiece type_name) {
147 auto found = device_copy_fns.find(std::make_pair(direction, type_name));
148 if (found == device_copy_fns.end()) return nullptr;
149 return &found->second;
150 }
151
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const string & type_name,const AsyncVariantDeviceCopyFn & device_copy_fn)152 void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
153 const VariantDeviceCopyDirection direction, const string& type_name,
154 const AsyncVariantDeviceCopyFn& device_copy_fn) {
155 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
156 AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
157 CHECK_EQ(existing, nullptr)
158 << "UnaryVariantDeviceCopy for direction: " << direction
159 << " and type_name: " << type_name << " already registered";
160 device_copy_fns.insert(
161 std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
162 AsyncVariantDeviceCopyFn>(
163 std::make_pair(direction, GetPersistentStringPiece(type_name)),
164 device_copy_fn));
165 }
166
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)167 Status VariantDeviceCopy(
168 const VariantDeviceCopyDirection direction, const Variant& from,
169 Variant* to,
170 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
171 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
172 UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
173 from.TypeName());
174 if (device_copy_fn == nullptr) {
175 return errors::Internal(
176 "No unary variant device copy function found for direction: ",
177 direction, " and Variant type_name: ", from.TypeName());
178 }
179 return (*device_copy_fn)(from, to, copy_fn);
180 }
181
182 // Special casing UnaryOpFn per op and per device.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,StringPiece type_name)183 UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
184 VariantUnaryOp op, StringPiece device, StringPiece type_name) {
185 auto found = unary_op_fns.find({op, device, type_name});
186 if (found == unary_op_fns.end()) return nullptr;
187 return &found->second;
188 }
189
RegisterUnaryOpFn(VariantUnaryOp op,const string & device,const string & type_name,const VariantUnaryOpFn & unary_op_fn)190 void UnaryVariantOpRegistry::RegisterUnaryOpFn(
191 VariantUnaryOp op, const string& device, const string& type_name,
192 const VariantUnaryOpFn& unary_op_fn) {
193 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
194 VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
195 CHECK_EQ(existing, nullptr)
196 << "Unary VariantUnaryOpFn for type_name: " << type_name
197 << " already registered for device type: " << device;
198 unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
199 {op, GetPersistentStringPiece(device),
200 GetPersistentStringPiece(type_name)},
201 unary_op_fn));
202 }
203
204 namespace {
205 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)206 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
207 T* t_out) {
208 *t_out = T(0);
209 return Status::OK();
210 }
211 } // namespace
212
213 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
214 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
215 DEVICE_CPU, T, TF_STR(T), \
216 ZerosLikeVariantPrimitiveType<T>);
217
218 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
219 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
220 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
221 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
222 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
223
224 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
225
226 // Special casing BinaryOpFn per op and per device.
227 UnaryVariantOpRegistry::VariantBinaryOpFn*
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,StringPiece type_name)228 UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
229 StringPiece type_name) {
230 auto found = binary_op_fns.find({op, device, type_name});
231 if (found == binary_op_fns.end()) return nullptr;
232 return &found->second;
233 }
234
RegisterBinaryOpFn(VariantBinaryOp op,const string & device,const string & type_name,const VariantBinaryOpFn & add_fn)235 void UnaryVariantOpRegistry::RegisterBinaryOpFn(
236 VariantBinaryOp op, const string& device, const string& type_name,
237 const VariantBinaryOpFn& add_fn) {
238 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
239 VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
240 CHECK_EQ(existing, nullptr)
241 << "Unary VariantBinaryOpFn for type_name: " << type_name
242 << " already registered for device type: " << device;
243 binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
244 {op, GetPersistentStringPiece(device),
245 GetPersistentStringPiece(type_name)},
246 add_fn));
247 }
248
249 namespace {
250 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)251 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
252 T* out) {
253 *out = a + b;
254 return Status::OK();
255 }
256 } // namespace
257
258 #define REGISTER_VARIANT_ADD_TYPE(T) \
259 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
260 T, TF_STR(T), \
261 AddVariantPrimitiveType<T>);
262
263 // No add registered for std::complex<> or Eigen::half objects yet.
264 REGISTER_VARIANT_ADD_TYPE(int);
265 REGISTER_VARIANT_ADD_TYPE(float);
266 REGISTER_VARIANT_ADD_TYPE(double);
267 REGISTER_VARIANT_ADD_TYPE(bool);
268
269 #undef REGISTER_VARIANT_ADD_TYPE
270
271 } // namespace tensorflow
272