1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17
18 #include <deque>
19 #include <numeric>
20 #include <vector>
21
22 #include "tensorflow/compiler/tf2xla/const_analysis.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
26 #include "tensorflow/compiler/tf2xla/type_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
28 #include "tensorflow/compiler/tf2xla/xla_context.h"
29 #include "tensorflow/compiler/tf2xla/xla_expression.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
31 #include "tensorflow/compiler/xla/client/client_library.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/executor.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/attr_value_util.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/node_def_util.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/graph/validate.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/hash/hash.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/util/dump_graph.h"
52
53 namespace tensorflow {
54
55 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,const NameAttrList & func,std::vector<XlaCompiler::Argument> * args)56 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
57 const std::vector<const XlaExpression*>& expressions,
58 const NameAttrList& func,
59 std::vector<XlaCompiler::Argument>* args) {
60 auto client = ctx->compiler()->client();
61 std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
62
63 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
64 *graph, &arg_must_be_compile_time_constant,
65 /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
66
67 args->resize(expressions.size());
68 for (int i = 0, end = args->size(); i < end; ++i) {
69 XlaCompiler::Argument& arg = (*args)[i];
70 arg.type = ctx->input_type(i);
71 arg.shape = ctx->InputShape(i);
72
73 switch (expressions[i]->kind()) {
74 case XlaExpression::Kind::kConstant:
75 arg.kind = XlaCompiler::Argument::kConstant;
76 arg.constant_value = *expressions[i]->constant_value();
77 break;
78 case XlaExpression::Kind::kXlaOp:
79 if (arg_must_be_compile_time_constant[i]) {
80 TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
81 expressions[i]->ResolveConstant(client));
82 if (!value.has_value()) {
83 return errors::InvalidArgument(absl::StrCat(
84 "Argument ", i, " to function '", func.name(),
85 "' must be a compile-time constant, but ",
86 "unable to resolve argument value to a constant."));
87 }
88 arg.kind = XlaCompiler::Argument::kConstant;
89 arg.constant_value = *value;
90 } else {
91 arg.kind = XlaCompiler::Argument::kParameter;
92 }
93 break;
94 case XlaExpression::Kind::kResource: {
95 XlaResource* resource = expressions[i]->resource();
96 XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
97 break;
98 }
99 case XlaExpression::Kind::kTensorList: {
100 arg.kind = XlaCompiler::Argument::kTensorList;
101 const xla::XlaOp& tensor_list = expressions[i]->handle();
102 arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
103 break;
104 }
105 case XlaExpression::Kind::kInvalid:
106 return errors::InvalidArgument("Invalid function argument");
107 }
108 }
109 return Status::OK();
110 }
111 } // namespace
Compile()112 Status GraphCompiler::Compile() {
113 // Check that the graph has no illegal cycles.
114 TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
115 // Maintain a mapping from node id to node outputs.
116 using NodeOutputs = std::vector<TensorValue>;
117 std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
118 auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
119 for (const NodeOutputs& outputs : output_registry) {
120 for (const TensorValue& value : outputs) {
121 CHECK(!value.is_ref());
122 delete value.tensor;
123 }
124 }
125 });
126
127 // XLA requires determinism, generate a stable ordering from DFS.
128 std::vector<Node*> topo_sorted_nodes;
129 GetReversePostOrder(*graph_, &topo_sorted_nodes,
130 /*stable_comparator=*/NodeComparatorName());
131
132 OpKernelContext::Params params;
133 PartiallySetupParams(¶ms);
134
135 for (Node* n : topo_sorted_nodes) {
136 OpKernel* op_kernel_raw = nullptr;
137 // The kernel is not actually run for functional ops, we just need it
138 // for metadata.
139 Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
140 // Transfer ownership of the kernel to a local smart pointer.
141 std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
142
143 if (!s.ok()) {
144 s = AttachDef(s, *n);
145 LOG(ERROR) << "Executor failed to create kernel. " << s;
146 return s;
147 }
148
149 TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
150 << "Not supported node: " << n->DebugString();
151 params.op_kernel = op_kernel.get();
152 absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
153 params.output_attr_array = output_attr.data();
154
155 // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
156 // reinitialize the buffer before we visit a new node.
157 tensor_inputs_.clear();
158 tensor_inputs_.resize(n->num_inputs());
159
160 // Set up inputs from outputs of previous nodes.
161 for (auto* e : n->in_edges()) {
162 if (e->IsControlEdge()) continue;
163 const Node* src = e->src();
164 const int output_registry_size = output_registry.size();
165 TF_RET_CHECK(src->id() < output_registry_size);
166 const NodeOutputs& src_outputs = output_registry[src->id()];
167
168 tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
169 }
170
171 OpKernelContext op_context(¶ms, n->num_outputs());
172 VLOG(3) << "Translating " << params.op_kernel->name();
173 if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
174 TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
175 } else {
176 device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
177 Status s = op_context.status();
178 if (!s.ok()) {
179 return AttachDef(s, n->def());
180 }
181 }
182
183 // Set up outputs. Also check if outputs from the previous computation is
184 // valid.
185 NodeOutputs& outputs = output_registry[n->id()];
186 outputs.resize(n->num_outputs());
187 for (int o = 0; o < n->num_outputs(); ++o) {
188 outputs[o] = op_context.release_output(o);
189 if (outputs[o].tensor == nullptr) {
190 return errors::Internal("Missing xla_context ", o, "-th output from ",
191 FormatNodeForError(*n));
192 }
193 }
194 }
195 return Status::OK();
196 }
197
198 namespace {
199
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)200 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
201 const Node& node, NameAttrList* func) {
202 if (node.IsPartitionedCall()) {
203 const AttrValue* attr_value;
204 TF_RETURN_IF_ERROR(
205 node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
206 if (!attr_value->has_func()) {
207 return errors::InvalidArgument(
208 "The attribute value for attribute 'f' in node ", node.DebugString(),
209 " does not have 'func' field set");
210 }
211 *func = attr_value->func();
212 return Status::OK();
213 }
214
215 if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
216 func->set_name(node.type_string());
217 } else {
218 func->set_name(FunctionLibraryDefinition::kGradientOp);
219 }
220 *func->mutable_attr() = node.def().attr();
221 return Status::OK();
222 }
223
224 } // namespace
225
CompileFunctionalNode(Node * n,OpKernelContext * op_context)226 Status GraphCompiler::CompileFunctionalNode(Node* n,
227 OpKernelContext* op_context) {
228 TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
229 // For functional nodes, compile them using compiler from the context and call
230 // into the functions.
231 XlaOpKernelContext xla_op_context(op_context);
232
233 XlaContext& context = XlaContext::Get(op_context);
234 auto* b = context.builder();
235
236 XlaCompiler* compiler = xla_op_context.compiler();
237
238 NameAttrList func;
239 TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
240
241 std::vector<const XlaExpression*> expressions;
242
243 for (auto tensor : tensor_inputs_) {
244 auto expression =
245 reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
246 expressions.push_back(expression);
247 }
248
249 // Prepare the arguments and compile the function.
250 std::vector<XlaCompiler::Argument> arguments;
251 const FunctionBody* fbody;
252 TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
253
254 auto graph = compiler->GetGraph(fbody);
255
256 TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions,
257 func, &arguments));
258
259 bool add_token_input_output =
260 func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
261
262 XlaCompiler::CompileOptions compile_options;
263 compile_options.is_entry_computation = false;
264 compile_options.add_token_input_output = add_token_input_output;
265 XlaCompiler::CompilationResult result;
266 TF_RETURN_IF_ERROR(
267 compiler->CompileFunction(compile_options, func, arguments, &result));
268
269 TF_RET_CHECK(arguments.size() == expressions.size());
270
271 std::vector<xla::XlaOp> handles;
272 for (int64 i = 0, end = expressions.size(); i < end; ++i) {
273 if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
274 continue;
275 }
276 if (arguments[i].kind == XlaCompiler::Argument::kResource) {
277 handles.push_back(expressions[i]->resource()->value());
278 } else {
279 handles.push_back(expressions[i]->handle());
280 }
281 }
282 if (add_token_input_output) {
283 std::vector<string> token_input_nodes;
284 TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
285 kXlaTokenInputNodesAttrName,
286 &token_input_nodes));
287 std::vector<xla::XlaOp> token_inputs;
288 for (const string& node_name : token_input_nodes) {
289 auto token_or = compiler->GetNodeToken(node_name);
290 TF_RETURN_IF_ERROR(token_or.status());
291 token_inputs.push_back(token_or.ConsumeValueOrDie());
292 }
293 xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
294 handles.push_back(token_input);
295 }
296
297 auto output_handle = xla::Call(b, *result.computation, handles);
298 // The output handle of `Call` computation is a tuple type. Unzip it so
299 // that it can fit into future computations.
300 int computation_output = 0;
301 for (int64 i = 0; i < n->num_outputs(); ++i) {
302 if (result.outputs[i].is_constant) {
303 xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
304 } else {
305 if (result.outputs[i].is_tensor_list) {
306 xla_op_context.SetTensorListOutput(
307 i, xla::GetTupleElement(output_handle, computation_output));
308 } else {
309 xla_op_context.SetOutput(
310 i, xla::GetTupleElement(output_handle, computation_output));
311 }
312 ++computation_output;
313 }
314 }
315
316 for (int64 i = 0, end = result.resource_updates.size(); i < end; i++) {
317 if (result.resource_updates[i].modified) {
318 XlaResource* resource =
319 expressions[result.resource_updates[i].input_index]->resource();
320 xla::XlaOp updated_value =
321 xla::GetTupleElement(output_handle, i + n->num_outputs());
322 TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
323 }
324 }
325
326 if (add_token_input_output) {
327 TF_RETURN_IF_ERROR(compiler->SetNodeToken(
328 n->name(), xla::GetTupleElement(output_handle, computation_output)));
329 }
330 return b->first_error();
331 }
332
PartiallySetupParams(OpKernelContext::Params * params)333 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
334 params->device = device_;
335 params->inputs = &tensor_inputs_;
336 params->step_container = step_container_;
337 params->resource_manager = device_->resource_manager();
338 params->function_library = flib_;
339 }
340
341 } // namespace tensorflow
342