1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18 
19 #include <string>
20 #include <unordered_set>
21 #include <vector>
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/framework/tensor.pb.h"
26 #include "tensorflow/core/framework/type_index.h"
27 #include "tensorflow/core/framework/types.h"
28 #include "tensorflow/core/framework/variant.h"
29 #include "tensorflow/core/framework/variant_encode_decode.h"
30 #include "tensorflow/core/lib/gtl/flatmap.h"
31 #include "tensorflow/core/lib/hash/hash.h"
32 #include "tensorflow/core/platform/abi.h"
33 
34 namespace tensorflow {
35 
36 class OpKernelContext;
37 // A global UnaryVariantOpRegistry is used to hold callback functions
38 // for different variant types.  To be used by ShapeOp, RankOp, and
39 // SizeOp, decoding, etc.
40 
41 enum VariantUnaryOp {
42   INVALID_VARIANT_UNARY_OP = 0,
43   ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44   CONJ_VARIANT_UNARY_OP = 2,
45 };
46 
47 enum VariantBinaryOp {
48   INVALID_VARIANT_BINARY_OP = 0,
49   ADD_VARIANT_BINARY_OP = 1,
50 };
51 
52 enum VariantDeviceCopyDirection {
53   INVALID_DEVICE_COPY_DIRECTION = 0,
54   HOST_TO_DEVICE = 1,
55   DEVICE_TO_HOST = 2,
56   DEVICE_TO_DEVICE = 3,
57 };
58 
59 class UnaryVariantOpRegistry {
60  public:
61   typedef std::function<bool(Variant*)> VariantDecodeFn;
62   typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
63       VariantUnaryOpFn;
64   typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
65                                Variant*)>
66       VariantBinaryOpFn;
67 
68   // An AsyncTensorDeviceCopyFn is a function provided to
69   // the user-provided DeviceCopyFn callback as the third argument ("copier").
70   //
71   // Expected inputs:
72   //   from: A Tensor on the host (if performing cpu->gpu copy), or
73   //         device (if performing gpu->cpu or gpu->gpu copy).
74   //   to: An empty/uninitialized tensor.  It will be updated upon
75   //       successful return of the function with the correct dtype and shape.
76   //       However, the copied data will not be available until the compute
77   //       stream has been synchronized.
78   //
79   // Returns:
80   //   The status upon memory allocation / initialization of the
81   //   "to" tensor, and enqueue of the copy onto the compute stream.
82   //   Any failure of the copy itself will update the underlying
83   //   stream status and propagate through the runtime independent
84   //   of the caller.
85   typedef std::function<Status(const Tensor& from, Tensor* to)>
86       AsyncTensorDeviceCopyFn;
87 
88   // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
89   // expected to be passed to the registration macro
90   // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
91   typedef std::function<Status(const Variant& from, Variant* to,
92                                AsyncTensorDeviceCopyFn copy_fn)>
93       AsyncVariantDeviceCopyFn;
94 
95   // Add a decode function to the registry.
96   void RegisterDecodeFn(const string& type_name,
97                         const VariantDecodeFn& decode_fn);
98 
99   // Returns nullptr if no decode function was found for the given TypeName.
100   VariantDecodeFn* GetDecodeFn(StringPiece type_name);
101 
102   // Add a copy-to-GPU function to the registry.
103   void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
104                             const TypeIndex& type_index,
105                             const AsyncVariantDeviceCopyFn& device_copy_fn);
106 
107   // Returns nullptr if no copy function was found for the given
108   // TypeName and direction.
109   AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
110       const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
111 
112   // Add a unary op function to the registry.
113   void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
114                          const TypeIndex& type_index,
115                          const VariantUnaryOpFn& unary_op_fn);
116 
117   // Returns nullptr if no unary op function was found for the given
118   // op, device, and TypeName.
119   VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
120                                  const TypeIndex& type_index);
121 
122   // Add a binary op function to the registry.
123   void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
124                           const TypeIndex& type_index,
125                           const VariantBinaryOpFn& add_fn);
126 
127   // Returns nullptr if no binary op function was found for the given
128   // op, device and TypeName.
129   VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
130                                    const TypeIndex& type_index);
131 
132   // Get a pointer to a global UnaryVariantOpRegistry object
133   static UnaryVariantOpRegistry* Global();
134 
135   // Get a pointer to a global persistent string storage object.
136   // ISO/IEC C++ working draft N4296 clarifies that insertion into an
137   // std::unordered_set does not invalidate memory locations of
138   // *values* inside the set (though it may invalidate existing
139   // iterators).  In other words, one may safely point a StringPiece to
140   // a value in the set without that StringPiece being invalidated by
141   // future insertions.
142   static std::unordered_set<string>* PersistentStringStorage();
143 
144  private:
145   struct TypeIndexHash {
operatorTypeIndexHash146     std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
147   };
148 
149   gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
150 
151   // Map std::pair<Direction, type_name> to function.
152   struct PairHash {
153     template <typename Direction>
operatorPairHash154     std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
155       // The hash of an enum is just its value as a std::size_t.
156       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
157       ret = Hash64Combine(ret, std::get<1>(x).hash_code());
158       return ret;
159     }
160   };
161 
162   gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
163                AsyncVariantDeviceCopyFn, PairHash>
164       device_copy_fns;
165 
166   // Map std::tuple<Op, device, type_name> to function.
167 
168   // this breaks by falling victim to "too perfect forwarding"
169   // see https://stackoverflow.com/questions/44475317/variadic-template-issue
170   // and references therein
171   template <typename Op>
172   struct FuncTuple {
FuncTupleFuncTuple173     FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
174         : op_type_(op), device_(dev), type_index_(type_index) {}
175     Op op_type_;
176     StringPiece device_;
177     TypeIndex type_index_;
178   };
179   // friend declaration for operator==
180   // needed for clang
181   template <typename Op>
182   friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
183   struct TupleHash {
184     template <typename Op>
operatorTupleHash185     std::size_t operator()(
186         const std::tuple<Op, StringPiece, TypeIndex>& x) const {
187       // The hash of an enum is just its value as a std::size_t.
188       std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
189       ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
190       ret = Hash64Combine(ret, std::get<2>(x).hash_code());
191       return ret;
192     }
193 
194     template <typename Op>
operatorTupleHash195     std::size_t operator()(const FuncTuple<Op>& x) const {
196       // The hash of an enum is just its value as a std::size_t.
197       std::size_t ret = static_cast<std::size_t>(x.op_type_);
198       ret = Hash64Combine(ret, sp_hasher_(x.device_));
199       ret = Hash64Combine(ret, x.type_index_.hash_code());
200       return ret;
201     }
202     StringPieceHasher sp_hasher_;
203   };
204   gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
205       unary_op_fns;
206   gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
207       binary_op_fns;
208 
209   // Find or insert a string into a persistent string storage
210   // container; return the StringPiece pointing to the permanent string
211   // location.
GetPersistentStringPiece(const string & str)212   static StringPiece GetPersistentStringPiece(const string& str) {
213     const auto string_storage = PersistentStringStorage();
214     auto found = string_storage->find(str);
215     if (found == string_storage->end()) {
216       auto inserted = string_storage->insert(str);
217       return StringPiece(*inserted.first);
218     } else {
219       return StringPiece(*found);
220     }
221   }
222 };
223 template <typename Op>
224 inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
225                        const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
226   return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
227          (lhs.type_index_ == rhs.type_index_);
228 }
229 
230 // Decodes the Variant whose data_type has a registered decode
231 // function.  Returns an Internal error if the Variant does not have a
232 // registered decode function, or if the decoding function fails.
233 //
234 // REQUIRES:
235 //   variant is not null.
236 //
237 bool DecodeUnaryVariant(Variant* variant);
238 
239 // Copies a variant between CPU<->GPU, or between GPU<->GPU.
240 // The variant 'from' must have a registered DeviceCopyFn for the
241 // given direction.  The returned variant 'to' will have
242 // (some subset of its) tensors stored on destination according to the
243 // registered DeviceCopyFn function for the given direction.  Returns
244 // an Internal error if the Variant does not have a registered
245 // DeviceCopyFn function for the given direction, or if initiating the
246 // copy fails.
247 //
248 // REQUIRES:
249 //   'to' is not null.
250 //
251 Status VariantDeviceCopy(
252     const VariantDeviceCopyDirection direction, const Variant& from,
253     Variant* to,
254     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
255 
256 // Sets *v_out = unary_op(v).  The variant v must have a registered
257 // UnaryOp function for the given Device.  Returns an Internal error
258 // if v does not have a registered unary_op function for this device, or if
259 // UnaryOp fails.
260 //
261 // REQUIRES:
262 //   v_out is not null.
263 //
264 template <typename Device>
UnaryOpVariant(OpKernelContext * ctx,VariantUnaryOp op,const Variant & v,Variant * v_out)265 Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
266                       Variant* v_out) {
267   const string& device = DeviceName<Device>::value;
268   UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
269       UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
270   if (unary_op_fn == nullptr) {
271     return errors::Internal(
272         "No unary variant unary_op function found for unary variant op enum: ",
273         op, " Variant type_name: ", v.TypeName(), " for device type: ", device);
274   }
275   return (*unary_op_fn)(ctx, v, v_out);
276 }
277 
278 // Sets *out = binary_op(a, b).  The variants a and b must be the same type
279 // and have a registered binary_op function for the given Device.  Returns an
280 // Internal error if a and b are not the same type_name or if
281 // if a does not have a registered op function for this device, or if
282 // BinaryOp fails.
283 //
284 // REQUIRES:
285 //   out is not null.
286 //
287 template <typename Device>
BinaryOpVariants(OpKernelContext * ctx,VariantBinaryOp op,const Variant & a,const Variant & b,Variant * out)288 Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
289                         const Variant& a, const Variant& b, Variant* out) {
290   if (a.TypeId() != b.TypeId()) {
291     return errors::Internal(
292         "BianryOpVariants: Variants a and b have different "
293         "type ids.  Type names: '",
294         a.TypeName(), "' vs. '", b.TypeName(), "'");
295   }
296   const string& device = DeviceName<Device>::value;
297   UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
298       UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
299   if (binary_op_fn == nullptr) {
300     return errors::Internal(
301         "No unary variant binary_op function found for binary variant op "
302         "enum: ",
303         op, " Variant type_name: '", a.TypeName(), "' for device type: ",
304         device);
305   }
306   return (*binary_op_fn)(ctx, a, b, out);
307 }
308 
309 namespace variant_op_registry_fn_registration {
310 
311 template <typename T>
312 class UnaryVariantDecodeRegistration {
313  public:
UnaryVariantDecodeRegistration(const string & type_name)314   UnaryVariantDecodeRegistration(const string& type_name) {
315     // The Variant is passed by pointer because it should be
316     // mutable: get below may Decode the variant, which
317     // is a self-mutating behavior.  The variant is not modified in
318     // any other way.
319     UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
320         type_name, [type_name](Variant* v) -> bool {
321           DCHECK_NE(v, nullptr);
322           VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
323           if (t == nullptr) {
324             return false;
325           }
326           Variant decoded = T();
327           VariantTensorData data(std::move(*t));
328           if (!decoded.Decode(std::move(data))) {
329             return false;
330           }
331           std::swap(decoded, *v);
332           return true;
333         });
334   }
335 };
336 
337 template <typename T>
338 class UnaryVariantDeviceCopyRegistration {
339  public:
340   typedef std::function<Status(const T& t, T* t_out,
341                                UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
342       LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const LocalVariantDeviceCopyFn & device_copy_fn)343   UnaryVariantDeviceCopyRegistration(
344       const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
345       const LocalVariantDeviceCopyFn& device_copy_fn) {
346     const string type_index_name = port::MaybeAbiDemangle(type_index.name());
347     UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
348         direction, type_index,
349         [type_index_name, device_copy_fn](
350             const Variant& from, Variant* to,
351             UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
352                 device_copy_tensor_fn) -> Status {
353           DCHECK_NE(to, nullptr);
354           *to = T();
355           if (from.get<T>() == nullptr) {
356             return errors::Internal(
357                 "VariantCopyToGPUFn: Could not access object, type_index: ",
358                 type_index_name);
359           }
360           const T& t = *from.get<T>();
361           T* t_out = to->get<T>();
362           return device_copy_fn(t, t_out, device_copy_tensor_fn);
363         });
364   }
365 };
366 
367 template <typename T>
368 class UnaryVariantUnaryOpRegistration {
369   typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
370       LocalVariantUnaryOpFn;
371 
372  public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op,const string & device,const TypeIndex & type_index,const LocalVariantUnaryOpFn & unary_op_fn)373   UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
374                                   const TypeIndex& type_index,
375                                   const LocalVariantUnaryOpFn& unary_op_fn) {
376     const string type_index_name = port::MaybeAbiDemangle(type_index.name());
377     UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
378         op, device, type_index,
379         [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
380                                        Variant* v_out) -> Status {
381           DCHECK_NE(v_out, nullptr);
382           *v_out = T();
383           if (v.get<T>() == nullptr) {
384             return errors::Internal(
385                 "VariantUnaryOpFn: Could not access object, type_index: ",
386                 type_index_name);
387           }
388           const T& t = *v.get<T>();
389           T* t_out = v_out->get<T>();
390           return unary_op_fn(ctx, t, t_out);
391         });
392   }
393 };
394 
395 template <typename T>
396 class UnaryVariantBinaryOpRegistration {
397   typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
398                                T* out)>
399       LocalVariantBinaryOpFn;
400 
401  public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op,const string & device,const TypeIndex & type_index,const LocalVariantBinaryOpFn & binary_op_fn)402   UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
403                                    const TypeIndex& type_index,
404                                    const LocalVariantBinaryOpFn& binary_op_fn) {
405     const string type_index_name = port::MaybeAbiDemangle(type_index.name());
406     UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
407         op, device, type_index,
408         [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
409                                         const Variant& b,
410                                         Variant* out) -> Status {
411           DCHECK_NE(out, nullptr);
412           *out = T();
413           if (a.get<T>() == nullptr) {
414             return errors::Internal(
415                 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
416                 type_index_name);
417           }
418           if (b.get<T>() == nullptr) {
419             return errors::Internal(
420                 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
421                 type_index_name);
422           }
423           const T& t_a = *a.get<T>();
424           const T& t_b = *b.get<T>();
425           T* t_out = out->get<T>();
426           return binary_op_fn(ctx, t_a, t_b, t_out);
427         });
428   }
429 };
430 
431 };  // namespace variant_op_registry_fn_registration
432 
433 // Register a unary decode variant function for the given type.
434 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
435   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
436 
437 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
438   REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
439 
440 #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)        \
441   static variant_op_registry_fn_registration::UnaryVariantDecodeRegistration< \
442       T>                                                                      \
443       register_unary_variant_op_decoder_fn_##ctr(type_name)
444 
445 // ****** NOTE ******
446 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
447 // ****** NOTE ******
448 //
449 // Register a device copy variant function for the given copy
450 // direction and type; where direction is the enum
451 // VariantDeviceCopyDirection, and the device_copy_fn has signature:
452 //
453 //   Status device_copy_fn(
454 //     const T& t, T* t_out,
455 //     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
456 //
457 // And device_copy_fn calls copier 0 or more times.  For details on
458 // the behavior of the copier function, see the comments at the
459 // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
460 //
461 // Note, the device_copy_fn may choose to keep some tensors
462 // on host, e.g. by assigning to->tensor = from.tensor (assuming
463 // from.tensor is already on host); or by setting
464 //   to->tensor = Tensor(cpu_allocator(), ...)
465 // and manually updating its values.
466 //
467 // If this is the case, the CopyFns for HOST_TO_DEVICE,
468 // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
469 // copies in a consistent manner.  For example, one must always
470 // manually copy any "always on host" tensors in all directions instead of e.g.
471 //   - performing a host-to-host copy in one direction,
472 //   - using the provided copier function in the reverse direction.
473 // Doing the latter will cause program failures.
474 //
475 // ****** NOTE ******
476 // FOR INTERNAL USE ONLY.  IF YOU USE THIS WE MAY BREAK YOUR CODE.
477 // ****** NOTE ******
478 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction,   \
479                                                              device_copy_fn) \
480   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER(          \
481       __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
482 
483 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
484     ctr, T, direction, type_index, device_copy_fn)                        \
485   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ(              \
486       ctr, T, direction, type_index, device_copy_fn)
487 
488 #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
489     ctr, T, direction, type_index, device_copy_fn)                 \
490   static variant_op_registry_fn_registration::                     \
491       UnaryVariantDeviceCopyRegistration<T>                        \
492           register_unary_variant_op_device_copy_fn_##ctr(          \
493               direction, type_index, device_copy_fn)
494 
495 // Register a unary unary_op variant function with the signature:
496 //    Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
497 // to Variants having TypeIndex type_index, for device string device,
498 // for UnaryVariantOp enum op.
499 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T,     \
500                                                  unary_op_function) \
501   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(             \
502       __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
503 
504 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER(       \
505     ctr, op, device, T, type_index, unary_op_function)              \
506   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
507                                                 type_index, unary_op_function)
508 
509 #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(                         \
510     ctr, op, device, T, type_index, unary_op_function)                         \
511   static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
512       T>                                                                       \
513       register_unary_variant_op_decoder_fn_##ctr(op, device, type_index,       \
514                                                  unary_op_function)
515 
516 // Register a binary_op variant function with the signature:
517 //    Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
518 // to Variants having TypeIndex type_index, for device string device,
519 // for BinaryVariantOp enum OP.
520 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T,      \
521                                                   binary_op_function) \
522   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER(              \
523       __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
524 
525 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
526     ctr, op, device, T, type_index, binary_op_function)        \
527   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(              \
528       ctr, op, device, T, type_index, binary_op_function)
529 
530 #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ(                      \
531     ctr, op, device, T, type_index, binary_op_function)                      \
532   static variant_op_registry_fn_registration::                               \
533       UnaryVariantBinaryOpRegistration<T>                                    \
534           register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
535                                                      binary_op_function)
536 
537 }  // end namespace tensorflow
538 
539 #endif  // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
540