1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/kernels/case_op.h"
17
18 #include "tensorflow/compiler/tf2xla/kernels/if_while_utils.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_context.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/client/lib/constants.h"
25 #include "tensorflow/compiler/xla/client/xla_builder.h"
26 #include "tensorflow/core/lib/core/errors.h"
27
28 namespace tensorflow {
29
XlaCaseOp(OpKernelConstruction * ctx)30 XlaCaseOp::XlaCaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
31 OP_REQUIRES_OK(ctx, ctx->GetAttr("branches", &unpruned_branches_));
32 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tin", &input_types_));
33 OP_REQUIRES_OK(ctx, ctx->GetAttr("Tout", &output_types_));
34 if (!ctx->GetAttr(kXlaTokenInputNodesAttrName, &token_input_nodes_).ok()) {
35 has_token_input_output_ = false;
36 } else {
37 has_token_input_output_ = !token_input_nodes_.empty();
38 }
39 if (ctx->HasAttr(kPropagateCompileTimeConsts)) {
40 OP_REQUIRES_OK(ctx, ctx->GetAttr(kPropagateCompileTimeConsts,
41 &propagate_compile_time_consts_));
42 }
43 }
44
45 std::pair<std::vector<NameAttrList>, xla::XlaOp>
GetPrunedBranchesAndIndex(XlaOpKernelContext * ctx)46 XlaCaseOp::GetPrunedBranchesAndIndex(XlaOpKernelContext* ctx) {
47 xla::Literal branch_index_literal;
48 bool branch_index_is_constant =
49 ctx->ConstantInput(0, &branch_index_literal).ok();
50
51 if (!branch_index_is_constant) {
52 return {unpruned_branches_, ctx->Input(0)};
53 }
54
55 int32 branch_index = branch_index_literal.Get<int32>({});
56 if (branch_index < 0 || branch_index >= unpruned_branches_.size()) {
57 branch_index = unpruned_branches_.size() - 1;
58 }
59
60 std::vector<NameAttrList> pruned_branch = {unpruned_branches_[branch_index]};
61 return {pruned_branch, xla::ZerosLike(ctx->Input(0))};
62 }
63
64 // TODO(b/35949885): There is duplication here with the handling of the
65 // while_op/if_op. Refactor the common code out/rework.
Compile(XlaOpKernelContext * ctx)66 void XlaCaseOp::Compile(XlaOpKernelContext* ctx) {
67 OP_REQUIRES(ctx, !unpruned_branches_.empty(),
68 errors::InvalidArgument("Must provide at least one case branch"));
69 OP_REQUIRES(ctx, input_type(0) == DT_INT32,
70 errors::InvalidArgument(
71 "branch_index argument must be a int32 for XLA compilation"));
72 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->InputShape(0)),
73 errors::InvalidArgument(
74 "branch_index argument must be scalar for XLA compilation"));
75
76 xla::XlaBuilder* b = ctx->builder();
77
78 // We opportunistically prune out branches if the branch index is a
79 // compile-time constant. This is important in the context of the DeviceIndex
80 // ops (and other such ops that may come later) since we may have a Case with
81 // trivially unselected branches that cannot be compiled into HLO.
82 std::vector<NameAttrList> branches;
83 xla::XlaOp branch_index;
84 std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx);
85
86 int num_branches = branches.size();
87
88 VLOG(1) << "Building Case: " << input_types_.size() << " inputs";
89
90 std::vector<XlaCompiler::Argument> arguments(input_types_.size());
91 int num_resource_args = 0;
92 for (int i = 0; i < input_types_.size(); ++i) {
93 XlaCompiler::Argument& arg = arguments[i];
94 DataType type = ctx->input_type(i + 1);
95
96 if (type == DT_RESOURCE) {
97 XlaResource* resource;
98 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(i + 1, &resource));
99 XlaCompiler::PopulateArgumentFromResource(*resource, &arg);
100 OP_REQUIRES(ctx, arg.initialized,
101 errors::Unimplemented("Uninitialized arguments: ", arg.name));
102 VLOG(2) << "Resource " << resource->name()
103 << " type: " << DataTypeString(arg.type)
104 << " shape: " << arg.HumanString()
105 << " initialized: " << arg.initialized;
106
107 num_resource_args++;
108 } else {
109 arg.kind = XlaCompiler::Argument::kParameter;
110 arg.type = input_types_[i];
111 // Use the xla::Shape for the input instead of ctx->InputShape. This is
112 // necessary for forwarding shapes of DT_VARIANTs, e.g. TensorLists.
113 auto shape_or = ctx->builder()->GetShape(ctx->Input(i + 1));
114 OP_REQUIRES_OK(ctx, shape_or.status());
115 arg.shape = shape_or.ValueOrDie();
116 VLOG(2) << "Arg type: " << DataTypeString(arg.type)
117 << " shape: " << arg.HumanString();
118 }
119 }
120
121 if (propagate_compile_time_consts_) {
122 std::vector<std::vector<bool>> case_branch_must_be_const_nodes(
123 num_branches);
124 std::vector<const FunctionBody*> case_bodies(num_branches);
125 for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) {
126 OP_REQUIRES_OK(ctx, FindMustBeConstNodes(
127 ctx, branches[branch_idx],
128 &case_branch_must_be_const_nodes[branch_idx],
129 &case_bodies[branch_idx]));
130 }
131
132 // Replaces `kParameter` type args in `arguments` with `kConstant` if
133 // the op input corresponding to that arg is a compile-time const. This
134 // is necessary to propagate compile time consts to ops in the branch
135 // functions.
136 auto arg_is_parameter = [&](int arg_idx) {
137 if (arguments[arg_idx].kind != XlaCompiler::Argument::kParameter) {
138 return false;
139 }
140 for (int branch_idx = 0; branch_idx < num_branches; branch_idx++) {
141 if (!case_branch_must_be_const_nodes
142 [branch_idx]
143 [case_bodies[branch_idx]->arg_nodes[arg_idx]->id()]) {
144 return false;
145 }
146 }
147 return true;
148 };
149 ConvertCompileTimeConstArgumentsToConst(ctx, &arguments,
150 /*xla_expression_offset=*/1,
151 arg_is_parameter);
152 }
153
154 // Compile each branch of the conditional.
155 XlaCompiler::CompileOptions options;
156 options.use_tuple_arg = true;
157 options.return_updated_values_for_all_resources = true;
158 options.is_entry_computation = false;
159 options.add_token_input_output = has_token_input_output_;
160 XlaCompiler* compiler = ctx->compiler();
161
162 std::vector<XlaCompiler::CompilationResult> branch_results(num_branches);
163 for (int j = 0; j < num_branches; ++j) {
164 OP_REQUIRES_OK(ctx,
165 compiler->CompileFunction(options, branches[j], arguments,
166 &branch_results[j]));
167 }
168
169 bool has_tensor_array_gradients = false;
170 for (XlaCompiler::CompilationResult& result : branch_results) {
171 for (const XlaCompiler::ResourceUpdate& update : result.resource_updates) {
172 XlaResource* resource;
173 OP_REQUIRES_OK(ctx,
174 ctx->GetResourceInput(update.input_index + 1, &resource));
175 XlaCompiler::Argument& arg = arguments[update.input_index];
176
177 // Add any TensorArray gradients touched by the then/else computation to
178 // the enclosing graph.
179 for (const string& grad_source : update.tensor_array_gradients_accessed) {
180 VLOG(5) << "TensorArray " << resource->name() << " accessed gradient "
181 << grad_source;
182 XlaResource* gradient;
183 OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient(
184 grad_source, b, &gradient));
185 }
186 // Add all of the TensorArray gradients to the argument. For simplicity,
187 // we always pass all known gradients.
188 for (const auto& gradient : resource->tensor_array_gradients()) {
189 arg.tensor_array_gradients.insert(gradient.first);
190 }
191 if (!resource->tensor_array_gradients().empty()) {
192 has_tensor_array_gradients = true;
193 }
194 }
195 }
196
197 // Recompile the functions to update the argument shapes for tensor arrays.
198 if (has_tensor_array_gradients) {
199 for (int j = 0; j < num_branches; ++j) {
200 branch_results[j] = {};
201 OP_REQUIRES_OK(ctx,
202 compiler->CompileFunction(options, branches[j], arguments,
203 &branch_results[j]));
204 }
205 }
206
207 xla::Shape branch0_input_shape;
208 std::vector<const xla::XlaComputation*> result_computations(num_branches);
209 for (int j = 0; j < num_branches; ++j) {
210 // Check that all branches have identical input shapes.
211 OP_REQUIRES(ctx, branch_results[j].xla_input_shapes.size() == 1,
212 errors::FailedPrecondition("Expected one input shape"));
213 xla::Shape branch_input_shape = branch_results[j].xla_input_shapes[0];
214 if (j == 0) {
215 branch0_input_shape = branch_input_shape;
216 }
217 OP_REQUIRES(ctx, branch_input_shape.IsTuple(),
218 errors::FailedPrecondition("Expected tuple shape"));
219 OP_REQUIRES(
220 ctx,
221 xla::ShapeUtil::Compatible(branch0_input_shape, branch_input_shape),
222 errors::InvalidArgument(
223 "Input shapes of 0 and ", j, " branches do not match: ",
224 xla::ShapeUtil::HumanString(branch0_input_shape), " vs. ",
225 xla::ShapeUtil::HumanString(branch_input_shape)));
226
227 // Check that all branches have identical output shapes.
228 OP_REQUIRES(
229 ctx,
230 xla::ShapeUtil::Compatible(branch_results[0].xla_output_shape,
231 branch_results[j].xla_output_shape),
232 errors::InvalidArgument(
233 "Output shapes of 0 and ", j, " branches do not match: ",
234 xla::ShapeUtil::HumanString(branch_results[0].xla_output_shape),
235 " vs. ",
236 xla::ShapeUtil::HumanString(branch_results[j].xla_output_shape)));
237
238 if (j == 0) {
239 VLOG(2) << "Input shape: "
240 << xla::ShapeUtil::HumanString(branch0_input_shape);
241 VLOG(2) << "Output shape: "
242 << xla::ShapeUtil::HumanString(
243 branch_results[0].xla_output_shape);
244 }
245
246 // Check that all branches have same TensorList output indices.
247 for (int output_index = 0; output_index < branch_results[0].outputs.size();
248 output_index++) {
249 bool is_tensor_list_in_branch_0 =
250 branch_results[0].outputs[output_index].is_tensor_list;
251 bool is_tensor_list_in_branch_j =
252 branch_results[j].outputs[output_index].is_tensor_list;
253 OP_REQUIRES(
254 ctx, is_tensor_list_in_branch_0 == is_tensor_list_in_branch_j,
255 errors::FailedPrecondition("Output #", output_index, " is ",
256 (is_tensor_list_in_branch_0 ? "" : "not"),
257 " a TensorList in branch 0, but is ",
258 (is_tensor_list_in_branch_j ? "" : "not"),
259 " a TensorList in branch ", j));
260 }
261
262 // We set return_updated_values_for_all_resources=true and we pass the same
263 // arguments to both computations, so the resource update count must match.
264 OP_REQUIRES(ctx,
265 branch_results[0].resource_updates.size() ==
266 branch_results[j].resource_updates.size(),
267 errors::FailedPrecondition(
268 "Different number of resources in 0 and ", j, " branch"));
269 for (int i = 0; i < branch_results[0].resource_updates.size(); ++i) {
270 const auto& lhs = branch_results[0].resource_updates[i];
271 const auto& rhs = branch_results[j].resource_updates[i];
272 bool equal = lhs.input_index == rhs.input_index &&
273 lhs.shape == rhs.shape &&
274 lhs.tensor_array_gradients_accessed ==
275 rhs.tensor_array_gradients_accessed;
276 OP_REQUIRES(ctx, equal,
277 errors::FailedPrecondition("Mismatch in resource of 0 and ",
278 j, " branch for resource ", i));
279 }
280 result_computations[j] = branch_results[j].computation.get();
281 }
282
283 // Prepare the input arg Tuple.
284 int num_inputs = branch_results[0].input_mapping.size();
285 std::vector<xla::XlaOp> inputs(num_inputs);
286 for (int i = 0; i < num_inputs; ++i) {
287 int input_num = branch_results[0].input_mapping[i] + 1;
288 if (has_token_input_output_ && i == num_inputs - 1) {
289 // Set token input for this "case" op.
290 std::vector<xla::XlaOp> token_inputs;
291 for (const string& node_name : token_input_nodes_) {
292 auto token_or = compiler->GetNodeToken(node_name);
293 OP_REQUIRES_OK(ctx, token_or.status());
294 token_inputs.push_back(token_or.ValueOrDie());
295 }
296 inputs[i] = xla::AfterAll(b, token_inputs);
297 } else if (ctx->input_type(input_num) == DT_RESOURCE) {
298 XlaResource* resource;
299 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(input_num, &resource));
300 OP_REQUIRES_OK(ctx, resource->Pack(&inputs[i], b));
301 } else {
302 inputs[i] = ctx->Input(input_num);
303 }
304 }
305 auto input_tuple = xla::Tuple(b, inputs);
306
307 xla::XlaOp outputs =
308 xla::Conditional(branch_index, absl::MakeSpan(result_computations),
309 std::vector<xla::XlaOp>(num_branches, input_tuple));
310 // Sets non-variable outputs.
311 for (int i = 0; i < output_types_.size(); ++i) {
312 xla::XlaOp output_handle = xla::GetTupleElement(outputs, i);
313 if (VLOG_IS_ON(2)) {
314 LOG(INFO) << "Setting output " << i;
315 auto shape_or = b->GetShape(output_handle);
316 if (shape_or.ok()) {
317 LOG(INFO) << "Shape for output " << i << ": "
318 << xla::ShapeUtil::HumanString(shape_or.ValueOrDie());
319 } else {
320 LOG(INFO) << "Shape unknown for output " << i;
321 }
322 }
323 // We have checked that all branches have same TensorList output indices.
324 if (branch_results[0].outputs[i].is_tensor_list) {
325 ctx->SetTensorListOutput(i, output_handle);
326 } else {
327 ctx->SetOutput(i, output_handle);
328 }
329 }
330 if (has_token_input_output_) {
331 // Set token output for this "Case" op. Token output is the last output of
332 // XLA computation, which comes after all "normal" TF outputs and resource
333 // updates. For "Case" node, num of resource updates equals to number of
334 // resource args because we set `return_updated_values_for_all_resources`
335 // to true in XlaCompiler option.
336 xla::XlaOp token_output =
337 xla::GetTupleElement(outputs, output_types_.size() + num_resource_args);
338 auto shape_or = b->GetShape(token_output);
339 OP_REQUIRES_OK(ctx, shape_or.status());
340 OP_REQUIRES(ctx, shape_or.ValueOrDie().IsToken(),
341 errors::FailedPrecondition(
342 "Token output is not token type: ",
343 xla::ShapeUtil::HumanString(shape_or.ValueOrDie())));
344 OP_REQUIRES_OK(ctx, compiler->SetNodeToken(name(), token_output));
345 }
346
347 // Updates the values of any resource variables modified by the conditional
348 // bodies.
349 for (const XlaCompiler::CompilationResult& result : branch_results) {
350 for (int i = 0; i < result.resource_updates.size(); ++i) {
351 const XlaCompiler::ResourceUpdate& update = result.resource_updates[i];
352 XlaResource* resource;
353 OP_REQUIRES_OK(ctx,
354 ctx->GetResourceInput(update.input_index + 1, &resource));
355 if (update.modified) {
356 int pos = static_cast<int>(result.outputs.size()) + i;
357 OP_REQUIRES_OK(ctx,
358 resource->SetFromPack(
359 arguments[update.input_index].tensor_array_gradients,
360 xla::GetTupleElement(outputs, pos), b));
361 }
362 VLOG(2) << "Case variable: pos: " << update.input_index
363 << " name: " << resource->name()
364 << " modified: " << update.modified
365 << " type: " << DataTypeString(update.type)
366 << " shape: " << update.shape.DebugString();
367 }
368 }
369 VLOG(1) << "Done building Case";
370 }
371
372 REGISTER_XLA_OP(Name("Case").AllowResourceTypes().AllowVariantTypes(),
373 XlaCaseOp);
374 REGISTER_XLA_OP(Name("StatelessCase").AllowResourceTypes().AllowVariantTypes(),
375 XlaCaseOp);
376
377 } // namespace tensorflow
378