1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22
23 #define EIGEN_USE_THREADS
24
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/framework/variant.h"
27 #include "tensorflow/core/framework/variant_encode_decode.h"
28 #include "tensorflow/core/lib/hash/hash.h"
29
30 namespace tensorflow {
31
32 class OpKernelContext;
33 // A global UnaryVariantOpRegistry is used to hold callback functions
34 // for different variant types. To be used by ShapeOp, RankOp, and
35 // SizeOp, decoding, etc.
36
37 enum VariantUnaryOp {
38 INVALID_VARIANT_UNARY_OP = 0,
39 ZEROS_LIKE_VARIANT_UNARY_OP = 1,
40 CONJ_VARIANT_UNARY_OP = 2,
41 };
42
43 enum VariantBinaryOp {
44 INVALID_VARIANT_BINARY_OP = 0,
45 ADD_VARIANT_BINARY_OP = 1,
46 };
47
48 enum VariantDeviceCopyDirection {
49 INVALID_DEVICE_COPY_DIRECTION = 0,
50 HOST_TO_DEVICE = 1,
51 DEVICE_TO_HOST = 2,
52 DEVICE_TO_DEVICE = 3,
53 };
54
55 class UnaryVariantOpRegistry {
56 public:
57 typedef std::function<Status(const Variant& v, TensorShape*)> VariantShapeFn;
58 typedef std::function<bool(Variant*)> VariantDecodeFn;
59 typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
60 VariantUnaryOpFn;
61 typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
62 Variant*)>
63 VariantBinaryOpFn;
64
65 // An AsyncTensorDeviceCopyFn is a function provided to
66 // the user-provided DeviceCopyFn callback as the third argument ("copier").
67 //
68 // Expected inputs:
69 // from: A Tensor on the host (if performing cpu->gpu copy), or
70 // device (if performing gpu->cpu or gpu->gpu copy).
71 // to: An empty/uninitialized tensor. It will be updated upon
72 // successful return of the function with the correct dtype and shape.
73 // However, the copied data will not be available until the compute
74 // stream has been synchronized.
75 //
76 // Returns:
77 // The status upon memory allocation / initialization of the
78 // "to" tensor, and enqueue of the copy onto the compute stream.
79 // Any failure of the copy itself will update the underlying
80 // stream status and propagate through the runtime independent
81 // of the caller.
82 typedef std::function<Status(const Tensor& from, Tensor* to)>
83 AsyncTensorDeviceCopyFn;
84
85 // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
86 // expected to be passed to the registration macro
87 // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
88 typedef std::function<Status(const Variant& from, Variant* to,
89 AsyncTensorDeviceCopyFn copy_fn)>
90 AsyncVariantDeviceCopyFn;
91
92 // Add a shape lookup function to the registry.
93 void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
94
95 // Returns nullptr if no shape function was found for the given TypeName.
96 VariantShapeFn* GetShapeFn(StringPiece type_name);
97
98 // Add a decode function to the registry.
99 void RegisterDecodeFn(const string& type_name,
100 const VariantDecodeFn& decode_fn);
101
102 // Returns nullptr if no decode function was found for the given TypeName.
103 VariantDecodeFn* GetDecodeFn(StringPiece type_name);
104
105 // Add a copy-to-GPU function to the registry.
106 void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
107 const string& type_name,
108 const AsyncVariantDeviceCopyFn& device_copy_fn);
109
110 // Returns nullptr if no copy function was found for the given
111 // TypeName and direction.
112 AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
113 const VariantDeviceCopyDirection direction, StringPiece type_name);
114
115 // Add a unary op function to the registry.
116 void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
117 const string& type_name,
118 const VariantUnaryOpFn& unary_op_fn);
119
120 // Returns nullptr if no unary op function was found for the given
121 // op, device, and TypeName.
122 VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
123 StringPiece type_name);
124
125 // Add a binary op function to the registry.
126 void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
127 const string& type_name,
128 const VariantBinaryOpFn& add_fn);
129
130 // Returns nullptr if no binary op function was found for the given
131 // op, device and TypeName.
132 VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
133 StringPiece type_name);
134
135 // Get a pointer to a global UnaryVariantOpRegistry object
136 static UnaryVariantOpRegistry* Global();
137
138 // Get a pointer to a global persistent string storage object.
139 // ISO/IEC C++ working draft N4296 clarifies that insertion into an
140 // std::unordered_set does not invalidate memory locations of
141 // *values* inside the set (though it may invalidate existing
142 // iterators). In other words, one may safely point a StringPiece to
143 // a value in the set without that StringPiece being invalidated by
144 // future insertions.
145 static std::unordered_set<string>* PersistentStringStorage();
146
147 private:
148 std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
149 std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
150 decode_fns;
151
152 // Map std::pair<Direction, type_name> to function.
153 struct PairHash {
154 template <typename Direction>
operatorPairHash155 std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
156 // The hash of an enum is just its value as a std::size_t.
157 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
158 ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
159 return ret;
160 }
161 StringPieceHasher sp_hasher_;
162 };
163
164 std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
165 AsyncVariantDeviceCopyFn, PairHash>
166 device_copy_fns;
167
168 // Map std::tuple<Op, device, type_name> to function.
169
170 // this breaks by falling victim to "too perfect forwarding"
171 // see https://stackoverflow.com/questions/44475317/variadic-template-issue
172 // and references therein
173 template <typename Op>
174 struct FuncTuple {
FuncTupleFuncTuple175 FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
176 : op_type_(op), device_(dev), typename_(tname){};
177 Op op_type_;
178 StringPiece device_, typename_;
179 };
180 // friend declaration for operator==
181 // needed for clang
182 template <typename Op>
183 friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
184 struct TupleHash {
185 template <typename Op>
operatorTupleHash186 std::size_t operator()(
187 const std::tuple<Op, StringPiece, StringPiece>& x) const {
188 // The hash of an enum is just its value as a std::size_t.
189 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
190 ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
191 ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
192 return ret;
193 }
194
195 template <typename Op>
operatorTupleHash196 std::size_t operator()(const FuncTuple<Op>& x) const {
197 // The hash of an enum is just its value as a std::size_t.
198 std::size_t ret = static_cast<std::size_t>(x.op_type_);
199 ret = Hash64Combine(ret, sp_hasher_(x.device_));
200 ret = Hash64Combine(ret, sp_hasher_(x.typename_));
201 return ret;
202 }
203 StringPieceHasher sp_hasher_;
204 };
205 std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
206 unary_op_fns;
207 std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
208 binary_op_fns;
209
210 // Find or insert a string into a persistent string storage
211 // container; return the StringPiece pointing to the permanent string
212 // location.
GetPersistentStringPiece(const string & str)213 static StringPiece GetPersistentStringPiece(const string& str) {
214 const auto string_storage = PersistentStringStorage();
215 auto found = string_storage->find(str);
216 if (found == string_storage->end()) {
217 auto inserted = string_storage->insert(str);
218 return StringPiece(*inserted.first);
219 } else {
220 return StringPiece(*found);
221 }
222 }
223 };
224 template <typename Op>
225 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
226 const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
227 return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
228 (lhs.typename_ == rhs.typename_);
229 }
230 // Gets a TensorShape from a Tensor containing a scalar Variant.
231 // Returns an Internal error if the Variant does not have a registered shape
232 // function, or if it's a serialized Variant that cannot be decoded.
233 //
234 // REQUIRES:
235 // variant_tensor.dtype() == DT_VARIANT
236 // variant_tensor.dims() == 0
237 //
238 Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape);
239
240 // Decodes the Variant whose data_type has a registered decode
241 // function. Returns an Internal error if the Variant does not have a
242 // registered decode function, or if the decoding function fails.
243 //
244 // REQUIRES:
245 // variant is not null.
246 //
247 bool DecodeUnaryVariant(Variant* variant);
248
249 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
250 // The variant 'from' must have a registered DeviceCopyFn for the
251 // given direction. The returned variant 'to' will have
252 // (some subset of its) tensors stored on destination according to the
253 // registered DeviceCopyFn function for the given direction. Returns
254 // an Internal error if the Variant does not have a registered
255 // DeviceCopyFn function for the given direction, or if initiating the
256 // copy fails.
257 //
258 // REQUIRES:
259 // 'to' is not null.
260 //
261 Status VariantDeviceCopy(
262 const VariantDeviceCopyDirection direction, const Variant& from,
263 Variant* to,
264 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
265
266 // Sets *v_out = unary_op(v). The variant v must have a registered
267 // UnaryOp function for the given Device. Returns an Internal error
268 // if v does not have a registered unary_op function for this device, or if
269 // UnaryOp fails.
270 //
271 // REQUIRES:
272 // v_out is not null.
273 //
274 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)275 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
276 Variant* v_out) {
277 const string& device = DeviceName<Device>::value;
278 UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
279 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
280 if (unary_op_fn == nullptr) {
281 return errors::Internal(
282 "No unary variant unary_op function found for unary variant op enum: ",
283 op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
284 }
285 return (*unary_op_fn)(ctx, v, v_out);
286 }
287
288 // Sets *out = binary_op(a, b). The variants a and b must be the same type
289 // and have a registered binary_op function for the given Device. Returns an
290 // Internal error if a and b are not the same type_name or if
291 // if a does not have a registered op function for this device, or if
292 // BinaryOp fails.
293 //
294 // REQUIRES:
295 // out is not null.
296 //
297 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)298 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
299 const Variant& a, const Variant& b, Variant* out) {
300 if (a.TypeName() != b.TypeName()) {
301 return errors::Internal(
302 "BianryOpVariants: Variants a and b have different "
303 "type names: '",
304 a.TypeName(), "' vs. '", b.TypeName(), "'");
305 }
306 const string& device = DeviceName<Device>::value;
307 UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
308 UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
309 if (binary_op_fn == nullptr) {
310 return errors::Internal(
311 "No unary variant binary_op function found for binary variant op "
312 "enum: ",
313 op, " Variant type_name: '", a.TypeName(), "' for device type: ",
314 device);
315 }
316 return (*binary_op_fn)(ctx, a, b, out);
317 }
318
319 namespace variant_op_registry_fn_registration {
320
321 template <typename T>
322 class UnaryVariantShapeRegistration {
323 public:
324 typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
325
UnaryVariantShapeRegistration(const string & type_name,const LocalVariantShapeFn & shape_fn)326 UnaryVariantShapeRegistration(const string& type_name,
327 const LocalVariantShapeFn& shape_fn) {
328 UnaryVariantOpRegistry::Global()->RegisterShapeFn(
329 type_name,
330 [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
331 const T* t = v.get<T>();
332 if (t == nullptr) {
333 return errors::Internal(
334 "VariantShapeFn: Could not access object, type_name: ",
335 type_name);
336 }
337 return shape_fn(*t, s);
338 });
339 }
340 };
341
342 template <typename T>
343 class UnaryVariantDecodeRegistration {
344 public:
UnaryVariantDecodeRegistration(const string & type_name)345 UnaryVariantDecodeRegistration(const string& type_name) {
346 // The Variant is passed by pointer because it should be
347 // mutable: get below may Decode the variant, which
348 // is a self-mutating behavior. The variant is not modified in
349 // any other way.
350 UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
351 type_name, [type_name](Variant* v) -> bool {
352 DCHECK_NE(v, nullptr);
353 VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
354 if (t == nullptr) {
355 return false;
356 }
357 Variant decoded = T();
358 VariantTensorData data(*t);
359 if (!decoded.Decode(data)) {
360 return false;
361 }
362 *v = std::move(decoded);
363 return true;
364 });
365 }
366 };
367
368 template <typename T>
369 class UnaryVariantDeviceCopyRegistration {
370 public:
371 typedef std::function<Status(const T& t, T* t_out,
372 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
373 LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const string & type_name,const LocalVariantDeviceCopyFn & device_copy_fn)374 UnaryVariantDeviceCopyRegistration(
375 const VariantDeviceCopyDirection direction, const string& type_name,
376 const LocalVariantDeviceCopyFn& device_copy_fn) {
377 UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
378 direction, type_name,
379 [type_name, device_copy_fn](
380 const Variant& from, Variant* to,
381 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
382 device_copy_tensor_fn) -> Status {
383 DCHECK_NE(to, nullptr);
384 *to = T();
385 if (from.get<T>() == nullptr) {
386 return errors::Internal(
387 "VariantCopyToGPUFn: Could not access object, type_name: ",
388 type_name);
389 }
390 const T& t = *from.get<T>();
391 T* t_out = to->get<T>();
392 return device_copy_fn(t, t_out, device_copy_tensor_fn);
393 });
394 }
395 };
396
397 template <typename T>
398 class UnaryVariantUnaryOpRegistration {
399 typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
400 LocalVariantUnaryOpFn;
401
402 public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const string & device,const string & type_name,const LocalVariantUnaryOpFn & unary_op_fn)403 UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
404 const string& type_name,
405 const LocalVariantUnaryOpFn& unary_op_fn) {
406 UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
407 op, device, type_name,
408 [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
409 Variant* v_out) -> Status {
410 DCHECK_NE(v_out, nullptr);
411 *v_out = T();
412 if (v.get<T>() == nullptr) {
413 return errors::Internal(
414 "VariantUnaryOpFn: Could not access object, type_name: ",
415 type_name);
416 }
417 const T& t = *v.get<T>();
418 T* t_out = v_out->get<T>();
419 return unary_op_fn(ctx, t, t_out);
420 });
421 }
422 };
423
424 template <typename T>
425 class UnaryVariantBinaryOpRegistration {
426 typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
427 T* out)>
428 LocalVariantBinaryOpFn;
429
430 public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const string & device,const string & type_name,const LocalVariantBinaryOpFn & binary_op_fn)431 UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
432 const string& type_name,
433 const LocalVariantBinaryOpFn& binary_op_fn) {
434 UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
435 op, device, type_name,
436 [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
437 const Variant& b, Variant* out) -> Status {
438 DCHECK_NE(out, nullptr);
439 *out = T();
440 if (a.get<T>() == nullptr) {
441 return errors::Internal(
442 "VariantBinaryOpFn: Could not access object 'a', type_name: ",
443 type_name);
444 }
445 if (b.get<T>() == nullptr) {
446 return errors::Internal(
447 "VariantBinaryOpFn: Could not access object 'b', type_name: ",
448 type_name);
449 }
450 const T& t_a = *a.get<T>();
451 const T& t_b = *b.get<T>();
452 T* t_out = out->get<T>();
453 return binary_op_fn(ctx, t_a, t_b, t_out);
454 });
455 }
456 };
457
458 }; // namespace variant_op_registry_fn_registration
459
460 // Register a unary shape variant function with the signature:
461 // Status ShapeFn(const T& t, TensorShape* s);
462 // to Variants having TypeName type_name.
463 #define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \
464 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
465 shape_function)
466
467 #define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
468 shape_function) \
469 REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
470
471 #define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \
472 shape_function) \
473 static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
474 register_unary_variant_op_shape_registration_fn_##ctr(type_name, \
475 shape_function)
476
477 // Register a unary decode variant function for the given type.
478 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
479 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
480
481 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
482 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
483
484 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
485 static variant_op_registry_fn_registration::UnaryVariantDecodeRegistration< \
486 T> \
487 register_unary_variant_op_decoder_fn_##ctr(type_name)
488
489 // ****** NOTE ******
490 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
491 // ****** NOTE ******
492 //
493 // Register a device copy variant function for the given copy
494 // direction and type; where direction is the enum
495 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
496 //
497 // Status device_copy_fn(
498 // const T& t, T* t_out,
499 // const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
500 //
501 // And device_copy_fn calls copier 0 or more times. For details on
502 // the behavior of the copier function, see the comments at the
503 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
504 //
505 // Note, the device_copy_fn may choose to keep some tensors
506 // on host, e.g. by assigning to->tensor = from.tensor (assuming
507 // from.tensor is already on host); or by setting
508 // to->tensor = Tensor(cpu_allocator(), ...)
509 // and manually updating its values.
510 //
511 // If this is the case, the CopyFns for HOST_TO_DEVICE,
512 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
513 // copies in a consistent manner. For example, one must always
514 // manually copy any "always on host" tensors in all directions instead of e.g.
515 // - performing a host-to-host copy in one direction,
516 // - using the provided copier function in the reverse direction.
517 // Doing the latter will cause program failures.
518 //
519 // ****** NOTE ******
520 // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
521 // ****** NOTE ******
522 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
523 T, direction, type_name, device_copy_fn) \
524 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
525 __COUNTER__, T, direction, type_name, device_copy_fn)
526
527 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
528 ctr, T, direction, type_name, device_copy_fn) \
529 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
530 ctr, T, direction, type_name, device_copy_fn)
531
532 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
533 ctr, T, direction, type_name, device_copy_fn) \
534 static variant_op_registry_fn_registration:: \
535 UnaryVariantDeviceCopyRegistration<T> \
536 register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
537 device_copy_fn)
538
539 // Register a unary unary_op variant function with the signature:
540 // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
541 // to Variants having TypeName type_name, for device string device,
542 // for UnaryVariantOp enum op.
543 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
544 unary_op_function) \
545 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
546 __COUNTER__, op, device, T, type_name, unary_op_function)
547
548 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
549 ctr, op, device, T, type_name, unary_op_function) \
550 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
551 unary_op_function)
552
553 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
554 ctr, op, device, T, type_name, unary_op_function) \
555 static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
556 T> \
557 register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
558 unary_op_function)
559
560 // Register a binary_op variant function with the signature:
561 // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
562 // to Variants having TypeName type_name, for device string device,
563 // for BinaryVariantOp enum OP.
564 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
565 binary_op_function) \
566 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
567 __COUNTER__, op, device, T, type_name, binary_op_function)
568
569 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
570 ctr, op, device, T, type_name, binary_op_function) \
571 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
572 ctr, op, device, T, type_name, binary_op_function)
573
574 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
575 ctr, op, device, T, type_name, binary_op_function) \
576 static variant_op_registry_fn_registration:: \
577 UnaryVariantBinaryOpRegistration<T> \
578 register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
579 binary_op_function)
580
581 } // end namespace tensorflow
582
583 #endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
584