1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/kernels/function_ops.h"
17
18 #include <deque>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/executor.h"
23 #include "tensorflow/core/common_runtime/function.h"
24 #include "tensorflow/core/common_runtime/gradients.h"
25 #include "tensorflow/core/common_runtime/graph_constructor.h"
26 #include "tensorflow/core/common_runtime/memory_types.h"
27 #include "tensorflow/core/framework/op.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/graph/algorithm.h"
30 #include "tensorflow/core/platform/macros.h"
31 #include "tensorflow/core/platform/tracing.h"
32 #include "tensorflow/core/profiler/lib/traceme.h"
33 #include "tensorflow/core/util/device_name_utils.h"
34
35 namespace tensorflow {
36
37 static const char* const kGradientOp = FunctionLibraryDefinition::kGradientOp;
38
ArgOp(OpKernelConstruction * ctx)39 ArgOp::ArgOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
40 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
41 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
42 }
43
Compute(OpKernelContext * ctx)44 void ArgOp::Compute(OpKernelContext* ctx) {
45 auto frame = ctx->call_frame();
46 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
47 const Tensor* val;
48
49 auto validate_type = [this](const Tensor& val) {
50 if (val.dtype() == dtype_) {
51 return Status::OK();
52 } else {
53 return errors::InvalidArgument("Type mismatch: actual ",
54 DataTypeString(val.dtype()),
55 " vs. expect ", DataTypeString(dtype_));
56 }
57 };
58
59 if (frame->CanConsumeArg(index_)) {
60 Tensor val;
61 frame->ConsumeArg(index_, &val);
62 OP_REQUIRES_OK(ctx, validate_type(val));
63 ctx->set_output(0, std::move(val));
64 } else {
65 OP_REQUIRES_OK(ctx, frame->GetArg(index_, &val));
66 OP_REQUIRES_OK(ctx, validate_type(*val));
67 ctx->set_output(0, *val);
68 }
69 }
70
RetvalOp(OpKernelConstruction * ctx)71 RetvalOp::RetvalOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
72 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
73 OP_REQUIRES_OK(ctx, ctx->GetAttr("index", &index_));
74 }
75
Compute(OpKernelContext * ctx)76 void RetvalOp::Compute(OpKernelContext* ctx) {
77 const Tensor& val = ctx->input(0);
78 OP_REQUIRES(ctx, val.dtype() == dtype_,
79 errors::InvalidArgument("Type mismatch: actual ",
80 DataTypeString(val.dtype()),
81 " vs. expect ", DataTypeString(dtype_)));
82 auto frame = ctx->call_frame();
83 OP_REQUIRES(ctx, frame != nullptr, errors::Internal("no call frame"));
84 OP_REQUIRES_OK(ctx, frame->SetRetval(index_, val));
85 }
86
87 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kArgOp).Device(DEVICE_CPU), ArgOp);
88 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceArgOp).Device(DEVICE_CPU), ArgOp);
89 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_CPU), RetvalOp);
90 REGISTER_SYSTEM_KERNEL_BUILDER(Name(kDeviceRetOp).Device(DEVICE_CPU), RetvalOp);
91
92 // TPU ops are only registered when they are required as part of the larger
93 // TPU runtime, and does not need to be registered when selective registration
94 // is turned on.
95 REGISTER_KERNEL_BUILDER(Name(kRetOp).Device(DEVICE_TPU_SYSTEM), RetvalOp);
96
97
98 #define REGISTER(type) \
99 REGISTER_KERNEL_BUILDER( \
100 Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), ArgOp);
101 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
102 TF_CALL_QUANTIZED_TYPES(REGISTER)
103 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kArgOp)
104 .Device(DEVICE_GPU)
105 .HostMemory("output")
106 .TypeConstraint<int32>("T"),
107 ArgOp);
108 REGISTER_KERNEL_BUILDER(
109 Name(kDeviceArgOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), ArgOp);
110 #undef REGISTER
111
112 REGISTER_KERNEL_BUILDER(Name(kArgOp)
113 .Device(DEVICE_GPU)
114 .HostMemory("output")
115 .TypeConstraint<ResourceHandle>("T"),
116 ArgOp);
117
118 REGISTER_KERNEL_BUILDER(Name(kArgOp)
119 .Device(DEVICE_GPU)
120 .HostMemory("output")
121 .TypeConstraint<tstring>("T"),
122 ArgOp);
123
124 REGISTER_KERNEL_BUILDER(
125 Name(kArgOp).Device(DEVICE_GPU).TypeConstraint<Variant>("T"), ArgOp);
126
127 #define REGISTER(type) \
128 REGISTER_KERNEL_BUILDER( \
129 Name(kRetOp).Device(DEVICE_GPU).TypeConstraint<type>("T"), RetvalOp);
130 TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER)
131 TF_CALL_QUANTIZED_TYPES(REGISTER)
132 REGISTER(Variant)
133 TF_CALL_bool(REGISTER) REGISTER_KERNEL_BUILDER(Name(kRetOp)
134 .Device(DEVICE_GPU)
135 .HostMemory("input")
136 .TypeConstraint<int32>("T"),
137 RetvalOp);
138 REGISTER_KERNEL_BUILDER(
139 Name(kDeviceRetOp).Device(DEVICE_GPU).TypeConstraint<int32>("T"), RetvalOp);
140
141 REGISTER_KERNEL_BUILDER(Name(kRetOp)
142 .Device(DEVICE_GPU)
143 .TypeConstraint<ResourceHandle>("T")
144 .HostMemory("input"),
145 RetvalOp);
146
147 REGISTER_KERNEL_BUILDER(Name(kRetOp)
148 .Device(DEVICE_GPU)
149 .TypeConstraint<tstring>("T")
150 .HostMemory("input"),
151 RetvalOp);
152 #undef REGISTER
153
154 class PassOn : public OpKernel {
155 public:
PassOn(OpKernelConstruction * ctx)156 explicit PassOn(OpKernelConstruction* ctx) : OpKernel(ctx) {
157 OP_REQUIRES(ctx, ctx->num_inputs() == ctx->num_outputs(),
158 errors::Internal("#inputs != #outputs : ", ctx->num_inputs(),
159 " vs. ", ctx->num_outputs()));
160 for (int i = 0; i < ctx->num_inputs(); ++i) {
161 OP_REQUIRES(
162 ctx, input_type(i) == output_type(i),
163 errors::Internal("Input and output types for position ", i,
164 " do not match: ", DataTypeString(input_type(i)),
165 " vs. ", DataTypeString(output_type(i))));
166 }
167 }
168
Compute(OpKernelContext * ctx)169 void Compute(OpKernelContext* ctx) override {
170 for (int i = 0; i < ctx->num_inputs(); ++i) {
171 ctx->set_output(i, ctx->input(i));
172 }
173 }
174 };
175
176 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ListToArray").Device(DEVICE_CPU), PassOn);
177 REGISTER_SYSTEM_KERNEL_BUILDER(Name("_ArrayToList").Device(DEVICE_CPU), PassOn);
178
179 #define REGISTER_GPU_KERNELS(type) \
180 REGISTER_KERNEL_BUILDER( \
181 Name("_ListToArray").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
182 PassOn); \
183 REGISTER_KERNEL_BUILDER( \
184 Name("_ArrayToList").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
185 PassOn);
186
187 REGISTER_GPU_KERNELS(Eigen::half);
188 REGISTER_GPU_KERNELS(float);
189 REGISTER_GPU_KERNELS(double);
190
191 #undef REGISTER_GPU_KERNELS
192
193 REGISTER_KERNEL_BUILDER(Name("_ListToArray")
194 .Device(DEVICE_GPU)
195 .HostMemory("input")
196 .HostMemory("output")
197 .TypeConstraint<int32>("T"),
198 PassOn);
199 REGISTER_KERNEL_BUILDER(Name("_ArrayToList")
200 .Device(DEVICE_GPU)
201 .HostMemory("input")
202 .HostMemory("output")
203 .TypeConstraint<int32>("T"),
204 PassOn);
205
206
207 class SymbolicGradientOp : public AsyncOpKernel {
208 public:
SymbolicGradientOp(OpKernelConstruction * ctx)209 explicit SymbolicGradientOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {}
210
~SymbolicGradientOp()211 ~SymbolicGradientOp() override {}
212
ComputeAsync(OpKernelContext * ctx,DoneCallback done)213 void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
214 FunctionLibraryRuntime* lib = ctx->function_library();
215 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
216 errors::Internal("No function library is provided."),
217 done);
218
219 FunctionLibraryRuntime::Handle handle;
220 OP_REQUIRES_OK_ASYNC(
221 ctx, lib->Instantiate(kGradientOp, AttrSlice(def()), &handle), done);
222
223 FunctionLibraryRuntime::Options opts;
224 opts.rendezvous = ctx->rendezvous();
225 opts.cancellation_manager = ctx->cancellation_manager();
226 opts.collective_executor = ctx->collective_executor();
227 opts.runner = ctx->runner();
228 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
229 opts.stats_collector = ctx->stats_collector();
230 opts.step_container = ctx->step_container();
231 std::vector<Tensor> args;
232 args.reserve(ctx->num_inputs());
233 for (int i = 0; i < ctx->num_inputs(); ++i) {
234 args.push_back(ctx->input(i));
235 }
236 std::vector<Tensor>* rets = new std::vector<Tensor>;
237 profiler::TraceMe trace_me("SymbolicGradientOp");
238 lib->Run(opts, handle, args, rets, [ctx, done, rets](const Status& status) {
239 if (!status.ok()) {
240 ctx->SetStatus(status);
241 } else if (rets->size() != ctx->num_outputs()) {
242 ctx->SetStatus(errors::InvalidArgument(
243 "SymGrad expects to return ", ctx->num_outputs(),
244 " tensor(s), but get ", rets->size(), " tensor(s) instead."));
245 } else {
246 for (size_t i = 0; i < rets->size(); ++i) {
247 ctx->set_output(i, std::move((*rets)[i]));
248 }
249 }
250 delete rets;
251 done();
252 });
253 }
254
255 private:
256 TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientOp);
257 };
258
259 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_CPU),
260 SymbolicGradientOp);
261 REGISTER_KERNEL_BUILDER(Name(kGradientOp).Device(DEVICE_GPU),
262 SymbolicGradientOp);
263
RemoteCallOp(OpKernelConstruction * ctx)264 RemoteCallOp::RemoteCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
265 OP_REQUIRES_OK(ctx,
266 ctx->GetAttr(FunctionLibraryDefinition::kFuncAttr, &func_));
267 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_dtypes_));
268 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_dtypes_));
269 }
270
ComputeAsync(OpKernelContext * ctx,DoneCallback done)271 void RemoteCallOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
272 FunctionLibraryRuntime* lib = ctx->function_library();
273 OP_REQUIRES_ASYNC(ctx, lib != nullptr,
274 errors::Internal("No function library is provided."), done);
275
276 const string& source_device = lib->device()->name();
277 const Tensor* target;
278 OP_REQUIRES_OK_ASYNC(ctx, ctx->input("target", &target), done);
279
280 FunctionTarget function_target;
281 OP_REQUIRES_OK_ASYNC(
282 ctx,
283 DeviceNameUtils::CanonicalizeDeviceName(
284 target->scalar<tstring>()(), source_device, &function_target.first),
285 done);
286 function_target.second = lib;
287
288 const string& target_device = function_target.first;
289 const string& func_name = func_.name();
290
291 FunctionLibraryRuntime::Handle handle;
292 {
293 mutex_lock l(mu_);
294 auto cached_entry = handle_cache_.find(function_target);
295 if (cached_entry != handle_cache_.end()) {
296 handle = cached_entry->second;
297 } else {
298 VLOG(1) << "Instantiating " << func_name << " on " << target_device;
299 profiler::TraceMe activity(
300 [&] {
301 return strings::StrCat("RemoteCall: Instantiate: ", func_name,
302 " on ", target_device);
303 },
304 profiler::TraceMeLevel::kInfo);
305 FunctionLibraryRuntime::InstantiateOptions instantiate_opts;
306 const auto* config = (ctx->function_library())
307 ? ctx->function_library()->config_proto()
308 : nullptr;
309 if (config) {
310 instantiate_opts.config_proto = *config;
311 }
312 instantiate_opts.target = target_device;
313 OP_REQUIRES_OK_ASYNC(ctx,
314 lib->Instantiate(func_name, AttrSlice(&func_.attr()),
315 instantiate_opts, &handle),
316 done);
317 auto insert_result = handle_cache_.insert({function_target, handle});
318 CHECK(insert_result.second) << "Insert unsuccessful.";
319 VLOG(1) << "Instantiated " << func_name << " on " << target_device
320 << ", resulting in handle: " << handle << " flr: " << lib;
321 }
322 }
323
324 OpInputList arguments;
325 OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
326
327 FunctionLibraryRuntime::Options opts;
328 opts.runner = nullptr; // Use default runner at remote device.
329 opts.run_all_kernels_inline = ctx->run_all_kernels_inline();
330 opts.source_device = source_device;
331 if (opts.source_device != target_device) {
332 opts.remote_execution = true;
333 }
334 opts.create_rendezvous = true;
335 std::vector<Tensor> args(arguments.begin(), arguments.end());
336 opts.args_alloc_attrs.reserve(input_dtypes_.size());
337 for (const auto& dtype : input_dtypes_) {
338 AllocatorAttributes arg_alloc_attrs;
339 arg_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
340 opts.args_alloc_attrs.push_back(arg_alloc_attrs);
341 }
342 opts.rets_alloc_attrs.reserve(output_dtypes_.size());
343 for (const auto& dtype : output_dtypes_) {
344 AllocatorAttributes ret_alloc_attrs;
345 ret_alloc_attrs.set_on_host(DataTypeAlwaysOnHost(dtype));
346 opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
347 }
348 auto* rets = new std::vector<Tensor>;
349 VLOG(1) << "Running " << func_name << " on " << target_device
350 << " with handle: " << handle;
351 profiler::TraceMe trace_me(
352 [&] {
353 return absl::StrCat("RemoteCallOp#func_name=", func_name,
354 ",device=", target_device, "#");
355 },
356 profiler::TraceMeLevel::kInfo);
357 lib->Run(
358 opts, handle, args, rets,
359 [rets, done = std::move(done), func_name, ctx,
360 function_step_id = opts.step_id,
361 target_device = std::move(function_target.first)](const Status& status) {
362 profiler::TraceMe activity(
363 [&] {
364 return absl::StrCat("RemoteCallOpDone#func_name=", func_name,
365 ",device=", target_device, "#");
366 },
367 profiler::TraceMeLevel::kInfo);
368 if (!status.ok()) {
369 ctx->SetStatus(status);
370 } else {
371 for (size_t i = 0; i < rets->size(); ++i) {
372 ctx->set_output(i, std::move((*rets)[i]));
373 }
374 }
375 delete rets;
376 done();
377 });
378 }
379
TraceString(const OpKernelContext & ctx,bool verbose) const380 string RemoteCallOp::TraceString(const OpKernelContext& ctx,
381 bool verbose) const {
382 string trace_string = profiler::TraceMeOp(
383 strings::StrCat(name_view(), "__", func_.name()), type_string_view());
384 if (verbose) {
385 string shape = ShapeTraceString(ctx);
386 if (!shape.empty()) {
387 trace_string =
388 profiler::TraceMeEncode(std::move(trace_string), {{"shape", shape}});
389 }
390 }
391 return trace_string;
392 }
393
394 REGISTER_KERNEL_BUILDER(
395 Name("RemoteCall").Device(DEVICE_CPU).HostMemory("target"), RemoteCallOp);
396 REGISTER_KERNEL_BUILDER(
397 Name("RemoteCall").Device(DEVICE_GPU).HostMemory("target"), RemoteCallOp);
398 } // namespace tensorflow
399