1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/eager/c_api.h"
17 
18 #include <algorithm>
19 #include <cstddef>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/memory/memory.h"
25 #include "tensorflow/c/c_api.h"
26 #include "tensorflow/c/c_api_internal.h"
27 #include "tensorflow/c/eager/c_api_internal.h"
28 #include "tensorflow/core/platform/host_info.h"
29 #ifdef TENSORFLOW_EAGER_USE_XLA
30 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
31 #endif  // TENSORFLOW_EAGER_USE_XLA
32 #include "tensorflow/core/common_runtime/copy_tensor.h"
33 #include "tensorflow/core/common_runtime/device_factory.h"
34 #include "tensorflow/core/common_runtime/device_mgr.h"
35 #include "tensorflow/core/common_runtime/device_set.h"
36 #include "tensorflow/core/common_runtime/eager/attr_builder.h"
37 #include "tensorflow/core/common_runtime/eager/copy_to_device_node.h"
38 #include "tensorflow/core/common_runtime/eager/execute.h"
39 #include "tensorflow/core/common_runtime/function.h"
40 #include "tensorflow/core/common_runtime/rendezvous_mgr.h"
41 #include "tensorflow/core/distributed_runtime/rpc/eager/grpc_eager_client.h"
42 #include "tensorflow/core/distributed_runtime/rpc/grpc_channel.h"
43 #include "tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h"
44 #include "tensorflow/core/distributed_runtime/server_lib.h"
45 #include "tensorflow/core/distributed_runtime/worker_env.h"
46 #include "tensorflow/core/framework/node_def_util.h"
47 #include "tensorflow/core/framework/rendezvous.h"
48 #include "tensorflow/core/framework/tensor_shape.pb.h"
49 #include "tensorflow/core/framework/types.h"
50 #include "tensorflow/core/lib/core/refcount.h"
51 #include "tensorflow/core/lib/core/stringpiece.h"
52 #include "tensorflow/core/lib/gtl/cleanup.h"
53 #include "tensorflow/core/lib/gtl/flatmap.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55 #include "tensorflow/core/lib/gtl/stl_util.h"
56 #include "tensorflow/core/lib/random/random.h"
57 #include "tensorflow/core/platform/env.h"
58 #include "tensorflow/core/platform/mutex.h"
59 #include "tensorflow/core/platform/thread_annotations.h"
60 #include "tensorflow/core/public/version.h"
61 
62 using tensorflow::int64;
63 using tensorflow::string;
64 
65 namespace {
IsCPU(const tensorflow::Device * d)66 bool IsCPU(const tensorflow::Device* d) {
67   return d == nullptr || d->tensorflow_gpu_device_info() == nullptr;
68 }
69 
IsXLA(const tensorflow::Device * d)70 bool IsXLA(const tensorflow::Device* d) {
71   if (d == nullptr) return false;
72   const auto& device_type = d->attributes().device_type();
73   return device_type.find("XLA") != std::string::npos;
74 }
75 
DeviceName(const tensorflow::Device * d)76 string DeviceName(const tensorflow::Device* d) {
77   return (d == nullptr) ? "cpu:0" : d->name();
78 }
79 
GetAllRemoteDevices(const std::vector<string> & remote_workers,tensorflow::WorkerCacheInterface * worker_cache,std::unique_ptr<tensorflow::DeviceMgr> * device_mgr)80 tensorflow::Status GetAllRemoteDevices(
81     const std::vector<string>& remote_workers,
82     tensorflow::WorkerCacheInterface* worker_cache,
83     std::unique_ptr<tensorflow::DeviceMgr>* device_mgr) {
84   std::vector<std::unique_ptr<tensorflow::Device>> remote_devices;
85   tensorflow::Status status;
86   // TODO(nareshmodi) do this in parallel instead of serially.
87   for (const string& remote_worker : remote_workers) {
88     tensorflow::Notification n;
89     tensorflow::NewRemoteDevices(
90         tensorflow::Env::Default(), worker_cache, remote_worker,
91         [&status, &n, &remote_devices](
92             const tensorflow::Status& s,
93             std::vector<tensorflow::Device*>* devices) {
94           status = s;
95           if (s.ok()) {
96             for (tensorflow::Device* d : *devices) {
97               remote_devices.emplace_back(d);
98             }
99           }
100           n.Notify();
101         });
102     n.WaitForNotification();
103   }
104   std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr(
105       new tensorflow::DeviceMgr(std::move(remote_devices)));
106 
107   TF_RETURN_IF_ERROR(status);
108 
109   *device_mgr = std::move(remote_device_mgr);
110   return tensorflow::Status::OK();
111 }
112 
CreateRemoteContexts(const std::vector<string> & remote_workers,int64 rendezvous_id,int keep_alive_secs,const tensorflow::ServerDef & server_def,tensorflow::eager::EagerClientCache * remote_eager_workers,bool async,tensorflow::gtl::FlatMap<string,tensorflow::uint64> * remote_contexts)113 tensorflow::Status CreateRemoteContexts(
114     const std::vector<string>& remote_workers, int64 rendezvous_id,
115     int keep_alive_secs, const tensorflow::ServerDef& server_def,
116     tensorflow::eager::EagerClientCache* remote_eager_workers, bool async,
117     tensorflow::gtl::FlatMap<string, tensorflow::uint64>* remote_contexts) {
118   for (int i = 0; i < remote_workers.size(); i++) {
119     const string& remote_worker = remote_workers[i];
120 
121     tensorflow::eager::CreateContextRequest request;
122     tensorflow::eager::CreateContextResponse response;
123     request.set_rendezvous_id(rendezvous_id);
124     tensorflow::DeviceNameUtils::ParsedName parsed_name;
125     if (!tensorflow::DeviceNameUtils::ParseFullName(remote_worker,
126                                                     &parsed_name)) {
127       return tensorflow::errors::InvalidArgument(
128           "Unable to parse ", remote_worker, " as a device name");
129     }
130     *request.mutable_server_def() = server_def;
131     request.mutable_server_def()->set_job_name(parsed_name.job);
132     request.mutable_server_def()->set_task_index(parsed_name.task);
133     request.set_async(async);
134     request.set_keep_alive_secs(keep_alive_secs);
135     auto* eager_client = remote_eager_workers->GetClient(remote_worker);
136     if (eager_client == nullptr) {
137       return tensorflow::errors::Internal(
138           "Cannot find a client for the given target:", remote_worker);
139     }
140     tensorflow::Notification n;
141     tensorflow::Status status;
142     // TODO(nareshmodi) do this in parallel instead of serially.
143     eager_client->CreateContextAsync(
144         &request, &response, [&status, &n](const tensorflow::Status& s) {
145           status = s;
146           n.Notify();
147         });
148     n.WaitForNotification();
149     TF_RETURN_IF_ERROR(status);
150 
151     remote_contexts->emplace(remote_worker, response.context_id());
152   }
153   return tensorflow::Status::OK();
154 }
155 
UpdateTFE_ContextWithServerDef(int keep_alive_secs,const tensorflow::ServerDef & server_def,TFE_Context * ctx)156 tensorflow::Status UpdateTFE_ContextWithServerDef(
157     int keep_alive_secs, const tensorflow::ServerDef& server_def,
158     TFE_Context* ctx) {
159   // We don't use the TF_RETURN_IF_ERROR macro directly since that destroys the
160   // server object (which currently CHECK-fails) and we miss the error, instead,
161   // we log the error, and then return to allow the user to see the error
162   // message.
163 #define LOG_AND_RETURN_IF_ERROR(...)                    \
164   do {                                                  \
165     const ::tensorflow::Status _status = (__VA_ARGS__); \
166     if (TF_PREDICT_FALSE(!_status.ok())) {              \
167       LOG(ERROR) << _status.error_message();            \
168       return _status;                                   \
169     }                                                   \
170   } while (0);
171 
172   string worker_name =
173       tensorflow::strings::StrCat("/job:", server_def.job_name(),
174                                   "/replica:0/task:", server_def.task_index());
175 
176   std::unique_ptr<tensorflow::ServerInterface> server;
177   LOG_AND_RETURN_IF_ERROR(tensorflow::NewServer(server_def, &server));
178 
179   tensorflow::GrpcServer* grpc_server =
180       dynamic_cast<tensorflow::GrpcServer*>(server.get());
181   if (grpc_server == nullptr) {
182     LOG_AND_RETURN_IF_ERROR(tensorflow::errors::Internal(
183         "Currently, TFE_NewContext only supports tensorflow::GrpcServer."));
184   }
185 
186   LOG_AND_RETURN_IF_ERROR(grpc_server->Start());
187 
188   int64 rendezvous_id = tensorflow::random::New64();
189 
190   std::vector<string> remote_workers;
191   grpc_server->master_env()->worker_cache->ListWorkers(&remote_workers);
192   remote_workers.erase(
193       std::remove(remote_workers.begin(), remote_workers.end(), worker_name),
194       remote_workers.end());
195 
196   std::unique_ptr<tensorflow::DeviceMgr> remote_device_mgr;
197   LOG_AND_RETURN_IF_ERROR(GetAllRemoteDevices(
198       remote_workers, grpc_server->master_env()->worker_cache,
199       &remote_device_mgr));
200 
201   std::shared_ptr<tensorflow::GrpcChannelCache> channel_cache =
202       grpc_server->channel_cache();
203   std::unique_ptr<tensorflow::eager::EagerClientCache> remote_eager_workers(
204       tensorflow::eager::NewGrpcEagerClientCache(channel_cache));
205 
206   // Initialize remote eager workers.
207   tensorflow::gtl::FlatMap<string, tensorflow::uint64> remote_contexts;
208   LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts(
209       remote_workers, rendezvous_id, keep_alive_secs, server_def,
210       remote_eager_workers.get(), ctx->context.Async(), &remote_contexts));
211 
212   tensorflow::RemoteRendezvous* r =
213       grpc_server->worker_env()->rendezvous_mgr->Find(rendezvous_id);
214 
215   auto session_name = tensorflow::strings::StrCat("eager_", rendezvous_id);
216   TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession(
217       session_name, server_def, true));
218 
219   std::shared_ptr<tensorflow::WorkerSession> worker_session;
220   TF_RETURN_IF_ERROR(
221       grpc_server->worker_env()->session_mgr->WorkerSessionForSession(
222           session_name, &worker_session));
223 
224   // Initialize remote tensor communication based on worker session.
225   TF_RETURN_IF_ERROR(r->Initialize(worker_session.get()));
226 
227   auto* device_mgr = grpc_server->worker_env()->device_mgr;
228 
229   return ctx->context.InitializeRemote(
230       std::move(server), std::move(remote_eager_workers),
231       std::move(remote_device_mgr), remote_contexts, r, device_mgr,
232       keep_alive_secs);
233 #undef LOG_AND_RETURN_IF_ERROR
234 }
235 
OpInferSingleInputAttrs(TFE_Op * op,TFE_TensorHandle * input)236 tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,
237                                            TFE_TensorHandle* input) {
238   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
239   const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
240   if (!input_def.number_attr().empty() || !input_def.type_list_attr().empty()) {
241     // Some clients that are still setting their input attributes manually are
242     // adding input list to their op by calling `TFE_OpAddInput` for each of
243     // its elements instead of calling `TFE_OpAddInputList`. When this happens,
244     // we cannot detect the end of such list, thus lose track of the input
245     // arguments in the op definition. To guarantee backward compatibility with
246     // those clients, disable automatic inference in this case.
247     op->inference_ctx.reset(nullptr);
248     return tensorflow::Status::OK();
249   }
250   const std::string& type_attr = input_def.type_attr();
251   if (!type_attr.empty() && ictx->attrs.find(type_attr) == ictx->attrs.end()) {
252     op->operation.MutableAttrs()->Set(type_attr, input->handle->dtype);
253     ictx->attrs.insert(type_attr);
254   }
255   return tensorflow::Status::OK();
256 }
257 
OpInferSingleTypeInputListAttrs(TFE_Op * op,const tensorflow::OpDef::ArgDef & input_def,TFE_TensorHandle ** inputs,int num_inputs)258 void OpInferSingleTypeInputListAttrs(TFE_Op* op,
259                                      const tensorflow::OpDef::ArgDef& input_def,
260                                      TFE_TensorHandle** inputs,
261                                      int num_inputs) {
262   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
263   if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
264     op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
265     ictx->attrs.insert(input_def.number_attr());
266   }
267   if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
268     op->operation.MutableAttrs()->Set(input_def.type_attr(),
269                                       inputs[0]->handle->dtype);
270     ictx->attrs.insert(input_def.type_attr());
271   }
272 }
273 
OpInferMixedTypeInputListAttrs(TFE_Op * op,const tensorflow::OpDef::ArgDef & input_def,TFE_TensorHandle ** inputs,int num_inputs)274 void OpInferMixedTypeInputListAttrs(TFE_Op* op,
275                                     const tensorflow::OpDef::ArgDef& input_def,
276                                     TFE_TensorHandle** inputs, int num_inputs) {
277   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
278   if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
279     std::unique_ptr<tensorflow::DataType[]> dtypes(
280         new tensorflow::DataType[num_inputs]);
281     for (int i = 0; i < num_inputs; ++i) {
282       dtypes[i] = inputs[i]->handle->dtype;
283     }
284     op->operation.MutableAttrs()->Set(
285         input_def.type_list_attr(),
286         tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
287                                                                 num_inputs));
288     ictx->attrs.insert(input_def.type_list_attr());
289   }
290 }
291 
OpInferInputListAttrs(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs)292 tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
293                                          int num_inputs) {
294   TFE_OpInferenceContext* ictx = op->inference_ctx.get();
295   const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
296   if (!input_def.type_list_attr().empty()) {
297     OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
298   } else if (!input_def.type_attr().empty() &&
299              !input_def.number_attr().empty()) {
300     OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
301   } else {
302     return tensorflow::errors::InvalidArgument("Invalid input list definition");
303   }
304   return tensorflow::Status::OK();
305 }
306 
307 }  // namespace
308 
309 extern "C" {
310 
TFE_NewContextOptions()311 TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; }
312 
TFE_ContextOptionsSetConfig(TFE_ContextOptions * options,const void * proto,size_t proto_len,TF_Status * status)313 void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto,
314                                  size_t proto_len, TF_Status* status) {
315   TF_SetConfig(&options->session_options, proto, proto_len, status);
316 }
317 
TFE_ContextOptionsSetAsync(TFE_ContextOptions * options,unsigned char enable)318 void TFE_ContextOptionsSetAsync(TFE_ContextOptions* options,
319                                 unsigned char enable) {
320   options->async = enable;
321 }
322 
TFE_ContextOptionsSetDevicePlacementPolicy(TFE_ContextOptions * options,TFE_ContextDevicePlacementPolicy policy)323 void TFE_ContextOptionsSetDevicePlacementPolicy(
324     TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) {
325   options->policy = policy;
326 }
327 
TFE_ContextSetAsyncForThread(TFE_Context * ctx,unsigned char enable,TF_Status * status)328 TF_CAPI_EXPORT extern void TFE_ContextSetAsyncForThread(TFE_Context* ctx,
329                                                         unsigned char enable,
330                                                         TF_Status* status) {
331   status->status = ctx->context.SetAsyncForThread(enable);
332 }
333 
TFE_DeleteContextOptions(TFE_ContextOptions * options)334 void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; }
335 
TFE_NewContext(const TFE_ContextOptions * opts,TF_Status * status)336 TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
337   std::vector<std::unique_ptr<tensorflow::Device>> devices;
338   status->status = tensorflow::DeviceFactory::AddDevices(
339       opts->session_options.options, "/job:localhost/replica:0/task:0",
340       &devices);
341   if (!status->status.ok()) return nullptr;
342   std::unique_ptr<tensorflow::DeviceMgr> device_mgr(
343       new tensorflow::DeviceMgr(std::move(devices)));
344 
345   tensorflow::Rendezvous* r =
346       new tensorflow::IntraProcessRendezvous(device_mgr.get());
347 
348   return new TFE_Context(opts->session_options.options, opts->policy,
349                          opts->async, device_mgr.release(),
350                          /*device_mgr_owned*/ true, r);
351 }
352 
TFE_NewContextFromSession(const TFE_ContextOptions * opts,TF_Session * sess,TF_Status * status)353 TFE_Context* TFE_NewContextFromSession(const TFE_ContextOptions* opts,
354                                        TF_Session* sess, TF_Status* status) {
355   const tensorflow::DeviceMgr* device_mgr = nullptr;
356   status->status = sess->session->LocalDeviceManager(&device_mgr);
357   if (!status->status.ok()) return nullptr;
358   tensorflow::Rendezvous* r =
359       new tensorflow::IntraProcessRendezvous(device_mgr);
360   return new TFE_Context(opts->session_options.options, opts->policy,
361                          opts->async, device_mgr, /*device_mgr_owned*/ false,
362                          r);
363 }
364 
TFE_DeleteContext(TFE_Context * ctx)365 void TFE_DeleteContext(TFE_Context* ctx) { delete ctx; }
366 
TFE_ContextListDevices(TFE_Context * ctx,TF_Status * status)367 TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
368   TF_DeviceList* list = new TF_DeviceList;
369   ctx->context.local_device_mgr()->ListDeviceAttributes(&list->response);
370   if (ctx->context.remote_device_mgr()) {
371     ctx->context.remote_device_mgr()->ListDeviceAttributes(&list->response);
372   }
373   return list;
374 }
375 
TFE_ContextClearCaches(TFE_Context * ctx,TF_Status * status)376 void TFE_ContextClearCaches(TFE_Context* ctx, TF_Status* status) {
377   status->status = ctx->context.ClearCaches();
378 }
379 
380 // Set server_def on the context, possibly updating it.
TFE_ContextSetServerDef(TFE_Context * ctx,int keep_alive_secs,const void * proto,size_t proto_len,TF_Status * status)381 TF_CAPI_EXPORT extern void TFE_ContextSetServerDef(TFE_Context* ctx,
382                                                    int keep_alive_secs,
383                                                    const void* proto,
384                                                    size_t proto_len,
385                                                    TF_Status* status) {
386   tensorflow::ServerDef server_def;
387   if (!server_def.ParseFromArray(proto, proto_len)) {
388     status->status = tensorflow::errors::InvalidArgument(
389         "Invalid tensorflow.ServerDef protocol buffer");
390     return;
391   }
392   status->status =
393       UpdateTFE_ContextWithServerDef(keep_alive_secs, server_def, ctx);
394 }
395 
TFE_ContextSetThreadLocalDevicePlacementPolicy(TFE_Context * ctx,TFE_ContextDevicePlacementPolicy policy)396 void TFE_ContextSetThreadLocalDevicePlacementPolicy(
397     TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) {
398   ctx->context.SetThreadLocalDevicePlacementPolicy(
399       static_cast<tensorflow::ContextDevicePlacementPolicy>(policy));
400 }
401 
402 // Note: this function looks up a thread local policy. So it should be called in
403 // the appropriate client thread. In particular, in async mode, it may not be
404 // safe to call this function from the async EagerExecutor threads.
TFE_ContextGetDevicePlacementPolicy(TFE_Context * ctx)405 extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy(
406     TFE_Context* ctx) {
407   return static_cast<TFE_ContextDevicePlacementPolicy>(
408       ctx->context.GetDevicePlacementPolicy());
409 }
410 
TFE_ContextAsyncWait(TFE_Context * ctx,TF_Status * status)411 void TFE_ContextAsyncWait(TFE_Context* ctx, TF_Status* status) {
412   status->status = ctx->context.AsyncWait();
413 }
414 
TFE_ContextGetStatus(TFE_Context * ctx,TF_Status * status)415 void TFE_ContextGetStatus(TFE_Context* ctx, TF_Status* status) {
416   status->status = ctx->context.GetStatus();
417 }
418 
TFE_ContextAsyncClearError(TFE_Context * ctx)419 void TFE_ContextAsyncClearError(TFE_Context* ctx) {
420   ctx->context.ClearAsyncError();
421 }
422 
TFE_NewTensorHandle(TF_Tensor * t,TF_Status * status)423 TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
424   tensorflow::Tensor tensor;
425   status->status = tensorflow::TF_TensorToTensor(t, &tensor);
426   if (!status->status.ok()) return nullptr;
427   return new TFE_TensorHandle(tensor, nullptr, nullptr);
428 }
429 
TFE_DeleteTensorHandle(TFE_TensorHandle * h)430 void TFE_DeleteTensorHandle(TFE_TensorHandle* h) {
431   if (h == nullptr) return;
432   VLOG(1) << "Deleting tensor handle " << h << " with internal handle "
433           << h->handle;
434   if (h->handle) {
435     h->handle->Unref();
436   }
437   delete h;
438 }
439 
TFE_TensorHandleDataType(TFE_TensorHandle * h)440 TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h) {
441   return static_cast<TF_DataType>(h->handle->dtype);
442 }
443 
TFE_TensorHandleNumDims(TFE_TensorHandle * h,TF_Status * status)444 int TFE_TensorHandleNumDims(TFE_TensorHandle* h, TF_Status* status) {
445   if (h == nullptr || h->handle == nullptr) {
446     status->status = tensorflow::errors::InvalidArgument(
447         "The passed in handle is a nullptr");
448     return -1;
449   }
450   int result;
451   status->status = h->handle->NumDims(&result);
452   return result;
453 }
454 
TFE_TensorHandleNumElements(TFE_TensorHandle * h,TF_Status * status)455 int64_t TFE_TensorHandleNumElements(TFE_TensorHandle* h, TF_Status* status) {
456   if (h == nullptr || h->handle == nullptr) {
457     status->status = tensorflow::errors::InvalidArgument(
458         "The passed in handle is a nullptr");
459     return -1;
460   }
461   tensorflow::int64 result;
462   status->status = h->handle->NumElements(&result);
463   return result;
464 }
465 
TFE_TensorHandleDim(TFE_TensorHandle * h,int dim_index,TF_Status * status)466 int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index,
467                             TF_Status* status) {
468   if (h == nullptr || h->handle == nullptr) {
469     status->status = tensorflow::errors::InvalidArgument(
470         "The passed in handle is a nullptr");
471     return -1;
472   }
473   tensorflow::int64 result;
474   status->status = h->handle->Dim(dim_index, &result);
475   return result;
476 }
477 
TFE_TensorHandleDeviceName(TFE_TensorHandle * h,TF_Status * status)478 const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h, TF_Status* status) {
479   if (h == nullptr || h->handle == nullptr) {
480     status->status = tensorflow::errors::InvalidArgument(
481         "The passed in handle is a nullptr");
482     return nullptr;
483   }
484   tensorflow::Device* d = h->handle->op_device();
485   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
486                         : d->name().c_str();
487 }
488 
TFE_TensorHandleBackingDeviceName(TFE_TensorHandle * h,TF_Status * status)489 const char* TFE_TensorHandleBackingDeviceName(TFE_TensorHandle* h,
490                                               TF_Status* status) {
491   if (h == nullptr || h->handle == nullptr) {
492     status->status = tensorflow::errors::InvalidArgument(
493         "The passed in handle is a nullptr");
494     return nullptr;
495   }
496   tensorflow::Device* d = h->handle->device();
497   return (d == nullptr) ? "/job:localhost/replica:0/task:0/device:CPU:0"
498                         : d->name().c_str();
499 }
500 
TFE_TensorHandleCopySharingTensor(TFE_TensorHandle * h,TF_Status * status)501 TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopySharingTensor(
502     TFE_TensorHandle* h, TF_Status* status) {
503   if (h == nullptr || h->handle == nullptr) {
504     status->status = tensorflow::errors::InvalidArgument(
505         "The passed in handle is a nullptr");
506     return nullptr;
507   }
508 
509   h->handle->Ref();
510 
511   return new TFE_TensorHandle(h->handle);
512 }
513 
TFE_TensorHandleResolve(TFE_TensorHandle * h,TF_Status * status)514 TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status) {
515   if (h == nullptr || h->handle == nullptr) {
516     status->status = tensorflow::errors::InvalidArgument(
517         "The passed in handle is a nullptr");
518     return nullptr;
519   }
520   // TODO(agarwal): move this implementation inside TFE_TensorHandle.
521   const tensorflow::Tensor* t = nullptr;
522   tensorflow::TensorHandle* h_cpu = nullptr;
523   tensorflow::Device* d = nullptr;
524   tensorflow::Device* op_device = nullptr;
525 
526   if (h->handle->IsRemote()) {
527     status->status = EagerCopyToDevice(
528         h->handle, h->handle->Context(),
529         h->handle->Context()->HostCPU()->name().c_str(), &h_cpu);
530     if (!status->status.ok()) {
531       return nullptr;
532     }
533     status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
534     if (!status->status.ok()) {
535       h_cpu->Unref();
536       return nullptr;
537     }
538   } else {
539     status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
540     if (!status->status.ok()) return nullptr;
541 
542     if (!IsCPU(d)) {
543       status->status = h->handle->CopyToDevice(
544           h->handle->Context(), h->handle->Context()->HostCPU(), &h_cpu);
545       if (!status->status.ok()) {
546         return nullptr;
547       }
548       status->status = h_cpu->TensorAndDevice(&t, &d, &op_device);
549       if (!status->status.ok()) {
550         h_cpu->Unref();
551         return nullptr;
552       }
553     }
554   }
555   TF_Tensor* retval = tensorflow::TF_TensorFromTensor(*t, status);
556   if (h_cpu != nullptr) {
557     h_cpu->Unref();
558   }
559   return retval;
560 }
561 
TFE_NewOp(TFE_Context * ctx,const char * op_or_function_name,TF_Status * status)562 TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name,
563                   TF_Status* status) {
564   const char* name = op_or_function_name;  // Shorthand
565   const tensorflow::AttrTypeMap* types;
566   bool is_function = false;
567   status->status = tensorflow::AttrTypeMapForOp(name, &types, &is_function);
568   if (!status->status.ok()) {
569     return nullptr;
570   }
571   if (!is_function) {
572     const tensorflow::OpDef* op_def;
573     status->status = tensorflow::OpDefForOp(op_or_function_name, &op_def);
574     if (!status->status.ok()) {
575       return nullptr;
576     }
577     return new TFE_Op(ctx, name, false, types,
578                       new TFE_OpInferenceContext(op_def));
579   }
580   if (!ctx->context.FindFunctionByName(name)) {
581     status->status = tensorflow::errors::NotFound(
582         "'", name,
583         "' is neither a type of a primitive operation nor a name "
584         "of a function registered in binary running on ",
585         tensorflow::port::Hostname(),
586         ". Make sure the operation or function is "
587         "registered in the binary running in this process.");
588     return nullptr;
589   }
590   return new TFE_Op(ctx, name, true, types, nullptr);
591 }
592 
TFE_DeleteOp(TFE_Op * op)593 void TFE_DeleteOp(TFE_Op* op) { delete op; }
594 
TFE_OpSetDevice(TFE_Op * op,const char * device_name,TF_Status * status)595 void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) {
596   status->status = op->operation.SetDevice(device_name);
597 }
598 
TFE_OpGetDevice(TFE_Op * op,TF_Status * status)599 const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) {
600   tensorflow::Device* device = (op->operation.Device() == nullptr)
601                                    ? op->operation.EagerContext()->HostCPU()
602                                    : op->operation.Device();
603   return device->name().c_str();
604 }
605 
TFE_OpSetXLACompilation(TFE_Op * op,unsigned char enable)606 void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) {
607   op->operation.SetUseXla(enable);
608 #ifndef TENSORFLOW_EAGER_USE_XLA
609   LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not "
610                   "built with XLA support.";
611 #endif  // TENSORFLOW_EAGER_USE_XLA
612 }
613 
TFE_OpAddInput(TFE_Op * op,TFE_TensorHandle * input,TF_Status * status)614 void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* input, TF_Status* status) {
615   op->operation.AddInput(input->handle);
616   if (op->inference_ctx) {
617     status->status = OpInferSingleInputAttrs(op, input);
618   }
619 }
620 
TFE_OpAddInputList(TFE_Op * op,TFE_TensorHandle ** inputs,int num_inputs,TF_Status * status)621 void TFE_OpAddInputList(TFE_Op* op, TFE_TensorHandle** inputs, int num_inputs,
622                         TF_Status* status) {
623   for (int i = 0; i < num_inputs; ++i) {
624     op->operation.AddInput(inputs[i]->handle);
625   }
626   if (op->inference_ctx) {
627     status->status = OpInferInputListAttrs(op, inputs, num_inputs);
628   }
629 }
630 
TFE_OpGetAttrType(TFE_Op * op,const char * attr_name,unsigned char * is_list,TF_Status * status)631 TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name,
632                               unsigned char* is_list, TF_Status* status) {
633   TF_AttrType ret;
634   status->status = tensorflow::AttrTypeByName(*op->operation.AttrTypes(),
635                                               attr_name, &ret, is_list);
636   return ret;
637 }
638 
TFE_OpNameGetAttrType(TFE_Context * ctx,const char * op_or_function_name,const char * attr_name,unsigned char * is_list,TF_Status * status)639 TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx,
640                                   const char* op_or_function_name,
641                                   const char* attr_name, unsigned char* is_list,
642                                   TF_Status* status) {
643   TF_AttrType ret;
644   TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status);
645   if (!status->status.ok()) {
646     return TF_ATTR_INT;  // Same dummy return as TFE_OpGetAttrType.
647   }
648   ret = TFE_OpGetAttrType(op, attr_name, is_list, status);
649   TFE_DeleteOp(op);
650   return ret;
651 }
652 
TFE_OpSetAttrString(TFE_Op * op,const char * attr_name,const void * value,size_t length)653 void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const void* value,
654                          size_t length) {
655   op->operation.MutableAttrs()->Set(
656       attr_name,
657       tensorflow::StringPiece(static_cast<const char*>(value), length));
658 }
659 
TFE_OpSetAttrInt(TFE_Op * op,const char * attr_name,int64_t value)660 void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value) {
661   op->operation.MutableAttrs()->Set(attr_name, static_cast<int64>(value));
662 }
663 
TFE_OpSetAttrFloat(TFE_Op * op,const char * attr_name,float value)664 void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value) {
665   op->operation.MutableAttrs()->Set(attr_name, value);
666 }
667 
TFE_OpSetAttrBool(TFE_Op * op,const char * attr_name,unsigned char value)668 void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value) {
669   op->operation.MutableAttrs()->Set(attr_name, (value == 0) ? false : true);
670 }
671 
TFE_OpSetAttrType(TFE_Op * op,const char * attr_name,TF_DataType value)672 void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, TF_DataType value) {
673   op->operation.MutableAttrs()->Set(attr_name,
674                                     static_cast<tensorflow::DataType>(value));
675 }
676 
TFE_OpSetAttrShape(TFE_Op * op,const char * attr_name,const int64_t * dims,const int num_dims,TF_Status * out_status)677 void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, const int64_t* dims,
678                         const int num_dims, TF_Status* out_status) {
679   if (num_dims > tensorflow::TensorShape::MaxDimensions()) {
680     TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
681                  tensorflow::strings::StrCat(
682                      "Value specified for `", attr_name, "` has ", num_dims,
683                      " dimensions which is over the limit of ",
684                      tensorflow::TensorShape::MaxDimensions(), ".")
685                      .c_str());
686     return;
687   }
688   tensorflow::TensorShapeProto proto;
689   if (num_dims < 0) {
690     proto.set_unknown_rank(true);
691   } else {
692     for (int d = 0; d < num_dims; ++d) {
693       proto.add_dim()->set_size(dims[d]);
694     }
695   }
696   op->operation.MutableAttrs()->Set(attr_name, proto);
697 }
698 
TFE_OpSetAttrFunction(TFE_Op * op,const char * attr_name,const TFE_Op * value)699 void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
700                            const TFE_Op* value) {
701   tensorflow::AttrValue attr_value;
702   tensorflow::NameAttrList* func = attr_value.mutable_func();
703   func->set_name(value->operation.Name());
704   value->operation.Attrs().FillAttrValueMap(func->mutable_attr());
705   op->operation.MutableAttrs()->Set(attr_name, attr_value);
706 }
707 
TFE_OpSetAttrFunctionName(TFE_Op * op,const char * attr_name,const char * data,size_t length)708 void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
709                                const char* data, size_t length) {
710   tensorflow::AttrValue attr_value;
711   tensorflow::NameAttrList* func = attr_value.mutable_func();
712   func->set_name(data, length);
713   op->operation.MutableAttrs()->Set(attr_name, attr_value);
714 }
715 
TFE_OpSetAttrTensor(TFE_Op * op,const char * attr_name,TF_Tensor * tensor,TF_Status * status)716 void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
717                          TF_Status* status) {
718   tensorflow::Tensor t;
719   status->status = TF_TensorToTensor(tensor, &t);
720   if (status->status.ok()) op->operation.MutableAttrs()->Set(attr_name, t);
721 }
722 
TFE_OpSetAttrStringList(TFE_Op * op,const char * attr_name,const void * const * values,const size_t * lengths,int num_values)723 void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name,
724                              const void* const* values, const size_t* lengths,
725                              int num_values) {
726   std::vector<tensorflow::StringPiece> v(num_values);
727   for (int i = 0; i < num_values; ++i) {
728     v[i] = tensorflow::StringPiece(static_cast<const char*>(values[i]),
729                                    lengths[i]);
730   }
731   op->operation.MutableAttrs()->Set(attr_name, v);
732 }
733 
TFE_OpSetAttrFloatList(TFE_Op * op,const char * attr_name,const float * values,int num_values)734 void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name,
735                             const float* values, int num_values) {
736   op->operation.MutableAttrs()->Set(
737       attr_name, tensorflow::gtl::ArraySlice<const float>(values, num_values));
738 }
739 
TFE_OpSetAttrIntList(TFE_Op * op,const char * attr_name,const int64_t * values,int num_values)740 void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name,
741                           const int64_t* values, int num_values) {
742   op->operation.MutableAttrs()->Set(
743       attr_name, tensorflow::gtl::ArraySlice<const int64>(
744                      reinterpret_cast<const int64*>(values), num_values));
745 }
746 
TFE_OpSetAttrTypeList(TFE_Op * op,const char * attr_name,const TF_DataType * values,int num_values)747 void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name,
748                            const TF_DataType* values, int num_values) {
749   op->operation.MutableAttrs()->Set(
750       attr_name,
751       tensorflow::gtl::ArraySlice<const tensorflow::DataType>(
752           reinterpret_cast<const tensorflow::DataType*>(values), num_values));
753 }
754 
TFE_OpSetAttrBoolList(TFE_Op * op,const char * attr_name,const unsigned char * values,int num_values)755 void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name,
756                            const unsigned char* values, int num_values) {
757   std::unique_ptr<bool[]> b(new bool[num_values]);
758   for (int i = 0; i < num_values; ++i) {
759     b[i] = values[i];
760   }
761   op->operation.MutableAttrs()->Set(
762       attr_name, tensorflow::gtl::ArraySlice<const bool>(b.get(), num_values));
763 }
764 
TFE_OpSetAttrShapeList(TFE_Op * op,const char * attr_name,const int64_t ** dims,const int * num_dims,int num_values,TF_Status * out_status)765 void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name,
766                             const int64_t** dims, const int* num_dims,
767                             int num_values, TF_Status* out_status) {
768   std::unique_ptr<tensorflow::TensorShapeProto[]> proto(
769       new tensorflow::TensorShapeProto[num_values]);
770   for (int i = 0; i < num_values; ++i) {
771     const auto num_dims_i = num_dims[i];
772 
773     if (num_dims_i > tensorflow::TensorShape::MaxDimensions()) {
774       TF_SetStatus(out_status, TF_INVALID_ARGUMENT,
775                    tensorflow::strings::StrCat(
776                        "Value specified for `", attr_name, "` has ", num_dims_i,
777                        " dimensions which is over the limit of ",
778                        tensorflow::TensorShape::MaxDimensions(), ".")
779                        .c_str());
780       return;
781     }
782     if (num_dims_i < 0) {
783       proto[i].set_unknown_rank(true);
784     } else {
785       const int64_t* dims_i = dims[i];
786       auto proto_i = &proto[i];
787       for (int d = 0; d < num_dims_i; ++d) {
788         proto_i->add_dim()->set_size(dims_i[d]);
789       }
790     }
791   }
792   op->operation.MutableAttrs()->Set(
793       attr_name, tensorflow::gtl::ArraySlice<tensorflow::TensorShapeProto>(
794                      proto.get(), num_values));
795 }
796 
TFE_OpSetAttrFunctionList(TFE_Op * op,const char * attr_name,const TFE_Op ** value,int num_values)797 void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name,
798                                const TFE_Op** value, int num_values) {
799   std::unique_ptr<tensorflow::NameAttrList[]> funcs(
800       new tensorflow::NameAttrList[num_values]);
801   for (int i = 0; i < num_values; i++) {
802     funcs[i].set_name(value[i]->operation.Name());
803     value[i]->operation.Attrs().FillAttrValueMap(funcs[i].mutable_attr());
804   }
805   op->operation.MutableAttrs()->Set(
806       attr_name, tensorflow::gtl::ArraySlice<const tensorflow::NameAttrList>(
807                      funcs.get(), num_values));
808 }
809 
TFE_Execute(TFE_Op * op,TFE_TensorHandle ** retvals,int * num_retvals,TF_Status * status)810 void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
811                  TF_Status* status) {
812   VLOG(1) << "Calling TFE_Execute() on op " << op;
813   tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> handle_retvals(
814       *num_retvals);
815   status->status =
816       tensorflow::EagerExecute(&op->operation, &handle_retvals, num_retvals);
817   if (!status->status.ok()) {
818     return;
819   }
820   for (int i = 0; i < *num_retvals; ++i) {
821     retvals[i] = new TFE_TensorHandle(handle_retvals[i]);
822   }
823 }
824 
TFE_TensorHandleCopyToDevice(TFE_TensorHandle * h,TFE_Context * ctx,const char * device_name,TF_Status * status)825 TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h,
826                                                TFE_Context* ctx,
827                                                const char* device_name,
828                                                TF_Status* status) {
829   tensorflow::TensorHandle* handle;
830   status->status = tensorflow::EagerCopyToDevice(h->handle, &ctx->context,
831                                                  device_name, &handle);
832   if (status->status.ok()) {
833     return new TFE_TensorHandle(handle);
834   }
835   return nullptr;
836 }
837 
TFE_ContextAddFunctionDef(TFE_Context * ctx,const char * serialized_function_def,size_t size,TF_Status * status)838 void TFE_ContextAddFunctionDef(TFE_Context* ctx,
839                                const char* serialized_function_def, size_t size,
840                                TF_Status* status) {
841   tensorflow::FunctionDef function_def;
842   if (!function_def.ParseFromArray(serialized_function_def, size)) {
843     status->status =
844         tensorflow::errors::InvalidArgument("Invalid FunctionDef proto");
845     return;
846   }
847   status->status = ctx->context.AddFunctionDef(function_def);
848 }
849 
TFE_ContextAddFunction(TFE_Context * ctx,TF_Function * function,TF_Status * status)850 void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function,
851                             TF_Status* status) {
852   status->status = ctx->context.AddFunctionDef(function->fdef);
853 }
854 
TFE_ContextHasFunction(TFE_Context * ctx,const char * name)855 unsigned char TFE_ContextHasFunction(TFE_Context* ctx, const char* name) {
856   return ctx->context.FindFunctionDef(name) != nullptr;
857 }
858 
TFE_ContextEnableRunMetadata(TFE_Context * ctx)859 void TFE_ContextEnableRunMetadata(TFE_Context* ctx) {
860   ctx->context.SetShouldStoreGraphs(true);
861   ctx->context.SetShouldStoreStepStats(true);
862 }
863 
TFE_ContextDisableRunMetadata(TFE_Context * ctx)864 void TFE_ContextDisableRunMetadata(TFE_Context* ctx) {
865   ctx->context.SetShouldStoreGraphs(false);
866   ctx->context.SetShouldStoreStepStats(false);
867 }
868 
869 }  // extern "C"
870 
TFE_NewTensorHandle(const tensorflow::Tensor & t)871 TFE_TensorHandle* TFE_NewTensorHandle(const tensorflow::Tensor& t) {
872   return new TFE_TensorHandle(t, nullptr, nullptr);
873 }
874 
TFE_TensorHandleUnderlyingTensorInHostMemory(TFE_TensorHandle * h,TF_Status * status)875 const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory(
876     TFE_TensorHandle* h, TF_Status* status) {
877   if (!h->handle->OnHostCPU()) {
878     status->status = tensorflow::errors::FailedPrecondition(
879         "TFE_TensorHandle is placed in device (not host) memory. Cannot return "
880         "a tensorflow::Tensor");
881     return nullptr;
882   }
883   tensorflow::Device* d = nullptr;
884   tensorflow::Device* op_device = nullptr;
885   const tensorflow::Tensor* t = nullptr;
886   status->status = h->handle->TensorAndDevice(&t, &d, &op_device);
887   if (!status->status.ok()) return nullptr;
888   return t;
889 }
890 
TFE_ContextExportRunMetadata(TFE_Context * ctx,TF_Buffer * buf,TF_Status * status)891 void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
892                                   TF_Status* status) {
893   TFE_ContextAsyncWait(ctx, status);
894   if (!status->status.ok()) return;
895   tensorflow::mutex_lock ml(*ctx->context.MetadataMu());
896   status->status = MessageToBuffer(*ctx->context.RunMetadataProto(), buf);
897   ctx->context.ClearRunMetadata();
898 }
899 
900 namespace {
GetFunc(TFE_Context * ctx,const tensorflow::NameAttrList & func,TF_Status * status)901 TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
902                 TF_Status* status) {
903   TFE_Op* func_op = TFE_NewOp(ctx, func.name().data(), status);
904   for (const auto& attr : func.attr()) {
905     if (TF_GetCode(status) != TF_OK) return nullptr;
906     SetOpAttrValueScalar(ctx, func_op, attr.second, attr.first.data(), status);
907     if (TF_GetCode(status) != TF_OK) return nullptr;
908   }
909   return func_op;
910 }
911 }  // namespace
912 
TFE_ContextStartStep(TFE_Context * ctx)913 void TFE_ContextStartStep(TFE_Context* ctx) { ctx->context.StartStep(); }
914 
TFE_ContextEndStep(TFE_Context * ctx)915 void TFE_ContextEndStep(TFE_Context* ctx) { ctx->context.EndStep(); }
916 
917 namespace tensorflow {
SetOpAttrValueScalar(TFE_Context * ctx,TFE_Op * op,const tensorflow::AttrValue & default_value,const char * attr_name,TF_Status * status)918 void SetOpAttrValueScalar(TFE_Context* ctx, TFE_Op* op,
919                           const tensorflow::AttrValue& default_value,
920                           const char* attr_name, TF_Status* status) {
921   switch (default_value.value_case()) {
922     case tensorflow::AttrValue::kS: {
923       const string& v = default_value.s();
924       TFE_OpSetAttrString(op, attr_name, v.data(), v.size());
925       break;
926     }
927     case tensorflow::AttrValue::kI:
928       TFE_OpSetAttrInt(op, attr_name, static_cast<int64_t>(default_value.i()));
929       break;
930     case tensorflow::AttrValue::kF:
931       TFE_OpSetAttrFloat(op, attr_name, default_value.f());
932       break;
933     case tensorflow::AttrValue::kB:
934       TFE_OpSetAttrBool(op, attr_name, default_value.b());
935       break;
936     case tensorflow::AttrValue::kType:
937       TFE_OpSetAttrType(op, attr_name,
938                         static_cast<TF_DataType>(default_value.type()));
939       break;
940     case tensorflow::AttrValue::kShape: {
941       const auto& tensor_shape = default_value.shape();
942       if (tensor_shape.unknown_rank()) {
943         TFE_OpSetAttrShape(op, attr_name, nullptr, -1, status);
944       } else {
945         const auto num_dims = tensor_shape.dim_size();
946         std::unique_ptr<int64_t[]> dims(new int64_t[num_dims]);
947         for (int i = 0; i < num_dims; ++i) {
948           dims[i] = tensor_shape.dim(i).size();
949         }
950         TFE_OpSetAttrShape(op, attr_name, dims.get(), num_dims, status);
951       }
952     } break;
953     case tensorflow::AttrValue::kFunc: {
954       const auto func_op = GetFunc(ctx, default_value.func(), status);
955       if (TF_GetCode(status) != TF_OK) return;
956       // TODO(nareshmodi): TFE_OpSetAttrFunction and TFE_OpSetAttrFunctionList
957       // require TFE_Op* and just convert it internally a NameAttrValue, so
958       // consider adding an overload to the C API to make this case easier.
959       TFE_OpSetAttrFunction(op, attr_name, func_op);
960     } break;
961     case tensorflow::AttrValue::kList:
962       TF_FALLTHROUGH_INTENDED;
963     case tensorflow::AttrValue::kTensor:
964       TF_FALLTHROUGH_INTENDED;
965     case tensorflow::AttrValue::kPlaceholder:
966       TF_FALLTHROUGH_INTENDED;
967     case tensorflow::AttrValue::VALUE_NOT_SET:
968       TF_SetStatus(
969           status, TF_UNIMPLEMENTED,
970           tensorflow::strings::StrCat("Unable to get setfor default value: ",
971                                       default_value.DebugString())
972               .data());
973   }
974 }
975 }  // namespace tensorflow
976