1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/eager/c_api.h"
17
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23
24 #include "tensorflow/c/c_api.h"
25 #include "tensorflow/c/c_api_internal.h"
26 #include "tensorflow/c/eager/c_api_internal.h"
27 #include "tensorflow/c/eager/runtime.h"
28 #ifdef TENSORFLOW_EAGER_USE_XLA
29 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
30 #endif // TENSORFLOW_EAGER_USE_XLA
31 #include "tensorflow/core/common_runtime/copy_tensor.h"
32 #include "tensorflow/core/common_runtime/device_factory.h"
33 #include "tensorflow/core/common_runtime/device_mgr.h"
34 #include "tensorflow/core/common_runtime/function.h"
35 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
36 #include "tensorflow/core/framework/rendezvous.h"
37 #include "tensorflow/core/framework/tensor_shape.pb.h"
38 #include "tensorflow/core/framework/types.h"
39 #include "tensorflow/core/lib/core/refcount.h"
40 #include "tensorflow/core/lib/gtl/flatmap.h"
41 #include "tensorflow/core/lib/gtl/map_util.h"
42 #include "tensorflow/core/lib/gtl/stl_util.h"
43 #include "tensorflow/core/platform/mutex.h"
44 #include "tensorflow/core/platform/thread_annotations.h"
45 #include "tensorflow/core/public/version.h"
46
47 using tensorflow::int64;
48 using tensorflow::string;
49
50 namespace {
IsCPU(const tensorflow::Device * d)51 bool IsCPU(const tensorflow::Device* d) {
52 return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
53 }
54
IsXLA(const tensorflow::Device * d)55 bool IsXLA(const tensorflow::Device* d) {
56 if (d == nullptr) return false;
57 const auto& device_type = d->attributes().device_type();
58 return device_type.find("XLA") != std::string::npos;
59 }
60
DeviceName(const tensorflow::Device * d)61 string DeviceName(const tensorflow::Device* d) {
62 return (d == nullptr) ? "cpu:0" : d->name();
63 }
64
65 #ifdef TENSORFLOW_EAGER_USE_XLA
66 std::atomic_int_fast64_t func_id_generator(0);
67 #endif // TENSORFLOW_EAGER_USE_XLA
68 } // namespace
69
70 extern "C" {
71
TFE_NewContextOptions()72 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
73
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)74 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
75 size_t proto_len, TF_Status* status) {
76 TF_SetConfig(&options->session_options, proto, proto_len, status);
77 }
78
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)79 void TFE_ContextOptionsSetDevicePlacementPolicy(
80 TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
81 options->policy = policy;
82 }
83
TFE_DeleteContextOptions(TFE_ContextOptions * options)84 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
85
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)86 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
87 TF_Graph* graph = TF_NewGraph();
88 TF_Session* session = TF_NewSession(graph, &opts->session_options, status);
89 if (status->status.ok()) {
90 if (session->device_mgr == nullptr || session->devices.empty()) {
91 status->status = tensorflow::errors::InvalidArgument(
92 "Provided TF_SessionOptions are not compatible with eager execution "
93 "(perhaps the TF_SessionOptions alluded to session execution in a "
94 "remote address space?)");
95 }
96 }
97 if (!status->status.ok()) {
98 TF_DeleteGraph(graph);
99 return nullptr;
100 }
101
102 return new TFE_Context(*opts, session);
103 }
104
TFE_DeleteContext(TFE_Context * ctx,TF_Status * status)105 void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
106 status->status = tensorflow::Status::OK();
107 {
108 tensorflow::mutex_lock ml(ctx->cache_mu);
109 tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
110 }
111 TF_Graph* graph = ctx->session->graph;
112 TF_DeleteSession(ctx->session, status);
113 TF_DeleteGraph(graph);
114 ctx->rendezvous->Unref();
115 delete ctx;
116 }
117
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)118 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
119 return TF_SessionListDevices(ctx->session, status);
120 }
121
TFE_ContextClearCaches(TFE_Context * ctx)122 void TFE_ContextClearCaches(TFE_Context* ctx) {
123 tensorflow::mutex_lock ml(ctx->cache_mu);
124 tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
125 }
126
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)127 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
128 TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
129 tensorflow::mutex_lock ml(ctx->policy_map_mu);
130 ctx->thread_local_policies[std::this_thread::get_id()] = policy;
131 }
132
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)133 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
134 TFE_Context* ctx) {
135 tensorflow::mutex_lock ml(ctx->policy_map_mu);
136 auto policy_map_it =
137 ctx->thread_local_policies.find(std::this_thread::get_id());
138 if (policy_map_it != ctx->thread_local_policies.end()) {
139 return policy_map_it->second;
140 }
141 return ctx->policy;
142 }
143
TFE_NewTensorHandle(TF_Tensor * t,TF_Status * status)144 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
145 tensorflow::Tensor tensor;
146 status->status = tensorflow::TF_TensorToTensor(t, &tensor);
147 if (!status->status.ok()) return nullptr;
148 return new TFE_TensorHandle(tensor, nullptr);
149 }
150
TFE_DeleteTensorHandle(TFE_TensorHandle * h)151 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) { delete h; }
152
TFE_TensorHandleDataType(TFE_TensorHandle * h)153 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
154 return static_cast<TF_DataType>(h->t.dtype());
155 }
156
TFE_TensorHandleNumDims(TFE_TensorHandle * h)157 int TFE_TensorHandleNumDims(TFE_TensorHandle* h) { return h->t.dims(); }
158
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index)159 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index) {
160 return h->t.dim_size(dim_index);
161 }
162
TFE_TensorHandleDeviceName(TFE_TensorHandle * h)163 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h) {
164 // TODO(apassos) this will be potentially incorrect in the distributed case as
165 // our local device will have a name which depends on the ClusterSpec and
166 // hence will require the context to resolve.
167 return (h->d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
168 : h->d->name().c_str();
169 }
170
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)171 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
172 if (!IsCPU(h->d)) {
173 TF_SetStatus(status, TF_UNIMPLEMENTED,
174 tensorflow::strings::StrCat(
175 "TFE_TensorHandle can be resolved iff it is on CPU (this "
176 "handle is on ",
177 h->d->name(),
178 "). Consider using TFE_TensorHandleCopyToDevice to get a "
179 "copy of the tensor on CPU")
180 .c_str());
181 return nullptr;
182 }
183 return tensorflow::TF_TensorFromTensor(h->t, status);
184 }
185
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)186 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
187 TFE_Context* ctx,
188 const char* device_name,
189 TF_Status* status) {
190 tensorflow::Device* dstd = ctx->devices()[0];
191 if (device_name != nullptr && strlen(device_name) > 0) {
192 status->status = ctx->session->device_mgr->LookupDevice(device_name, &dstd);
193 if (!status->status.ok()) return nullptr;
194 }
195
196 tensorflow::Device* srcd = h->d == nullptr ? ctx->devices()[0] : h->d;
197 bool is_same_device =
198 (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd));
199 const bool dst_cpu = IsCPU(dstd);
200 const bool src_cpu = IsCPU(srcd);
201 // both_on_cpu can be true and yet is_same_device is false, if one of src/dst
202 // has device type XLA_CPU, and the other CPU.
203 const bool both_on_cpu = src_cpu && dst_cpu;
204 if (is_same_device || both_on_cpu) {
205 return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd);
206 }
207 tensorflow::Tensor* src = &(h->t);
208 if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT &&
209 !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) {
210 TF_SetStatus(
211 status, TF_INVALID_ARGUMENT,
212 tensorflow::strings::StrCat("Can't copy Tensor with type ",
213 tensorflow::DataTypeString(src->dtype()),
214 " to device ", DeviceName(dstd), ".")
215 .c_str());
216 return nullptr;
217 }
218 tensorflow::AllocatorAttributes attr;
219 if (src->dtype() == tensorflow::DT_VARIANT) {
220 attr.set_on_host(true);
221 }
222 tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape());
223 if (src->shape().num_elements() == 0) {
224 return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd);
225 }
226 tensorflow::DeviceContext* src_device_context = nullptr;
227 if (!src_cpu) {
228 src_device_context = srcd->tensorflow_gpu_device_info()->default_context;
229 }
230 tensorflow::DeviceContext* dst_device_context = nullptr;
231 if (!dst_cpu) {
232 dst_device_context = dstd->tensorflow_gpu_device_info()->default_context;
233 }
234 // TODO(ashankar): The Sync() call below may be more aggressive than
235 // necessary. It is based on knowledge of implementation details - that
236 // GPU devices are implemented using 3 streams - one for host->device copies,
237 // one for device->host copies and one for sending operations to the GPU.
238 // With that setup, Sync()ing across all 3 streams should be sufficient
239 // but more than necessary (since it waits for operations that might have
240 // nothing to do with this tensor to complete).
241 status->status = srcd->Sync();
242 tensorflow::Notification n;
243 tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context,
244 srcd, dstd, tensorflow::AllocatorAttributes(),
245 tensorflow::AllocatorAttributes(), src, &dst,
246 [status, &n](const tensorflow::Status& s) {
247 status->status = s;
248 n.Notify();
249 });
250 n.WaitForNotification();
251 return (TF_GetCode(status) == TF_OK)
252 ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd)
253 : nullptr;
254 }
255
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)256 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
257 TF_Status* status) {
258 const char* name = op_or_function_name; // Shorthand
259 const tensorflow::AttrTypeMap* types;
260 status->status = tensorflow::AttrTypeMapForOp(name, &types);
261 if (status->status.ok()) return new TFE_Op(ctx, name, types);
262 if (TF_GetCode(status) == TF_NOT_FOUND) {
263 tensorflow::mutex_lock l(ctx->functions_mu);
264 if (ctx->func_lib_def.Find(name) != nullptr) {
265 status->status = tensorflow::Status::OK();
266 return new TFE_Op(ctx, name, nullptr);
267 }
268 }
269 return nullptr;
270 }
271
TFE_DeleteOp(TFE_Op * op)272 void TFE_DeleteOp(TFE_Op* op) { delete op; }
273
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)274 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
275 tensorflow::Device* d = nullptr;
276 if (device_name != nullptr && strlen(device_name) > 0) {
277 status->status =
278 op->ctx->session->device_mgr->LookupDevice(device_name, &d);
279 if (!status->status.ok()) return;
280 }
281 op->device = d;
282 }
283
TFE_OpGetDevice(TFE_Op * op,TF_Status * status)284 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
285 tensorflow::Device* device =
286 (op->device == nullptr) ? op->ctx->devices()[0] : op->device;
287 return device->name().c_str();
288 }
289
TFE_OpSetXLACompilation(TFE_Op * op,unsigned char enable)290 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
291 op->use_xla = enable;
292 #ifndef TENSORFLOW_EAGER_USE_XLA
293 LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
294 "built with XLA support.";
295 #endif // TENSORFLOW_EAGER_USE_XLA
296 }
297
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * h,TF_Status * status)298 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) {
299 // Questionable heuristic ...
300 //
301 // Motivation: After an 'op' is placed on GPU because some of its earlier
302 // inputs are on GPU, we want to keep the 'op' there, even if some later
303 // inputs of it are not on GPU.
304 if (IsCPU(op->device) && !IsCPU(h->d)) {
305 op->device = h->d;
306 }
307 if (!status->status.ok()) return;
308 op->inputs.push_back(h->t);
309 op->input_devices.push_back(h->d);
310 op->attrs.NumInputs(op->inputs.size());
311 }
312
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)313 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
314 unsigned char* is_list, TF_Status* status) {
315 TF_AttrType ret;
316 if (op->is_function()) {
317 status->status = tensorflow::errors::Unimplemented(
318 "TODO(apassos): Support for attributes for TensorFlow functions is not "
319 "ready yet.");
320 return TF_ATTR_INT; // The compiler requires that we return something.
321 }
322 status->status =
323 tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list);
324 return ret;
325 }
326
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)327 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
328 const char* op_or_function_name,
329 const char* attr_name, unsigned char* is_list,
330 TF_Status* status) {
331 TF_AttrType ret;
332 TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
333 if (!status->status.ok()) {
334 return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType.
335 }
336 ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
337 TFE_DeleteOp(op);
338 return ret;
339 }
340
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const char * value)341 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) {
342 op->attrs.Set(attr_name, value);
343 }
344
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)345 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
346 op->attrs.Set(attr_name, static_cast<int64>(value));
347 }
348
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)349 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
350 op->attrs.Set(attr_name, value);
351 }
352
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)353 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
354 op->attrs.Set(attr_name, (value == 0) ? false : true);
355 }
356
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)357 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
358 op->attrs.Set(attr_name, static_cast<tensorflow::DataType>(value));
359 }
360
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)361 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
362 const int num_dims, TF_Status* out_status) {
363 if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
364 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
365 tensorflow::strings::StrCat(
366 "Value specified for `", attr_name, "` has ", num_dims,
367 " dimensions which is over the limit of ",
368 tensorflow::TensorShape::MaxDimensions(), ".")
369 .c_str());
370 return;
371 }
372 tensorflow::TensorShapeProto proto;
373 if (num_dims < 0) {
374 proto.set_unknown_rank(true);
375 } else {
376 for (int d = 0; d < num_dims; ++d) {
377 proto.add_dim()->set_size(dims[d]);
378 }
379 }
380 op->attrs.Set(attr_name, proto);
381 }
382
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)383 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
384 const TFE_Op* value) {
385 tensorflow::AttrValue attr_value;
386 tensorflow::NameAttrList* func = attr_value.mutable_func();
387 func->set_name(value->name);
388 value->attrs.FillAttrValueMap(func->mutable_attr());
389 op->attrs.Set(attr_name, attr_value);
390 }
391
392 #define TFE_OP_SET_ATTR_LIST(fn, type) \
393 void fn(TFE_Op* op, const char* attr_name, const type* values, \
394 int num_values) { \
395 op->attrs.Set(attr_name, tensorflow::gtl::ArraySlice<const type>( \
396 values, num_values)); \
397 }
TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList,char *)398 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrStringList, char*)
399 TFE_OP_SET_ATTR_LIST(TFE_OpSetAttrFloatList, float)
400 #undef TFE_OP_SET_ATTR_LIST
401
402 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
403 const int64_t* values, int num_values) {
404 op->attrs.Set(attr_name,
405 tensorflow::gtl::ArraySlice<const int64>(
406 reinterpret_cast<const int64*>(values), num_values));
407 }
408
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)409 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
410 const TF_DataType* values, int num_values) {
411 op->attrs.Set(
412 attr_name,
413 tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
414 reinterpret_cast<const tensorflow::DataType*>(values), num_values));
415 }
416
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)417 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
418 const unsigned char* values, int num_values) {
419 std::unique_ptr<bool[]> b(new bool[num_values]);
420 for (int i = 0; i < num_values; ++i) {
421 b[i] = values[i];
422 }
423 op->attrs.Set(attr_name,
424 tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
425 }
426
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)427 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
428 const int64_t** dims, const int* num_dims,
429 int num_values, TF_Status* out_status) {
430 std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
431 new tensorflow::TensorShapeProto[num_values]);
432 for (int i = 0; i < num_values; ++i) {
433 const auto num_dims_i = num_dims[i];
434
435 if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
436 TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
437 tensorflow::strings::StrCat(
438 "Value specified for `", attr_name, "` has ", num_dims_i,
439 " dimensions which is over the limit of ",
440 tensorflow::TensorShape::MaxDimensions(), ".")
441 .c_str());
442 return;
443 }
444 if (num_dims_i < 0) {
445 proto[i].set_unknown_rank(true);
446 } else {
447 const int64_t* dims_i = dims[i];
448 auto proto_i = &proto[i];
449 for (int d = 0; d < num_dims_i; ++d) {
450 proto_i->add_dim()->set_size(dims_i[d]);
451 }
452 }
453 }
454 op->attrs.Set(attr_name,
455 tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
456 proto.get(), num_values));
457 }
458
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)459 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
460 const TFE_Op** value, int num_values) {
461 std::unique_ptr<tensorflow::NameAttrList[]> funcs(
462 new tensorflow::NameAttrList[num_values]);
463 for (int i = 0; i < num_values; i++) {
464 funcs[i].set_name(value[i]->name);
465 value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr());
466 }
467 op->attrs.Set(attr_name,
468 tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
469 funcs.get(), num_values));
470 }
471
472 namespace {
473
ValidateInputTypeAndPlacement(TFE_Context * ctx,tensorflow::Device * host_device,tensorflow::Device * op_device,TFE_Op * op,const tensorflow::OpKernel * kernel,std::vector<TFE_TensorHandle * > * copied_tensors)474 tensorflow::Status ValidateInputTypeAndPlacement(
475 TFE_Context* ctx, tensorflow::Device* host_device,
476 tensorflow::Device* op_device, TFE_Op* op,
477 const tensorflow::OpKernel* kernel,
478 std::vector<TFE_TensorHandle*>* copied_tensors) {
479 const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types();
480 if (memtypes.size() != op->inputs.size()) {
481 return tensorflow::errors::InvalidArgument(
482 "expected ", memtypes.size(), " inputs, got ", op->inputs.size());
483 }
484 for (int i = 0; i < op->inputs.size(); ++i) {
485 const tensorflow::Device* expected_device =
486 memtypes[i] == tensorflow::HOST_MEMORY ? host_device : op_device;
487 const tensorflow::Device* actual_device =
488 op->input_devices[i] == nullptr ? host_device : op->input_devices[i];
489 if (expected_device != actual_device) {
490 switch (TFE_ContextGetDevicePlacementPolicy(ctx)) {
491 case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32:
492 // TODO(xpan): See if we could bubble python related error up
493 // to python level.
494 if (op->inputs[i].dtype() == tensorflow::DT_INT32) {
495 // Note: enabling silent copies of int32 tensors to match behavior
496 // of graph mode.
497 break;
498 }
499 TF_FALLTHROUGH_INTENDED;
500 case TFE_DEVICE_PLACEMENT_EXPLICIT:
501 return tensorflow::errors::InvalidArgument(
502 "Tensors on conflicting devices:"
503 " cannot compute ",
504 op->name, " as input #", i, " was expected to be on ",
505 expected_device->name(), " but is actually on ",
506 actual_device->name(), " (operation running on ",
507 op_device->name(), ")",
508 " Tensors can be copied explicitly using .gpu() or .cpu(),"
509 " or transparently copied by using tfe.enable_eager_execution("
510 "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices"
511 " may slow down your model");
512 case TFE_DEVICE_PLACEMENT_WARN:
513 LOG(WARNING) << "before computing " << op->name << " input #" << i
514 << " was expected to be on " << expected_device->name()
515 << " but is actually on " << actual_device->name()
516 << " (operation running on " << op_device->name()
517 << "). This triggers a copy which can be a performance "
518 "bottleneck.";
519 break;
520 case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing.
521 break;
522 }
523 // We are only here if the policy is warn or silent copies, so we should
524 // trigger a copy.
525 TFE_TensorHandle original{op->inputs[i], op->input_devices[i]};
526 TF_Status* s = TF_NewStatus();
527 TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice(
528 &original, ctx, expected_device->name().c_str(), s);
529 if (!s->status.ok()) {
530 tensorflow::Status status = s->status;
531 delete s;
532 return tensorflow::errors::Internal(
533 "Failed copying input tensor from ", actual_device->name(), " to ",
534 expected_device->name(), " in order to run ", op->name, ": ",
535 status.error_message());
536 }
537 op->inputs[i] = copied_tensor->t;
538 copied_tensors->push_back(copied_tensor);
539 op->input_devices[i] = copied_tensor->d;
540 delete s;
541 }
542 if (op->inputs[i].dtype() != kernel->input_type(i)) {
543 return tensorflow::errors::InvalidArgument(
544 "cannot compute ", op->name, " as input #", i,
545 " was expected to be a ",
546 tensorflow::DataTypeString(kernel->input_type(i)),
547 " tensor but is a ",
548 tensorflow::DataTypeString(op->inputs[i].dtype()), " tensor");
549 }
550 }
551 return tensorflow::Status::OK();
552 }
553
554 #ifdef TENSORFLOW_EAGER_USE_XLA
555 // Synthesizes and returns a wrapper function over `op`, which must be a
556 // primitive op (e.g. matmul).
557 //
558 // The wrapper function conforms to the function signature expected by
559 // _XlaLaunchOp, with input params ordered by <constants, (variable) args and
560 // resources>. For example, if the op has input params <Const1, Arg2, Const3,
561 // Resource4, Arg5>, they will be reordered to <Const1, Const3, Arg2, Arg5,
562 // Resource4> as the input params to the synthesized function.
563 //
564 // It populates `const_input_types`, `arg_input_types` and
565 // `op_input_to_func_input` based on the reordering results, that the caller can
566 // use them to build an _XlaLaunchOp. On error, it returns NULL, and sets
567 // `status` accordingly.
OpToFunction(TFE_Op * op,std::vector<TF_DataType> * const_input_types,std::vector<TF_DataType> * arg_input_types,tensorflow::gtl::FlatMap<int,int> * op_input_to_func_input,TF_Status * status)568 const tensorflow::FunctionDef* OpToFunction(
569 TFE_Op* op, std::vector<TF_DataType>* const_input_types,
570 std::vector<TF_DataType>* arg_input_types,
571 tensorflow::gtl::FlatMap<int, int>* op_input_to_func_input,
572 TF_Status* status) {
573 DCHECK(!op->is_function());
574
575 tensorflow::FunctionDef fdef;
576
577 // Get the OpDef of the op we are trying to encapsulate.
578 TFE_Context* ctx = op->ctx;
579 const tensorflow::OpRegistrationData* op_data;
580 {
581 tensorflow::tf_shared_lock l(ctx->functions_mu);
582 status->status = ctx->func_lib_def.LookUp(op->name, &op_data);
583 if (!status->status.ok()) {
584 return nullptr;
585 }
586 }
587 const tensorflow::OpDef& op_def = op_data->op_def;
588
589 tensorflow::OpDef* signature = fdef.mutable_signature();
590
591 // Handle constant inputs.
592 const std::unordered_set<string> const_inputs(
593 *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name));
594
595 // First add place holders for the input args, so that we can refer to them by
596 // position in the next loop. Also tally up the resource inputs.
597 int num_resource_inputs = 0;
598 for (int i = 0; i < op_def.input_arg_size(); ++i) {
599 if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) {
600 ++num_resource_inputs;
601 }
602 signature->add_input_arg();
603 }
604
605 // Now we map the input params from `op_def` to `signature`, where the param
606 // ordering for `signature` is: <constants, args, resources>.
607 int const_index = 0;
608 int arg_index = const_inputs.size();
609 int resource_index = op_def.input_arg_size() - num_resource_inputs;
610 for (int i = 0; i < op_def.input_arg_size(); ++i) {
611 const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i);
612 tensorflow::OpDef::ArgDef* func_input_arg = nullptr;
613 if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) {
614 VLOG(1) << "For const input, mapping op input " << i << " to func input "
615 << const_index;
616 (*op_input_to_func_input)[i] = const_index;
617 func_input_arg = signature->mutable_input_arg(const_index++);
618 const_input_types->push_back(
619 static_cast<TF_DataType>(op->inputs[i].dtype()));
620 } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) {
621 VLOG(1) << "For resource input, mapping op input " << i
622 << " to func input " << resource_index;
623 (*op_input_to_func_input)[i] = resource_index;
624 func_input_arg = signature->mutable_input_arg(resource_index++);
625 } else {
626 VLOG(1) << "For arg input, mapping op input " << i << " to func input "
627 << arg_index;
628 (*op_input_to_func_input)[i] = arg_index;
629 func_input_arg = signature->mutable_input_arg(arg_index++);
630 arg_input_types->push_back(
631 static_cast<TF_DataType>(op->inputs[i].dtype()));
632 }
633
634 func_input_arg->set_name(op_input_arg.name());
635 func_input_arg->set_type(op->inputs[i].dtype());
636 }
637 VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString();
638
639 // Resources args are at the end of the function input params, and we should
640 // have iterated over all of them.
641 DCHECK_EQ(signature->input_arg_size(), resource_index);
642
643 // Make the synthesized function's name unique.
644 signature->set_name(tensorflow::strings::StrCat(
645 op_def.name(), func_id_generator.fetch_add(1)));
646
647 // Add the node def and set its input names to match op_def's names.
648 const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
649 DCHECK_EQ(signature->input_arg_size(), ndef.input_size());
650 *fdef.add_node_def() = ndef;
651 for (int i = 0; i < op_def.input_arg_size(); ++i) {
652 fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name());
653 }
654 VLOG(1) << "Added NodeDef: " << fdef.DebugString();
655
656 // Fix the output names and set output types.
657 for (int i = 0; i < op_def.output_arg_size(); ++i) {
658 tensorflow::OpDef::ArgDef* arg = signature->add_output_arg();
659 const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i);
660 const string& out_tensor_name = tensorflow::strings::StrCat(
661 ndef.name(), ":", op_def_arg.name(), ":", 0);
662 arg->set_name(op_def_arg.name());
663 (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name;
664 const string& type_attr = op_def_arg.type_attr();
665 if (!type_attr.empty()) {
666 auto i = ndef.attr().find(type_attr);
667 if (i == ndef.attr().end()) {
668 status->status = tensorflow::errors::InvalidArgument(
669 tensorflow::strings::StrCat("Could not find attr ", type_attr,
670 " in NodeDef ", ndef.DebugString()));
671 return nullptr;
672 }
673 arg->set_type(i->second.type());
674 }
675 }
676 VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString();
677
678 tensorflow::mutex_lock l(ctx->functions_mu);
679 status->status = ctx->func_lib_def.AddFunctionDef(fdef);
680 if (!status->status.ok()) return nullptr;
681 const auto ret = ctx->func_lib_def.Find(signature->name());
682 DCHECK(ret != nullptr);
683 return ret;
684 }
685
686 // Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed
687 // via XLA.
BuildXlaLaunch(TFE_Op * op,TF_Status * status)688 std::unique_ptr<TFE_Op> BuildXlaLaunch(TFE_Op* op, TF_Status* status) {
689 VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name;
690 auto launch_op =
691 std::unique_ptr<TFE_Op>(TFE_NewOp(op->ctx, "_XlaLaunch", status));
692 if (TF_GetCode(status) != TF_OK) return nullptr;
693 if (op->device) {
694 TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status);
695 if (TF_GetCode(status) != TF_OK) return nullptr;
696 }
697
698 const tensorflow::FunctionDef* fdef;
699 {
700 tensorflow::tf_shared_lock l(op->ctx->functions_mu);
701 fdef = op->ctx->func_lib_def.Find(op->name);
702 }
703 std::vector<TF_DataType> const_input_types;
704 std::vector<TF_DataType> arg_input_types;
705 tensorflow::gtl::FlatMap<int, int> op_input_to_func_input;
706 if (fdef == nullptr) {
707 // See if this is a primitive op, and if so create a function for it, so
708 // that _XlaLaunchOp can access it.
709 fdef = OpToFunction(op, &const_input_types, &arg_input_types,
710 &op_input_to_func_input, status);
711 if (!status->status.ok()) return nullptr;
712 } else {
713 // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for
714 // functions, so we need to find another way to handle constant inputs.
715 for (int i = const_input_types.size();
716 i < fdef->signature().input_arg_size(); ++i) {
717 VLOG(1) << "Adding Targs from input arg " << i;
718 const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i);
719 arg_input_types.push_back(static_cast<TF_DataType>(arg.type()));
720 }
721 }
722 DCHECK(fdef != nullptr);
723
724 // Copy inputs and their devices.
725 // Since input param reordering may have occurred between `op` and `launch_op`
726 // via `op_input_to_func_input`, adjust the actual inputs accordingly.
727 launch_op->inputs = op->inputs;
728 launch_op->input_devices = op->input_devices;
729 if (!op_input_to_func_input.empty()) {
730 DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size());
731 if (!op->input_devices.empty()) {
732 DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size());
733 }
734 for (int i = 0; i < op_input_to_func_input.size(); ++i) {
735 VLOG(1) << "mapping op input " << i << " to func input "
736 << op_input_to_func_input[i];
737
738 launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i];
739 if (!op->input_devices.empty()) {
740 launch_op->input_devices[op_input_to_func_input[i]] =
741 op->input_devices[i];
742 }
743 }
744 }
745 launch_op->attrs.NumInputs(op->inputs.size());
746
747 TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(),
748 const_input_types.size());
749
750 // Set Targs and Nresources attrs.
751 TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(),
752 arg_input_types.size());
753 const int num_resource_inputs = fdef->signature().input_arg_size() -
754 const_input_types.size() -
755 arg_input_types.size();
756 TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs);
757
758 // Set Tresults attr.
759 std::vector<TF_DataType> tresults;
760 for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) {
761 tresults.push_back(static_cast<TF_DataType>(arg.type()));
762 }
763 TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(),
764 tresults.size());
765
766 // Set function attr.
767 tensorflow::AttrValue attr_value;
768 tensorflow::NameAttrList* func = attr_value.mutable_func();
769 func->set_name(fdef->signature().name());
770 launch_op->attrs.Set("function", attr_value);
771
772 return launch_op;
773 }
774 #endif // TENSORFLOW_EAGER_USE_XLA
775 } // namespace
776
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)777 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
778 TF_Status* status) {
779 TFE_Context* ctx = op->ctx;
780 // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU
781 tensorflow::Device* device =
782 (op->device == nullptr) ? ctx->devices()[0] : op->device;
783
784 #ifdef TENSORFLOW_EAGER_USE_XLA
785 std::unique_ptr<TFE_Op> xla_launch_op;
786 if (op->use_xla && op->name != "_XlaLaunch") {
787 xla_launch_op = BuildXlaLaunch(op, status);
788 if (!status->status.ok()) {
789 return;
790 }
791 op = xla_launch_op.get();
792 }
793 #endif // TENSORFLOW_EAGER_USE_XLA
794
795 std::vector<tensorflow::Tensor> outputs(1);
796 const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
797 tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
798 tensorflow::KernelAndDevice* kernel;
799 {
800 tensorflow::tf_shared_lock l(ctx->cache_mu);
801 kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
802 }
803 if (kernel == nullptr) {
804 const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
805 kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
806 // Knowledge of the implementation of Init (and in-turn
807 // FunctionLibraryRuntime::CreateKernel) tells us that ctx->func_lib_def
808 // will be accessed, so grab on to the lock.
809 // See WARNING comment below - would be nice to rework to avoid this
810 // subtlety.
811 tensorflow::tf_shared_lock l(ctx->functions_mu);
812 status->status =
813 tensorflow::KernelAndDevice::Init(ndef, ctx->func_lib(device), kernel);
814 if (!status->status.ok()) {
815 delete kernel;
816 return;
817 }
818 tensorflow::mutex_lock ml(ctx->cache_mu);
819 tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
820 }
821 std::vector<TFE_TensorHandle*> copied_tensors;
822 status->status = ValidateInputTypeAndPlacement(
823 ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors);
824 output_memory_types = &kernel->kernel()->output_memory_types();
825 if (!status->status.ok()) {
826 for (auto* t : copied_tensors) {
827 TFE_DeleteTensorHandle(t);
828 }
829 return;
830 }
831 std::unique_ptr<tensorflow::NodeExecStats> maybe_stats;
832 if (ctx->should_store_metadata.load()) {
833 maybe_stats.reset(new tensorflow::NodeExecStats);
834 maybe_stats->set_node_name(op->name);
835 maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros());
836 maybe_stats->set_op_start_rel_micros(0);
837 maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros());
838 // TODO(apassos) track referenced tensors
839 }
840 // WARNING: kernel->Run utilizes the FunctionLibraryRuntime
841 // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def,
842 // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation
843 // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by
844 // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here.
845 // This is quite subtle. Re-work things to make this better? (Would it make
846 // sense for FunctionLibraryRuntime to ensure thread-safe access to
847 // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats
848 // for ops which are a part of functions.
849 status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get());
850 for (auto* t : copied_tensors) {
851 TFE_DeleteTensorHandle(t);
852 }
853 if (!status->status.ok()) return;
854 if (maybe_stats != nullptr) {
855 maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() -
856 maybe_stats->all_start_micros());
857 tensorflow::mutex_lock ml(ctx->metadata_mu);
858 if (ctx->should_store_metadata.load()) {
859 auto* step_stats = ctx->run_metadata.mutable_step_stats();
860 // Lazily initialize the RunMetadata with information about all devices if
861 // this is the first call.
862 while (step_stats->dev_stats_size() < ctx->devices().size()) {
863 step_stats->add_dev_stats();
864 }
865 // Find the current device's index.
866 int device_idx = 0;
867 for (int i = 0; i < ctx->devices().size(); ++i) {
868 if (ctx->devices()[i] == device) {
869 device_idx = i;
870 break;
871 }
872 }
873 // Populate the device stats for this device.
874 auto* dev_stats = step_stats->mutable_dev_stats(device_idx);
875 dev_stats->set_device(device->name());
876 *dev_stats->add_node_stats() = *maybe_stats;
877 }
878 }
879 *num_retvals = std::min<int>(*num_retvals, outputs.size());
880 for (int i = 0; i < *num_retvals; ++i) {
881 tensorflow::Device* d = IsCPU(device) ? nullptr : device;
882 if (d != nullptr && output_memory_types != nullptr &&
883 (*output_memory_types)[i] == tensorflow::HOST_MEMORY) {
884 d = nullptr;
885 }
886 retvals[i] = new TFE_TensorHandle(outputs[i], d);
887 }
888 }
889
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)890 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
891 const char* serialized_function_def, size_t size,
892 TF_Status* status) {
893 tensorflow::FunctionDef function_def;
894 if (!function_def.ParseFromArray(serialized_function_def, size)) {
895 status->status =
896 tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
897 return;
898 }
899 tensorflow::mutex_lock l(ctx->functions_mu);
900 status->status = ctx->func_lib_def.AddFunctionDef(function_def);
901 }
902
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)903 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
904 TF_Status* status) {
905 tensorflow::mutex_lock l(ctx->functions_mu);
906 status->status = ctx->func_lib_def.AddFunctionDef(function->fdef);
907 }
908
909 } // extern "C"
910
TFE_NewTensorHandle(const tensorflow::Tensor & t)911 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
912 return new TFE_TensorHandle(t, nullptr);
913 }
914
TFE_TensorHandleUnderlyingTensorInHostMemory(TFE_TensorHandle * h,TF_Status * status)915 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
916 TFE_TensorHandle* h, TF_Status* status) {
917 if (h->d != nullptr) {
918 status->status = tensorflow::errors::FailedPrecondition(
919 "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
920 "a tensorflow::Tensor");
921 return nullptr;
922 }
923 return &h->t;
924 }
925
TFE_ContextEnableRunMetadata(TFE_Context * ctx)926 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
927 ctx->should_store_metadata.store(true);
928 }
929
TFE_ContextDisableRunMetadata(TFE_Context * ctx)930 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
931 tensorflow::mutex_lock ml(ctx->metadata_mu);
932 ctx->should_store_metadata.store(false);
933 ctx->run_metadata.Clear();
934 }
935
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)936 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
937 TF_Status* status) {
938 tensorflow::mutex_lock ml(ctx->metadata_mu);
939 status->status = MessageToBuffer(ctx->run_metadata, buf);
940 ctx->run_metadata.Clear();
941 }
942