1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/c_api_internal.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/runtime.h"
28 #ifdef TENSORFLOW_EAGER_USE_XLA
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #endif  // TENSORFLOW_EAGER_USE_XLA
31 #include "tensorflow/core/common_runtime/copy_tensor.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/device_mgr.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/framework/rendezvous.h"
37 #include "tensorflow/core/framework/tensor_shape.pb.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/refcount.h"
40 #include "tensorflow/core/lib/gtl/flatmap.h"
41 #include "tensorflow/core/lib/gtl/map_util.h"
42 #include "tensorflow/core/lib/gtl/stl_util.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 #include "tensorflow/core/public/version.h"
46 
47 using tensorflow::int64;
48 using tensorflow::string;
49 
50 namespace {
IsCPU(const tensorflow::Device * d)51 bool IsCPU(const tensorflow::Device* d) {
52   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
53 }
54 
IsXLA(const tensorflow::Device * d)55 bool IsXLA(const tensorflow::Device* d) {
56   if (d == nullptr) return false;
57   const auto& device_type = d->attributes().device_type();
58   return device_type.find("XLA") != std::string::npos;
59 }
60 
DeviceName(const tensorflow::Device * d)61 string DeviceName(const tensorflow::Device* d) {
62   return (d == nullptr) ? "cpu:0" : d->name();
63 }
64 
65 #ifdef TENSORFLOW_EAGER_USE_XLA
66 std::atomic_int_fast64_t func_id_generator(0);
67 #endif  // TENSORFLOW_EAGER_USE_XLA
68 }  // namespace
69 
70 extern "C" {
71 
TFE_NewContextOptions()72 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
73 
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)74 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
75                                  size_t proto_len, TF_Status* status) {
76   TF_SetConfig(&options->session_options, proto, proto_len, status);
77 }
78 
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)79 void TFE_ContextOptionsSetDevicePlacementPolicy(
80     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
81   options->policy = policy;
82 }
83 
TFE_DeleteContextOptions(TFE_ContextOptions * options)84 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
85 
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)86 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
87   TF_Graph* graph = TF_NewGraph();
88   TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
89   if (status->status.ok()) {
90     if (session->device_mgr == nullptr || session->devices.empty()) {
91       status->status = tensorflow::errors::InvalidArgument(
92           "Provided TF_SessionOptions are not compatible with eager execution "
93           "(perhaps the TF_SessionOptions alluded to session execution in a "
94           "remote address space?)");
95     }
96   }
97   if (!status->status.ok()) {
98     TF_DeleteGraph(graph);
99     return nullptr;
100   }
101 
102   return new TFE_Context(*opts, session);
103 }
104 
TFE_DeleteContext(TFE_Context * ctx,TF_Status * status)105 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
106   status->status = tensorflow::Status::OK();
107   {
108     tensorflow::mutex_lock ml(ctx->cache_mu);
109     tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
110   }
111   TF_Graph* graph = ctx->session->graph;
112   TF_DeleteSession(ctx->session, status);
113   TF_DeleteGraph(graph);
114   ctx->rendezvous->Unref();
115   delete ctx;
116 }
117 
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)118 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
119   return TF_SessionListDevices(ctx->session, status);
120 }
121 
TFE_ContextClearCaches(TFE_Context * ctx)122 void TFE_ContextClearCaches(TFE_Context* ctx) {
123   tensorflow::mutex_lock ml(ctx->cache_mu);
124   tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
125 }
126 
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)127 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
128     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
129   tensorflow::mutex_lock ml(ctx->policy_map_mu);
130   ctx->thread_local_policies[std::this_thread::get_id()] = policy;
131 }
132 
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)133 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
134     TFE_Context* ctx) {
135   tensorflow::mutex_lock ml(ctx->policy_map_mu);
136   auto policy_map_it =
137       ctx->thread_local_policies.find(std::this_thread::get_id());
138   if (policy_map_it != ctx->thread_local_policies.end()) {
139     return policy_map_it->second;
140   }
141   return ctx->policy;
142 }
143 
TFE_NewTensorHandle(TF_Tensor * t,TF_Status * status)144 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
145   tensorflow::Tensor tensor;
146   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
147   if (!status->status.ok()) return nullptr;
148   return new TFE_TensorHandle(tensor, nullptr);
149 }
150 
TFE_DeleteTensorHandle(TFE_TensorHandle * h)151 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
152 
TFE_TensorHandleDataType(TFE_TensorHandle * h)153 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
154   return static_cast<TF_DataType>(h->t.dtype());
155 }
156 
TFE_TensorHandleNumDims(TFE_TensorHandle * h)157 int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
158 
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index)159 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
160   return h->t.dim_size(dim_index);
161 }
162 
TFE_TensorHandleDeviceName(TFE_TensorHandle * h)163 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
164   // TODO(apassos) this will be potentially incorrect in the distributed case as
165   // our local device will have a name which depends on the ClusterSpec and
166   // hence will require the context to resolve.
167   return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
168                            : h->d->name().c_str();
169 }
170 
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)171 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
172   if (!IsCPU(h->d)) {
173     TF_SetStatus(status, TF_UNIMPLEMENTED,
174                  tensorflow::strings::StrCat(
175                      "TFE_TensorHandle can be resolved iff it is on CPU (this "
176                      "handle is on ",
177                      h->d->name(),
178                      "). Consider using TFE_TensorHandleCopyToDevice to get a "
179                      "copy of the tensor on CPU")
180                      .c_str());
181     return nullptr;
182   }
183   return tensorflow::TF_TensorFromTensor(h->t, status);
184 }
185 
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)186 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
187                                                TFE_Context* ctx,
188                                                const char* device_name,
189                                                TF_Status* status) {
190   tensorflow::Device* dstd = ctx->devices()[0];
191   if (device_name != nullptr && strlen(device_name) > 0) {
192     status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
193     if (!status->status.ok()) return nullptr;
194   }
195 
196   tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
197   bool is_same_device =
198       (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
199   const bool dst_cpu = IsCPU(dstd);
200   const bool src_cpu = IsCPU(srcd);
201   // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
202   // has device type XLA_CPU, and the other CPU.
203   const bool both_on_cpu = src_cpu && dst_cpu;
204   if (is_same_device || both_on_cpu) {
205     return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
206   }
207   tensorflow::Tensor* src = &(h->t);
208   if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
209                    !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
210     TF_SetStatus(
211         status, TF_INVALID_ARGUMENT,
212         tensorflow::strings::StrCat("Can't copy Tensor with type ",
213                                     tensorflow::DataTypeString(src->dtype()),
214                                     " to device ", DeviceName(dstd), ".")
215             .c_str());
216     return nullptr;
217   }
218   tensorflow::AllocatorAttributes attr;
219   if (src->dtype() == tensorflow::DT_VARIANT) {
220     attr.set_on_host(true);
221   }
222   tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
223   if (src->shape().num_elements() == 0) {
224     return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
225   }
226   tensorflow::DeviceContext* src_device_context = nullptr;
227   if (!src_cpu) {
228     src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
229   }
230   tensorflow::DeviceContext* dst_device_context = nullptr;
231   if (!dst_cpu) {
232     dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
233   }
234   // TODO(ashankar): The Sync() call below may be more aggressive than
235   // necessary. It is based on knowledge of implementation details - that
236   // GPU devices are implemented using 3 streams - one for host->device copies,
237   // one for device->host copies and one for sending operations to the GPU.
238   // With that setup, Sync()ing across all 3 streams should be sufficient
239   // but more than necessary (since it waits for operations that might have
240   // nothing to do with this tensor to complete).
241   status->status = srcd->Sync();
242   tensorflow::Notification n;
243   tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
244                                  srcd, dstd, tensorflow::AllocatorAttributes(),
245                                  tensorflow::AllocatorAttributes(), src, &dst,
246                                  [status, &n](const tensorflow::Status& s) {
247                                    status->status = s;
248                                    n.Notify();
249                                  });
250   n.WaitForNotification();
251   return (TF_GetCode(status) == TF_OK)
252              ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
253              : nullptr;
254 }
255 
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)256 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
257                   TF_Status* status) {
258   const char* name = op_or_function_name;  // Shorthand
259   const tensorflow::AttrTypeMap* types;
260   status->status = tensorflow::AttrTypeMapForOp(name, &types);
261   if (status->status.ok()) return new TFE_Op(ctx, name, types);
262   if (TF_GetCode(status) == TF_NOT_FOUND) {
263     tensorflow::mutex_lock l(ctx->functions_mu);
264     if (ctx->func_lib_def.Find(name) != nullptr) {
265       status->status = tensorflow::Status::OK();
266       return new TFE_Op(ctx, name, nullptr);
267     }
268   }
269   return nullptr;
270 }
271 
TFE_DeleteOp(TFE_Op * op)272 void TFE_DeleteOp(TFE_Op* op) { delete op; }
273 
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)274 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
275   tensorflow::Device* d = nullptr;
276   if (device_name != nullptr && strlen(device_name) > 0) {
277     status->status =
278         op->ctx->session->device_mgr->LookupDevice(device_name, &d);
279     if (!status->status.ok()) return;
280   }
281   op->device = d;
282 }
283 
TFE_OpGetDevice(TFE_Op * op,TF_Status * status)284 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
285   tensorflow::Device* device =
286       (op->device == nullptr) ? op->ctx->devices()[0] : op->device;
287   return device->name().c_str();
288 }
289 
TFE_OpSetXLACompilation(TFE_Op * op,unsigned char enable)290 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
291   op->use_xla = enable;
292 #ifndef TENSORFLOW_EAGER_USE_XLA
293   LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
294                   "built with XLA support.";
295 #endif  // TENSORFLOW_EAGER_USE_XLA
296 }
297 
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * h,TF_Status * status)298 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
299   // Questionable heuristic ...
300   //
301   // Motivation: After an 'op' is placed on GPU because some of its earlier
302   // inputs are on GPU, we want to keep the 'op' there, even if some later
303   // inputs of it are not on GPU.
304   if (IsCPU(op->device) && !IsCPU(h->d)) {
305     op->device = h->d;
306   }
307   if (!status->status.ok()) return;
308   op->inputs.push_back(h->t);
309   op->input_devices.push_back(h->d);
310   op->attrs.NumInputs(op->inputs.size());
311 }
312 
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)313 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
314                               unsigned char* is_list, TF_Status* status) {
315   TF_AttrType ret;
316   if (op->is_function()) {
317     status->status = tensorflow::errors::Unimplemented(
318         "TODO(apassos): Support for attributes for TensorFlow functions is not "
319         "ready yet.");
320     return TF_ATTR_INT;  // The compiler requires that we return something.
321   }
322   status->status =
323       tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
324   return ret;
325 }
326 
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)327 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
328                                   const char* op_or_function_name,
329                                   const char* attr_name, unsigned char* is_list,
330                                   TF_Status* status) {
331   TF_AttrType ret;
332   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
333   if (!status->status.ok()) {
334     return TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
335   }
336   ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
337   TFE_DeleteOp(op);
338   return ret;
339 }
340 
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const char * value)341 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
342   op->attrs.Set(attr_name, value);
343 }
344 
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)345 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
346   op->attrs.Set(attr_name, static_cast<int64>(value));
347 }
348 
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)349 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
350   op->attrs.Set(attr_name, value);
351 }
352 
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)353 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
354   op->attrs.Set(attr_name, (value == 0) ? false : true);
355 }
356 
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)357 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
358   op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
359 }
360 
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)361 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
362                         const int num_dims, TF_Status* out_status) {
363   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
364     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
365                  tensorflow::strings::StrCat(
366                      "Value specified for `", attr_name, "` has ", num_dims,
367                      " dimensions which is over the limit of ",
368                      tensorflow::TensorShape::MaxDimensions(), ".")
369                      .c_str());
370     return;
371   }
372   tensorflow::TensorShapeProto proto;
373   if (num_dims < 0) {
374     proto.set_unknown_rank(true);
375   } else {
376     for (int d = 0; d < num_dims; ++d) {
377       proto.add_dim()->set_size(dims[d]);
378     }
379   }
380   op->attrs.Set(attr_name, proto);
381 }
382 
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)383 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
384                            const TFE_Op* value) {
385   tensorflow::AttrValue attr_value;
386   tensorflow::NameAttrList* func = attr_value.mutable_func();
387   func->set_name(value->name);
388   value->attrs.FillAttrValueMap(func->mutable_attr());
389   op->attrs.Set(attr_name, attr_value);
390 }
391 
392 #define TFE_OP_SET_ATTR_LIST(fn, type)                                \
393   void fn(TFE_Op* op, const char* attr_name, const type* values,      \
394           int num_values) {                                           \
395     op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
396                                  values, num_values));                \
397   }
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList,char *)398 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
399 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
400 #undef TFE_OP_SET_ATTR_LIST
401 
402 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
403                           const int64_t* values, int num_values) {
404   op->attrs.Set(attr_name,
405                 tensorflow::gtl::ArraySlice<const int64>(
406                     reinterpret_cast<const int64*>(values), num_values));
407 }
408 
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)409 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
410                            const TF_DataType* values, int num_values) {
411   op->attrs.Set(
412       attr_name,
413       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
414           reinterpret_cast<const tensorflow::DataType*>(values), num_values));
415 }
416 
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)417 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
418                            const unsigned char* values, int num_values) {
419   std::unique_ptr<bool[]> b(new bool[num_values]);
420   for (int i = 0; i < num_values; ++i) {
421     b[i] = values[i];
422   }
423   op->attrs.Set(attr_name,
424                 tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
425 }
426 
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)427 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
428                             const int64_t** dims, const int* num_dims,
429                             int num_values, TF_Status* out_status) {
430   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
431       new tensorflow::TensorShapeProto[num_values]);
432   for (int i = 0; i < num_values; ++i) {
433     const auto num_dims_i = num_dims[i];
434 
435     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
436       TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
437                    tensorflow::strings::StrCat(
438                        "Value specified for `", attr_name, "` has ", num_dims_i,
439                        " dimensions which is over the limit of ",
440                        tensorflow::TensorShape::MaxDimensions(), ".")
441                        .c_str());
442       return;
443     }
444     if (num_dims_i < 0) {
445       proto[i].set_unknown_rank(true);
446     } else {
447       const int64_t* dims_i = dims[i];
448       auto proto_i = &proto[i];
449       for (int d = 0; d < num_dims_i; ++d) {
450         proto_i->add_dim()->set_size(dims_i[d]);
451       }
452     }
453   }
454   op->attrs.Set(attr_name,
455                 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
456                     proto.get(), num_values));
457 }
458 
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)459 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
460                                const TFE_Op** value, int num_values) {
461   std::unique_ptr<tensorflow::NameAttrList[]> funcs(
462       new tensorflow::NameAttrList[num_values]);
463   for (int i = 0; i < num_values; i++) {
464     funcs[i].set_name(value[i]->name);
465     value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
466   }
467   op->attrs.Set(attr_name,
468                 tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
469                     funcs.get(), num_values));
470 }
471 
472 namespace {
473 
ValidateInputTypeAndPlacement(TFE_Context * ctx,tensorflow::Device * host_device,tensorflow::Device * op_device,TFE_Op * op,const tensorflow::OpKernel * kernel,std::vector<TFE_TensorHandle * > * copied_tensors)474 tensorflow::Status ValidateInputTypeAndPlacement(
475     TFE_Context* ctx, tensorflow::Device* host_device,
476     tensorflow::Device* op_device, TFE_Op* op,
477     const tensorflow::OpKernel* kernel,
478     std::vector<TFE_TensorHandle*>* copied_tensors) {
479   const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
480   if (memtypes.size() != op->inputs.size()) {
481     return tensorflow::errors::InvalidArgument(
482         "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
483   }
484   for (int i = 0; i < op->inputs.size(); ++i) {
485     const tensorflow::Device* expected_device =
486         memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
487     const tensorflow::Device* actual_device =
488         op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
489     if (expected_device != actual_device) {
490       switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
491         case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
492           // TODO(xpan): See if we could bubble python related error up
493           // to python level.
494           if (op->inputs[i].dtype() == tensorflow::DT_INT32) {
495             // Note: enabling silent copies of int32 tensors to match behavior
496             // of graph mode.
497             break;
498           }
499           TF_FALLTHROUGH_INTENDED;
500         case TFE_DEVICE_PLACEMENT_EXPLICIT:
501           return tensorflow::errors::InvalidArgument(
502               "Tensors on conflicting devices:"
503               " cannot compute ",
504               op->name, " as input #", i, " was expected to be on ",
505               expected_device->name(), " but is actually on ",
506               actual_device->name(), " (operation running on ",
507               op_device->name(), ")",
508               " Tensors can be copied explicitly using .gpu() or .cpu(),"
509               " or transparently copied by using tfe.enable_eager_execution("
510               "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
511               " may slow down your model");
512         case TFE_DEVICE_PLACEMENT_WARN:
513           LOG(WARNING) << "before computing " << op->name << " input #" << i
514                        << " was expected to be on " << expected_device->name()
515                        << " but is actually on " << actual_device->name()
516                        << " (operation running on " << op_device->name()
517                        << "). This triggers a copy which can be a performance "
518                           "bottleneck.";
519           break;
520         case TFE_DEVICE_PLACEMENT_SILENT:  // Do nothing.
521           break;
522       }
523       // We are only here if the policy is warn or silent copies, so we should
524       // trigger a copy.
525       TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
526       TF_Status* s = TF_NewStatus();
527       TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
528           &original, ctx, expected_device->name().c_str(), s);
529       if (!s->status.ok()) {
530         tensorflow::Status status = s->status;
531         delete s;
532         return tensorflow::errors::Internal(
533             "Failed copying input tensor from ", actual_device->name(), " to ",
534             expected_device->name(), " in order to run ", op->name, ": ",
535             status.error_message());
536       }
537       op->inputs[i] = copied_tensor->t;
538       copied_tensors->push_back(copied_tensor);
539       op->input_devices[i] = copied_tensor->d;
540       delete s;
541     }
542     if (op->inputs[i].dtype() != kernel->input_type(i)) {
543       return tensorflow::errors::InvalidArgument(
544           "cannot compute ", op->name, " as input #", i,
545           " was expected to be a ",
546           tensorflow::DataTypeString(kernel->input_type(i)),
547           " tensor but is a ",
548           tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor");
549     }
550   }
551   return tensorflow::Status::OK();
552 }
553 
554 #ifdef TENSORFLOW_EAGER_USE_XLA
555 // Synthesizes and returns a wrapper function over `op`, which must be a
556 // primitive op (e.g. matmul).
557 //
558 // The wrapper function conforms to the function signature expected by
559 // _XlaLaunchOp, with input params ordered by <constants, (variable) args and
560 // resources>. For example, if the op has input params <Const1, Arg2, Const3,
561 // Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
562 // Resource4> as the input params to the synthesized function.
563 //
564 // It populates `const_input_types`, `arg_input_types` and
565 // `op_input_to_func_input` based on the reordering results, that the caller can
566 // use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
567 // `status` accordingly.
OpToFunction(TFE_Op * op,std::vector<TF_DataType> * const_input_types,std::vector<TF_DataType> * arg_input_types,tensorflow::gtl::FlatMap<int,int> * op_input_to_func_input,TF_Status * status)568 const tensorflow::FunctionDef* OpToFunction(
569     TFE_Op* op, std::vector<TF_DataType>* const_input_types,
570     std::vector<TF_DataType>* arg_input_types,
571     tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
572     TF_Status* status) {
573   DCHECK(!op->is_function());
574 
575   tensorflow::FunctionDef fdef;
576 
577   // Get the OpDef of the op we are trying to encapsulate.
578   TFE_Context* ctx = op->ctx;
579   const tensorflow::OpRegistrationData* op_data;
580   {
581     tensorflow::tf_shared_lock l(ctx->functions_mu);
582     status->status = ctx->func_lib_def.LookUp(op->name, &op_data);
583     if (!status->status.ok()) {
584       return nullptr;
585     }
586   }
587   const tensorflow::OpDef& op_def = op_data->op_def;
588 
589   tensorflow::OpDef* signature = fdef.mutable_signature();
590 
591   // Handle constant inputs.
592   const std::unordered_set<string> const_inputs(
593       *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
594 
595   // First add place holders for the input args, so that we can refer to them by
596   // position in the next loop. Also tally up the resource inputs.
597   int num_resource_inputs = 0;
598   for (int i = 0; i < op_def.input_arg_size(); ++i) {
599     if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
600       ++num_resource_inputs;
601     }
602     signature->add_input_arg();
603   }
604 
605   // Now we map the input params from `op_def` to `signature`, where the param
606   // ordering for `signature` is: <constants, args, resources>.
607   int const_index = 0;
608   int arg_index = const_inputs.size();
609   int resource_index = op_def.input_arg_size() - num_resource_inputs;
610   for (int i = 0; i < op_def.input_arg_size(); ++i) {
611     const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
612     tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
613     if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
614       VLOG(1) << "For const input, mapping op input " << i << " to func input "
615               << const_index;
616       (*op_input_to_func_input)[i] = const_index;
617       func_input_arg = signature->mutable_input_arg(const_index++);
618       const_input_types->push_back(
619           static_cast<TF_DataType>(op->inputs[i].dtype()));
620     } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
621       VLOG(1) << "For resource input, mapping op input " << i
622               << " to func input " << resource_index;
623       (*op_input_to_func_input)[i] = resource_index;
624       func_input_arg = signature->mutable_input_arg(resource_index++);
625     } else {
626       VLOG(1) << "For arg input, mapping op input " << i << " to func input "
627               << arg_index;
628       (*op_input_to_func_input)[i] = arg_index;
629       func_input_arg = signature->mutable_input_arg(arg_index++);
630       arg_input_types->push_back(
631           static_cast<TF_DataType>(op->inputs[i].dtype()));
632     }
633 
634     func_input_arg->set_name(op_input_arg.name());
635     func_input_arg->set_type(op->inputs[i].dtype());
636   }
637   VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
638 
639   // Resources args are at the end of the function input params, and we should
640   // have iterated over all of them.
641   DCHECK_EQ(signature->input_arg_size(), resource_index);
642 
643   // Make the synthesized function's name unique.
644   signature->set_name(tensorflow::strings::StrCat(
645       op_def.name(), func_id_generator.fetch_add(1)));
646 
647   // Add the node def and set its input names to match op_def's names.
648   const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
649   DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
650   *fdef.add_node_def() = ndef;
651   for (int i = 0; i < op_def.input_arg_size(); ++i) {
652     fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
653   }
654   VLOG(1) << "Added NodeDef: " << fdef.DebugString();
655 
656   // Fix the output names and set output types.
657   for (int i = 0; i < op_def.output_arg_size(); ++i) {
658     tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
659     const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
660     const string& out_tensor_name = tensorflow::strings::StrCat(
661         ndef.name(), ":", op_def_arg.name(), ":", 0);
662     arg->set_name(op_def_arg.name());
663     (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
664     const string& type_attr = op_def_arg.type_attr();
665     if (!type_attr.empty()) {
666       auto i = ndef.attr().find(type_attr);
667       if (i == ndef.attr().end()) {
668         status->status = tensorflow::errors::InvalidArgument(
669             tensorflow::strings::StrCat("Could not find attr ", type_attr,
670                                         " in NodeDef ", ndef.DebugString()));
671         return nullptr;
672       }
673       arg->set_type(i->second.type());
674     }
675   }
676   VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
677 
678   tensorflow::mutex_lock l(ctx->functions_mu);
679   status->status = ctx->func_lib_def.AddFunctionDef(fdef);
680   if (!status->status.ok()) return nullptr;
681   const auto ret = ctx->func_lib_def.Find(signature->name());
682   DCHECK(ret != nullptr);
683   return ret;
684 }
685 
686 // Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
687 // via XLA.
BuildXlaLaunch(TFE_Op * op,TF_Status * status)688 std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
689   VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
690   auto launch_op =
691       std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
692   if (TF_GetCode(status) != TF_OK) return nullptr;
693   if (op->device) {
694     TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
695     if (TF_GetCode(status) != TF_OK) return nullptr;
696   }
697 
698   const tensorflow::FunctionDef* fdef;
699   {
700     tensorflow::tf_shared_lock l(op->ctx->functions_mu);
701     fdef = op->ctx->func_lib_def.Find(op->name);
702   }
703   std::vector<TF_DataType> const_input_types;
704   std::vector<TF_DataType> arg_input_types;
705   tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
706   if (fdef == nullptr) {
707     // See if this is a primitive op, and if so create a function for it, so
708     // that _XlaLaunchOp can access it.
709     fdef = OpToFunction(op, &const_input_types, &arg_input_types,
710                         &op_input_to_func_input, status);
711     if (!status->status.ok()) return nullptr;
712   } else {
713     // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
714     // functions, so we need to find another way to handle constant inputs.
715     for (int i = const_input_types.size();
716          i < fdef->signature().input_arg_size(); ++i) {
717       VLOG(1) << "Adding Targs from input arg " << i;
718       const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
719       arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
720     }
721   }
722   DCHECK(fdef != nullptr);
723 
724   // Copy inputs and their devices.
725   // Since input param reordering may have occurred between `op` and `launch_op`
726   // via `op_input_to_func_input`, adjust the actual inputs accordingly.
727   launch_op->inputs = op->inputs;
728   launch_op->input_devices = op->input_devices;
729   if (!op_input_to_func_input.empty()) {
730     DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
731     if (!op->input_devices.empty()) {
732       DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
733     }
734     for (int i = 0; i < op_input_to_func_input.size(); ++i) {
735       VLOG(1) << "mapping op input " << i << " to func input "
736               << op_input_to_func_input[i];
737 
738       launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
739       if (!op->input_devices.empty()) {
740         launch_op->input_devices[op_input_to_func_input[i]] =
741             op->input_devices[i];
742       }
743     }
744   }
745   launch_op->attrs.NumInputs(op->inputs.size());
746 
747   TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
748                         const_input_types.size());
749 
750   // Set Targs and Nresources attrs.
751   TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
752                         arg_input_types.size());
753   const int num_resource_inputs = fdef->signature().input_arg_size() -
754                                   const_input_types.size() -
755                                   arg_input_types.size();
756   TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
757 
758   // Set Tresults attr.
759   std::vector<TF_DataType> tresults;
760   for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
761     tresults.push_back(static_cast<TF_DataType>(arg.type()));
762   }
763   TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
764                         tresults.size());
765 
766   // Set function attr.
767   tensorflow::AttrValue attr_value;
768   tensorflow::NameAttrList* func = attr_value.mutable_func();
769   func->set_name(fdef->signature().name());
770   launch_op->attrs.Set("function", attr_value);
771 
772   return launch_op;
773 }
774 #endif  // TENSORFLOW_EAGER_USE_XLA
775 }  // namespace
776 
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)777 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
778                  TF_Status* status) {
779   TFE_Context* ctx = op->ctx;
780   // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
781   tensorflow::Device* device =
782       (op->device == nullptr) ? ctx->devices()[0] : op->device;
783 
784 #ifdef TENSORFLOW_EAGER_USE_XLA
785   std::unique_ptr<TFE_Op> xla_launch_op;
786   if (op->use_xla && op->name != "_XlaLaunch") {
787     xla_launch_op = BuildXlaLaunch(op, status);
788     if (!status->status.ok()) {
789       return;
790     }
791     op = xla_launch_op.get();
792   }
793 #endif  // TENSORFLOW_EAGER_USE_XLA
794 
795   std::vector<tensorflow::Tensor> outputs(1);
796   const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
797   tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
798   tensorflow::KernelAndDevice* kernel;
799   {
800     tensorflow::tf_shared_lock l(ctx->cache_mu);
801     kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
802   }
803   if (kernel == nullptr) {
804     const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
805     kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
806     // Knowledge of the implementation of Init (and in-turn
807     // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
808     // will be accessed, so grab on to the lock.
809     // See WARNING comment below - would be nice to rework to avoid this
810     // subtlety.
811     tensorflow::tf_shared_lock l(ctx->functions_mu);
812     status->status =
813         tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
814     if (!status->status.ok()) {
815       delete kernel;
816       return;
817     }
818     tensorflow::mutex_lock ml(ctx->cache_mu);
819     tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
820   }
821   std::vector<TFE_TensorHandle*> copied_tensors;
822   status->status = ValidateInputTypeAndPlacement(
823       ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
824   output_memory_types = &kernel->kernel()->output_memory_types();
825   if (!status->status.ok()) {
826     for (auto* t : copied_tensors) {
827       TFE_DeleteTensorHandle(t);
828     }
829     return;
830   }
831   std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
832   if (ctx->should_store_metadata.load()) {
833     maybe_stats.reset(new tensorflow::NodeExecStats);
834     maybe_stats->set_node_name(op->name);
835     maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
836     maybe_stats->set_op_start_rel_micros(0);
837     maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
838     // TODO(apassos) track referenced tensors
839   }
840   // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
841   // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
842   // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
843   // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
844   // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
845   // This is quite subtle. Re-work things to make this better?  (Would it make
846   // sense for FunctionLibraryRuntime to ensure thread-safe access to
847   // FunctionLibraryDefinition?).  TODO(apassos) figure out how to record stats
848   // for ops which are a part of functions.
849   status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get());
850   for (auto* t : copied_tensors) {
851     TFE_DeleteTensorHandle(t);
852   }
853   if (!status->status.ok()) return;
854   if (maybe_stats != nullptr) {
855     maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
856                                        maybe_stats->all_start_micros());
857     tensorflow::mutex_lock ml(ctx->metadata_mu);
858     if (ctx->should_store_metadata.load()) {
859       auto* step_stats = ctx->run_metadata.mutable_step_stats();
860       // Lazily initialize the RunMetadata with information about all devices if
861       // this is the first call.
862       while (step_stats->dev_stats_size() < ctx->devices().size()) {
863         step_stats->add_dev_stats();
864       }
865       // Find the current device's index.
866       int device_idx = 0;
867       for (int i = 0; i < ctx->devices().size(); ++i) {
868         if (ctx->devices()[i] == device) {
869           device_idx = i;
870           break;
871         }
872       }
873       // Populate the device stats for this device.
874       auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
875       dev_stats->set_device(device->name());
876       *dev_stats->add_node_stats() = *maybe_stats;
877     }
878   }
879   *num_retvals = std::min<int>(*num_retvals, outputs.size());
880   for (int i = 0; i < *num_retvals; ++i) {
881     tensorflow::Device* d = IsCPU(device) ? nullptr : device;
882     if (d != nullptr && output_memory_types != nullptr &&
883         (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
884       d = nullptr;
885     }
886     retvals[i] = new TFE_TensorHandle(outputs[i], d);
887   }
888 }
889 
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)890 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
891                                const char* serialized_function_def, size_t size,
892                                TF_Status* status) {
893   tensorflow::FunctionDef function_def;
894   if (!function_def.ParseFromArray(serialized_function_def, size)) {
895     status->status =
896         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
897     return;
898   }
899   tensorflow::mutex_lock l(ctx->functions_mu);
900   status->status = ctx->func_lib_def.AddFunctionDef(function_def);
901 }
902 
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)903 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
904                             TF_Status* status) {
905   tensorflow::mutex_lock l(ctx->functions_mu);
906   status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
907 }
908 
909 }  // extern "C"
910 
TFE_NewTensorHandle(const tensorflow::Tensor & t)911 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
912   return new TFE_TensorHandle(t, nullptr);
913 }
914 
TFE_TensorHandleUnderlyingTensorInHostMemory(TFE_TensorHandle * h,TF_Status * status)915 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
916     TFE_TensorHandle* h, TF_Status* status) {
917   if (h->d != nullptr) {
918     status->status = tensorflow::errors::FailedPrecondition(
919         "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
920         "a tensorflow::Tensor");
921     return nullptr;
922   }
923   return &h->t;
924 }
925 
TFE_ContextEnableRunMetadata(TFE_Context * ctx)926 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
927   ctx->should_store_metadata.store(true);
928 }
929 
TFE_ContextDisableRunMetadata(TFE_Context * ctx)930 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
931   tensorflow::mutex_lock ml(ctx->metadata_mu);
932   ctx->should_store_metadata.store(false);
933   ctx->run_metadata.Clear();
934 }
935 
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)936 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
937                                   TF_Status* status) {
938   tensorflow::mutex_lock ml(ctx->metadata_mu);
939   status->status = MessageToBuffer(ctx->run_metadata, buf);
940   ctx->run_metadata.Clear();
941 }
942