1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 
18 #include "tensorflow/core/framework/register_types.h"
19 #include "tensorflow/core/framework/type_index.h"
20 #include "tensorflow/core/framework/variant.h"
21 #include "tensorflow/core/framework/variant_op_registry.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 #include "tensorflow/core/public/version.h"
24 
25 namespace tensorflow {
26 
PersistentStringStorage()27 std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
28   static std::unordered_set<string>* string_storage =
29       new std::unordered_set<string>();
30   return string_storage;
31 }
32 
33 // static
Global()34 UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
35   static UnaryVariantOpRegistry* global_unary_variant_op_registry =
36       new UnaryVariantOpRegistry;
37   return global_unary_variant_op_registry;
38 }
39 
GetShapeFn(StringPiece type_name)40 UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
41     StringPiece type_name) {
42   auto found = shape_fns.find(type_name);
43   if (found == shape_fns.end()) return nullptr;
44   return &found->second;
45 }
46 
RegisterShapeFn(const string & type_name,const VariantShapeFn & shape_fn)47 void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
48                                              const VariantShapeFn& shape_fn) {
49   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
50   VariantShapeFn* existing = GetShapeFn(type_name);
51   CHECK_EQ(existing, nullptr)
52       << "Unary VariantShapeFn for type_name: " << type_name
53       << " already registered";
54   shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
55       GetPersistentStringPiece(type_name), shape_fn));
56 }
57 
GetUnaryVariantShape(const Tensor & variant_tensor,TensorShape * shape)58 Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
59   CHECK_EQ(variant_tensor.dtype(), DT_VARIANT);
60   CHECK_EQ(variant_tensor.dims(), 0);
61   const Variant& v = variant_tensor.scalar<Variant>()();
62   UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
63       UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
64   if (shape_fn == nullptr) {
65     return errors::Internal(
66         "No unary variant shape function found for Variant type_name: ",
67         v.TypeName());
68   }
69   return (*shape_fn)(v, shape);
70 }
71 
72 // Add some basic registrations for use by others, e.g., for testing.
73 namespace {
74 template <typename T>
ScalarShape(const T &,TensorShape * shape)75 Status ScalarShape(const T&, TensorShape* shape) {
76   *shape = TensorShape({});
77   return Status::OK();
78 }
79 }  // namespace
80 
81 #define REGISTER_VARIANT_SHAPE_TYPE(T) \
82   REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
83 
84 // No encode/shape registered for std::complex<> and Eigen::half
85 // objects yet.
86 REGISTER_VARIANT_SHAPE_TYPE(int);
87 REGISTER_VARIANT_SHAPE_TYPE(float);
88 REGISTER_VARIANT_SHAPE_TYPE(bool);
89 REGISTER_VARIANT_SHAPE_TYPE(double);
90 
91 #undef REGISTER_VARIANT_SHAPE_TYPE
92 
GetDecodeFn(StringPiece type_name)93 UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
94     StringPiece type_name) {
95   auto found = decode_fns.find(type_name);
96   if (found == decode_fns.end()) return nullptr;
97   return &found->second;
98 }
99 
RegisterDecodeFn(const string & type_name,const VariantDecodeFn & decode_fn)100 void UnaryVariantOpRegistry::RegisterDecodeFn(
101     const string& type_name, const VariantDecodeFn& decode_fn) {
102   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
103   VariantDecodeFn* existing = GetDecodeFn(type_name);
104   CHECK_EQ(existing, nullptr)
105       << "Unary VariantDecodeFn for type_name: " << type_name
106       << " already registered";
107   decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
108       GetPersistentStringPiece(type_name), decode_fn));
109 }
110 
DecodeUnaryVariant(Variant * variant)111 bool DecodeUnaryVariant(Variant* variant) {
112   UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
113       UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
114   if (decode_fn == nullptr) {
115     return false;
116   }
117   const string type_name = variant->TypeName();
118   bool decoded = (*decode_fn)(variant);
119   if (!decoded) return false;
120   if (variant->TypeName() != type_name) {
121     LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
122                << type_name
123                << " but after decoding was: " << variant->TypeName()
124                << ".  Treating this as a failure.";
125     return false;
126   }
127   return true;
128 }
129 
130 // Add some basic registrations for use by others, e.g., for testing.
131 
132 #define REGISTER_VARIANT_DECODE_TYPE(T) \
133   REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
134 
135 // No encode/decode registered for std::complex<> and Eigen::half
136 // objects yet.
137 REGISTER_VARIANT_DECODE_TYPE(int);
138 REGISTER_VARIANT_DECODE_TYPE(float);
139 REGISTER_VARIANT_DECODE_TYPE(bool);
140 REGISTER_VARIANT_DECODE_TYPE(double);
141 
142 #undef REGISTER_VARIANT_DECODE_TYPE
143 
144 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
GetDeviceCopyFn(const VariantDeviceCopyDirection direction,StringPiece type_name)145 UnaryVariantOpRegistry::GetDeviceCopyFn(
146     const VariantDeviceCopyDirection direction, StringPiece type_name) {
147   auto found = device_copy_fns.find(std::make_pair(direction, type_name));
148   if (found == device_copy_fns.end()) return nullptr;
149   return &found->second;
150 }
151 
RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,const string & type_name,const AsyncVariantDeviceCopyFn & device_copy_fn)152 void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
153     const VariantDeviceCopyDirection direction, const string& type_name,
154     const AsyncVariantDeviceCopyFn& device_copy_fn) {
155   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
156   AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
157   CHECK_EQ(existing, nullptr)
158       << "UnaryVariantDeviceCopy for direction: " << direction
159       << " and type_name: " << type_name << " already registered";
160   device_copy_fns.insert(
161       std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
162                 AsyncVariantDeviceCopyFn>(
163           std::make_pair(direction, GetPersistentStringPiece(type_name)),
164           device_copy_fn));
165 }
166 
VariantDeviceCopy(const VariantDeviceCopyDirection direction,const Variant & from,Variant * to,const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn & copy_fn)167 Status VariantDeviceCopy(
168     const VariantDeviceCopyDirection direction, const Variant& from,
169     Variant* to,
170     const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
171   UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
172       UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
173                                                         from.TypeName());
174   if (device_copy_fn == nullptr) {
175     return errors::Internal(
176         "No unary variant device copy function found for direction: ",
177         direction, " and Variant type_name: ", from.TypeName());
178   }
179   return (*device_copy_fn)(from, to, copy_fn);
180 }
181 
182 // Special casing UnaryOpFn per op and per device.
GetUnaryOpFn(VariantUnaryOp op,StringPiece device,StringPiece type_name)183 UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
184     VariantUnaryOp op, StringPiece device, StringPiece type_name) {
185   auto found = unary_op_fns.find({op, device, type_name});
186   if (found == unary_op_fns.end()) return nullptr;
187   return &found->second;
188 }
189 
RegisterUnaryOpFn(VariantUnaryOp op,const string & device,const string & type_name,const VariantUnaryOpFn & unary_op_fn)190 void UnaryVariantOpRegistry::RegisterUnaryOpFn(
191     VariantUnaryOp op, const string& device, const string& type_name,
192     const VariantUnaryOpFn& unary_op_fn) {
193   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
194   VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
195   CHECK_EQ(existing, nullptr)
196       << "Unary VariantUnaryOpFn for type_name: " << type_name
197       << " already registered for device type: " << device;
198   unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
199       {op, GetPersistentStringPiece(device),
200        GetPersistentStringPiece(type_name)},
201       unary_op_fn));
202 }
203 
204 namespace {
205 template <typename T>
ZerosLikeVariantPrimitiveType(OpKernelContext * ctx,const T & t,T * t_out)206 Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
207                                      T* t_out) {
208   *t_out = T(0);
209   return Status::OK();
210 }
211 }  // namespace
212 
213 #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T)                             \
214   REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
215                                            DEVICE_CPU, T, TF_STR(T),    \
216                                            ZerosLikeVariantPrimitiveType<T>);
217 
218 // No zeros_like registered for std::complex<> or Eigen::half objects yet.
219 REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
220 REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
221 REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
222 REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
223 
224 #undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
225 
226 // Special casing BinaryOpFn per op and per device.
227 UnaryVariantOpRegistry::VariantBinaryOpFn*
GetBinaryOpFn(VariantBinaryOp op,StringPiece device,StringPiece type_name)228 UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
229                                       StringPiece type_name) {
230   auto found = binary_op_fns.find({op, device, type_name});
231   if (found == binary_op_fns.end()) return nullptr;
232   return &found->second;
233 }
234 
RegisterBinaryOpFn(VariantBinaryOp op,const string & device,const string & type_name,const VariantBinaryOpFn & add_fn)235 void UnaryVariantOpRegistry::RegisterBinaryOpFn(
236     VariantBinaryOp op, const string& device, const string& type_name,
237     const VariantBinaryOpFn& add_fn) {
238   CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
239   VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
240   CHECK_EQ(existing, nullptr)
241       << "Unary VariantBinaryOpFn for type_name: " << type_name
242       << " already registered for device type: " << device;
243   binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
244       {op, GetPersistentStringPiece(device),
245        GetPersistentStringPiece(type_name)},
246       add_fn));
247 }
248 
249 namespace {
250 template <typename T>
AddVariantPrimitiveType(OpKernelContext * ctx,const T & a,const T & b,T * out)251 Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
252                                T* out) {
253   *out = a + b;
254   return Status::OK();
255 }
256 }  // namespace
257 
258 #define REGISTER_VARIANT_ADD_TYPE(T)                                           \
259   REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
260                                             T, TF_STR(T),                      \
261                                             AddVariantPrimitiveType<T>);
262 
263 // No add registered for std::complex<> or Eigen::half objects yet.
264 REGISTER_VARIANT_ADD_TYPE(int);
265 REGISTER_VARIANT_ADD_TYPE(float);
266 REGISTER_VARIANT_ADD_TYPE(double);
267 REGISTER_VARIANT_ADD_TYPE(bool);
268 
269 #undef REGISTER_VARIANT_ADD_TYPE
270 
271 }  // namespace tensorflow
272