1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
17 #include "absl/strings/match.h"
18 #include "tensorflow/cc/framework/ops.h"
19 #include "tensorflow/cc/ops/data_flow_ops.h"
20 #include "tensorflow/cc/ops/function_ops.h"
21 #include "tensorflow/cc/ops/resource_variable_ops.h"
22 #include "tensorflow/cc/ops/standard_ops.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
25 #include "tensorflow/compiler/tf2xla/type_util.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
28 #include "tensorflow/compiler/xla/client/client_library.h"
29 #include "tensorflow/compiler/xla/client/local_client.h"
30 #include "tensorflow/compiler/xla/client/xla_builder.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/status_macros.h"
34 #include "tensorflow/compiler/xla/tests/literal_test_util.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/common_shape_fns.h"
37 #include "tensorflow/core/framework/function.h"
38 #include "tensorflow/core/framework/function_testlib.h"
39 #include "tensorflow/core/framework/node_def_util.h"
40 #include "tensorflow/core/framework/resource_mgr.h"
41 #include "tensorflow/core/framework/tensor.h"
42 #include "tensorflow/core/framework/tensor_testutil.h"
43 #include "tensorflow/core/graph/algorithm.h"
44 #include "tensorflow/core/graph/graph.h"
45 #include "tensorflow/core/graph/graph_constructor.h"
46 #include "tensorflow/core/lib/core/status_test_util.h"
47 #include "tensorflow/core/platform/test.h"
48 #include "tensorflow/core/public/version.h"
49 
50 namespace tensorflow {
51 
52 class XlaCompilerTest : public ::testing::Test {
53  protected:
SetUp()54   void SetUp() override {
55     client_ = xla::ClientLibrary::LocalClientOrDie();
56 
57     XlaOpRegistry::RegisterCompilationKernels();
58 
59     FunctionDefLibrary flib;
60     flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib));
61   }
62 
DefaultOptions()63   XlaCompiler::Options DefaultOptions() {
64     XlaCompiler::Options options;
65     options.device_type = DeviceType(DEVICE_CPU_XLA_JIT);
66     options.client = client_;
67     options.flib_def = flib_def_.get();
68     return options;
69   }
70 
LocalFlibDef(XlaCompiler * compiler)71   FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) {
72     return compiler->local_flib_def_.get();
73   }
74 
75   xla::Client* client_;
76   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
77 };
78 
79 namespace {
80 
81 // Helper class to test the ability to pass resources through to XLA
82 // compiled kernels.
83 class DummyResourceForTest : public ResourceBase {
84  public:
DebugString() const85   string DebugString() const override { return "dummy"; }
Increment()86   void Increment() { ++value_; }
Get()87   int Get() { return value_; }
88 
89  private:
90   int value_ = 0;
91 };
92 
93 class DummyReadResourceOp : public XlaOpKernel {
94  public:
DummyReadResourceOp(OpKernelConstruction * ctx)95   explicit DummyReadResourceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)96   void Compile(XlaOpKernelContext* ctx) override {
97     ResourceMgr* rm = ctx->op_kernel_context()->resource_manager();
98     OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
99     DummyResourceForTest* dummy;
100     OP_REQUIRES_OK(ctx, rm->Lookup<DummyResourceForTest>(
101                             rm->default_container(), "dummy", &dummy));
102     dummy->Increment();
103     dummy->Unref();
104 
105     ctx->SetOutput(0, ctx->Input(0));
106     ctx->SetOutput(1, ctx->Input(0));
107   }
108 };
109 
110 class DummyReadResourceCC {
111  public:
DummyReadResourceCC(const Scope & scope,const Input & value)112   DummyReadResourceCC(const Scope& scope, const Input& value) {
113     if (!scope.ok()) return;
114     auto _value = ops::AsNodeOut(scope, value);
115     if (!scope.ok()) return;
116     Node* ret;
117     const auto unique_name = scope.GetUniqueNameForOp("DummyReadResource");
118     auto builder = NodeBuilder(unique_name, "DummyReadResource").Input(_value);
119     scope.UpdateBuilder(&builder);
120     scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
121     if (!scope.ok()) return;
122     scope.UpdateStatus(scope.DoShapeInference(ret));
123     if (!scope.ok()) return;
124     this->output1_ = Output(ret, 0);
125     this->output2_ = Output(ret, 1);
126   }
127 
128   Output output1_;
129   Output output2_;
130 };
131 
132 REGISTER_OP("DummyReadResource")
133     .Input("input: int32")
134     .Output("output1: int32")
135     .Output("output2: int32")
136     .SetShapeFn(shape_inference::UnknownShape)
137     .Doc(R"doc(
138 A dummy Op.
139 
140 input: dummy input.
141 output1: dummy output.
142 output2: dummy output.
143 )doc");
144 
145 REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp);
146 
147 // DummyDuplicateOp is present purely to test multiple REGISTER_XLA_OP calls
148 // on the same Op name below.
149 class DummyDuplicateOp : public XlaOpKernel {
150  public:
DummyDuplicateOp(OpKernelConstruction * ctx)151   explicit DummyDuplicateOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)152   void Compile(XlaOpKernelContext* ctx) override {
153     ctx->SetOutput(0, ctx->Input(0));
154   }
155 };
156 
157 REGISTER_OP("DummyDuplicateOp")
158     .Input("input: int32")
159     .Output("output: int32")
160     .Doc(R"doc(
161 A dummy Op.
162 
163 input: dummy input.
164 output: dummy output.
165 )doc");
166 
167 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT),
168                 DummyDuplicateOp);
169 REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT),
170                 DummyDuplicateOp);
171 
172 // Tests compilation and execution of an empty graph.
TEST_F(XlaCompilerTest,EmptyReturnValues)173 TEST_F(XlaCompilerTest, EmptyReturnValues) {
174   XlaCompiler compiler(DefaultOptions());
175 
176   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
177   XlaCompiler::CompilationResult result;
178   TF_ASSERT_OK(compiler.CompileGraph(
179       XlaCompiler::CompileOptions(), "add", std::move(graph),
180       /*args=*/{}, /*user_aliases=*/{}, &result));
181 
182   TF_ASSERT_OK(client_->Execute(*result.computation, {}).status());
183 }
184 
185 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,Simple)186 TEST_F(XlaCompilerTest, Simple) {
187   // Builds a graph that adds two Tensors.
188   Scope scope = Scope::NewRootScope().ExitOnError();
189   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
190   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
191   auto c = ops::Add(scope.WithOpName("C"), a, b);
192   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
193   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
194   TF_ASSERT_OK(scope.ToGraph(graph.get()));
195 
196   // Builds a description of the arguments.
197   std::vector<XlaCompiler::Argument> args(2);
198   args[0].kind = XlaCompiler::Argument::kParameter;
199   args[0].type = DT_INT32;
200   args[0].shape = TensorShape({2});
201   args[1].kind = XlaCompiler::Argument::kParameter;
202   args[1].type = DT_INT32;
203   args[1].shape = TensorShape({2});
204 
205   // Compiles the graph.
206   XlaCompiler compiler(DefaultOptions());
207 
208   XlaCompiler::CompilationResult result;
209   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
210                                      std::move(graph), args,
211                                      /*user_aliases=*/{}, &result));
212 
213   // Tests that the generated computation works.
214   xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
215   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
216   std::unique_ptr<xla::GlobalData> param0_data =
217       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
218   std::unique_ptr<xla::GlobalData> param1_data =
219       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
220 
221   std::unique_ptr<xla::GlobalData> actual =
222       client_
223           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
224           .ConsumeValueOrDie();
225   xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
226 
227   xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({4, 143});
228   xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
229   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
230 }
231 
232 // Tests compilation of a graph where the _Retval node is not necessarily last
233 // amongst the graph nodes in construction order, and always_return_tuple is
234 // false. Regression test for bug where the wrong value was returned.
TEST_F(XlaCompilerTest,OutOfOrderGraph)235 TEST_F(XlaCompilerTest, OutOfOrderGraph) {
236   Scope scope = Scope::NewRootScope().ExitOnError();
237   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
238   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
239   // The _Retval node is not last in construction order.
240   auto d = ops::_Retval(scope.WithOpName("D"), a, 0);
241   auto c = ops::Add(scope.WithOpName("C"), a, b);
242 
243   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
244   TF_ASSERT_OK(scope.ToGraph(graph.get()));
245 
246   // Builds a description of the arguments.
247   std::vector<XlaCompiler::Argument> args(2);
248   args[0].kind = XlaCompiler::Argument::kParameter;
249   args[0].type = DT_INT32;
250   args[0].shape = TensorShape({2});
251   args[1].kind = XlaCompiler::Argument::kParameter;
252   args[1].type = DT_INT32;
253   args[1].shape = TensorShape({2});
254 
255   // Compiles the graph.
256   XlaCompiler compiler(DefaultOptions());
257 
258   XlaCompiler::CompileOptions compile_options;
259   compile_options.always_return_tuple = false;
260   XlaCompiler::CompilationResult result;
261   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
262                                      args, /*user_aliases=*/{}, &result));
263 
264   // Tests that the generated computation works.
265   xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
266   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
267   std::unique_ptr<xla::GlobalData> param0_data =
268       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
269   std::unique_ptr<xla::GlobalData> param1_data =
270       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
271 
272   std::unique_ptr<xla::GlobalData> actual =
273       client_
274           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
275           .ConsumeValueOrDie();
276   xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
277 
278   EXPECT_TRUE(xla::LiteralTestUtil::Equal(param0_literal, actual_literal));
279 }
280 
281 // Tests that the compiler can correctly propagate the layout assigned by
282 // shape_representation_fn_ to return types.
TEST_F(XlaCompilerTest,HonorShapeRepresentationFnForRetVal)283 TEST_F(XlaCompilerTest, HonorShapeRepresentationFnForRetVal) {
284   Scope scope = Scope::NewRootScope().ExitOnError();
285   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
286   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
287   // Adds an identity op around the resource to make sure identity ops propagate
288   // resources correctly.
289   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
290   auto write = ops::AssignAddVariableOp(scope, identity, a);
291   auto read = ops::ReadVariableOp(
292       scope.WithControlDependencies(std::vector<Operation>{write}), var,
293       DT_INT32);
294   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
295   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
296   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
297   TF_ASSERT_OK(scope.ToGraph(graph.get()));
298 
299   // Builds a description of the arguments.
300   std::vector<XlaCompiler::Argument> args(2);
301   args[0].kind = XlaCompiler::Argument::kParameter;
302   args[0].type = DT_INT32;
303   args[0].shape = TensorShape({2, 3});
304   args[1].kind = XlaCompiler::Argument::kResource;
305   args[1].resource_kind = XlaResource::kVariable;
306   args[1].initialized = true;
307   args[1].type = DT_INT32;
308   args[1].shape = TensorShape({2, 3});
309 
310   auto options = DefaultOptions();
311   options.shape_representation_fn =
312       [](const TensorShape& shape, DataType dt) -> xla::StatusOr<xla::Shape> {
313     xla::Shape xla_shape;
314     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dt, shape, &xla_shape));
315     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
316     return xla_shape;
317   };
318   // Compiles the graph.
319   XlaCompiler compiler(options);
320 
321   XlaCompiler::CompilationResult result;
322   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
323                                      std::move(graph), args,
324                                      /*user_aliases=*/{}, &result));
325   xla::Shape transposed =
326       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
327   // Check that the return shapes are correctly tranposed.
328   EXPECT_EQ(result.xla_output_shape,
329             xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
330 }
331 
332 // The layout of resource variable shouldn't change after transpose
TEST_F(XlaCompilerTest,TransposeVariables)333 TEST_F(XlaCompilerTest, TransposeVariables) {
334   Scope scope = Scope::NewRootScope().ExitOnError();
335   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
336   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
337   // Adds an identity op around the resource to make sure identity ops propagate
338   // resources correctly.
339   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
340   auto write = ops::AssignAddVariableOp(scope, identity, a);
341   auto read = ops::ReadVariableOp(
342       scope.WithControlDependencies(std::vector<Operation>{write}), var,
343       DT_INT32);
344   auto transposed_read = ops::Transpose(scope, read, {1, 0});
345   auto reshape = ops::Reshape(scope, transposed_read, {2, 3});
346   auto d = ops::_Retval(scope.WithOpName("D"), reshape, 0);
347   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
348   TF_ASSERT_OK(scope.ToGraph(graph.get()));
349 
350   // Builds a description of the arguments.
351   std::vector<XlaCompiler::Argument> args(2);
352   args[0].kind = XlaCompiler::Argument::kParameter;
353   args[0].type = DT_INT32;
354   args[0].shape = TensorShape({2, 3});
355   args[1].kind = XlaCompiler::Argument::kResource;
356   args[1].resource_kind = XlaResource::kVariable;
357   args[1].initialized = true;
358   args[1].type = DT_INT32;
359   args[1].shape = TensorShape({2, 3});
360   // Compiles the graph.
361   XlaCompiler compiler(DefaultOptions());
362 
363   XlaCompiler::CompilationResult result;
364   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "transpose",
365                                      std::move(graph), args,
366                                      /*user_aliases=*/{}, &result));
367   xla::Shape transposed =
368       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {1, 0});
369   // Check that the return shapes are correctly tranposed.
370   EXPECT_EQ(result.xla_output_shape,
371             xla::ShapeUtil::MakeTupleShape({transposed, transposed}));
372 }
373 
374 // Tests that the compiler doesn't reorder the parameters.
TEST_F(XlaCompilerTest,MixedOrderArguments)375 TEST_F(XlaCompilerTest, MixedOrderArguments) {
376   for (bool swap_order : {false, true}) {
377     Scope scope = Scope::NewRootScope().ExitOnError();
378     auto var =
379         ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, swap_order ? 0 : 1);
380     auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, swap_order ? 1 : 0);
381     // Adds an identity op around the resource to make sure identity ops
382     // propagate resources correctly.
383     auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
384     auto write = ops::AssignAddVariableOp(scope, identity, a);
385     auto read = ops::ReadVariableOp(
386         scope.WithControlDependencies(std::vector<Operation>{write}), var,
387         DT_INT32);
388     auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
389     auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
390     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
391     TF_ASSERT_OK(scope.ToGraph(graph.get()));
392 
393     // Builds a description of the arguments.
394     std::vector<XlaCompiler::Argument> args(2);
395     args[0].kind = XlaCompiler::Argument::kParameter;
396     args[0].type = DT_INT32;
397     args[0].shape = TensorShape({2});
398     args[1].kind = XlaCompiler::Argument::kResource;
399     args[1].resource_kind = XlaResource::kVariable;
400     args[1].initialized = true;
401     args[1].type = DT_INT32;
402     args[1].shape = TensorShape({2});
403 
404     if (swap_order) {
405       // Even after swapping arguments, the compiler should maintain the new
406       // ordering of parameters.
407       std::swap(args[0], args[1]);
408     }
409     // Compiles the graph.
410     XlaCompiler compiler(DefaultOptions());
411 
412     XlaCompiler::CompileOptions compile_options;
413     compile_options.always_return_tuple = false;
414     XlaCompiler::CompilationResult result;
415     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
416                                        args, /*user_aliases=*/{}, &result));
417 
418     EXPECT_THAT(result.input_mapping, ::testing::ElementsAre(0, 1));
419   }
420 }
421 
TEST_F(XlaCompilerTest,HasSaneErrorOnNonCompileTimeConstantInputToReshape)422 TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) {
423   // Builds a graph that adds reshapes a tensor, but with the shape not
424   // statically known.
425   Scope scope = Scope::NewRootScope().ExitOnError();
426   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
427   auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
428   auto c = ops::Reshape(scope.WithOpName("C"), a, b);
429   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
430   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
431   TF_ASSERT_OK(scope.ToGraph(graph.get()));
432 
433   // Builds a description of the arguments.
434   std::vector<XlaCompiler::Argument> args(2);
435   args[0].kind = XlaCompiler::Argument::kParameter;
436   args[0].type = DT_INT32;
437   args[0].shape = TensorShape({2});
438   args[1].kind = XlaCompiler::Argument::kParameter;
439   args[1].type = DT_INT32;
440   args[1].shape = TensorShape({2});
441 
442   // Compiles the graph.
443   XlaCompiler compiler(DefaultOptions());
444 
445   XlaCompiler::CompilationResult result;
446   Status status = compiler.CompileGraph(XlaCompiler::CompileOptions(),
447                                         "reshape", std::move(graph), args,
448                                         /*user_aliases=*/{}, &result);
449   EXPECT_FALSE(status.ok());
450   EXPECT_TRUE(
451       absl::StrContains(status.error_message(), "depends on a parameter"))
452       << status.error_message();
453   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node C}}"))
454       << status.error_message();
455   EXPECT_TRUE(absl::StrContains(status.error_message(),
456                                 "must be a compile-time constant"))
457       << status.error_message();
458 }
459 
460 // Tests handling of compile-time constant outputs.
TEST_F(XlaCompilerTest,ConstantOutputs)461 TEST_F(XlaCompilerTest, ConstantOutputs) {
462   // Builds a graph with one compile-time constant output and one data-dependent
463   // output, i.e.,
464   // func(a) { b=7; c=-a; return b, c; }
465   Scope scope = Scope::NewRootScope().ExitOnError();
466   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
467   auto b = ops::Const<int32>(scope.WithOpName("B"), 7);
468   auto c = ops::Neg(scope.WithOpName("C"), a);
469   auto d = ops::_Retval(scope.WithOpName("D"), b, 0);
470   auto e = ops::_Retval(scope.WithOpName("E"), c, 1);
471   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
472   TF_ASSERT_OK(scope.ToGraph(graph.get()));
473 
474   // Builds a description of the arguments.
475   std::vector<XlaCompiler::Argument> args(1);
476   args[0].kind = XlaCompiler::Argument::kParameter;
477   args[0].type = DT_INT32;
478   args[0].shape = TensorShape({2});
479 
480   XlaCompiler::Options options = DefaultOptions();
481   XlaCompiler compiler(options);
482   {
483     // Compiles the graph, with resolve_compile_time_constants enabled.
484 
485     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
486     CopyGraph(*graph, graph_copy.get());
487 
488     XlaCompiler::CompileOptions compile_options;
489     compile_options.resolve_compile_time_constants = true;
490     XlaCompiler::CompilationResult result;
491     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
492                                        std::move(graph_copy), args,
493                                        /*user_aliases=*/{}, &result));
494 
495     ASSERT_EQ(2, result.outputs.size());
496     EXPECT_TRUE(result.outputs[0].is_constant);
497     test::ExpectTensorEqual<int32>(result.outputs[0].constant_value,
498                                    test::AsScalar(7));
499     EXPECT_FALSE(result.outputs[1].is_constant);
500 
501     // Tests that the generated computation works.
502     xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
503     std::unique_ptr<xla::GlobalData> param0_data =
504         client_->TransferToServer(param0_literal).ConsumeValueOrDie();
505 
506     std::unique_ptr<xla::GlobalData> actual =
507         client_->Execute(*result.computation, {param0_data.get()})
508             .ConsumeValueOrDie();
509     xla::Literal actual_literal =
510         client_->Transfer(*actual).ConsumeValueOrDie();
511 
512     xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
513     xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected0});
514     EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
515   }
516 
517   {
518     // Compiles the graph, with resolve_compile_time_constants disabled.
519     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
520     CopyGraph(*graph, graph_copy.get());
521 
522     XlaCompiler::CompileOptions compile_options;
523     compile_options.resolve_compile_time_constants = false;
524     XlaCompiler::CompilationResult result;
525     TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
526                                        std::move(graph_copy), args,
527                                        /*user_aliases=*/{}, &result));
528 
529     ASSERT_EQ(2, result.outputs.size());
530     EXPECT_FALSE(result.outputs[0].is_constant);
531     EXPECT_FALSE(result.outputs[1].is_constant);
532 
533     // Tests that the generated computation works.
534     xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
535     std::unique_ptr<xla::GlobalData> param0_data =
536         client_->TransferToServer(param0_literal).ConsumeValueOrDie();
537 
538     std::unique_ptr<xla::GlobalData> actual =
539         client_->Execute(*result.computation, {param0_data.get()})
540             .ConsumeValueOrDie();
541     xla::Literal actual_literal =
542         client_->Transfer(*actual).ConsumeValueOrDie();
543 
544     xla::Literal expected0 = xla::LiteralUtil::CreateR0<int32>(7);
545     xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({-7, -42});
546     xla::Literal expected =
547         xla::LiteralUtil::MakeTuple({&expected0, &expected1});
548     EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected, actual_literal));
549   }
550 }
551 
TEST_F(XlaCompilerTest,ConstantOutputsOfFunctionalNode)552 TEST_F(XlaCompilerTest, ConstantOutputsOfFunctionalNode) {
553   // Define a function with one compile-time constant output and one
554   // data-dependent output.
555   // @function.Defun(noinline=True)
556   // foo(a) {b=7; return b, a; }
557   const Tensor seven = test::AsScalar<int>(7);
558   FunctionDef fdef = FunctionDefHelper::Create(
559       "foo", {"a_0:int32"}, {"const:int32", "a:int32"}, {},
560       {
561           {{"Const"}, "Const", {}, {{"dtype", DT_INT32}, {"value", seven}}},
562       },
563       {{"a", "a_0"}, {"const", "Const:output:0"}});
564   (*fdef.mutable_attr())["_noinline"].set_b(true);
565   FunctionDefLibrary fdef_lib;
566   *(fdef_lib.add_function()) = fdef;
567   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
568   {
569     Scope scope = Scope::NewRootScope().ExitOnError();
570     TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(fdef_lib));
571     auto arg = ops::_Arg(scope.WithOpName("input_arg"), DT_INT32, 0);
572     NodeDef foo;
573     foo.set_name("foo");
574     foo.set_op("foo");
575     *foo.add_input() = "input_arg";
576     Status status;
577     scope.graph()->AddNode(foo, &status);
578     TF_ASSERT_OK(status);
579     NodeDef retval_1;
580     retval_1.set_name("retval_0");
581     retval_1.set_op(FunctionLibraryDefinition::kRetOp);
582     *retval_1.add_input() = "foo";
583     (*retval_1.mutable_attr())["T"].set_type(DT_INT32);
584     (*retval_1.mutable_attr())["index"].set_i(0);
585     scope.graph()->AddNode(retval_1, &status);
586     TF_ASSERT_OK(status);
587     NodeDef retval_2;
588     retval_2.set_name("retval_1");
589     retval_2.set_op(FunctionLibraryDefinition::kRetOp);
590     *retval_2.add_input() = "foo:1";
591     (*retval_2.mutable_attr())["T"].set_type(DT_INT32);
592     (*retval_2.mutable_attr())["index"].set_i(1);
593     scope.graph()->AddNode(retval_2, &status);
594     TF_ASSERT_OK(status);
595     TF_ASSERT_OK(scope.ToGraph(graph.get()));
596   }
597 
598   // Builds a description of the arguments.
599   std::vector<XlaCompiler::Argument> args(1);
600   args[0].kind = XlaCompiler::Argument::kParameter;
601   args[0].type = DT_INT32;
602   args[0].shape = TensorShape({1});
603 
604   XlaCompiler::Options options = DefaultOptions();
605   FunctionLibraryDefinition flib_def(OpRegistry::Global(), fdef_lib);
606   options.flib_def = &flib_def;
607   XlaCompiler compiler(options);
608 
609   XlaCompiler::CompileOptions compile_options;
610   compile_options.resolve_compile_time_constants = true;
611   XlaCompiler::CompilationResult result;
612   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "constants",
613                                      std::move(graph), args,
614                                      /*user_aliases=*/{}, &result));
615 
616   ASSERT_EQ(2, result.outputs.size());
617   EXPECT_TRUE(result.outputs[0].is_constant);
618   test::ExpectTensorEqual<int32>(result.outputs[0].constant_value,
619                                  test::AsScalar(7));
620   EXPECT_FALSE(result.outputs[1].is_constant);
621 }
622 
623 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,ResourceManager)624 TEST_F(XlaCompilerTest, ResourceManager) {
625   // Builds a graph that calls the dummy resource Op.
626   Scope scope = Scope::NewRootScope().ExitOnError();
627   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
628   auto b = DummyReadResourceCC(scope.WithOpName("B"), a);
629   auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_);
630   auto d = ops::_Retval(scope.WithOpName("D"), c, 0);
631   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
632   TF_ASSERT_OK(scope.ToGraph(graph.get()));
633 
634   // Builds a description of the argument.
635   std::vector<XlaCompiler::Argument> args(1);
636   args[0].kind = XlaCompiler::Argument::kParameter;
637   args[0].type = DT_INT32;
638   args[0].shape = TensorShape({2});
639 
640   DummyResourceForTest* resource = new DummyResourceForTest();
641 
642   // Compiles the graph.
643   auto options = DefaultOptions();
644   std::function<Status(ResourceMgr*)> populate_function =
645       [resource](ResourceMgr* rm) {
646         resource->Ref();
647         return rm->Create(rm->default_container(), "dummy", resource);
648       };
649   options.populate_resource_manager = &populate_function;
650   XlaCompiler compiler(options);
651 
652   EXPECT_EQ(0, resource->Get());
653 
654   XlaCompiler::CompilationResult result;
655   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
656                                      std::move(graph), args,
657                                      /*user_aliases=*/{}, &result));
658 
659   EXPECT_EQ(1, resource->Get());
660 
661   resource->Unref();
662 }
663 
664 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,DeterministicCompilation)665 TEST_F(XlaCompilerTest, DeterministicCompilation) {
666   // Builds a graph that contains a node with two output edges. The compiler
667   // should always traverse them in the same order.
668   const int64 test_count = 2;
669 
670   std::vector<XlaCompiler::CompilationResult> results(test_count);
671 
672   for (int64 i = 0; i < test_count; ++i) {
673     Scope scope = Scope::NewRootScope().ExitOnError();
674     auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
675     auto b = ops::Neg(scope.WithOpName("B"), a);
676     auto c = ops::Neg(scope.WithOpName("C"), a);
677     auto d = ops::Add(scope.WithOpName("D"), b, c);
678     auto e = ops::_Retval(scope.WithOpName("E"), d, 0);
679     std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
680     TF_ASSERT_OK(scope.ToGraph(graph.get()));
681 
682     // Builds a description of the argument.
683     std::vector<XlaCompiler::Argument> args(1);
684     args[0].kind = XlaCompiler::Argument::kParameter;
685     args[0].type = DT_INT32;
686     args[0].shape = TensorShape({2});
687 
688     // Compiles the graph.
689     auto options = DefaultOptions();
690     XlaCompiler compiler(options);
691 
692     TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy",
693                                        std::move(graph), args,
694                                        /*user_aliases=*/{}, &results[i]));
695   }
696 
697   for (int64 i = 1; i < test_count; ++i) {
698     const auto& m1 = results[i - 1].computation->proto();
699     const auto& m2 = results[i].computation->proto();
700     ASSERT_EQ(m1.computations_size(), m2.computations_size());
701     // Check if every hlo computation is the same.
702     for (int k = 0; k < m1.computations_size(); k++) {
703       const auto& c1 = m1.computations(k);
704       const auto& c2 = m2.computations(k);
705       ASSERT_EQ(c1.instructions_size(), c2.instructions_size());
706       for (int j = 0; j < c1.instructions_size(); j++) {
707         auto instr1 = c1.instructions(j);
708         auto instr2 = c2.instructions(j);
709         instr1.clear_name();
710         instr1.clear_id();
711         instr1.clear_operand_ids();
712         instr2.clear_name();
713         instr2.clear_id();
714         instr2.clear_operand_ids();
715         // The names of instructions were uniquified by the XlaBuilder and the
716         // unique ids may be different, the rest of the fields should be
717         // identical.
718         string str1, str2;
719         LOG(INFO) << "instr1 = " << instr1.DebugString();
720         LOG(INFO) << "instr2 = " << instr2.DebugString();
721         instr1.AppendPartialToString(&str1);
722         instr2.AppendPartialToString(&str2);
723         EXPECT_EQ(str1, str2);
724       }
725     }
726   }
727 }
728 
729 // Tests a computation that receives a TensorArray resource as input and
730 // updates it.
TEST_F(XlaCompilerTest,CanPassTensorArraysToAndFromComputation)731 TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) {
732   Scope scope = Scope::NewRootScope().ExitOnError();
733   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
734   auto flow = ops::Const<float>(scope, {});
735   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
736   auto grad2 = ops::TensorArrayGrad(scope, arg, grad1.flow_out, "grad2");
737   auto index = ops::Const<int32>(scope, 1);
738   auto write = ops::TensorArrayWrite(scope, grad1.grad_handle, index, index,
739                                      grad2.flow_out);
740   auto read = ops::TensorArrayRead(scope, arg, index, write.flow_out, DT_INT32);
741   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
742   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
743   TF_ASSERT_OK(scope.ToGraph(graph.get()));
744 
745   // Builds a description of the arguments.
746   std::vector<XlaCompiler::Argument> args(1);
747   args[0].kind = XlaCompiler::Argument::kResource;
748   args[0].resource_kind = XlaResource::kTensorArray;
749   args[0].initialized = true;
750   args[0].type = DT_INT32;
751   args[0].shape = TensorShape({});
752   args[0].max_array_size = 2;
753   args[0].tensor_array_gradients = {"grad2"};
754 
755   // Compiles the graph.
756   XlaCompiler compiler(DefaultOptions());
757 
758   XlaCompiler::CompilationResult result;
759   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
760                                      std::move(graph), args,
761                                      /*user_aliases=*/{}, &result));
762 
763   ASSERT_EQ(1, result.resource_updates.size());
764   const XlaCompiler::ResourceUpdate& update = result.resource_updates[0];
765   EXPECT_EQ(0, update.input_index);
766   EXPECT_EQ(DT_INT32, update.type);
767   EXPECT_EQ((std::set<string>{"grad1", "grad2"}),
768             update.tensor_array_gradients_accessed);
769 
770   // Tests that the generated computation works.
771   xla::Literal input_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
772   xla::Literal input_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
773   xla::Literal input = xla::LiteralUtil::MakeTuple({&input_base, &input_grad2});
774   std::unique_ptr<xla::GlobalData> param0_data =
775       client_->TransferToServer(input).ConsumeValueOrDie();
776 
777   std::unique_ptr<xla::GlobalData> actual =
778       client_->Execute(*result.computation, {param0_data.get()})
779           .ConsumeValueOrDie();
780   xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
781 
782   xla::Literal output_read = xla::LiteralUtil::CreateR0<int32>(42);
783   xla::Literal output_base = xla::LiteralUtil::CreateR1<int32>({7, 42});
784   xla::Literal output_grad1 = xla::LiteralUtil::CreateR1<int32>({0, 1});
785   xla::Literal output_grad2 = xla::LiteralUtil::CreateR1<int32>({-3, 101});
786   xla::Literal output_resource =
787       xla::LiteralUtil::MakeTuple({&output_base, &output_grad1, &output_grad2});
788   xla::Literal expected_literal =
789       xla::LiteralUtil::MakeTuple({&output_read, &output_resource});
790   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
791 }
792 
793 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,UnwrittenTensorArrayGradientsAreNotComputationOutputs)794 TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) {
795   Scope scope = Scope::NewRootScope().ExitOnError();
796   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
797   auto flow = ops::Const<float>(scope, {});
798   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad1");
799   auto index = ops::Const<int32>(scope, 1);
800   auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
801   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
802   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
803   TF_ASSERT_OK(scope.ToGraph(graph.get()));
804 
805   // Builds a description of the arguments.
806   std::vector<XlaCompiler::Argument> args(1);
807   args[0].kind = XlaCompiler::Argument::kResource;
808   args[0].resource_kind = XlaResource::kTensorArray;
809   args[0].initialized = true;
810   args[0].type = DT_INT32;
811   args[0].shape = TensorShape({});
812   args[0].max_array_size = 2;
813   args[0].tensor_array_gradients = {"grad1"};
814 
815   // Compiles the graph.
816   XlaCompiler compiler(DefaultOptions());
817 
818   XlaCompiler::CompilationResult result;
819   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
820                                      std::move(graph), args,
821                                      /*user_aliases=*/{}, &result));
822 
823   EXPECT_EQ(0, result.resource_updates.size());
824 }
825 
826 // Tests compilation and execution of a graph that adds two tensors.
TEST_F(XlaCompilerTest,NewTensorArrayGradientsAreComputationOutputs)827 TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) {
828   Scope scope = Scope::NewRootScope().ExitOnError();
829   auto arg = ops::_Arg(scope.WithOpName("arg"), DT_RESOURCE, 0);
830   auto flow = ops::Const<float>(scope, {});
831   auto grad1 = ops::TensorArrayGrad(scope, arg, flow, "grad2");
832   auto index = ops::Const<int32>(scope, 1);
833   auto read = ops::TensorArrayRead(scope, arg, index, grad1.flow_out, DT_INT32);
834   auto retval = ops::_Retval(scope.WithOpName("retval"), read, 0);
835   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
836   TF_ASSERT_OK(scope.ToGraph(graph.get()));
837 
838   // Builds a description of the arguments.
839   std::vector<XlaCompiler::Argument> args(1);
840   args[0].kind = XlaCompiler::Argument::kResource;
841   args[0].resource_kind = XlaResource::kTensorArray;
842   args[0].initialized = true;
843   args[0].type = DT_INT32;
844   args[0].shape = TensorShape({});
845   args[0].max_array_size = 2;
846   args[0].tensor_array_gradients = {"grad1"};
847 
848   // Compiles the graph.
849   XlaCompiler compiler(DefaultOptions());
850 
851   XlaCompiler::CompilationResult result;
852   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
853                                      std::move(graph), args,
854                                      /*user_aliases=*/{}, &result));
855 
856   EXPECT_EQ(1, result.resource_updates.size());
857 }
858 
859 // Tests CompileFunction with undefined function fails.
TEST_F(XlaCompilerTest,UndefinedFunctionFails)860 TEST_F(XlaCompilerTest, UndefinedFunctionFails) {
861   XlaCompiler compiler(DefaultOptions());
862 
863   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
864   XlaCompiler::CompilationResult result;
865   NameAttrList name_attr;
866   name_attr.set_name("Function_NotDefined_");
867   Status status =
868       compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
869                                /*args=*/{}, &result);
870   EXPECT_FALSE(status.ok());
871   EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
872       << status.error_message();
873 }
874 
FillFn()875 FunctionDef FillFn() {
876   return FunctionDefHelper::Define(
877       // Name
878       "FillFn",
879       // Args
880       {"x: T", "dims: int32"},
881       // Return values
882       {"y: T"},
883       // Attr def
884       {"T: {float, double, int32, int64}"},
885       // Nodes
886       {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}});
887 }
888 
TEST_F(XlaCompilerTest,FunctionCallWithConstants)889 TEST_F(XlaCompilerTest, FunctionCallWithConstants) {
890   // Certain operations in a function, "Fill" for example, requires the
891   // operator's argument to be a compile-time constant instead of a parameter.
892   // This testcase tests if XlaCompiler can handle such operators inside
893   // function calls.
894   XlaCompiler compiler(DefaultOptions());
895 
896   FunctionDefLibrary flib;
897   *flib.add_function() = FillFn();
898 
899   TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn()));
900 
901   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
902 
903   Scope scope = Scope::NewRootScope().ExitOnError();
904   auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
905   auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
906   TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib));
907 
908   NodeDef def;
909   TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get())
910                    .Input(value.name(), 0, DT_INT32)
911                    .Input(shape.name(), 1, DT_INT32)
912                    .Finalize(&def));
913   Status status;
914   Node* fill = scope.graph()->AddNode(def, &status);
915   TF_ASSERT_OK(status);
916   TF_ASSERT_OK(scope.DoShapeInference(fill));
917   scope.graph()->AddEdge(value.node(), 0, fill, 0);
918   scope.graph()->AddEdge(shape.node(), 0, fill, 1);
919 
920   auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
921 
922   TF_ASSERT_OK(scope.ToGraph(graph.get()));
923 
924   // Builds a description of the argument.
925   std::vector<XlaCompiler::Argument> args;
926 
927   XlaCompiler::CompilationResult result;
928   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
929                                      std::move(graph), args,
930                                      /*user_aliases=*/{}, &result));
931 }
932 
933 // Tests CompileFunction with a local function lookup failing, fails with
934 // informative error about both lookups.
TEST_F(XlaCompilerTest,LocalFunctionWithWrongArgumentsFail)935 TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) {
936   XlaCompiler compiler(DefaultOptions());
937 
938   auto local_flib_def = LocalFlibDef(&compiler);
939   TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo()));
940 
941   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
942   XlaCompiler::CompilationResult result;
943   NameAttrList name_attr;
944   name_attr.set_name("XTimesTwo");
945   Status status =
946       compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr,
947                                /*args=*/{}, &result);
948 
949   ASSERT_FALSE(status.ok());
950   // Flib lookup failure.
951   EXPECT_TRUE(absl::StrContains(status.error_message(), "is not defined."))
952       << status.error_message();
953   // Local flib lookup failure.
954   EXPECT_TRUE(absl::StrContains(status.error_message(), "Attr T is not found"))
955       << status.error_message();
956 }
957 
RunAndCheckVariablesComputation(xla::Client * client,const XlaCompiler::CompilationResult & result)958 void RunAndCheckVariablesComputation(
959     xla::Client* client, const XlaCompiler::CompilationResult& result) {
960   xla::Literal param0_literal = xla::LiteralUtil::CreateR1<int32>({7, 42});
961   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
962   std::unique_ptr<xla::GlobalData> param0_data =
963       client->TransferToServer(param0_literal).ConsumeValueOrDie();
964   std::unique_ptr<xla::GlobalData> param1_data =
965       client->TransferToServer(param1_literal).ConsumeValueOrDie();
966 
967   std::unique_ptr<xla::GlobalData> actual =
968       client
969           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
970           .ConsumeValueOrDie();
971   xla::Literal actual_literal = client->Transfer(*actual).ConsumeValueOrDie();
972 
973   xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({5, 144});
974   xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({4, 143});
975   xla::Literal expected_literal =
976       xla::LiteralUtil::MakeTuple({&expected0, &expected1});
977   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
978 }
979 
980 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,Variables)981 TEST_F(XlaCompilerTest, Variables) {
982   Scope scope = Scope::NewRootScope().ExitOnError();
983   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
984   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
985   // Adds an identity op around the resource to make sure identity ops propagate
986   // resources correctly.
987   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
988   auto write = ops::AssignAddVariableOp(scope, identity, a);
989   auto read = ops::ReadVariableOp(
990       scope.WithControlDependencies(std::vector<Operation>{write}), var,
991       DT_INT32);
992   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
993   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
994   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
995   TF_ASSERT_OK(scope.ToGraph(graph.get()));
996 
997   // Builds a description of the arguments.
998   std::vector<XlaCompiler::Argument> args(2);
999   args[0].kind = XlaCompiler::Argument::kParameter;
1000   args[0].type = DT_INT32;
1001   args[0].shape = TensorShape({2});
1002   args[1].kind = XlaCompiler::Argument::kResource;
1003   args[1].resource_kind = XlaResource::kVariable;
1004   args[1].initialized = true;
1005   args[1].type = DT_INT32;
1006   args[1].shape = TensorShape({2});
1007 
1008   // Compiles the graph.
1009   XlaCompiler compiler(DefaultOptions());
1010 
1011   XlaCompiler::CompilationResult result;
1012   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1013                                      std::move(graph), args,
1014                                      /*user_aliases=*/{}, &result));
1015   RunAndCheckVariablesComputation(client_, result);
1016 }
1017 
TEST_F(XlaCompilerTest,ResultLayoutSingle)1018 TEST_F(XlaCompilerTest, ResultLayoutSingle) {
1019   Scope scope = Scope::NewRootScope().ExitOnError();
1020   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1021   auto b = ops::_Retval(scope.WithOpName("RET"), a, 0);
1022 
1023   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1024   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1025 
1026   // Builds a description of the arguments.
1027   std::vector<XlaCompiler::Argument> args(1);
1028   args[0].kind = XlaCompiler::Argument::kParameter;
1029   args[0].type = DT_INT32;
1030   args[0].shape = TensorShape({2, 3});
1031 
1032   auto options = DefaultOptions();
1033   // Sets the representation function to return a non-default layout.
1034   options.shape_representation_fn =
1035       [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1036     xla::Shape xla_shape;
1037     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1038     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1039     return xla_shape;
1040   };
1041 
1042   // Compiles the graph.
1043   XlaCompiler compiler(options);
1044 
1045   XlaCompiler::CompilationResult result;
1046   auto compile_options = XlaCompiler::CompileOptions();
1047   compile_options.always_return_tuple = false;
1048   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "id", std::move(graph),
1049                                      args, /*user_aliases=*/{}, &result));
1050   EXPECT_TRUE(xla::ShapeUtil::Equal(
1051       result.xla_output_shape,
1052       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1})));
1053 }
1054 
TEST_F(XlaCompilerTest,ResultLayoutMultiple)1055 TEST_F(XlaCompilerTest, ResultLayoutMultiple) {
1056   Scope scope = Scope::NewRootScope().ExitOnError();
1057   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1058   auto b = ops::_Retval(scope.WithOpName("RET1"), a, 0);
1059   auto c = ops::_Retval(scope.WithOpName("RET2"), a, 1);
1060 
1061   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1062   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1063 
1064   // Builds a description of the arguments.
1065   std::vector<XlaCompiler::Argument> args(1);
1066   args[0].kind = XlaCompiler::Argument::kParameter;
1067   args[0].type = DT_INT32;
1068   args[0].shape = TensorShape({2, 3});
1069 
1070   auto options = DefaultOptions();
1071   // Sets the representation function to return a non-default layout.
1072   options.shape_representation_fn =
1073       [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1074     xla::Shape xla_shape;
1075     TF_RETURN_IF_ERROR(TensorShapeToXLAShape(type, shape, &xla_shape));
1076     *xla_shape.mutable_layout() = xla::LayoutUtil::MakeLayout({0, 1});
1077     return xla_shape;
1078   };
1079 
1080   // Compiles the graph.
1081   XlaCompiler compiler(options);
1082 
1083   XlaCompiler::CompilationResult result;
1084   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "id",
1085                                      std::move(graph), args,
1086                                      /*user_aliases=*/{}, &result));
1087   xla::Shape result_shape =
1088       xla::ShapeUtil::MakeShapeWithLayout(xla::S32, {2, 3}, {0, 1});
1089 
1090   EXPECT_TRUE(xla::ShapeUtil::Equal(
1091       result.xla_output_shape,
1092       xla::ShapeUtil::MakeTupleShape({result_shape, result_shape})));
1093 }
1094 
1095 // Tests a simple graph that reads and writes a variable.
TEST_F(XlaCompilerTest,ReturnResourceHandleOnly)1096 TEST_F(XlaCompilerTest, ReturnResourceHandleOnly) {
1097   Scope scope = Scope::NewRootScope().ExitOnError();
1098   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 0);
1099   auto d = ops::_Retval(scope.WithOpName("D"), var, 0);
1100   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1101   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1102 
1103   // Builds a description of the arguments.
1104   std::vector<XlaCompiler::Argument> args(1);
1105   args[0].kind = XlaCompiler::Argument::kResource;
1106   args[0].resource_kind = XlaResource::kVariable;
1107   args[0].initialized = true;
1108   args[0].type = DT_INT32;
1109   args[0].shape = TensorShape({2});
1110 
1111   // Compiles the graph.
1112   XlaCompiler compiler(DefaultOptions());
1113 
1114   XlaCompiler::CompilationResult result;
1115   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1116                                      std::move(graph), args,
1117                                      /*user_aliases=*/{}, &result));
1118 
1119   // Tests that the generated computation works.
1120   xla::Literal param1_literal = xla::LiteralUtil::CreateR1<int32>({-3, 101});
1121   std::unique_ptr<xla::GlobalData> param1_data =
1122       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
1123 
1124   std::unique_ptr<xla::GlobalData> actual =
1125       client_->Execute(*result.computation, {param1_data.get()})
1126           .ConsumeValueOrDie();
1127   xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
1128 
1129   xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({});
1130   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1131 }
1132 
TEST_F(XlaCompilerTest,ReturnResourceHandle)1133 TEST_F(XlaCompilerTest, ReturnResourceHandle) {
1134   Scope scope = Scope::NewRootScope().ExitOnError();
1135   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1136   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1137   // Adds an identity op around the resource to make sure identity ops propagate
1138   // resources correctly.
1139   auto identity = ops::Identity(scope.WithOpName("VIdentity"), var);
1140   auto write = ops::AssignAddVariableOp(scope, identity, a);
1141   auto read = ops::ReadVariableOp(
1142       scope.WithControlDependencies(std::vector<Operation>{write}), var,
1143       DT_INT32);
1144   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1145   auto r = ops::_Retval(scope.WithOpName("R"), var, 0);
1146   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 1);
1147 
1148   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1149   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1150 
1151   // Builds a description of the arguments.
1152   std::vector<XlaCompiler::Argument> args(2);
1153   args[0].kind = XlaCompiler::Argument::kParameter;
1154   args[0].type = DT_INT32;
1155   args[0].shape = TensorShape({2});
1156   args[1].kind = XlaCompiler::Argument::kResource;
1157   args[1].resource_kind = XlaResource::kVariable;
1158   args[1].initialized = true;
1159   args[1].type = DT_INT32;
1160   args[1].shape = TensorShape({2});
1161 
1162   // Compiles the graph.
1163   XlaCompiler compiler(DefaultOptions());
1164 
1165   XlaCompiler::CompilationResult result;
1166   TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "add",
1167                                      std::move(graph), args,
1168                                      /*user_aliases=*/{}, &result));
1169   RunAndCheckVariablesComputation(client_, result);
1170 }
1171 
BuildTestGraph()1172 xla::StatusOr<std::unique_ptr<Graph>> BuildTestGraph() {
1173   Scope scope = Scope::NewRootScope().ExitOnError();
1174   auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
1175   auto var = ops::_Arg(scope.WithOpName("V"), DT_RESOURCE, 1);
1176   auto write = ops::AssignAddVariableOp(scope, var, a);
1177   auto read = ops::ReadVariableOp(
1178       scope.WithControlDependencies(std::vector<Operation>{write}), var,
1179       DT_INT32);
1180   auto read_plus_one = ops::Add(scope, read, ops::Const<int32>(scope, 1));
1181   auto d = ops::_Retval(scope.WithOpName("D"), read_plus_one, 0);
1182   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1183   TF_RETURN_IF_ERROR(scope.ToGraph(graph.get()));
1184   return std::move(graph);
1185 }
1186 
1187 // Tests a simple graph that reads and writes a variable, with a
1188 // shape_representation_fn passed to the compiler that flattens all
1189 // variable tensors to vectors.
TEST_F(XlaCompilerTest,VariableRepresentationShapeFunction)1190 TEST_F(XlaCompilerTest, VariableRepresentationShapeFunction) {
1191   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1192 
1193   // Builds a description of the arguments.
1194   std::vector<XlaCompiler::Argument> args(2);
1195   args[0].kind = XlaCompiler::Argument::kParameter;
1196   args[0].type = DT_INT32;
1197   args[0].shape = TensorShape({2, 2});
1198   args[1].kind = XlaCompiler::Argument::kResource;
1199   args[1].resource_kind = XlaResource::kVariable;
1200   args[1].initialized = true;
1201   args[1].type = DT_INT32;
1202   args[1].shape = TensorShape({2, 2});
1203 
1204   // Compiles the graph.
1205   XlaCompiler::Options options = DefaultOptions();
1206   options.shape_representation_fn =
1207       [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1208     xla::PrimitiveType ptype;
1209     TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1210     return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1211   };
1212   XlaCompiler compiler(options);
1213 
1214   XlaCompiler::CompileOptions compile_options;
1215   compile_options.is_entry_computation = false;  // Only reshape variables.
1216 
1217   XlaCompiler::CompilationResult result;
1218   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1219                                      args, /*user_aliases=*/{}, &result));
1220 
1221   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1222                           client_->GetComputationShape(*result.computation));
1223 
1224   ASSERT_EQ(program_shape->parameters_size(), 2);
1225   EXPECT_TRUE(
1226       xla::ShapeUtil::Compatible(program_shape->parameters(0),
1227                                  xla::ShapeUtil::MakeShape(xla::S32, {2, 2})));
1228   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1229       program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1230   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1231       program_shape->result(),
1232       xla::ShapeUtil::MakeTupleShape(
1233           {xla::ShapeUtil::MakeShape(xla::S32, {2, 2}),
1234            xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1235 
1236   // Tests that the generated computation works.
1237   xla::Literal param0_literal =
1238       xla::LiteralUtil::CreateR2<int32>({{4, 55}, {1, -3}});
1239   xla::Literal param1_literal =
1240       xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1241   std::unique_ptr<xla::GlobalData> param0_data =
1242       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
1243   std::unique_ptr<xla::GlobalData> param1_data =
1244       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
1245 
1246   std::unique_ptr<xla::GlobalData> actual =
1247       client_
1248           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1249           .ConsumeValueOrDie();
1250   xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
1251 
1252   xla::Literal expected0 =
1253       xla::LiteralUtil::CreateR2<int32>({{27, 67}, {35, 402}});
1254   xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1255   xla::Literal expected_literal =
1256       xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1257   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1258 }
1259 
TEST_F(XlaCompilerTest,ArgRetvalShapeRepresentationFunction)1260 TEST_F(XlaCompilerTest, ArgRetvalShapeRepresentationFunction) {
1261   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Graph> graph, BuildTestGraph());
1262 
1263   // Builds a description of the arguments.
1264   std::vector<XlaCompiler::Argument> args(2);
1265   args[0].kind = XlaCompiler::Argument::kParameter;
1266   args[0].type = DT_INT32;
1267   args[0].shape = TensorShape({2, 2});
1268   args[1].kind = XlaCompiler::Argument::kResource;
1269   args[1].resource_kind = XlaResource::kVariable;
1270   args[1].initialized = true;
1271   args[1].type = DT_INT32;
1272   args[1].shape = TensorShape({2, 2});
1273 
1274   // Compiles the graph.
1275   XlaCompiler::Options options = DefaultOptions();
1276   options.shape_representation_fn =
1277       [](const TensorShape& shape, DataType type) -> xla::StatusOr<xla::Shape> {
1278     xla::PrimitiveType ptype;
1279     TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(type, &ptype));
1280     return xla::ShapeUtil::MakeShape(ptype, {shape.num_elements()});
1281   };
1282   XlaCompiler compiler(options);
1283 
1284   XlaCompiler::CompileOptions compile_options;
1285   compile_options.is_entry_computation = true;  // Reshape args and retvals.
1286 
1287   XlaCompiler::CompilationResult result;
1288   TF_ASSERT_OK(compiler.CompileGraph(compile_options, "add", std::move(graph),
1289                                      args, /*user_aliases=*/{}, &result));
1290 
1291   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::ProgramShape> program_shape,
1292                           client_->GetComputationShape(*result.computation));
1293 
1294   ASSERT_EQ(program_shape->parameters_size(), 2);
1295   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1296       program_shape->parameters(0), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1297   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1298       program_shape->parameters(1), xla::ShapeUtil::MakeShape(xla::S32, {4})));
1299   EXPECT_TRUE(xla::ShapeUtil::Compatible(
1300       program_shape->result(),
1301       xla::ShapeUtil::MakeTupleShape(
1302           {xla::ShapeUtil::MakeShape(xla::S32, {4}),
1303            xla::ShapeUtil::MakeShape(xla::S32, {4})})));
1304 
1305   // Tests that the generated computation works.
1306   xla::Literal param0_literal =
1307       xla::LiteralUtil::CreateR1<int32>({4, 55, 1, -3});
1308   xla::Literal param1_literal =
1309       xla::LiteralUtil::CreateR1<int32>({22, 11, 33, 404});
1310   std::unique_ptr<xla::GlobalData> param0_data =
1311       client_->TransferToServer(param0_literal).ConsumeValueOrDie();
1312   std::unique_ptr<xla::GlobalData> param1_data =
1313       client_->TransferToServer(param1_literal).ConsumeValueOrDie();
1314 
1315   std::unique_ptr<xla::GlobalData> actual =
1316       client_
1317           ->Execute(*result.computation, {param0_data.get(), param1_data.get()})
1318           .ConsumeValueOrDie();
1319   xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
1320 
1321   xla::Literal expected0 = xla::LiteralUtil::CreateR1<int32>({27, 67, 35, 402});
1322   xla::Literal expected1 = xla::LiteralUtil::CreateR1<int32>({26, 66, 34, 401});
1323   xla::Literal expected_literal =
1324       xla::LiteralUtil::MakeTuple({&expected0, &expected1});
1325   EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
1326 }
1327 
1328 // Tests a graph which has a function with an invalid op.
TEST_F(XlaCompilerTest,FunctionWithInvalidOp)1329 TEST_F(XlaCompilerTest, FunctionWithInvalidOp) {
1330   XlaCompiler compiler(DefaultOptions());
1331 
1332   FunctionDefLibrary flib;
1333   FunctionDef fn = FillFn();
1334   NodeDef* node = fn.add_node_def();
1335   node->set_name("Invalid");
1336   node->set_op("InvalidOp"); /* unsupported op */
1337   node = fn.add_node_def();
1338   node->set_name("Switch");
1339   node->set_op("Switch"); /* control flow node */
1340   *flib.add_function() = fn;
1341 
1342   TF_ASSERT_OK(flib_def_->AddFunctionDef(fn));
1343 
1344   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1345 
1346   Scope scope = Scope::NewRootScope().ExitOnError();
1347   auto value = ops::Const<int32>(scope.WithOpName("value"), 1, {});
1348   auto shape = ops::Const<int32>(scope.WithOpName("shape"), {5}, {1});
1349   TF_ASSERT_OK(scope.graph()->AddFunctionLibrary(flib));
1350 
1351   NodeDef def;
1352   TF_ASSERT_OK(NodeDefBuilder("fill_fn", "FillFn", flib_def_.get())
1353                    .Input(value.name(), 0, DT_INT32)
1354                    .Input(shape.name(), 1, DT_INT32)
1355                    .Finalize(&def));
1356   Status status;
1357   Node* fill = scope.graph()->AddNode(def, &status);
1358   TF_ASSERT_OK(status);
1359   TF_ASSERT_OK(scope.DoShapeInference(fill));
1360   scope.graph()->AddEdge(value.node(), 0, fill, 0);
1361   scope.graph()->AddEdge(shape.node(), 0, fill, 1);
1362 
1363   auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0);
1364 
1365   TF_ASSERT_OK(scope.ToGraph(graph.get()));
1366 
1367   std::vector<XlaCompiler::Argument> args;
1368   XlaCompiler::CompilationResult result;
1369   status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill",
1370                                  std::move(graph), args, /*user_aliases=*/{},
1371                                  &result);
1372   ASSERT_FALSE(status.ok());
1373   EXPECT_TRUE(absl::StrContains(status.error_message(), "InvalidOp"))
1374       << status.error_message();
1375   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node fill_fn}}"))
1376       << status.error_message();
1377 }
1378 
1379 // Tests a graph which has a node with invalid data type.
TEST_F(XlaCompilerTest,NodeWithInvalidDataType)1380 TEST_F(XlaCompilerTest, NodeWithInvalidDataType) {
1381   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1382   NodeDef shape;
1383   shape.set_name("Shape");
1384   shape.set_op("Shape");
1385   (*shape.mutable_attr())["T"].set_type(DT_INT32);
1386   (*shape.mutable_attr())["out_type"].set_type(DT_BOOL); /* invalid type */
1387   Status status;
1388   Node* shape_node = graph->AddNode(shape, &status);
1389   TF_ASSERT_OK(status);
1390   graph->AddControlEdge(graph->source_node(), shape_node);
1391 
1392   std::vector<XlaCompiler::Argument> args;
1393   XlaCompiler::CompilationResult result;
1394   XlaCompiler compiler(DefaultOptions());
1395   status = compiler.CompileGraph(XlaCompiler::CompileOptions(), "invalid_type",
1396                                  std::move(graph), args, /*user_aliases=*/{},
1397                                  &result);
1398   ASSERT_FALSE(status.ok());
1399   EXPECT_TRUE(absl::StrContains(status.error_message(),
1400                                 "is not in the list of allowed values"))
1401       << status.error_message();
1402   EXPECT_TRUE(absl::StrContains(status.error_message(), "{{node Shape}}"))
1403       << status.error_message();
1404 }
1405 
TEST_F(XlaCompilerTest,SingleOpWithoutInputs)1406 TEST_F(XlaCompilerTest, SingleOpWithoutInputs) {
1407   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1408   NodeDef no_op;
1409   no_op.set_name("NoOp");
1410   no_op.set_op("NoOp");
1411   Status status;
1412   graph->AddNode(no_op, &status);
1413   TF_ASSERT_OK(status);
1414 
1415   std::vector<XlaCompiler::Argument> args;
1416   XlaCompiler compiler(DefaultOptions());
1417   // No control edge linking NoOp with source/sink.
1418   {
1419     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1420     CopyGraph(*graph, graph_copy.get());
1421     XlaCompiler::CompilationResult result;
1422     TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "NoOp",
1423                                        std::move(graph_copy), args,
1424                                        /*user_aliases=*/{}, &result));
1425   }
1426 }
1427 
1428 class DummySideEffectingOp : public XlaOpKernel {
1429  public:
DummySideEffectingOp(OpKernelConstruction * ctx)1430   explicit DummySideEffectingOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)1431   void Compile(XlaOpKernelContext* ctx) override {
1432     OP_REQUIRES_OK(ctx, ctx->compiler()->SetNodeToken(
1433                             name(), xla::CreateToken(ctx->builder())));
1434   }
1435 };
1436 
1437 REGISTER_OP("DummySideEffectingOp");
1438 
1439 REGISTER_XLA_OP(Name("DummySideEffectingOp"), DummySideEffectingOp);
1440 
TEST_F(XlaCompilerTest,TokenInputAndOutput)1441 TEST_F(XlaCompilerTest, TokenInputAndOutput) {
1442   std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
1443   NodeDef side_effecting_op;
1444   side_effecting_op.set_name("DummySideEffectingOp");
1445   side_effecting_op.set_op("DummySideEffectingOp");
1446   AddNodeAttr(kXlaTokenInputNodesAttrName,
1447               std::vector<string>{kXlaTokenArgNodeName}, &side_effecting_op);
1448   Status status;
1449   graph->AddNode(side_effecting_op, &status);
1450   TF_ASSERT_OK(status);
1451   EXPECT_TRUE(FixupSourceAndSinkEdges(graph.get()));
1452 
1453   std::vector<XlaCompiler::Argument> args(1);
1454   args[0].kind = XlaCompiler::Argument::kResource;
1455   args[0].resource_kind = XlaResource::kVariable;
1456   args[0].initialized = true;
1457   args[0].type = DT_INT32;
1458   args[0].shape = TensorShape({2, 2});
1459 
1460   {
1461     // The case for entry computation: we don't add token input/output. Instead,
1462     // we use CreateToken HLO to create the entry token.
1463     XlaCompiler::CompileOptions options;
1464     options.is_entry_computation = true;
1465     options.add_token_input_output = false;
1466     options.return_updated_values_for_all_resources = true;
1467     XlaCompiler compiler(DefaultOptions());
1468 
1469     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1470     CopyGraph(*graph, graph_copy.get());
1471     XlaCompiler::CompilationResult result;
1472     TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1473                                        args, /*user_aliases=*/{}, &result));
1474     EXPECT_EQ(result.xla_input_shapes.size(), 1);
1475     EXPECT_TRUE(result.xla_output_shape.IsTuple());
1476     EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 1);
1477   }
1478   {
1479     // The case for non-entry computation (e.g. while loop body). We add token
1480     // input/output.
1481     XlaCompiler::CompileOptions options;
1482     options.is_entry_computation = false;
1483     options.add_token_input_output = true;
1484     options.return_updated_values_for_all_resources = true;
1485     XlaCompiler compiler(DefaultOptions());
1486 
1487     std::unique_ptr<Graph> graph_copy(new Graph(OpRegistry::Global()));
1488     CopyGraph(*graph, graph_copy.get());
1489     XlaCompiler::CompilationResult result;
1490     TF_ASSERT_OK(compiler.CompileGraph(options, "NoOp", std::move(graph_copy),
1491                                        args, /*user_aliases=*/{}, &result));
1492     EXPECT_EQ(result.xla_input_shapes.size(), 2);
1493     EXPECT_TRUE(result.xla_input_shapes[1].IsToken());
1494     EXPECT_TRUE(result.xla_output_shape.IsTuple());
1495     EXPECT_EQ(xla::ShapeUtil::TupleElementCount(result.xla_output_shape), 2);
1496     EXPECT_TRUE(xla::ShapeUtil::GetTupleElementShape(result.xla_output_shape, 1)
1497                     .IsToken());
1498   }
1499 }
1500 
1501 }  // namespace
1502 }  // namespace tensorflow
1503