1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/graph_compiler.h"
17 
18 #include <deque>
19 #include <numeric>
20 #include <vector>
21 
22 #include "tensorflow/compiler/tf2xla/const_analysis.h"
23 #include "tensorflow/compiler/tf2xla/literal_util.h"
24 #include "tensorflow/compiler/tf2xla/shape_util.h"
25 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
26 #include "tensorflow/compiler/tf2xla/type_util.h"
27 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
28 #include "tensorflow/compiler/tf2xla/xla_context.h"
29 #include "tensorflow/compiler/tf2xla/xla_expression.h"
30 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
31 #include "tensorflow/compiler/xla/client/client_library.h"
32 #include "tensorflow/compiler/xla/client/xla_builder.h"
33 #include "tensorflow/core/common_runtime/device.h"
34 #include "tensorflow/core/common_runtime/executor.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/common_runtime/graph_constructor.h"
37 #include "tensorflow/core/common_runtime/graph_optimizer.h"
38 #include "tensorflow/core/framework/attr_value.pb.h"
39 #include "tensorflow/core/framework/attr_value_util.h"
40 #include "tensorflow/core/framework/function.h"
41 #include "tensorflow/core/framework/node_def_util.h"
42 #include "tensorflow/core/framework/op_kernel.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/node_builder.h"
45 #include "tensorflow/core/graph/validate.h"
46 #include "tensorflow/core/lib/core/errors.h"
47 #include "tensorflow/core/lib/gtl/cleanup.h"
48 #include "tensorflow/core/lib/hash/hash.h"
49 #include "tensorflow/core/platform/logging.h"
50 #include "tensorflow/core/public/version.h"
51 #include "tensorflow/core/util/dump_graph.h"
52 
53 namespace tensorflow {
54 
55 namespace {
PrepareArguments(XlaOpKernelContext * ctx,Graph * graph,const std::vector<const XlaExpression * > & expressions,const NameAttrList & func,std::vector<XlaCompiler::Argument> * args)56 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
57                         const std::vector<const XlaExpression*>& expressions,
58                         const NameAttrList& func,
59                         std::vector<XlaCompiler::Argument>* args) {
60   auto client = ctx->compiler()->client();
61   std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
62 
63   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(
64       *graph, &arg_must_be_compile_time_constant,
65       /*compile_time_const_nodes=*/nullptr, ctx->function_library()));
66 
67   args->resize(expressions.size());
68   for (int i = 0, end = args->size(); i < end; ++i) {
69     XlaCompiler::Argument& arg = (*args)[i];
70     arg.type = ctx->input_type(i);
71     arg.shape = ctx->InputShape(i);
72 
73     switch (expressions[i]->kind()) {
74       case XlaExpression::Kind::kConstant:
75         arg.kind = XlaCompiler::Argument::kConstant;
76         arg.constant_value = *expressions[i]->constant_value();
77         break;
78       case XlaExpression::Kind::kXlaOp:
79         if (arg_must_be_compile_time_constant[i]) {
80           TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
81                               expressions[i]->ResolveConstant(client));
82           if (!value.has_value()) {
83             return errors::InvalidArgument(absl::StrCat(
84                 "Argument ", i, " to function '", func.name(),
85                 "' must be a compile-time constant, but ",
86                 "unable to resolve argument value to a constant."));
87           }
88           arg.kind = XlaCompiler::Argument::kConstant;
89           arg.constant_value = *value;
90         } else {
91           arg.kind = XlaCompiler::Argument::kParameter;
92         }
93         break;
94       case XlaExpression::Kind::kResource: {
95         XlaResource* resource = expressions[i]->resource();
96         XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
97         break;
98       }
99       case XlaExpression::Kind::kTensorList: {
100         arg.kind = XlaCompiler::Argument::kTensorList;
101         const xla::XlaOp& tensor_list = expressions[i]->handle();
102         arg.shape = tensor_list.builder()->GetShape(tensor_list).ValueOrDie();
103         break;
104       }
105       case XlaExpression::Kind::kInvalid:
106         return errors::InvalidArgument("Invalid function argument");
107     }
108   }
109   return Status::OK();
110 }
111 }  // namespace
Compile()112 Status GraphCompiler::Compile() {
113   // Check that the graph has no illegal cycles.
114   TF_RETURN_IF_ERROR(graph::ValidateGraphHasNoCycle(*graph_));
115   // Maintain a mapping from node id to node outputs.
116   using NodeOutputs = std::vector<TensorValue>;
117   std::vector<NodeOutputs> output_registry(graph_->num_node_ids());
118   auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] {
119     for (const NodeOutputs& outputs : output_registry) {
120       for (const TensorValue& value : outputs) {
121         CHECK(!value.is_ref());
122         delete value.tensor;
123       }
124     }
125   });
126 
127   // XLA requires determinism, generate a stable ordering from DFS.
128   std::vector<Node*> topo_sorted_nodes;
129   GetReversePostOrder(*graph_, &topo_sorted_nodes,
130                       /*stable_comparator=*/NodeComparatorName());
131 
132   OpKernelContext::Params params;
133   PartiallySetupParams(&params);
134 
135   for (Node* n : topo_sorted_nodes) {
136     OpKernel* op_kernel_raw = nullptr;
137     // The kernel is not actually run for functional ops, we just need it
138     // for metadata.
139     Status s = flib_->CreateKernel(n->properties(), &op_kernel_raw);
140     // Transfer ownership of the kernel to a local smart pointer.
141     std::unique_ptr<OpKernel> op_kernel(op_kernel_raw);
142 
143     if (!s.ok()) {
144       s = AttachDef(s, *n);
145       LOG(ERROR) << "Executor failed to create kernel. " << s;
146       return s;
147     }
148 
149     TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch())
150         << "Not supported node: " << n->DebugString();
151     params.op_kernel = op_kernel.get();
152     absl::InlinedVector<AllocatorAttributes, 4> output_attr(n->num_outputs());
153     params.output_attr_array = output_attr.data();
154 
155     // tensor_inputs_ is a buffer reused across graph traversal. We clean up and
156     // reinitialize the buffer before we visit a new node.
157     tensor_inputs_.clear();
158     tensor_inputs_.resize(n->num_inputs());
159 
160     // Set up inputs from outputs of previous nodes.
161     for (auto* e : n->in_edges()) {
162       if (e->IsControlEdge()) continue;
163       const Node* src = e->src();
164       const int output_registry_size = output_registry.size();
165       TF_RET_CHECK(src->id() < output_registry_size);
166       const NodeOutputs& src_outputs = output_registry[src->id()];
167 
168       tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output());
169     }
170 
171     OpKernelContext op_context(&params, n->num_outputs());
172     VLOG(3) << "Translating " << params.op_kernel->name();
173     if (IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n)) {
174       TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context));
175     } else {
176       device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context);
177       Status s = op_context.status();
178       if (!s.ok()) {
179         return AttachDef(s, n->def());
180       }
181     }
182 
183     // Set up outputs. Also check if outputs from the previous computation is
184     // valid.
185     NodeOutputs& outputs = output_registry[n->id()];
186     outputs.resize(n->num_outputs());
187     for (int o = 0; o < n->num_outputs(); ++o) {
188       outputs[o] = op_context.release_output(o);
189       if (outputs[o].tensor == nullptr) {
190         return errors::Internal("Missing xla_context ", o, "-th output from ",
191                                 FormatNodeForError(*n));
192       }
193     }
194   }
195   return Status::OK();
196 }
197 
198 namespace {
199 
GetFunctionNameAndAttr(const FunctionLibraryRuntime & flib,const Node & node,NameAttrList * func)200 Status GetFunctionNameAndAttr(const FunctionLibraryRuntime& flib,
201                               const Node& node, NameAttrList* func) {
202   if (node.IsPartitionedCall()) {
203     const AttrValue* attr_value;
204     TF_RETURN_IF_ERROR(
205         node.attrs().Find(FunctionLibraryDefinition::kFuncAttr, &attr_value));
206     if (!attr_value->has_func()) {
207       return errors::InvalidArgument(
208           "The attribute value for attribute 'f' in node ", node.DebugString(),
209           " does not have 'func' field set");
210     }
211     *func = attr_value->func();
212     return Status::OK();
213   }
214 
215   if (flib.GetFunctionLibraryDefinition()->Find(node.def().op())) {
216     func->set_name(node.type_string());
217   } else {
218     func->set_name(FunctionLibraryDefinition::kGradientOp);
219   }
220   *func->mutable_attr() = node.def().attr();
221   return Status::OK();
222 }
223 
224 }  // namespace
225 
CompileFunctionalNode(Node * n,OpKernelContext * op_context)226 Status GraphCompiler::CompileFunctionalNode(Node* n,
227                                             OpKernelContext* op_context) {
228   TF_RET_CHECK(IsFunctionCall(*flib_->GetFunctionLibraryDefinition(), *n));
229   // For functional nodes, compile them using compiler from the context and call
230   // into the functions.
231   XlaOpKernelContext xla_op_context(op_context);
232 
233   XlaContext& context = XlaContext::Get(op_context);
234   auto* b = context.builder();
235 
236   XlaCompiler* compiler = xla_op_context.compiler();
237 
238   NameAttrList func;
239   TF_RETURN_IF_ERROR(GetFunctionNameAndAttr(*flib_, *n, &func));
240 
241   std::vector<const XlaExpression*> expressions;
242 
243   for (auto tensor : tensor_inputs_) {
244     auto expression =
245         reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
246     expressions.push_back(expression);
247   }
248 
249   // Prepare the arguments and compile the function.
250   std::vector<XlaCompiler::Argument> arguments;
251   const FunctionBody* fbody;
252   TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody));
253 
254   auto graph = compiler->GetGraph(fbody);
255 
256   TF_RETURN_IF_ERROR(PrepareArguments(&xla_op_context, graph.get(), expressions,
257                                       func, &arguments));
258 
259   bool add_token_input_output =
260       func.attr().find(kXlaTokenInputNodesAttrName) != func.attr().end();
261 
262   XlaCompiler::CompileOptions compile_options;
263   compile_options.is_entry_computation = false;
264   compile_options.add_token_input_output = add_token_input_output;
265   XlaCompiler::CompilationResult result;
266   TF_RETURN_IF_ERROR(
267       compiler->CompileFunction(compile_options, func, arguments, &result));
268 
269   TF_RET_CHECK(arguments.size() == expressions.size());
270 
271   std::vector<xla::XlaOp> handles;
272   for (int64 i = 0, end = expressions.size(); i < end; ++i) {
273     if (arguments[i].kind == XlaCompiler::Argument::kConstant) {
274       continue;
275     }
276     if (arguments[i].kind == XlaCompiler::Argument::kResource) {
277       handles.push_back(expressions[i]->resource()->value());
278     } else {
279       handles.push_back(expressions[i]->handle());
280     }
281   }
282   if (add_token_input_output) {
283     std::vector<string> token_input_nodes;
284     TF_RETURN_IF_ERROR(GetNodeAttr(AttrSlice(&func.attr()),
285                                    kXlaTokenInputNodesAttrName,
286                                    &token_input_nodes));
287     std::vector<xla::XlaOp> token_inputs;
288     for (const string& node_name : token_input_nodes) {
289       auto token_or = compiler->GetNodeToken(node_name);
290       TF_RETURN_IF_ERROR(token_or.status());
291       token_inputs.push_back(token_or.ConsumeValueOrDie());
292     }
293     xla::XlaOp token_input = xla::AfterAll(b, token_inputs);
294     handles.push_back(token_input);
295   }
296 
297   auto output_handle = xla::Call(b, *result.computation, handles);
298   // The output handle of `Call` computation is a tuple type. Unzip it so
299   // that it can fit into future computations.
300   int computation_output = 0;
301   for (int64 i = 0; i < n->num_outputs(); ++i) {
302     if (result.outputs[i].is_constant) {
303       xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value);
304     } else {
305       if (result.outputs[i].is_tensor_list) {
306         xla_op_context.SetTensorListOutput(
307             i, xla::GetTupleElement(output_handle, computation_output));
308       } else {
309         xla_op_context.SetOutput(
310             i, xla::GetTupleElement(output_handle, computation_output));
311       }
312       ++computation_output;
313     }
314   }
315 
316   for (int64 i = 0, end = result.resource_updates.size(); i < end; i++) {
317     if (result.resource_updates[i].modified) {
318       XlaResource* resource =
319           expressions[result.resource_updates[i].input_index]->resource();
320       xla::XlaOp updated_value =
321           xla::GetTupleElement(output_handle, i + n->num_outputs());
322       TF_RETURN_IF_ERROR(resource->SetValue(updated_value));
323     }
324   }
325 
326   if (add_token_input_output) {
327     TF_RETURN_IF_ERROR(compiler->SetNodeToken(
328         n->name(), xla::GetTupleElement(output_handle, computation_output)));
329   }
330   return b->first_error();
331 }
332 
PartiallySetupParams(OpKernelContext::Params * params)333 void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) {
334   params->device = device_;
335   params->inputs = &tensor_inputs_;
336   params->step_container = step_container_;
337   params->resource_manager = device_->resource_manager();
338   params->function_library = flib_;
339 }
340 
341 }  // namespace tensorflow
342