1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 #include "absl/strings/match.h"
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/data_flow_ops.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/client_library.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/common_shape_fns.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/function_testlib.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/resource_mgr.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_testutil.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/graph_constructor.h"
46 #include "tensorflow/core/lib/core/status_test_util.h"
47 #include "tensorflow/core/platform/test.h"
48 #include "tensorflow/core/public/version.h"
49
50 namespace tensorflow {
51
52 class XlaCompilerTest : public ::testing::Test {
53 protected:
SetUp()54 void SetUp() override {
55 client_ = xla::ClientLibrary::LocalClientOrDie();
56
57 XlaOpRegistry::RegisterCompilationKernels();
58
59 FunctionDefLibrary flib;
60 flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
61 }
62
DefaultOptions()63 XlaCompiler::Options DefaultOptions() {
64 XlaCompiler::Options options;
65 options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
66 options.client = client_;
67 options.flib_def = flib_def_.get();
68 return options;
69 }
70
LocalFlibDef(XlaCompiler * compiler)71 FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
72 return compiler->local_flib_def_.get();
73 }
74
75 xla::Client* client_;
76 std::unique_ptr<FunctionLibraryDefinition> flib_def_;
77 };
78
79 namespace {
80
81 // Helper class to test the ability to pass resources through to XLA
82 // compiled kernels.
83 class DummyResourceForTest : public ResourceBase {
84 public:
DebugString() const85 string DebugString() const override { return "dummy"; }
Increment()86 void Increment() { ++value_; }
Get()87 int Get() { return value_; }
88
89 private:
90 int value_ = 0;
91 };
92
93 class DummyReadResourceOp : public XlaOpKernel {
94 public:
DummyReadResourceOp(OpKernelConstruction * ctx)95 explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)96 void Compile(XlaOpKernelContext* ctx) override {
97 ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
98 OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
99 DummyResourceForTest* dummy;
100 OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
101 rm->default_container(), "dummy", &dummy));
102 dummy->Increment();
103 dummy->Unref();
104
105 ctx->SetOutput(0, ctx->Input(0));
106 ctx->SetOutput(1, ctx->Input(0));
107 }
108 };
109
110 class DummyReadResourceCC {
111 public:
DummyReadResourceCC(const Scope & scope,const Input & value)112 DummyReadResourceCC(const Scope& scope, const Input& value) {
113 if (!scope.ok()) return;
114 auto _value = ops::AsNodeOut(scope, value);
115 if (!scope.ok()) return;
116 Node* ret;
117 const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
118 auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
119 scope.UpdateBuilder(&builder);
120 scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
121 if (!scope.ok()) return;
122 scope.UpdateStatus(scope.DoShapeInference(ret));
123 if (!scope.ok()) return;
124 this->output1_ = Output(ret, 0);
125 this->output2_ = Output(ret, 1);
126 }
127
128 Output output1_;
129 Output output2_;
130 };
131
132 REGISTER_OP("DummyReadResource")
133 .Input("input: int32")
134 .Output("output1: int32")
135 .Output("output2: int32")
136 .SetShapeFn(shape_inference::UnknownShape)
137 .Doc(R"doc(
138 A dummy Op.
139
140 input: dummy input.
141 output1: dummy output.
142 output2: dummy output.
143 )doc");
144
145 REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
146
147 // DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
148 // on the same Op name below.
149 class DummyDuplicateOp : public XlaOpKernel {
150 public:
DummyDuplicateOp(OpKernelConstruction * ctx)151 explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)152 void Compile(XlaOpKernelContext* ctx) override {
153 ctx->SetOutput(0, ctx->Input(0));
154 }
155 };
156
157 REGISTER_OP("DummyDuplicateOp")
158 .Input("input: int32")
159 .Output("output: int32")
160 .Doc(R"doc(
161 A dummy Op.
162
163 input: dummy input.
164 output: dummy output.
165 )doc");
166
167 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
168 DummyDuplicateOp);
169 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
170 DummyDuplicateOp);
171
172 // Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest,EmptyReturnValues)173 TEST_F(XlaCompilerTest, EmptyReturnValues) {
174 XlaCompiler compiler(DefaultOptions());
175
176 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
177 XlaCompiler::CompilationResult result;
178 TF_ASSERT_OK(compiler.CompileGraph(
179 XlaCompiler::CompileOptions(), "add", std::move(graph),
180 /*args=*/{}, /*user_aliases=*/{}, &result));
181
182 TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
183 }
184
185 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,Simple)186 TEST_F(XlaCompilerTest, Simple) {
187 // Builds a graph that adds two Tensors.
188 Scope scope = Scope::NewRootScope().ExitOnError();
189 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
190 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
191 auto c = ops::Add(scope.WithOpName("C"), a, b);
192 auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
193 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
194 TF_ASSERT_OK(scope.ToGraph(graph.get()));
195
196 // Builds a description of the arguments.
197 std::vector<XlaCompiler::Argument> args(2);
198 args[0].kind = XlaCompiler::Argument::kParameter;
199 args[0].type = DT_INT32;
200 args[0].shape = TensorShape({2});
201 args[1].kind = XlaCompiler::Argument::kParameter;
202 args[1].type = DT_INT32;
203 args[1].shape = TensorShape({2});
204
205 // Compiles the graph.
206 XlaCompiler compiler(DefaultOptions());
207
208 XlaCompiler::CompilationResult result;
209 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
210 std::move(graph), args,
211 /*user_aliases=*/{}, &result));
212
213 // Tests that the generated computation works.
214 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
215 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
216 std::unique_ptr<xla::GlobalData> param0_data =
217 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
218 std::unique_ptr<xla::GlobalData> param1_data =
219 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
220
221 std::unique_ptr<xla::GlobalData> actual =
222 client_
223 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
224 .ConsumeValueOrDie();
225 xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
226
227 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
228 xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
229 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
230 }
231
232 // Tests compilation of a graph where the _Retval node is not necessarily last
233 // amongst the graph nodes in construction order, and always_return_tuple is
234 // false. Regression test for bug where the wrong value was returned.
TEST_F(XlaCompilerTest,OutOfOrderGraph)235 TEST_F(XlaCompilerTest, OutOfOrderGraph) {
236 Scope scope = Scope::NewRootScope().ExitOnError();
237 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
238 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
239 // The _Retval node is not last in construction order.
240 auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
241 auto c = ops::Add(scope.WithOpName("C"), a, b);
242
243 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
244 TF_ASSERT_OK(scope.ToGraph(graph.get()));
245
246 // Builds a description of the arguments.
247 std::vector<XlaCompiler::Argument> args(2);
248 args[0].kind = XlaCompiler::Argument::kParameter;
249 args[0].type = DT_INT32;
250 args[0].shape = TensorShape({2});
251 args[1].kind = XlaCompiler::Argument::kParameter;
252 args[1].type = DT_INT32;
253 args[1].shape = TensorShape({2});
254
255 // Compiles the graph.
256 XlaCompiler compiler(DefaultOptions());
257
258 XlaCompiler::CompileOptions compile_options;
259 compile_options.always_return_tuple = false;
260 XlaCompiler::CompilationResult result;
261 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
262 args, /*user_aliases=*/{}, &result));
263
264 // Tests that the generated computation works.
265 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
266 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
267 std::unique_ptr<xla::GlobalData> param0_data =
268 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
269 std::unique_ptr<xla::GlobalData> param1_data =
270 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
271
272 std::unique_ptr<xla::GlobalData> actual =
273 client_
274 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
275 .ConsumeValueOrDie();
276 xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
277
278 EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
279 }
280
281 // Tests that the compiler can correctly propagate the layout assigned by
282 // shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForRetVal)283 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
284 Scope scope = Scope::NewRootScope().ExitOnError();
285 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
286 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
287 // Adds an identity op around the resource to make sure identity ops propagate
288 // resources correctly.
289 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
290 auto write = ops::AssignAddVariableOp(scope, identity, a);
291 auto read = ops::ReadVariableOp(
292 scope.WithControlDependencies(std::vector<Operation>{write}), var,
293 DT_INT32);
294 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
295 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
296 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
297 TF_ASSERT_OK(scope.ToGraph(graph.get()));
298
299 // Builds a description of the arguments.
300 std::vector<XlaCompiler::Argument> args(2);
301 args[0].kind = XlaCompiler::Argument::kParameter;
302 args[0].type = DT_INT32;
303 args[0].shape = TensorShape({2, 3});
304 args[1].kind = XlaCompiler::Argument::kResource;
305 args[1].resource_kind = XlaResource::kVariable;
306 args[1].initialized = true;
307 args[1].type = DT_INT32;
308 args[1].shape = TensorShape({2, 3});
309
310 auto options = DefaultOptions();
311 options.shape_representation_fn =
312 [](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
313 xla::Shape xla_shape;
314 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
315 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
316 return xla_shape;
317 };
318 // Compiles the graph.
319 XlaCompiler compiler(options);
320
321 XlaCompiler::CompilationResult result;
322 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
323 std::move(graph), args,
324 /*user_aliases=*/{}, &result));
325 xla::Shape transposed =
326 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
327 // Check that the return shapes are correctly tranposed.
328 EXPECT_EQ(result.xla_output_shape,
329 xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
330 }
331
332 // The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest,TransposeVariables)333 TEST_F(XlaCompilerTest, TransposeVariables) {
334 Scope scope = Scope::NewRootScope().ExitOnError();
335 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
336 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
337 // Adds an identity op around the resource to make sure identity ops propagate
338 // resources correctly.
339 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
340 auto write = ops::AssignAddVariableOp(scope, identity, a);
341 auto read = ops::ReadVariableOp(
342 scope.WithControlDependencies(std::vector<Operation>{write}), var,
343 DT_INT32);
344 auto transposed_read = ops::Transpose(scope, read, {1, 0});
345 auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
346 auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
347 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
348 TF_ASSERT_OK(scope.ToGraph(graph.get()));
349
350 // Builds a description of the arguments.
351 std::vector<XlaCompiler::Argument> args(2);
352 args[0].kind = XlaCompiler::Argument::kParameter;
353 args[0].type = DT_INT32;
354 args[0].shape = TensorShape({2, 3});
355 args[1].kind = XlaCompiler::Argument::kResource;
356 args[1].resource_kind = XlaResource::kVariable;
357 args[1].initialized = true;
358 args[1].type = DT_INT32;
359 args[1].shape = TensorShape({2, 3});
360 // Compiles the graph.
361 XlaCompiler compiler(DefaultOptions());
362
363 XlaCompiler::CompilationResult result;
364 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
365 std::move(graph), args,
366 /*user_aliases=*/{}, &result));
367 xla::Shape transposed =
368 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
369 // Check that the return shapes are correctly tranposed.
370 EXPECT_EQ(result.xla_output_shape,
371 xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
372 }
373
374 // Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest,MixedOrderArguments)375 TEST_F(XlaCompilerTest, MixedOrderArguments) {
376 for (bool swap_order : {false, true}) {
377 Scope scope = Scope::NewRootScope().ExitOnError();
378 auto var =
379 ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
380 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
381 // Adds an identity op around the resource to make sure identity ops
382 // propagate resources correctly.
383 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
384 auto write = ops::AssignAddVariableOp(scope, identity, a);
385 auto read = ops::ReadVariableOp(
386 scope.WithControlDependencies(std::vector<Operation>{write}), var,
387 DT_INT32);
388 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
389 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
390 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
391 TF_ASSERT_OK(scope.ToGraph(graph.get()));
392
393 // Builds a description of the arguments.
394 std::vector<XlaCompiler::Argument> args(2);
395 args[0].kind = XlaCompiler::Argument::kParameter;
396 args[0].type = DT_INT32;
397 args[0].shape = TensorShape({2});
398 args[1].kind = XlaCompiler::Argument::kResource;
399 args[1].resource_kind = XlaResource::kVariable;
400 args[1].initialized = true;
401 args[1].type = DT_INT32;
402 args[1].shape = TensorShape({2});
403
404 if (swap_order) {
405 // Even after swapping arguments, the compiler should maintain the new
406 // ordering of parameters.
407 std::swap(args[0], args[1]);
408 }
409 // Compiles the graph.
410 XlaCompiler compiler(DefaultOptions());
411
412 XlaCompiler::CompileOptions compile_options;
413 compile_options.always_return_tuple = false;
414 XlaCompiler::CompilationResult result;
415 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
416 args, /*user_aliases=*/{}, &result));
417
418 EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
419 }
420 }
421
TEST_F(XlaCompilerTest,HasSaneErrorOnNonCompileTimeConstantInputToReshape)422 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
423 // Builds a graph that adds reshapes a tensor, but with the shape not
424 // statically known.
425 Scope scope = Scope::NewRootScope().ExitOnError();
426 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
427 auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
428 auto c = ops::Reshape(scope.WithOpName("C"), a, b);
429 auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
430 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
431 TF_ASSERT_OK(scope.ToGraph(graph.get()));
432
433 // Builds a description of the arguments.
434 std::vector<XlaCompiler::Argument> args(2);
435 args[0].kind = XlaCompiler::Argument::kParameter;
436 args[0].type = DT_INT32;
437 args[0].shape = TensorShape({2});
438 args[1].kind = XlaCompiler::Argument::kParameter;
439 args[1].type = DT_INT32;
440 args[1].shape = TensorShape({2});
441
442 // Compiles the graph.
443 XlaCompiler compiler(DefaultOptions());
444
445 XlaCompiler::CompilationResult result;
446 Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(),
447 "reshape", std::move(graph), args,
448 /*user_aliases=*/{}, &result);
449 EXPECT_FALSE(status.ok());
450 EXPECT_TRUE(
451 absl::StrContains(status.error_message(), "depends on a parameter"))
452 << status.error_message();
453 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}"))
454 << status.error_message();
455 EXPECT_TRUE(absl::StrContains(status.error_message(),
456 "must be a compile-time constant"))
457 << status.error_message();
458 }
459
460 // Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest,ConstantOutputs)461 TEST_F(XlaCompilerTest, ConstantOutputs) {
462 // Builds a graph with one compile-time constant output and one data-dependent
463 // output, i.e.,
464 // func(a) { b=7; c=-a; return b, c; }
465 Scope scope = Scope::NewRootScope().ExitOnError();
466 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
467 auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
468 auto c = ops::Neg(scope.WithOpName("C"), a);
469 auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
470 auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
471 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
472 TF_ASSERT_OK(scope.ToGraph(graph.get()));
473
474 // Builds a description of the arguments.
475 std::vector<XlaCompiler::Argument> args(1);
476 args[0].kind = XlaCompiler::Argument::kParameter;
477 args[0].type = DT_INT32;
478 args[0].shape = TensorShape({2});
479
480 XlaCompiler::Options options = DefaultOptions();
481 XlaCompiler compiler(options);
482 {
483 // Compiles the graph, with resolve_compile_time_constants enabled.
484
485 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
486 CopyGraph(*graph, graph_copy.get());
487
488 XlaCompiler::CompileOptions compile_options;
489 compile_options.resolve_compile_time_constants = true;
490 XlaCompiler::CompilationResult result;
491 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
492 std::move(graph_copy), args,
493 /*user_aliases=*/{}, &result));
494
495 ASSERT_EQ(2, result.outputs.size());
496 EXPECT_TRUE(result.outputs[0].is_constant);
497 test::ExpectTensorEqual<int32>(result.outputs[0].constant_value,
498 test::AsScalar(7));
499 EXPECT_FALSE(result.outputs[1].is_constant);
500
501 // Tests that the generated computation works.
502 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
503 std::unique_ptr<xla::GlobalData> param0_data =
504 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
505
506 std::unique_ptr<xla::GlobalData> actual =
507 client_->Execute(*result.computation, {param0_data.get()})
508 .ConsumeValueOrDie();
509 xla::Literal actual_literal =
510 client_->Transfer(*actual).ConsumeValueOrDie();
511
512 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
513 xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
514 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
515 }
516
517 {
518 // Compiles the graph, with resolve_compile_time_constants disabled.
519 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
520 CopyGraph(*graph, graph_copy.get());
521
522 XlaCompiler::CompileOptions compile_options;
523 compile_options.resolve_compile_time_constants = false;
524 XlaCompiler::CompilationResult result;
525 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
526 std::move(graph_copy), args,
527 /*user_aliases=*/{}, &result));
528
529 ASSERT_EQ(2, result.outputs.size());
530 EXPECT_FALSE(result.outputs[0].is_constant);
531 EXPECT_FALSE(result.outputs[1].is_constant);
532
533 // Tests that the generated computation works.
534 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
535 std::unique_ptr<xla::GlobalData> param0_data =
536 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
537
538 std::unique_ptr<xla::GlobalData> actual =
539 client_->Execute(*result.computation, {param0_data.get()})
540 .ConsumeValueOrDie();
541 xla::Literal actual_literal =
542 client_->Transfer(*actual).ConsumeValueOrDie();
543
544 xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
545 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
546 xla::Literal expected =
547 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
548 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
549 }
550 }
551
TEST_F(XlaCompilerTest,ConstantOutputsOfFunctionalNode)552 TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
553 // Define a function with one compile-time constant output and one
554 // data-dependent output.
555 // @function.Defun(noinline=True)
556 // foo(a) {b=7; return b, a; }
557 const Tensor seven = test::AsScalar<int>(7);
558 FunctionDef fdef = FunctionDefHelper::Create(
559 "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
560 {
561 {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
562 },
563 {{"a", "a_0"}, {"const", "Const:output:0"}});
564 (*fdef.mutable_attr())["_noinline"].set_b(true);
565 FunctionDefLibrary fdef_lib;
566 *(fdef_lib.add_function()) = fdef;
567 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
568 {
569 Scope scope = Scope::NewRootScope().ExitOnError();
570 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
571 auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
572 NodeDef foo;
573 foo.set_name("foo");
574 foo.set_op("foo");
575 *foo.add_input() = "input_arg";
576 Status status;
577 scope.graph()->AddNode(foo, &status);
578 TF_ASSERT_OK(status);
579 NodeDef retval_1;
580 retval_1.set_name("retval_0");
581 retval_1.set_op(FunctionLibraryDefinition::kRetOp);
582 *retval_1.add_input() = "foo";
583 (*retval_1.mutable_attr())["T"].set_type(DT_INT32);
584 (*retval_1.mutable_attr())["index"].set_i(0);
585 scope.graph()->AddNode(retval_1, &status);
586 TF_ASSERT_OK(status);
587 NodeDef retval_2;
588 retval_2.set_name("retval_1");
589 retval_2.set_op(FunctionLibraryDefinition::kRetOp);
590 *retval_2.add_input() = "foo:1";
591 (*retval_2.mutable_attr())["T"].set_type(DT_INT32);
592 (*retval_2.mutable_attr())["index"].set_i(1);
593 scope.graph()->AddNode(retval_2, &status);
594 TF_ASSERT_OK(status);
595 TF_ASSERT_OK(scope.ToGraph(graph.get()));
596 }
597
598 // Builds a description of the arguments.
599 std::vector<XlaCompiler::Argument> args(1);
600 args[0].kind = XlaCompiler::Argument::kParameter;
601 args[0].type = DT_INT32;
602 args[0].shape = TensorShape({1});
603
604 XlaCompiler::Options options = DefaultOptions();
605 FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
606 options.flib_def = &flib_def;
607 XlaCompiler compiler(options);
608
609 XlaCompiler::CompileOptions compile_options;
610 compile_options.resolve_compile_time_constants = true;
611 XlaCompiler::CompilationResult result;
612 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
613 std::move(graph), args,
614 /*user_aliases=*/{}, &result));
615
616 ASSERT_EQ(2, result.outputs.size());
617 EXPECT_TRUE(result.outputs[0].is_constant);
618 test::ExpectTensorEqual<int32>(result.outputs[0].constant_value,
619 test::AsScalar(7));
620 EXPECT_FALSE(result.outputs[1].is_constant);
621 }
622
623 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,ResourceManager)624 TEST_F(XlaCompilerTest, ResourceManager) {
625 // Builds a graph that calls the dummy resource Op.
626 Scope scope = Scope::NewRootScope().ExitOnError();
627 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
628 auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
629 auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
630 auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
631 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
632 TF_ASSERT_OK(scope.ToGraph(graph.get()));
633
634 // Builds a description of the argument.
635 std::vector<XlaCompiler::Argument> args(1);
636 args[0].kind = XlaCompiler::Argument::kParameter;
637 args[0].type = DT_INT32;
638 args[0].shape = TensorShape({2});
639
640 DummyResourceForTest* resource = new DummyResourceForTest();
641
642 // Compiles the graph.
643 auto options = DefaultOptions();
644 std::function<Status(ResourceMgr*)> populate_function =
645 [resource](ResourceMgr* rm) {
646 resource->Ref();
647 return rm->Create(rm->default_container(), "dummy", resource);
648 };
649 options.populate_resource_manager = &populate_function;
650 XlaCompiler compiler(options);
651
652 EXPECT_EQ(0, resource->Get());
653
654 XlaCompiler::CompilationResult result;
655 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
656 std::move(graph), args,
657 /*user_aliases=*/{}, &result));
658
659 EXPECT_EQ(1, resource->Get());
660
661 resource->Unref();
662 }
663
664 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,DeterministicCompilation)665 TEST_F(XlaCompilerTest, DeterministicCompilation) {
666 // Builds a graph that contains a node with two output edges. The compiler
667 // should always traverse them in the same order.
668 const int64 test_count = 2;
669
670 std::vector<XlaCompiler::CompilationResult> results(test_count);
671
672 for (int64 i = 0; i < test_count; ++i) {
673 Scope scope = Scope::NewRootScope().ExitOnError();
674 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
675 auto b = ops::Neg(scope.WithOpName("B"), a);
676 auto c = ops::Neg(scope.WithOpName("C"), a);
677 auto d = ops::Add(scope.WithOpName("D"), b, c);
678 auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
679 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
680 TF_ASSERT_OK(scope.ToGraph(graph.get()));
681
682 // Builds a description of the argument.
683 std::vector<XlaCompiler::Argument> args(1);
684 args[0].kind = XlaCompiler::Argument::kParameter;
685 args[0].type = DT_INT32;
686 args[0].shape = TensorShape({2});
687
688 // Compiles the graph.
689 auto options = DefaultOptions();
690 XlaCompiler compiler(options);
691
692 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
693 std::move(graph), args,
694 /*user_aliases=*/{}, &results[i]));
695 }
696
697 for (int64 i = 1; i < test_count; ++i) {
698 const auto& m1 = results[i - 1].computation->proto();
699 const auto& m2 = results[i].computation->proto();
700 ASSERT_EQ(m1.computations_size(), m2.computations_size());
701 // Check if every hlo computation is the same.
702 for (int k = 0; k < m1.computations_size(); k++) {
703 const auto& c1 = m1.computations(k);
704 const auto& c2 = m2.computations(k);
705 ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
706 for (int j = 0; j < c1.instructions_size(); j++) {
707 auto instr1 = c1.instructions(j);
708 auto instr2 = c2.instructions(j);
709 instr1.clear_name();
710 instr1.clear_id();
711 instr1.clear_operand_ids();
712 instr2.clear_name();
713 instr2.clear_id();
714 instr2.clear_operand_ids();
715 // The names of instructions were uniquified by the XlaBuilder and the
716 // unique ids may be different, the rest of the fields should be
717 // identical.
718 string str1, str2;
719 LOG(INFO) << "instr1 = " << instr1.DebugString();
720 LOG(INFO) << "instr2 = " << instr2.DebugString();
721 instr1.AppendPartialToString(&str1);
722 instr2.AppendPartialToString(&str2);
723 EXPECT_EQ(str1, str2);
724 }
725 }
726 }
727 }
728
729 // Tests a computation that receives a TensorArray resource as input and
730 // updates it.
TEST_F(XlaCompilerTest,CanPassTensorArraysToAndFromComputation)731 TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
732 Scope scope = Scope::NewRootScope().ExitOnError();
733 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
734 auto flow = ops::Const<float>(scope, {});
735 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
736 auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
737 auto index = ops::Const<int32>(scope, 1);
738 auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
739 grad2.flow_out);
740 auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
741 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
742 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
743 TF_ASSERT_OK(scope.ToGraph(graph.get()));
744
745 // Builds a description of the arguments.
746 std::vector<XlaCompiler::Argument> args(1);
747 args[0].kind = XlaCompiler::Argument::kResource;
748 args[0].resource_kind = XlaResource::kTensorArray;
749 args[0].initialized = true;
750 args[0].type = DT_INT32;
751 args[0].shape = TensorShape({});
752 args[0].max_array_size = 2;
753 args[0].tensor_array_gradients = {"grad2"};
754
755 // Compiles the graph.
756 XlaCompiler compiler(DefaultOptions());
757
758 XlaCompiler::CompilationResult result;
759 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
760 std::move(graph), args,
761 /*user_aliases=*/{}, &result));
762
763 ASSERT_EQ(1, result.resource_updates.size());
764 const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
765 EXPECT_EQ(0, update.input_index);
766 EXPECT_EQ(DT_INT32, update.type);
767 EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
768 update.tensor_array_gradients_accessed);
769
770 // Tests that the generated computation works.
771 xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
772 xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
773 xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
774 std::unique_ptr<xla::GlobalData> param0_data =
775 client_->TransferToServer(input).ConsumeValueOrDie();
776
777 std::unique_ptr<xla::GlobalData> actual =
778 client_->Execute(*result.computation, {param0_data.get()})
779 .ConsumeValueOrDie();
780 xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
781
782 xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
783 xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
784 xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
785 xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
786 xla::Literal output_resource =
787 xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
788 xla::Literal expected_literal =
789 xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
790 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
791 }
792
793 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,UnwrittenTensorArrayGradientsAreNotComputationOutputs)794 TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
795 Scope scope = Scope::NewRootScope().ExitOnError();
796 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
797 auto flow = ops::Const<float>(scope, {});
798 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
799 auto index = ops::Const<int32>(scope, 1);
800 auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
801 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
802 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
803 TF_ASSERT_OK(scope.ToGraph(graph.get()));
804
805 // Builds a description of the arguments.
806 std::vector<XlaCompiler::Argument> args(1);
807 args[0].kind = XlaCompiler::Argument::kResource;
808 args[0].resource_kind = XlaResource::kTensorArray;
809 args[0].initialized = true;
810 args[0].type = DT_INT32;
811 args[0].shape = TensorShape({});
812 args[0].max_array_size = 2;
813 args[0].tensor_array_gradients = {"grad1"};
814
815 // Compiles the graph.
816 XlaCompiler compiler(DefaultOptions());
817
818 XlaCompiler::CompilationResult result;
819 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
820 std::move(graph), args,
821 /*user_aliases=*/{}, &result));
822
823 EXPECT_EQ(0, result.resource_updates.size());
824 }
825
826 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,NewTensorArrayGradientsAreComputationOutputs)827 TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
828 Scope scope = Scope::NewRootScope().ExitOnError();
829 auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
830 auto flow = ops::Const<float>(scope, {});
831 auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
832 auto index = ops::Const<int32>(scope, 1);
833 auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
834 auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
835 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
836 TF_ASSERT_OK(scope.ToGraph(graph.get()));
837
838 // Builds a description of the arguments.
839 std::vector<XlaCompiler::Argument> args(1);
840 args[0].kind = XlaCompiler::Argument::kResource;
841 args[0].resource_kind = XlaResource::kTensorArray;
842 args[0].initialized = true;
843 args[0].type = DT_INT32;
844 args[0].shape = TensorShape({});
845 args[0].max_array_size = 2;
846 args[0].tensor_array_gradients = {"grad1"};
847
848 // Compiles the graph.
849 XlaCompiler compiler(DefaultOptions());
850
851 XlaCompiler::CompilationResult result;
852 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
853 std::move(graph), args,
854 /*user_aliases=*/{}, &result));
855
856 EXPECT_EQ(1, result.resource_updates.size());
857 }
858
859 // Tests CompileFunction with undefined function fails.
TEST_F(XlaCompilerTest,UndefinedFunctionFails)860 TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
861 XlaCompiler compiler(DefaultOptions());
862
863 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
864 XlaCompiler::CompilationResult result;
865 NameAttrList name_attr;
866 name_attr.set_name("Function_NotDefined_");
867 Status status =
868 compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
869 /*args=*/{}, &result);
870 EXPECT_FALSE(status.ok());
871 EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
872 << status.error_message();
873 }
874
FillFn()875 FunctionDef FillFn() {
876 return FunctionDefHelper::Define(
877 // Name
878 "FillFn",
879 // Args
880 {"x: T", "dims: int32"},
881 // Return values
882 {"y: T"},
883 // Attr def
884 {"T: {float, double, int32, int64}"},
885 // Nodes
886 {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}});
887 }
888
TEST_F(XlaCompilerTest,FunctionCallWithConstants)889 TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
890 // Certain operations in a function, "Fill" for example, requires the
891 // operator's argument to be a compile-time constant instead of a parameter.
892 // This testcase tests if XlaCompiler can handle such operators inside
893 // function calls.
894 XlaCompiler compiler(DefaultOptions());
895
896 FunctionDefLibrary flib;
897 *flib.add_function() = FillFn();
898
899 TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn()));
900
901 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
902
903 Scope scope = Scope::NewRootScope().ExitOnError();
904 auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
905 auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
906 TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
907
908 NodeDef def;
909 TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get())
910 .Input(value.name(), 0, DT_INT32)
911 .Input(shape.name(), 1, DT_INT32)
912 .Finalize(&def));
913 Status status;
914 Node* fill = scope.graph()->AddNode(def, &status);
915 TF_ASSERT_OK(status);
916 TF_ASSERT_OK(scope.DoShapeInference(fill));
917 scope.graph()->AddEdge(value.node(), 0, fill, 0);
918 scope.graph()->AddEdge(shape.node(), 0, fill, 1);
919
920 auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
921
922 TF_ASSERT_OK(scope.ToGraph(graph.get()));
923
924 // Builds a description of the argument.
925 std::vector<XlaCompiler::Argument> args;
926
927 XlaCompiler::CompilationResult result;
928 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
929 std::move(graph), args,
930 /*user_aliases=*/{}, &result));
931 }
932
933 // Tests CompileFunction with a local function lookup failing, fails with
934 // informative error about both lookups.
TEST_F(XlaCompilerTest,LocalFunctionWithWrongArgumentsFail)935 TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
936 XlaCompiler compiler(DefaultOptions());
937
938 auto local_flib_def = LocalFlibDef(&compiler);
939 TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
940
941 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
942 XlaCompiler::CompilationResult result;
943 NameAttrList name_attr;
944 name_attr.set_name("XTimesTwo");
945 Status status =
946 compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
947 /*args=*/{}, &result);
948
949 ASSERT_FALSE(status.ok());
950 // Flib lookup failure.
951 EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
952 << status.error_message();
953 // Local flib lookup failure.
954 EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
955 << status.error_message();
956 }
957
RunAndCheckVariablesComputation(xla::Client * client,const XlaCompiler::CompilationResult & result)958 void RunAndCheckVariablesComputation(
959 xla::Client* client, const XlaCompiler::CompilationResult& result) {
960 xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
961 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
962 std::unique_ptr<xla::GlobalData> param0_data =
963 client->TransferToServer(param0_literal).ConsumeValueOrDie();
964 std::unique_ptr<xla::GlobalData> param1_data =
965 client->TransferToServer(param1_literal).ConsumeValueOrDie();
966
967 std::unique_ptr<xla::GlobalData> actual =
968 client
969 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
970 .ConsumeValueOrDie();
971 xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
972
973 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
974 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
975 xla::Literal expected_literal =
976 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
977 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
978 }
979
980 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,Variables)981 TEST_F(XlaCompilerTest, Variables) {
982 Scope scope = Scope::NewRootScope().ExitOnError();
983 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
984 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
985 // Adds an identity op around the resource to make sure identity ops propagate
986 // resources correctly.
987 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
988 auto write = ops::AssignAddVariableOp(scope, identity, a);
989 auto read = ops::ReadVariableOp(
990 scope.WithControlDependencies(std::vector<Operation>{write}), var,
991 DT_INT32);
992 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
993 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
994 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
995 TF_ASSERT_OK(scope.ToGraph(graph.get()));
996
997 // Builds a description of the arguments.
998 std::vector<XlaCompiler::Argument> args(2);
999 args[0].kind = XlaCompiler::Argument::kParameter;
1000 args[0].type = DT_INT32;
1001 args[0].shape = TensorShape({2});
1002 args[1].kind = XlaCompiler::Argument::kResource;
1003 args[1].resource_kind = XlaResource::kVariable;
1004 args[1].initialized = true;
1005 args[1].type = DT_INT32;
1006 args[1].shape = TensorShape({2});
1007
1008 // Compiles the graph.
1009 XlaCompiler compiler(DefaultOptions());
1010
1011 XlaCompiler::CompilationResult result;
1012 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1013 std::move(graph), args,
1014 /*user_aliases=*/{}, &result));
1015 RunAndCheckVariablesComputation(client_, result);
1016 }
1017
TEST_F(XlaCompilerTest,ResultLayoutSingle)1018 TEST_F(XlaCompilerTest, ResultLayoutSingle) {
1019 Scope scope = Scope::NewRootScope().ExitOnError();
1020 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1021 auto b = ops::_Retval(scope.WithOpName("RET"), a, 0);
1022
1023 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1024 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1025
1026 // Builds a description of the arguments.
1027 std::vector<XlaCompiler::Argument> args(1);
1028 args[0].kind = XlaCompiler::Argument::kParameter;
1029 args[0].type = DT_INT32;
1030 args[0].shape = TensorShape({2, 3});
1031
1032 auto options = DefaultOptions();
1033 // Sets the representation function to return a non-default layout.
1034 options.shape_representation_fn =
1035 [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1036 xla::Shape xla_shape;
1037 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1038 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1039 return xla_shape;
1040 };
1041
1042 // Compiles the graph.
1043 XlaCompiler compiler(options);
1044
1045 XlaCompiler::CompilationResult result;
1046 auto compile_options = XlaCompiler::CompileOptions();
1047 compile_options.always_return_tuple = false;
1048 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
1049 args, /*user_aliases=*/{}, &result));
1050 EXPECT_TRUE(xla::ShapeUtil::Equal(
1051 result.xla_output_shape,
1052 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
1053 }
1054
TEST_F(XlaCompilerTest,ResultLayoutMultiple)1055 TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
1056 Scope scope = Scope::NewRootScope().ExitOnError();
1057 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1058 auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0);
1059 auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1);
1060
1061 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1062 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1063
1064 // Builds a description of the arguments.
1065 std::vector<XlaCompiler::Argument> args(1);
1066 args[0].kind = XlaCompiler::Argument::kParameter;
1067 args[0].type = DT_INT32;
1068 args[0].shape = TensorShape({2, 3});
1069
1070 auto options = DefaultOptions();
1071 // Sets the representation function to return a non-default layout.
1072 options.shape_representation_fn =
1073 [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1074 xla::Shape xla_shape;
1075 TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1076 *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1077 return xla_shape;
1078 };
1079
1080 // Compiles the graph.
1081 XlaCompiler compiler(options);
1082
1083 XlaCompiler::CompilationResult result;
1084 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
1085 std::move(graph), args,
1086 /*user_aliases=*/{}, &result));
1087 xla::Shape result_shape =
1088 xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
1089
1090 EXPECT_TRUE(xla::ShapeUtil::Equal(
1091 result.xla_output_shape,
1092 xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
1093 }
1094
1095 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,ReturnResourceHandleOnly)1096 TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
1097 Scope scope = Scope::NewRootScope().ExitOnError();
1098 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
1099 auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
1100 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1101 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1102
1103 // Builds a description of the arguments.
1104 std::vector<XlaCompiler::Argument> args(1);
1105 args[0].kind = XlaCompiler::Argument::kResource;
1106 args[0].resource_kind = XlaResource::kVariable;
1107 args[0].initialized = true;
1108 args[0].type = DT_INT32;
1109 args[0].shape = TensorShape({2});
1110
1111 // Compiles the graph.
1112 XlaCompiler compiler(DefaultOptions());
1113
1114 XlaCompiler::CompilationResult result;
1115 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1116 std::move(graph), args,
1117 /*user_aliases=*/{}, &result));
1118
1119 // Tests that the generated computation works.
1120 xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
1121 std::unique_ptr<xla::GlobalData> param1_data =
1122 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
1123
1124 std::unique_ptr<xla::GlobalData> actual =
1125 client_->Execute(*result.computation, {param1_data.get()})
1126 .ConsumeValueOrDie();
1127 xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
1128
1129 xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
1130 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1131 }
1132
TEST_F(XlaCompilerTest,ReturnResourceHandle)1133 TEST_F(XlaCompilerTest, ReturnResourceHandle) {
1134 Scope scope = Scope::NewRootScope().ExitOnError();
1135 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1136 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1137 // Adds an identity op around the resource to make sure identity ops propagate
1138 // resources correctly.
1139 auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
1140 auto write = ops::AssignAddVariableOp(scope, identity, a);
1141 auto read = ops::ReadVariableOp(
1142 scope.WithControlDependencies(std::vector<Operation>{write}), var,
1143 DT_INT32);
1144 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1145 auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
1146 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
1147
1148 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1149 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1150
1151 // Builds a description of the arguments.
1152 std::vector<XlaCompiler::Argument> args(2);
1153 args[0].kind = XlaCompiler::Argument::kParameter;
1154 args[0].type = DT_INT32;
1155 args[0].shape = TensorShape({2});
1156 args[1].kind = XlaCompiler::Argument::kResource;
1157 args[1].resource_kind = XlaResource::kVariable;
1158 args[1].initialized = true;
1159 args[1].type = DT_INT32;
1160 args[1].shape = TensorShape({2});
1161
1162 // Compiles the graph.
1163 XlaCompiler compiler(DefaultOptions());
1164
1165 XlaCompiler::CompilationResult result;
1166 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1167 std::move(graph), args,
1168 /*user_aliases=*/{}, &result));
1169 RunAndCheckVariablesComputation(client_, result);
1170 }
1171
BuildTestGraph()1172 xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
1173 Scope scope = Scope::NewRootScope().ExitOnError();
1174 auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1175 auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1176 auto write = ops::AssignAddVariableOp(scope, var, a);
1177 auto read = ops::ReadVariableOp(
1178 scope.WithControlDependencies(std::vector<Operation>{write}), var,
1179 DT_INT32);
1180 auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1181 auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
1182 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1183 TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
1184 return std::move(graph);
1185 }
1186
1187 // Tests a simple graph that reads and writes a variable, with a
1188 // shape_representation_fn passed to the compiler that flattens all
1189 // variable tensors to vectors.
TEST_F(XlaCompilerTest,VariableRepresentationShapeFunction)1190 TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
1191 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1192
1193 // Builds a description of the arguments.
1194 std::vector<XlaCompiler::Argument> args(2);
1195 args[0].kind = XlaCompiler::Argument::kParameter;
1196 args[0].type = DT_INT32;
1197 args[0].shape = TensorShape({2, 2});
1198 args[1].kind = XlaCompiler::Argument::kResource;
1199 args[1].resource_kind = XlaResource::kVariable;
1200 args[1].initialized = true;
1201 args[1].type = DT_INT32;
1202 args[1].shape = TensorShape({2, 2});
1203
1204 // Compiles the graph.
1205 XlaCompiler::Options options = DefaultOptions();
1206 options.shape_representation_fn =
1207 [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1208 xla::PrimitiveType ptype;
1209 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1210 return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1211 };
1212 XlaCompiler compiler(options);
1213
1214 XlaCompiler::CompileOptions compile_options;
1215 compile_options.is_entry_computation = false; // Only reshape variables.
1216
1217 XlaCompiler::CompilationResult result;
1218 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1219 args, /*user_aliases=*/{}, &result));
1220
1221 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1222 client_->GetComputationShape(*result.computation));
1223
1224 ASSERT_EQ(program_shape->parameters_size(), 2);
1225 EXPECT_TRUE(
1226 xla::ShapeUtil::Compatible(program_shape->parameters(0),
1227 xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
1228 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1229 program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1230 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1231 program_shape->result(),
1232 xla::ShapeUtil::MakeTupleShape(
1233 {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
1234 xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1235
1236 // Tests that the generated computation works.
1237 xla::Literal param0_literal =
1238 xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
1239 xla::Literal param1_literal =
1240 xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1241 std::unique_ptr<xla::GlobalData> param0_data =
1242 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
1243 std::unique_ptr<xla::GlobalData> param1_data =
1244 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
1245
1246 std::unique_ptr<xla::GlobalData> actual =
1247 client_
1248 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1249 .ConsumeValueOrDie();
1250 xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
1251
1252 xla::Literal expected0 =
1253 xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
1254 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1255 xla::Literal expected_literal =
1256 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1257 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1258 }
1259
TEST_F(XlaCompilerTest,ArgRetvalShapeRepresentationFunction)1260 TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
1261 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1262
1263 // Builds a description of the arguments.
1264 std::vector<XlaCompiler::Argument> args(2);
1265 args[0].kind = XlaCompiler::Argument::kParameter;
1266 args[0].type = DT_INT32;
1267 args[0].shape = TensorShape({2, 2});
1268 args[1].kind = XlaCompiler::Argument::kResource;
1269 args[1].resource_kind = XlaResource::kVariable;
1270 args[1].initialized = true;
1271 args[1].type = DT_INT32;
1272 args[1].shape = TensorShape({2, 2});
1273
1274 // Compiles the graph.
1275 XlaCompiler::Options options = DefaultOptions();
1276 options.shape_representation_fn =
1277 [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1278 xla::PrimitiveType ptype;
1279 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1280 return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1281 };
1282 XlaCompiler compiler(options);
1283
1284 XlaCompiler::CompileOptions compile_options;
1285 compile_options.is_entry_computation = true; // Reshape args and retvals.
1286
1287 XlaCompiler::CompilationResult result;
1288 TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1289 args, /*user_aliases=*/{}, &result));
1290
1291 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1292 client_->GetComputationShape(*result.computation));
1293
1294 ASSERT_EQ(program_shape->parameters_size(), 2);
1295 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1296 program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1297 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1298 program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1299 EXPECT_TRUE(xla::ShapeUtil::Compatible(
1300 program_shape->result(),
1301 xla::ShapeUtil::MakeTupleShape(
1302 {xla::ShapeUtil::MakeShape(xla::S32, {4}),
1303 xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1304
1305 // Tests that the generated computation works.
1306 xla::Literal param0_literal =
1307 xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
1308 xla::Literal param1_literal =
1309 xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1310 std::unique_ptr<xla::GlobalData> param0_data =
1311 client_->TransferToServer(param0_literal).ConsumeValueOrDie();
1312 std::unique_ptr<xla::GlobalData> param1_data =
1313 client_->TransferToServer(param1_literal).ConsumeValueOrDie();
1314
1315 std::unique_ptr<xla::GlobalData> actual =
1316 client_
1317 ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1318 .ConsumeValueOrDie();
1319 xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
1320
1321 xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
1322 xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1323 xla::Literal expected_literal =
1324 xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1325 EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1326 }
1327
1328 // Tests a graph which has a function with an invalid op.
TEST_F(XlaCompilerTest,FunctionWithInvalidOp)1329 TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
1330 XlaCompiler compiler(DefaultOptions());
1331
1332 FunctionDefLibrary flib;
1333 FunctionDef fn = FillFn();
1334 NodeDef* node = fn.add_node_def();
1335 node->set_name("Invalid");
1336 node->set_op("InvalidOp"); /* unsupported op */
1337 node = fn.add_node_def();
1338 node->set_name("Switch");
1339 node->set_op("Switch"); /* control flow node */
1340 *flib.add_function() = fn;
1341
1342 TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
1343
1344 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1345
1346 Scope scope = Scope::NewRootScope().ExitOnError();
1347 auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
1348 auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
1349 TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
1350
1351 NodeDef def;
1352 TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
1353 .Input(value.name(), 0, DT_INT32)
1354 .Input(shape.name(), 1, DT_INT32)
1355 .Finalize(&def));
1356 Status status;
1357 Node* fill = scope.graph()->AddNode(def, &status);
1358 TF_ASSERT_OK(status);
1359 TF_ASSERT_OK(scope.DoShapeInference(fill));
1360 scope.graph()->AddEdge(value.node(), 0, fill, 0);
1361 scope.graph()->AddEdge(shape.node(), 0, fill, 1);
1362
1363 auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
1364
1365 TF_ASSERT_OK(scope.ToGraph(graph.get()));
1366
1367 std::vector<XlaCompiler::Argument> args;
1368 XlaCompiler::CompilationResult result;
1369 status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
1370 std::move(graph), args, /*user_aliases=*/{},
1371 &result);
1372 ASSERT_FALSE(status.ok());
1373 EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
1374 << status.error_message();
1375 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
1376 << status.error_message();
1377 }
1378
1379 // Tests a graph which has a node with invalid data type.
TEST_F(XlaCompilerTest,NodeWithInvalidDataType)1380 TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
1381 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1382 NodeDef shape;
1383 shape.set_name("Shape");
1384 shape.set_op("Shape");
1385 (*shape.mutable_attr())["T"].set_type(DT_INT32);
1386 (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
1387 Status status;
1388 Node* shape_node = graph->AddNode(shape, &status);
1389 TF_ASSERT_OK(status);
1390 graph->AddControlEdge(graph->source_node(), shape_node);
1391
1392 std::vector<XlaCompiler::Argument> args;
1393 XlaCompiler::CompilationResult result;
1394 XlaCompiler compiler(DefaultOptions());
1395 status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
1396 std::move(graph), args, /*user_aliases=*/{},
1397 &result);
1398 ASSERT_FALSE(status.ok());
1399 EXPECT_TRUE(absl::StrContains(status.error_message(),
1400 "is not in the list of allowed values"))
1401 << status.error_message();
1402 EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
1403 << status.error_message();
1404 }
1405
TEST_F(XlaCompilerTest,SingleOpWithoutInputs)1406 TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
1407 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1408 NodeDef no_op;
1409 no_op.set_name("NoOp");
1410 no_op.set_op("NoOp");
1411 Status status;
1412 graph->AddNode(no_op, &status);
1413 TF_ASSERT_OK(status);
1414
1415 std::vector<XlaCompiler::Argument> args;
1416 XlaCompiler compiler(DefaultOptions());
1417 // No control edge linking NoOp with source/sink.
1418 {
1419 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1420 CopyGraph(*graph, graph_copy.get());
1421 XlaCompiler::CompilationResult result;
1422 TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
1423 std::move(graph_copy), args,
1424 /*user_aliases=*/{}, &result));
1425 }
1426 }
1427
1428 class DummySideEffectingOp : public XlaOpKernel {
1429 public:
DummySideEffectingOp(OpKernelConstruction * ctx)1430 explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)1431 void Compile(XlaOpKernelContext* ctx) override {
1432 OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
1433 name(), xla::CreateToken(ctx->builder())));
1434 }
1435 };
1436
1437 REGISTER_OP("DummySideEffectingOp");
1438
1439 REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
1440
TEST_F(XlaCompilerTest,TokenInputAndOutput)1441 TEST_F(XlaCompilerTest, TokenInputAndOutput) {
1442 std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1443 NodeDef side_effecting_op;
1444 side_effecting_op.set_name("DummySideEffectingOp");
1445 side_effecting_op.set_op("DummySideEffectingOp");
1446 AddNodeAttr(kXlaTokenInputNodesAttrName,
1447 std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
1448 Status status;
1449 graph->AddNode(side_effecting_op, &status);
1450 TF_ASSERT_OK(status);
1451 EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
1452
1453 std::vector<XlaCompiler::Argument> args(1);
1454 args[0].kind = XlaCompiler::Argument::kResource;
1455 args[0].resource_kind = XlaResource::kVariable;
1456 args[0].initialized = true;
1457 args[0].type = DT_INT32;
1458 args[0].shape = TensorShape({2, 2});
1459
1460 {
1461 // The case for entry computation: we don't add token input/output. Instead,
1462 // we use CreateToken HLO to create the entry token.
1463 XlaCompiler::CompileOptions options;
1464 options.is_entry_computation = true;
1465 options.add_token_input_output = false;
1466 options.return_updated_values_for_all_resources = true;
1467 XlaCompiler compiler(DefaultOptions());
1468
1469 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1470 CopyGraph(*graph, graph_copy.get());
1471 XlaCompiler::CompilationResult result;
1472 TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1473 args, /*user_aliases=*/{}, &result));
1474 EXPECT_EQ(result.xla_input_shapes.size(), 1);
1475 EXPECT_TRUE(result.xla_output_shape.IsTuple());
1476 EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
1477 }
1478 {
1479 // The case for non-entry computation (e.g. while loop body). We add token
1480 // input/output.
1481 XlaCompiler::CompileOptions options;
1482 options.is_entry_computation = false;
1483 options.add_token_input_output = true;
1484 options.return_updated_values_for_all_resources = true;
1485 XlaCompiler compiler(DefaultOptions());
1486
1487 std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1488 CopyGraph(*graph, graph_copy.get());
1489 XlaCompiler::CompilationResult result;
1490 TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1491 args, /*user_aliases=*/{}, &result));
1492 EXPECT_EQ(result.xla_input_shapes.size(), 2);
1493 EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
1494 EXPECT_TRUE(result.xla_output_shape.IsTuple());
1495 EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2);
1496 EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1)
1497 .IsToken());
1498 }
1499 }
1500
1501 } // namespace
1502 } // namespace tensorflow
1503