1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/case_op.h"
17 
18 #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_context.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 
28 namespace tensorflow {
29 
XlaCaseOp(OpKernelConstruction * ctx)30 XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
31   OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &unpruned_branches_));
32   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
33   OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
34   if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
35     has_token_input_output_ = false;
36   } else {
37     has_token_input_output_ = !token_input_nodes_.empty();
38   }
39   if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
40     OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
41                                      &propagate_compile_time_consts_));
42   }
43 }
44 
45 std::pair<std::vector<NameAttrList>, xla::XlaOp>
GetPrunedBranchesAndIndex(XlaOpKernelContext * ctx)46 XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) {
47   xla::Literal branch_index_literal;
48   bool branch_index_is_constant =
49       ctx->ConstantInput(0, &branch_index_literal).ok();
50 
51   if (!branch_index_is_constant) {
52     return {unpruned_branches_, ctx->Input(0)};
53   }
54 
55   int32 branch_index = branch_index_literal.Get<int32>({});
56   if (branch_index < 0 || branch_index >= unpruned_branches_.size()) {
57     branch_index = unpruned_branches_.size() - 1;
58   }
59 
60   std::vector<NameAttrList> pruned_branch = {unpruned_branches_[branch_index]};
61   return {pruned_branch, xla::ZerosLike(ctx->Input(0))};
62 }
63 
64 // TODO(b/35949885): There is duplication here with the handling of the
65 // while_op/if_op. Refactor the common code out/rework.
Compile(XlaOpKernelContext * ctx)66 void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
67   OP_REQUIRES(ctx, !unpruned_branches_.empty(),
68               errors::InvalidArgument("Must provide at least one case branch"));
69   OP_REQUIRES(ctx, input_type(0) == DT_INT32,
70               errors::InvalidArgument(
71                   "branch_index argument must be a int32 for XLA compilation"));
72   OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)),
73               errors::InvalidArgument(
74                   "branch_index argument must be scalar for XLA compilation"));
75 
76   xla::XlaBuilder* b = ctx->builder();
77 
78   // We opportunistically prune out branches if the branch index is a
79   // compile-time constant.  This is important in the context of the DeviceIndex
80   // ops (and other such ops that may come later) since we may have a Case with
81   // trivially unselected branches that cannot be compiled into HLO.
82   std::vector<NameAttrList> branches;
83   xla::XlaOp branch_index;
84   std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx);
85 
86   int num_branches = branches.size();
87 
88   VLOG(1) << "Building Case: " << input_types_.size() << " inputs";
89 
90   std::vector<XlaCompiler::Argument> arguments(input_types_.size());
91   int num_resource_args = 0;
92   for (int i = 0; i < input_types_.size(); ++i) {
93     XlaCompiler::Argument& arg = arguments[i];
94     DataType type = ctx->input_type(i + 1);
95 
96     if (type == DT_RESOURCE) {
97       XlaResource* resource;
98       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
99       XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
100       OP_REQUIRES(ctx, arg.initialized,
101                   errors::Unimplemented("Uninitialized arguments: ", arg.name));
102       VLOG(2) << "Resource " << resource->name()
103               << " type: " << DataTypeString(arg.type)
104               << " shape: " << arg.HumanString()
105               << " initialized: " << arg.initialized;
106 
107       num_resource_args++;
108     } else {
109       arg.kind = XlaCompiler::Argument::kParameter;
110       arg.type = input_types_[i];
111       // Use the xla::Shape for the input instead of ctx->InputShape. This is
112       // necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists.
113       auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1));
114       OP_REQUIRES_OK(ctx, shape_or.status());
115       arg.shape = shape_or.ValueOrDie();
116       VLOG(2) << "Arg type: " << DataTypeString(arg.type)
117               << " shape: " << arg.HumanString();
118     }
119   }
120 
121   if (propagate_compile_time_consts_) {
122     std::vector<std::vector<bool>> case_branch_must_be_const_nodes(
123         num_branches);
124     std::vector<const FunctionBody*> case_bodies(num_branches);
125     for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) {
126       OP_REQUIRES_OK(ctx, FindMustBeConstNodes(
127                               ctx, branches[branch_idx],
128                               &case_branch_must_be_const_nodes[branch_idx],
129                               &case_bodies[branch_idx]));
130     }
131 
132     // Replaces `kParameter` type args in `arguments` with `kConstant` if
133     // the op input corresponding to that arg is a compile-time const. This
134     // is necessary to propagate compile time consts to ops in the branch
135     // functions.
136     auto arg_is_parameter = [&](int arg_idx) {
137       if (arguments[arg_idx].kind != XlaCompiler::Argument::kParameter) {
138         return false;
139       }
140       for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) {
141         if (!case_branch_must_be_const_nodes
142                 [branch_idx]
143                 [case_bodies[branch_idx]->arg_nodes[arg_idx]->id()]) {
144           return false;
145         }
146       }
147       return true;
148     };
149     ConvertCompileTimeConstArgumentsToConst(ctx, &arguments,
150                                             /*xla_expression_offset=*/1,
151                                             arg_is_parameter);
152   }
153 
154   // Compile each branch of the conditional.
155   XlaCompiler::CompileOptions options;
156   options.use_tuple_arg = true;
157   options.return_updated_values_for_all_resources = true;
158   options.is_entry_computation = false;
159   options.add_token_input_output = has_token_input_output_;
160   XlaCompiler* compiler = ctx->compiler();
161 
162   std::vector<XlaCompiler::CompilationResult> branch_results(num_branches);
163   for (int j = 0; j < num_branches; ++j) {
164     OP_REQUIRES_OK(ctx,
165                    compiler->CompileFunction(options, branches[j], arguments,
166                                              &branch_results[j]));
167   }
168 
169   bool has_tensor_array_gradients = false;
170   for (XlaCompiler::CompilationResult& result : branch_results) {
171     for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) {
172       XlaResource* resource;
173       OP_REQUIRES_OK(ctx,
174                      ctx->GetResourceInput(update.input_index + 1, &resource));
175       XlaCompiler::Argument& arg = arguments[update.input_index];
176 
177       // Add any TensorArray gradients touched by the then/else computation to
178       // the enclosing graph.
179       for (const string& grad_source : update.tensor_array_gradients_accessed) {
180         VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
181                 << grad_source;
182         XlaResource* gradient;
183         OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
184                                 grad_source, b, &gradient));
185       }
186       // Add all of the TensorArray gradients to the argument. For simplicity,
187       // we always pass all known gradients.
188       for (const auto& gradient : resource->tensor_array_gradients()) {
189         arg.tensor_array_gradients.insert(gradient.first);
190       }
191       if (!resource->tensor_array_gradients().empty()) {
192         has_tensor_array_gradients = true;
193       }
194     }
195   }
196 
197   // Recompile the functions to update the argument shapes for tensor arrays.
198   if (has_tensor_array_gradients) {
199     for (int j = 0; j < num_branches; ++j) {
200       branch_results[j] = {};
201       OP_REQUIRES_OK(ctx,
202                      compiler->CompileFunction(options, branches[j], arguments,
203                                                &branch_results[j]));
204     }
205   }
206 
207   xla::Shape branch0_input_shape;
208   std::vector<const xla::XlaComputation*> result_computations(num_branches);
209   for (int j = 0; j < num_branches; ++j) {
210     // Check that all branches have identical input shapes.
211     OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1,
212                 errors::FailedPrecondition("Expected one input shape"));
213     xla::Shape branch_input_shape = branch_results[j].xla_input_shapes[0];
214     if (j == 0) {
215       branch0_input_shape = branch_input_shape;
216     }
217     OP_REQUIRES(ctx, branch_input_shape.IsTuple(),
218                 errors::FailedPrecondition("Expected tuple shape"));
219     OP_REQUIRES(
220         ctx,
221         xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape),
222         errors::InvalidArgument(
223             "Input shapes of 0 and ", j, " branches do not match: ",
224             xla::ShapeUtil::HumanString(branch0_input_shape), " vs. ",
225             xla::ShapeUtil::HumanString(branch_input_shape)));
226 
227     // Check that all branches have identical output shapes.
228     OP_REQUIRES(
229         ctx,
230         xla::ShapeUtil::Compatible(branch_results[0].xla_output_shape,
231                                    branch_results[j].xla_output_shape),
232         errors::InvalidArgument(
233             "Output shapes of 0 and ", j, " branches do not match: ",
234             xla::ShapeUtil::HumanString(branch_results[0].xla_output_shape),
235             " vs. ",
236             xla::ShapeUtil::HumanString(branch_results[j].xla_output_shape)));
237 
238     if (j == 0) {
239       VLOG(2) << "Input shape: "
240               << xla::ShapeUtil::HumanString(branch0_input_shape);
241       VLOG(2) << "Output shape: "
242               << xla::ShapeUtil::HumanString(
243                      branch_results[0].xla_output_shape);
244     }
245 
246     // Check that all branches have same TensorList output indices.
247     for (int output_index = 0; output_index < branch_results[0].outputs.size();
248          output_index++) {
249       bool is_tensor_list_in_branch_0 =
250           branch_results[0].outputs[output_index].is_tensor_list;
251       bool is_tensor_list_in_branch_j =
252           branch_results[j].outputs[output_index].is_tensor_list;
253       OP_REQUIRES(
254           ctx, is_tensor_list_in_branch_0 == is_tensor_list_in_branch_j,
255           errors::FailedPrecondition("Output #", output_index, " is ",
256                                      (is_tensor_list_in_branch_0 ? "" : "not"),
257                                      " a TensorList in branch 0, but is ",
258                                      (is_tensor_list_in_branch_j ? "" : "not"),
259                                      " a TensorList in branch ", j));
260     }
261 
262     // We set return_updated_values_for_all_resources=true and we pass the same
263     // arguments to both computations, so the resource update count must match.
264     OP_REQUIRES(ctx,
265                 branch_results[0].resource_updates.size() ==
266                     branch_results[j].resource_updates.size(),
267                 errors::FailedPrecondition(
268                     "Different number of resources in 0 and ", j, " branch"));
269     for (int i = 0; i < branch_results[0].resource_updates.size(); ++i) {
270       const auto& lhs = branch_results[0].resource_updates[i];
271       const auto& rhs = branch_results[j].resource_updates[i];
272       bool equal = lhs.input_index == rhs.input_index &&
273                    lhs.shape == rhs.shape &&
274                    lhs.tensor_array_gradients_accessed ==
275                        rhs.tensor_array_gradients_accessed;
276       OP_REQUIRES(ctx, equal,
277                   errors::FailedPrecondition("Mismatch in resource of 0 and ",
278                                              j, " branch for resource ", i));
279     }
280     result_computations[j] = branch_results[j].computation.get();
281   }
282 
283   // Prepare the input arg Tuple.
284   int num_inputs = branch_results[0].input_mapping.size();
285   std::vector<xla::XlaOp> inputs(num_inputs);
286   for (int i = 0; i < num_inputs; ++i) {
287     int input_num = branch_results[0].input_mapping[i] + 1;
288     if (has_token_input_output_ && i == num_inputs - 1) {
289       // Set token input for this "case" op.
290       std::vector<xla::XlaOp> token_inputs;
291       for (const string& node_name : token_input_nodes_) {
292         auto token_or = compiler->GetNodeToken(node_name);
293         OP_REQUIRES_OK(ctx, token_or.status());
294         token_inputs.push_back(token_or.ValueOrDie());
295       }
296       inputs[i] = xla::AfterAll(b, token_inputs);
297     } else if (ctx->input_type(input_num) == DT_RESOURCE) {
298       XlaResource* resource;
299       OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
300       OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
301     } else {
302       inputs[i] = ctx->Input(input_num);
303     }
304   }
305   auto input_tuple = xla::Tuple(b, inputs);
306 
307   xla::XlaOp outputs =
308       xla::Conditional(branch_index, absl::MakeSpan(result_computations),
309                        std::vector<xla::XlaOp>(num_branches, input_tuple));
310   // Sets non-variable outputs.
311   for (int i = 0; i < output_types_.size(); ++i) {
312     xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
313     if (VLOG_IS_ON(2)) {
314       LOG(INFO) << "Setting output " << i;
315       auto shape_or = b->GetShape(output_handle);
316       if (shape_or.ok()) {
317         LOG(INFO) << "Shape for output " << i << ": "
318                   << xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
319       } else {
320         LOG(INFO) << "Shape unknown for output " << i;
321       }
322     }
323     // We have checked that all branches have same TensorList output indices.
324     if (branch_results[0].outputs[i].is_tensor_list) {
325       ctx->SetTensorListOutput(i, output_handle);
326     } else {
327       ctx->SetOutput(i, output_handle);
328     }
329   }
330   if (has_token_input_output_) {
331     // Set token output for this "Case" op. Token output is the last output of
332     // XLA computation, which comes after all "normal" TF outputs and resource
333     // updates. For "Case" node, num of resource updates equals to number of
334     // resource args because we set `return_updated_values_for_all_resources`
335     // to true in XlaCompiler option.
336     xla::XlaOp token_output =
337         xla::GetTupleElement(outputs, output_types_.size() + num_resource_args);
338     auto shape_or = b->GetShape(token_output);
339     OP_REQUIRES_OK(ctx, shape_or.status());
340     OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(),
341                 errors::FailedPrecondition(
342                     "Token output is not token type: ",
343                     xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
344     OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
345   }
346 
347   // Updates the values of any resource variables modified by the conditional
348   // bodies.
349   for (const XlaCompiler::CompilationResult& result : branch_results) {
350     for (int i = 0; i < result.resource_updates.size(); ++i) {
351       const XlaCompiler::ResourceUpdate& update = result.resource_updates[i];
352       XlaResource* resource;
353       OP_REQUIRES_OK(ctx,
354                      ctx->GetResourceInput(update.input_index + 1, &resource));
355       if (update.modified) {
356         int pos = static_cast<int>(result.outputs.size()) + i;
357         OP_REQUIRES_OK(ctx,
358                        resource->SetFromPack(
359                            arguments[update.input_index].tensor_array_gradients,
360                            xla::GetTupleElement(outputs, pos), b));
361       }
362       VLOG(2) << "Case variable: pos: " << update.input_index
363               << " name: " << resource->name()
364               << " modified: " << update.modified
365               << " type: " << DataTypeString(update.type)
366               << " shape: " << update.shape.DebugString();
367     }
368   }
369   VLOG(1) << "Done building Case";
370 }
371 
372 REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(),
373                 XlaCaseOp);
374 REGISTER_XLA_OP(Name("StatelessCase").AllowResourceTypes().AllowVariantTypes(),
375                 XlaCaseOp);
376 
377 }  // namespace tensorflow
378