1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/type_index.h"
20 #include "tensorflow/core/framework/variant.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/public/version.h"
24 
25 namespace tensorflow {
26 
PersistentStringStorage()27 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
28   static std::unordered_set<string>* string_storage =
29       new std::unordered_set<string>();
30   return string_storage;
31 }
32 
33 // static
Global()34 UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
35   static UnaryVariantOpRegistry* global_unary_variant_op_registry =
36       new UnaryVariantOpRegistry;
37   return global_unary_variant_op_registry;
38 }
39 
GetDecodeFn(StringPiece type_name)40 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
41     StringPiece type_name) {
42   auto found = decode_fns.find(type_name);
43   if (found == decode_fns.end()) return nullptr;
44   return &found->second;
45 }
46 
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)47 void UnaryVariantOpRegistry::RegisterDecodeFn(
48     const string& type_name, const VariantDecodeFn& decode_fn) {
49   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
50   VariantDecodeFn* existing = GetDecodeFn(type_name);
51   CHECK_EQ(existing, nullptr)
52       << "Unary VariantDecodeFn for type_name: " << type_name
53       << " already registered";
54   decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
55       GetPersistentStringPiece(type_name), decode_fn));
56 }
57 
DecodeUnaryVariant(Variant * variant)58 bool DecodeUnaryVariant(Variant* variant) {
59   UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
60       UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
61   if (decode_fn == nullptr) {
62     return false;
63   }
64   const string type_name = variant->TypeName();
65   bool decoded = (*decode_fn)(variant);
66   if (!decoded) return false;
67   if (variant->TypeName() != type_name) {
68     LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
69                << type_name
70                << " but after decoding was: " << variant->TypeName()
71                << ".  Treating this as a failure.";
72     return false;
73   }
74   return true;
75 }
76 
77 // Add some basic registrations for use by others, e.g., for testing.
78 
79 #define REGISTER_VARIANT_DECODE_TYPE(T) \
80   REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
81 
82 // No encode/decode registered for std::complex<> and Eigen::half
83 // objects yet.
84 REGISTER_VARIANT_DECODE_TYPE(int);
85 REGISTER_VARIANT_DECODE_TYPE(float);
86 REGISTER_VARIANT_DECODE_TYPE(bool);
87 REGISTER_VARIANT_DECODE_TYPE(double);
88 
89 #undef REGISTER_VARIANT_DECODE_TYPE
90 
91 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index)92 UnaryVariantOpRegistry::GetDeviceCopyFn(
93     const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
94   auto found = device_copy_fns.find(std::make_pair(direction, type_index));
95   if (found == device_copy_fns.end()) return nullptr;
96   return &found->second;
97 }
98 
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const TypeIndex & type_index,const AsyncVariantDeviceCopyFn & device_copy_fn)99 void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
100     const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
101     const AsyncVariantDeviceCopyFn& device_copy_fn) {
102   AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
103   CHECK_EQ(existing, nullptr)
104       << "UnaryVariantDeviceCopy for direction: " << direction
105       << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
106       << " already registered";
107   device_copy_fns.insert(
108       std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
109                 AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
110                                           device_copy_fn));
111 }
112 
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)113 Status VariantDeviceCopy(
114     const VariantDeviceCopyDirection direction, const Variant& from,
115     Variant* to,
116     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
117   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
118       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
119                                                         from.TypeId());
120   if (device_copy_fn == nullptr) {
121     return errors::Internal(
122         "No unary variant device copy function found for direction: ",
123         direction, " and Variant type_index: ",
124         port::MaybeAbiDemangle(from.TypeId().name()));
125   }
126   return (*device_copy_fn)(from, to, copy_fn);
127 }
128 
129 namespace {
130 template <typename T>
DeviceCopyPrimitiveType(const T & in,T * out,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copier)131 Status DeviceCopyPrimitiveType(
132     const T& in, T* out,
133     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
134   // Dummy copy, we don't actually bother copying to the device and back for
135   // testing.
136   *out = in;
137   return Status::OK();
138 }
139 }  // namespace
140 
141 #define REGISTER_VARIANT_DEVICE_COPY_TYPE(T)            \
142   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
143       T, VariantDeviceCopyDirection::HOST_TO_DEVICE,    \
144       DeviceCopyPrimitiveType<T>);                      \
145   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
146       T, VariantDeviceCopyDirection::DEVICE_TO_HOST,    \
147       DeviceCopyPrimitiveType<T>);                      \
148   INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
149       T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,  \
150       DeviceCopyPrimitiveType<T>);
151 
152 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
153 REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
154 REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
155 REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
156 REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
157 
158 #undef REGISTER_VARIANT_DEVICE_COPY_TYPE
159 
160 // Special casing UnaryOpFn per op and per device.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,const TypeIndex & type_index)161 UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
162     VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
163   auto found = unary_op_fns.find({op, device, type_index});
164   if (found == unary_op_fns.end()) return nullptr;
165   return &found->second;
166 }
167 
RegisterUnaryOpFn(VariantUnaryOp op,const string & device,const TypeIndex & type_index,const VariantUnaryOpFn & unary_op_fn)168 void UnaryVariantOpRegistry::RegisterUnaryOpFn(
169     VariantUnaryOp op, const string& device, const TypeIndex& type_index,
170     const VariantUnaryOpFn& unary_op_fn) {
171   VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
172   CHECK_EQ(existing, nullptr)
173       << "Unary VariantUnaryOpFn for type_index: "
174       << port::MaybeAbiDemangle(type_index.name())
175       << " already registered for device type: " << device;
176   unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
177       {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
178 }
179 
180 namespace {
181 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)182 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
183                                      T* t_out) {
184   *t_out = T(0);
185   return Status::OK();
186 }
187 }  // namespace
188 
189 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T)                             \
190   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
191                                            DEVICE_CPU, T,               \
192                                            ZerosLikeVariantPrimitiveType<T>);
193 
194 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
195 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
196 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
197 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
198 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
199 
200 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
201 
202 // Special casing BinaryOpFn per op and per device.
203 UnaryVariantOpRegistry::VariantBinaryOpFn*
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,const TypeIndex & type_index)204 UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
205                                       const TypeIndex& type_index) {
206   auto found = binary_op_fns.find({op, device, type_index});
207   if (found == binary_op_fns.end()) return nullptr;
208   return &found->second;
209 }
210 
RegisterBinaryOpFn(VariantBinaryOp op,const string & device,const TypeIndex & type_index,const VariantBinaryOpFn & add_fn)211 void UnaryVariantOpRegistry::RegisterBinaryOpFn(
212     VariantBinaryOp op, const string& device, const TypeIndex& type_index,
213     const VariantBinaryOpFn& add_fn) {
214   VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
215   CHECK_EQ(existing, nullptr)
216       << "Unary VariantBinaryOpFn for type_index: "
217       << port::MaybeAbiDemangle(type_index.name())
218       << " already registered for device type: " << device;
219   binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
220       {op, GetPersistentStringPiece(device), type_index}, add_fn));
221 }
222 
223 namespace {
224 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)225 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
226                                T* out) {
227   *out = a + b;
228   return Status::OK();
229 }
230 }  // namespace
231 
232 #define REGISTER_VARIANT_ADD_TYPE(T)                                           \
233   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
234                                             T, AddVariantPrimitiveType<T>);
235 
236 // No add registered for std::complex<> or Eigen::half objects yet.
237 REGISTER_VARIANT_ADD_TYPE(int);
238 REGISTER_VARIANT_ADD_TYPE(float);
239 REGISTER_VARIANT_ADD_TYPE(double);
240 REGISTER_VARIANT_ADD_TYPE(bool);
241 
242 #undef REGISTER_VARIANT_ADD_TYPE
243 
244 }  // namespace tensorflow
245