1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/kernels/function_ops.h"
17 
18 #include <deque>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/gradients.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/tracing.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34 
35 namespace tensorflow {
36 
37 static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp;
38 
ArgOp(OpKernelConstruction * ctx)39 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
40   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
41   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
42 }
43 
Compute(OpKernelContext * ctx)44 void ArgOp::Compute(OpKernelContext* ctx) {
45   auto frame = ctx->call_frame();
46   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
47   const Tensor* val;
48 
49   auto validate_type = [this](const Tensor& val) {
50     if (val.dtype() == dtype_) {
51       return Status::OK();
52     } else {
53       return errors::InvalidArgument("Type mismatch: actual ",
54                                      DataTypeString(val.dtype()),
55                                      " vs. expect ", DataTypeString(dtype_));
56     }
57   };
58 
59   if (frame->CanConsumeArg(index_)) {
60     Tensor val;
61     frame->ConsumeArg(index_, &val);
62     OP_REQUIRES_OK(ctx, validate_type(val));
63     ctx->set_output(0, std::move(val));
64   } else {
65     OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
66     OP_REQUIRES_OK(ctx, validate_type(*val));
67     ctx->set_output(0, *val);
68   }
69 }
70 
RetvalOp(OpKernelConstruction * ctx)71 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
72   OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
73   OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
74 }
75 
Compute(OpKernelContext * ctx)76 void RetvalOp::Compute(OpKernelContext* ctx) {
77   const Tensor& val = ctx->input(0);
78   OP_REQUIRES(ctx, val.dtype() == dtype_,
79               errors::InvalidArgument("Type mismatch: actual ",
80                                       DataTypeString(val.dtype()),
81                                       " vs. expect ", DataTypeString(dtype_)));
82   auto frame = ctx->call_frame();
83   OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
84   OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
85 }
86 
87 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
88 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
89 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
90 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
91 
92 // TPU ops are only registered when they are required as part of the larger
93 // TPU runtime, and does not need to be registered when selective registration
94 // is turned on.
95 REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
96 
97 
98 #define REGISTER(type)     \
99   REGISTER_KERNEL_BUILDER( \
100       Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
101 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
102 TF_CALL_QUANTIZED_TYPES(REGISTER)
103 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
104                                                    .Device(DEVICE_GPU)
105                                                    .HostMemory("output")
106                                                    .TypeConstraint<int32>("T"),
107                                                ArgOp);
108 REGISTER_KERNEL_BUILDER(
109     Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
110 #undef REGISTER
111 
112 REGISTER_KERNEL_BUILDER(Name(kArgOp)
113                             .Device(DEVICE_GPU)
114                             .HostMemory("output")
115                             .TypeConstraint<ResourceHandle>("T"),
116                         ArgOp);
117 
118 REGISTER_KERNEL_BUILDER(Name(kArgOp)
119                             .Device(DEVICE_GPU)
120                             .HostMemory("output")
121                             .TypeConstraint<tstring>("T"),
122                         ArgOp);
123 
124 REGISTER_KERNEL_BUILDER(
125     Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<Variant>("T"), ArgOp);
126 
127 #define REGISTER(type)     \
128   REGISTER_KERNEL_BUILDER( \
129       Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
130 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
131 TF_CALL_QUANTIZED_TYPES(REGISTER)
132 REGISTER(Variant)
133 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
134                                                    .Device(DEVICE_GPU)
135                                                    .HostMemory("input")
136                                                    .TypeConstraint<int32>("T"),
137                                                RetvalOp);
138 REGISTER_KERNEL_BUILDER(
139     Name(kDeviceRetOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), RetvalOp);
140 
141 REGISTER_KERNEL_BUILDER(Name(kRetOp)
142                             .Device(DEVICE_GPU)
143                             .TypeConstraint<ResourceHandle>("T")
144                             .HostMemory("input"),
145                         RetvalOp);
146 
147 REGISTER_KERNEL_BUILDER(Name(kRetOp)
148                             .Device(DEVICE_GPU)
149                             .TypeConstraint<tstring>("T")
150                             .HostMemory("input"),
151                         RetvalOp);
152 #undef REGISTER
153 
154 class PassOn : public OpKernel {
155  public:
PassOn(OpKernelConstruction * ctx)156   explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
157     OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
158                 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
159                                  " vs. ", ctx->num_outputs()));
160     for (int i = 0; i < ctx->num_inputs(); ++i) {
161       OP_REQUIRES(
162           ctx, input_type(i) == output_type(i),
163           errors::Internal("Input and output types for position ", i,
164                            " do not match: ", DataTypeString(input_type(i)),
165                            " vs. ", DataTypeString(output_type(i))));
166     }
167   }
168 
Compute(OpKernelContext * ctx)169   void Compute(OpKernelContext* ctx) override {
170     for (int i = 0; i < ctx->num_inputs(); ++i) {
171       ctx->set_output(i, ctx->input(i));
172     }
173   }
174 };
175 
176 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
177 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
178 
179 #define REGISTER_GPU_KERNELS(type)                                       \
180   REGISTER_KERNEL_BUILDER(                                               \
181       Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
182       PassOn);                                                           \
183   REGISTER_KERNEL_BUILDER(                                               \
184       Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
185       PassOn);
186 
187 REGISTER_GPU_KERNELS(Eigen::half);
188 REGISTER_GPU_KERNELS(float);
189 REGISTER_GPU_KERNELS(double);
190 
191 #undef REGISTER_GPU_KERNELS
192 
193 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
194                             .Device(DEVICE_GPU)
195                             .HostMemory("input")
196                             .HostMemory("output")
197                             .TypeConstraint<int32>("T"),
198                         PassOn);
199 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
200                             .Device(DEVICE_GPU)
201                             .HostMemory("input")
202                             .HostMemory("output")
203                             .TypeConstraint<int32>("T"),
204                         PassOn);
205 
206 
207 class SymbolicGradientOp : public AsyncOpKernel {
208  public:
SymbolicGradientOp(OpKernelConstruction * ctx)209   explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
210 
~SymbolicGradientOp()211   ~SymbolicGradientOp() override {}
212 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)213   void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
214     FunctionLibraryRuntime* lib = ctx->function_library();
215     OP_REQUIRES_ASYNC(ctx, lib != nullptr,
216                       errors::Internal("No function library is provided."),
217                       done);
218 
219     FunctionLibraryRuntime::Handle handle;
220     OP_REQUIRES_OK_ASYNC(
221         ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
222 
223     FunctionLibraryRuntime::Options opts;
224     opts.rendezvous = ctx->rendezvous();
225     opts.cancellation_manager = ctx->cancellation_manager();
226     opts.collective_executor = ctx->collective_executor();
227     opts.runner = ctx->runner();
228     opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
229     opts.stats_collector = ctx->stats_collector();
230     opts.step_container = ctx->step_container();
231     std::vector<Tensor> args;
232     args.reserve(ctx->num_inputs());
233     for (int i = 0; i < ctx->num_inputs(); ++i) {
234       args.push_back(ctx->input(i));
235     }
236     std::vector<Tensor>* rets = new std::vector<Tensor>;
237     profiler::TraceMe trace_me("SymbolicGradientOp");
238     lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
239       if (!status.ok()) {
240         ctx->SetStatus(status);
241       } else if (rets->size() != ctx->num_outputs()) {
242         ctx->SetStatus(errors::InvalidArgument(
243             "SymGrad expects to return ", ctx->num_outputs(),
244             " tensor(s), but get ", rets->size(), " tensor(s) instead."));
245       } else {
246         for (size_t i = 0; i < rets->size(); ++i) {
247           ctx->set_output(i, std::move((*rets)[i]));
248         }
249       }
250       delete rets;
251       done();
252     });
253   }
254 
255  private:
256   TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
257 };
258 
259 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
260                         SymbolicGradientOp);
261 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
262                         SymbolicGradientOp);
263 
RemoteCallOp(OpKernelConstruction * ctx)264 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
265   OP_REQUIRES_OK(ctx,
266                  ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
267   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
268   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
269 }
270 
ComputeAsync(OpKernelContext * ctx,DoneCallback done)271 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
272   FunctionLibraryRuntime* lib = ctx->function_library();
273   OP_REQUIRES_ASYNC(ctx, lib != nullptr,
274                     errors::Internal("No function library is provided."), done);
275 
276   const string& source_device = lib->device()->name();
277   const Tensor* target;
278   OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
279 
280   FunctionTarget function_target;
281   OP_REQUIRES_OK_ASYNC(
282       ctx,
283       DeviceNameUtils::CanonicalizeDeviceName(
284           target->scalar<tstring>()(), source_device, &function_target.first),
285       done);
286   function_target.second = lib;
287 
288   const string& target_device = function_target.first;
289   const string& func_name = func_.name();
290 
291   FunctionLibraryRuntime::Handle handle;
292   {
293     mutex_lock l(mu_);
294     auto cached_entry = handle_cache_.find(function_target);
295     if (cached_entry != handle_cache_.end()) {
296       handle = cached_entry->second;
297     } else {
298       VLOG(1) << "Instantiating " << func_name << " on " << target_device;
299       profiler::TraceMe activity(
300           [&] {
301             return strings::StrCat("RemoteCall: Instantiate: ", func_name,
302                                    " on ", target_device);
303           },
304           profiler::TraceMeLevel::kInfo);
305       FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
306       const auto* config = (ctx->function_library())
307                                ? ctx->function_library()->config_proto()
308                                : nullptr;
309       if (config) {
310         instantiate_opts.config_proto = *config;
311       }
312       instantiate_opts.target = target_device;
313       OP_REQUIRES_OK_ASYNC(ctx,
314                            lib->Instantiate(func_name, AttrSlice(&func_.attr()),
315                                             instantiate_opts, &handle),
316                            done);
317       auto insert_result = handle_cache_.insert({function_target, handle});
318       CHECK(insert_result.second) << "Insert unsuccessful.";
319       VLOG(1) << "Instantiated " << func_name << " on " << target_device
320               << ", resulting in handle: " << handle << " flr: " << lib;
321     }
322   }
323 
324   OpInputList arguments;
325   OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
326 
327   FunctionLibraryRuntime::Options opts;
328   opts.runner = nullptr;  // Use default runner at remote device.
329   opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
330   opts.source_device = source_device;
331   if (opts.source_device != target_device) {
332     opts.remote_execution = true;
333   }
334   opts.create_rendezvous = true;
335   std::vector<Tensor> args(arguments.begin(), arguments.end());
336   opts.args_alloc_attrs.reserve(input_dtypes_.size());
337   for (const auto& dtype : input_dtypes_) {
338     AllocatorAttributes arg_alloc_attrs;
339     arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
340     opts.args_alloc_attrs.push_back(arg_alloc_attrs);
341   }
342   opts.rets_alloc_attrs.reserve(output_dtypes_.size());
343   for (const auto& dtype : output_dtypes_) {
344     AllocatorAttributes ret_alloc_attrs;
345     ret_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
346     opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
347   }
348   auto* rets = new std::vector<Tensor>;
349   VLOG(1) << "Running " << func_name << " on " << target_device
350           << " with handle: " << handle;
351   profiler::TraceMe trace_me(
352       [&] {
353         return absl::StrCat("RemoteCallOp#func_name=", func_name,
354                             ",device=", target_device, "#");
355       },
356       profiler::TraceMeLevel::kInfo);
357   lib->Run(
358       opts, handle, args, rets,
359       [rets, done = std::move(done), func_name, ctx,
360        function_step_id = opts.step_id,
361        target_device = std::move(function_target.first)](const Status& status) {
362         profiler::TraceMe activity(
363             [&] {
364               return absl::StrCat("RemoteCallOpDone#func_name=", func_name,
365                                   ",device=", target_device, "#");
366             },
367             profiler::TraceMeLevel::kInfo);
368         if (!status.ok()) {
369           ctx->SetStatus(status);
370         } else {
371           for (size_t i = 0; i < rets->size(); ++i) {
372             ctx->set_output(i, std::move((*rets)[i]));
373           }
374         }
375         delete rets;
376         done();
377       });
378 }
379 
TraceString(const OpKernelContext & ctx,bool verbose) const380 string RemoteCallOp::TraceString(const OpKernelContext& ctx,
381                                  bool verbose) const {
382   string trace_string = profiler::TraceMeOp(
383       strings::StrCat(name_view(), "__", func_.name()), type_string_view());
384   if (verbose) {
385     string shape = ShapeTraceString(ctx);
386     if (!shape.empty()) {
387       trace_string =
388           profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
389     }
390   }
391   return trace_string;
392 }
393 
394 REGISTER_KERNEL_BUILDER(
395     Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
396 REGISTER_KERNEL_BUILDER(
397     Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
398 }  // namespace tensorflow
399