1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/cc/client/client_session.h"
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/local_client.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/literal.h"
32 #include "tensorflow/compiler/xla/literal_util.h"
33 #include "tensorflow/compiler/xla/service/platform_util.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/xla_data.pb.h"
36 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
37 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
38 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
39 #include "tensorflow/compiler/xrt/xrt.pb.h"
40 #include "tensorflow/core/framework/tensor.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/lib/core/status_test_util.h"
43 #include "tensorflow/core/lib/gtl/array_slice.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/util/command_line_flags.h"
46 
47 namespace tensorflow {
48 namespace {
49 
50 string* xla_test_device_ptr;  // initial value set in main()
51 string* xla_platform_ptr;     // initial value set in main()
52 
DeviceFromFlag()53 string DeviceFromFlag() {
54   string xla_test_device = *xla_test_device_ptr;
55   return absl::StrCat("/device:", xla_test_device, ":0");
56 }
57 
GetAttrLayout(absl::Span<const int64> minor_to_mayor)58 std::vector<int> GetAttrLayout(absl::Span<const int64> minor_to_mayor) {
59   std::vector<int> layout;
60   for (auto dim : minor_to_mayor) {
61     layout.push_back(static_cast<int>(dim));
62   }
63   return layout;
64 }
65 
TwoElementTuple()66 xla::LiteralProto TwoElementTuple() {
67   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
68   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
69   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
70   return tuple.ToProto();
71 }
72 
ScalarLiteral()73 xla::LiteralProto ScalarLiteral() {
74   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
75   return scalar.ToProto();
76 }
77 
NestedTuple()78 xla::LiteralProto NestedTuple() {
79   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
80   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
81   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
82   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
83   auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
84   return nested.ToProto();
85 }
86 
MakeTuple0()87 xla::LiteralProto MakeTuple0() {
88   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
89   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
90   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
91   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
92   auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
93   auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
94   return nested1.ToProto();
95 }
96 
FloatVector(absl::Span<const float> v)97 xla::LiteralProto FloatVector(absl::Span<const float> v) {
98   auto array = xla::LiteralUtil::CreateR1<float>(v);
99   return array.ToProto();
100 }
101 
FloatMatrix(std::initializer_list<std::initializer_list<float>> v,const xla::Layout & layout)102 xla::LiteralProto FloatMatrix(
103     std::initializer_list<std::initializer_list<float>> v,
104     const xla::Layout& layout) {
105   auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
106   return array.ToProto();
107 }
108 
ReadOutputLiteral(const std::vector<Tensor> & outputs,size_t idx)109 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
110   xla::LiteralProto response;
111   CHECK(response.ParseFromString(outputs[idx].scalar<string>()()));
112   return xla::Literal::CreateFromProto(response).ValueOrDie();
113 }
114 
CompareLiteralProtos(const xla::LiteralProto & a,const xla::LiteralProto & b)115 bool CompareLiteralProtos(const xla::LiteralProto& a,
116                           const xla::LiteralProto& b) {
117   auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
118   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
119   bool equal = l_a == l_b;
120   if (!equal) {
121     LOG(INFO) << "LiteralProtos don't match:\n"
122               << a.DebugString() << "\n!=\n"
123               << b.DebugString();
124   }
125   return equal;
126 }
127 
CompareLiteralToLiteralProto(const xla::Literal & a,const xla::LiteralProto & b)128 bool CompareLiteralToLiteralProto(const xla::Literal& a,
129                                   const xla::LiteralProto& b) {
130   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
131   bool equal = a == l_b;
132   if (!equal) {
133     LOG(INFO) << "Literal and LiteralProto don't match:\n"
134               << a.ToProto().DebugString() << "\n!=\n"
135               << b.DebugString();
136   }
137   return equal;
138 }
139 
CompareLiterals(const xla::Literal & a,const xla::Literal & b)140 bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) {
141   bool equal = a == b;
142   if (!equal) {
143     LOG(INFO) << "Literals don't match:\n"
144               << a.ToProto().DebugString() << "\n!=\n"
145               << b.ToProto().DebugString();
146   }
147   return equal;
148 }
149 
OnePlusTwo()150 xla::XlaComputation OnePlusTwo() {
151   xla::XlaBuilder builder("OnePlusTwo");
152   auto c0 = xla::ConstantR0(&builder, 1.0f);
153   auto c1 = xla::ConstantR0(&builder, 2.0f);
154   xla::Add(c0, c1);
155   return builder.Build().ValueOrDie();
156 }
157 
AddAndScale()158 xla::XlaComputation AddAndScale() {
159   xla::XlaBuilder builder("AddAndScale");
160   auto p0 = xla::Parameter(&builder, 0,
161                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
162   auto p1 = xla::Parameter(&builder, 1,
163                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
164   auto sum = xla::Add(p0, p1);
165   auto c = xla::ConstantR0<float>(&builder, 3.0f);
166   xla::Mul(sum, c);
167   return builder.Build().ValueOrDie();
168 }
169 
Dot()170 xla::XlaComputation Dot() {
171   xla::XlaBuilder builder("Dot");
172   auto p0 = xla::Parameter(
173       &builder, 0,
174       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
175   auto p1 = xla::Parameter(
176       &builder, 1,
177       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
178   xla::DotDimensionNumbers ddn;
179   ddn.add_lhs_contracting_dimensions(1);
180   ddn.add_rhs_contracting_dimensions(0);
181   xla::DotGeneral(p0, p1, ddn);
182   return builder.Build().ValueOrDie();
183 }
184 
AddS64()185 xla::XlaComputation AddS64() {
186   xla::XlaBuilder builder("AddS64");
187   auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
188                            "P0");
189   auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}),
190                            "P1");
191   xla::Add(p0, p1);
192   return builder.Build().ValueOrDie();
193 }
194 
AddAndTuple()195 xla::XlaComputation AddAndTuple() {
196   xla::XlaBuilder builder("AddAndTuple");
197   auto p0 = xla::Parameter(&builder, 0,
198                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
199   auto p1 = xla::Parameter(&builder, 1,
200                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
201   auto sum = xla::Add(p0, p1);
202   xla::Tuple(&builder, {sum});
203   return builder.Build().ValueOrDie();
204 }
205 
AddAndSubTuple()206 xla::XlaComputation AddAndSubTuple() {
207   xla::XlaBuilder builder("AddAndSubTuple");
208   auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}),
209                            "P0");
210   auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}),
211                            "P1");
212   auto sum = xla::Add(p0, p1);
213   auto sub = xla::Sub(p0, p1);
214   xla::Tuple(&builder, {sum, sub});
215   return builder.Build().ValueOrDie();
216 }
217 
StoreComputationSnapshot(const xla::XlaComputation & computation,xla::HloSnapshot * dst)218 void StoreComputationSnapshot(const xla::XlaComputation& computation,
219                               xla::HloSnapshot* dst) {
220   auto snapshot = computation.Snapshot().ValueOrDie();
221   *dst = *snapshot;
222 }
223 
XlaCompiledProgramShape(const xla::XlaComputation & computation,const xla::ProgramShape & input_program_shape)224 xla::ProgramShape XlaCompiledProgramShape(
225     const xla::XlaComputation& computation,
226     const xla::ProgramShape& input_program_shape) {
227   se::Platform* platform =
228       xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
229   xla::LocalClient* client =
230       xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
231   xla::ExecutableBuildOptions exec_options;
232   exec_options.set_result_layout(input_program_shape.result());
233   std::vector<const xla::Shape*> parameters_shapes;
234   for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
235     parameters_shapes.push_back(&input_program_shape.parameters(i));
236   }
237   auto local_executable =
238       client->Compile(computation, parameters_shapes, exec_options)
239           .ValueOrDie();
240   return local_executable->executable()
241       ->module()
242       .entry_computation()
243       ->ComputeProgramShape();
244 }
245 
TEST(RawApiTest,AllocFromTensor)246 TEST(RawApiTest, AllocFromTensor) {
247   xla::Literal literal =
248       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
249   Tensor tensor;
250   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
251 
252   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
253   std::vector<int> layout =
254       GetAttrLayout(literal.shape().layout().minor_to_major());
255   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
256       ops::XRTAllocateFromTensor::Layouts(layout);
257   auto handle =
258       ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
259   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
260   TF_ASSERT_OK(root.status());
261 
262   ClientSession session(root);
263   std::vector<Tensor> outputs;
264   TF_EXPECT_OK(session.Run({read_back}, &outputs));
265   EXPECT_EQ(outputs.size(), 1);
266 
267   xla::LiteralProto response;
268   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
269   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
270 }
271 
TEST(RawApiTest,AllocFromTensorTuple)272 TEST(RawApiTest, AllocFromTensorTuple) {
273   xla::Literal literal0 =
274       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
275   xla::Literal literal1 =
276       xla::LiteralUtil::CreateR2<float>({{14.0f, -5.0f}, {16.0f, 17.0f}});
277   xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
278   Tensor tensor0;
279   TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
280   Tensor tensor1;
281   TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1));
282 
283   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
284   std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
285   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
286       ops::XRTAllocateFromTensor::Layouts(layout);
287   auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1},
288                                            {tensor0.shape(), tensor1.shape()},
289                                            alloc_attrs);
290   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
291   TF_ASSERT_OK(root.status());
292 
293   ClientSession session(root);
294   std::vector<Tensor> outputs;
295   TF_EXPECT_OK(session.Run({read_back}, &outputs));
296   EXPECT_EQ(outputs.size(), 1);
297 
298   xla::LiteralProto response;
299   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
300   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
301 }
302 
TEST(RawApiTest,AllocFromTensorTupleSingle)303 TEST(RawApiTest, AllocFromTensorTupleSingle) {
304   xla::Literal literal0 =
305       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
306   xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0});
307   Tensor tensor0;
308   TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
309 
310   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
311   std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
312   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
313       ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true);
314   auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()},
315                                            alloc_attrs);
316   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
317   TF_ASSERT_OK(root.status());
318 
319   ClientSession session(root);
320   std::vector<Tensor> outputs;
321   TF_EXPECT_OK(session.Run({read_back}, &outputs));
322   EXPECT_EQ(outputs.size(), 1);
323 
324   xla::LiteralProto response;
325   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
326   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
327 }
328 
TEST(RawApiTest,AllocFromTensorRelayout)329 TEST(RawApiTest, AllocFromTensorRelayout) {
330   xla::Literal literal =
331       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
332   Tensor tensor;
333   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
334 
335   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
336   // Use inverse array layout with the tensor data above.
337   std::vector<int> layout({0, 1});
338   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
339       ops::XRTAllocateFromTensor::Layouts(layout);
340   auto handle =
341       ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
342   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
343   TF_ASSERT_OK(root.status());
344 
345   ClientSession session(root);
346   std::vector<Tensor> outputs;
347   TF_EXPECT_OK(session.Run({read_back}, &outputs));
348   EXPECT_EQ(outputs.size(), 1);
349 
350   xla::LiteralProto response;
351   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
352   // We have sent literal's data (in array layout) with a attribute layout
353   // {0,1}, so the expected literal read from device needs to be changed
354   // accordingly.
355   xla::Literal expected_literal =
356       xla::LiteralUtil::CreateR2<float>({{4.0f, 6.0f}, {5.0f, 7.0f}});
357   EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response));
358 }
359 
TEST(RawApiTest,AllocAndRewrite)360 TEST(RawApiTest, AllocAndRewrite) {
361   xrt::XLAAllocation alloc;
362   *alloc.mutable_value() =
363       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
364 
365   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
366   auto value =
367       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
368   auto handle = ops::XRTAllocate(root, value);
369   auto read_back = ops::XRTReadLiteral(root, handle);
370   TF_ASSERT_OK(root.status());
371 
372   tensorflow::ClientSession session(root);
373   std::vector<tensorflow::Tensor> outputs;
374   TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
375   EXPECT_EQ(outputs.size(), 2);
376 
377   int64 allocation_handle = outputs[1].scalar<int64>()();
378   xla::LiteralProto response;
379   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
380   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
381   outputs.clear();
382 
383   xla::LiteralProto new_literal =
384       xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
385   auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
386                               new_literal.SerializeAsString());
387   auto write_op =
388       ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
389   TF_ASSERT_OK(root.status());
390   TF_EXPECT_OK(session.Run({write_op}, &outputs));
391   EXPECT_EQ(outputs.size(), 1);
392   EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
393   outputs.clear();
394 
395   auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
396   TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
397   EXPECT_EQ(outputs.size(), 1);
398 
399   xla::LiteralProto new_response;
400   EXPECT_TRUE(new_response.ParseFromString(outputs[0].scalar<string>()()));
401   EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
402 
403   Tensor release_tensor(DT_INT64, TensorShape({1}));
404   release_tensor.flat<int64>()(0) = allocation_handle;
405 
406   auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
407   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
408                            &outputs));
409 }
410 
TEST(RawApiTest,AllocReleaseMany)411 TEST(RawApiTest, AllocReleaseMany) {
412   xrt::XLAAllocation alloc1;
413   *alloc1.mutable_value() =
414       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
415   xrt::XLAAllocation alloc2;
416   *alloc2.mutable_value() =
417       xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto();
418 
419   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
420   auto value1 =
421       ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString());
422   auto value2 =
423       ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString());
424   auto handle1 = ops::XRTAllocate(root, value1);
425   auto handle2 = ops::XRTAllocate(root, value2);
426   TF_ASSERT_OK(root.status());
427 
428   tensorflow::ClientSession session(root);
429   std::vector<tensorflow::Tensor> outputs;
430   TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs));
431   EXPECT_EQ(outputs.size(), 2);
432 
433   int64 allocation_handle1 = outputs[0].scalar<int64>()();
434   int64 allocation_handle2 = outputs[1].scalar<int64>()();
435 
436   Tensor release_tensor(DT_INT64, TensorShape({2}));
437   release_tensor.flat<int64>()(0) = allocation_handle1;
438   release_tensor.flat<int64>()(1) = allocation_handle2;
439 
440   auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
441   outputs.clear();
442   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
443                            &outputs));
444 }
445 
TEST(RawApiTest,CompileAndReleaseMany)446 TEST(RawApiTest, CompileAndReleaseMany) {
447   xrt::XLAComputation c1;
448   auto config1 = c1.mutable_config();
449   auto shapes1 = config1->mutable_program_shape();
450   *shapes1->add_parameters() =
451       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
452   *shapes1->add_parameters() =
453       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
454   *shapes1->mutable_result() =
455       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
456   StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot());
457 
458   xrt::XLAComputation c2;
459   auto config2 = c2.mutable_config();
460   auto shapes2 = config2->mutable_program_shape();
461   *shapes2->add_parameters() =
462       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
463   *shapes2->add_parameters() =
464       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
465   *shapes2->mutable_result() =
466       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
467           .ToProto();
468   StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot());
469 
470   xrt::XRTExecutionConfig e;
471   e.set_release_input_handles(true);
472   e.set_release_compilation_handle(false);
473 
474   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
475   auto e_config =
476       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
477   auto computation1 =
478       ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString());
479   auto c_handle1 = ops::XRTCompile(root, computation1);
480   auto computation2 =
481       ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString());
482   auto c_handle2 = ops::XRTCompile(root, computation2);
483   TF_ASSERT_OK(root.status());
484 
485   ClientSession session(root);
486   std::vector<Tensor> outputs;
487   TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs));
488   EXPECT_EQ(outputs.size(), 2);
489 
490   int64 compilation_handle1 = outputs[0].scalar<int64>()();
491   int64 compilation_handle2 = outputs[1].scalar<int64>()();
492 
493   Tensor release_tensor(DT_INT64, TensorShape({2}));
494   release_tensor.flat<int64>()(0) = compilation_handle1;
495   release_tensor.flat<int64>()(1) = compilation_handle2;
496 
497   auto release = ops::XRTReleaseCompilationHandle(root, release_tensor);
498   outputs.clear();
499   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {}, {release},
500                            &outputs));
501 }
502 
TEST(RawApiTest,AllocAndClearAll)503 TEST(RawApiTest, AllocAndClearAll) {
504   xrt::XLAAllocation alloc;
505   *alloc.mutable_value() =
506       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
507 
508   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
509   auto value =
510       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
511   auto handle = ops::XRTAllocate(root, value);
512   TF_ASSERT_OK(root.status());
513 
514   tensorflow::ClientSession session(root);
515   std::vector<tensorflow::Tensor> outputs;
516   TF_EXPECT_OK(session.Run({handle}, &outputs));
517   EXPECT_EQ(outputs.size(), 1);
518 
519   int64 allocation_handle = outputs[0].scalar<int64>()();
520 
521   auto clear_all = ops::XRTReleaseAllAllocations(root);
522 
523   outputs.clear();
524   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {},
525                            {clear_all}, &outputs));
526   EXPECT_EQ(outputs.size(), 0);
527 
528   auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle));
529   EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(),
530             tensorflow::error::Code::NOT_FOUND);
531 }
532 
TEST(RawApiTest,ReadAndWriteState)533 TEST(RawApiTest, ReadAndWriteState) {
534   xrt::XLAAllocation alloc;
535   *alloc.mutable_value() = TwoElementTuple();
536 
537   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
538   auto value =
539       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
540   auto handle = ops::XRTAllocate(root, value);
541   auto read_back = ops::XRTReadLiteral(root, handle);
542   auto release = ops::XRTReleaseAllocationHandle(
543       root.WithControlDependencies(read_back), handle);
544   TF_ASSERT_OK(root.status());
545 
546   tensorflow::ClientSession session(root);
547   std::vector<tensorflow::Tensor> outputs;
548   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
549                            {release}, &outputs));
550 
551   xla::LiteralProto response;
552   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
553 
554   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
555 }
556 
TEST(RawApiTest,ReadAndWriteStateAutoFree)557 TEST(RawApiTest, ReadAndWriteStateAutoFree) {
558   xrt::XLAAllocation alloc;
559   *alloc.mutable_value() = TwoElementTuple();
560 
561   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
562   auto value =
563       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
564   auto handle = ops::XRTAllocate(root, value);
565   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
566   TF_ASSERT_OK(root.status());
567 
568   ClientSession session(root);
569   std::vector<Tensor> outputs;
570   TF_EXPECT_OK(session.Run({read_back}, &outputs));
571 
572   xla::LiteralProto response;
573   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
574   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
575 }
576 
TEST(RawApiTest,SubBuffer)577 TEST(RawApiTest, SubBuffer) {
578   xrt::XLAAllocation alloc;
579   *alloc.mutable_value() = NestedTuple();
580 
581   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
582   auto value =
583       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
584   auto base_handle = ops::XRTAllocate(root, value);
585   auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
586   auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
587   auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
588   auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
589   auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
590   auto sub_00 = ops::XRTSubTupleAndRelease(
591       root.WithControlDependencies(
592           {sub_0.output_handle.op(), sub_1.output_handle.op()}),
593       base_handle, index_00);
594   auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
595   auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
596   auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
597   TF_ASSERT_OK(root.status());
598 
599   ClientSession session(root);
600   std::vector<Tensor> outputs;
601   TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
602 
603   auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
604   auto base_elements = base_literal.DecomposeTuple();
605   auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
606   xla::LiteralProto response_0;
607   EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
608   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
609   xla::LiteralProto response_1;
610   EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
611   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
612   xla::LiteralProto response_00;
613   EXPECT_TRUE(response_00.ParseFromString(outputs[2].scalar<string>()()));
614   EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
615 }
616 
TEST(RawApiTest,MakeTuple)617 TEST(RawApiTest, MakeTuple) {
618   xrt::XLAAllocation alloc_0;
619   *alloc_0.mutable_value() = TwoElementTuple();
620   xrt::XLAAllocation alloc_1;
621   *alloc_1.mutable_value() = ScalarLiteral();
622 
623   // The trivial tuple that just forwards its input and releases it.
624   xrt::XLATupleNode desc_0;
625   desc_0.set_input_index(0);
626   desc_0.set_release_input_handle(true);
627 
628   xrt::XLATupleNode desc_1;
629   auto subdesc_10 = desc_1.add_tuples();
630   auto subdesc_11 = desc_1.add_tuples();
631   subdesc_10->set_input_index(0);
632   auto subdesc_110 = subdesc_11->add_tuples();
633   subdesc_110->set_input_index(0);
634   auto subdesc_111 = subdesc_11->add_tuples();
635   subdesc_111->set_input_index(1);
636 
637   xrt::XLATupleNode desc_2;
638   auto subdesc_20 = desc_2.add_tuples();
639   auto subdesc_21 = desc_2.add_tuples();
640   subdesc_20->set_input_index(1);
641   subdesc_20->set_release_input_handle(true);
642   subdesc_21->set_input_index(0);
643   subdesc_21->set_release_input_handle(true);
644 
645   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
646   auto value_0 =
647       ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
648   auto handle_0 = ops::XRTAllocate(root, value_0);
649   auto value_1 =
650       ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
651   auto handle_1 = ops::XRTAllocate(root, value_1);
652   auto tuple_0 =
653       ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
654   auto handle_2 =
655       ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
656   // handle_0 has now been released.
657   auto tuple_1 =
658       ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
659   auto handle_3 = ops::XRTMakeTuple(
660       root, tuple_1,
661       {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
662   auto tuple_2 =
663       ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
664   // Make sure this runs after handle_3 has completed, since it will free
665   // handle_1 and handle_2.
666   auto handle_4 = ops::XRTMakeTuple(
667       root.WithControlDependencies(handle_3), tuple_2,
668       {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
669   // handle_1 and handle_2 have now been released.
670 
671   auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
672   auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
673   TF_ASSERT_OK(root.status());
674 
675   ClientSession session(root);
676   std::vector<Tensor> outputs;
677   TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
678   xla::LiteralProto response_0;
679   EXPECT_TRUE(response_0.ParseFromString(outputs[0].scalar<string>()()));
680   xla::LiteralProto response_1;
681   EXPECT_TRUE(response_1.ParseFromString(outputs[1].scalar<string>()()));
682 
683   auto expected_0 = MakeTuple0();
684   EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
685   auto expected_1 = NestedTuple();
686   EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
687 }
688 
TEST(RawApiTest,CompileAndExecute)689 TEST(RawApiTest, CompileAndExecute) {
690   xrt::XLAAllocation p0;
691   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
692   xrt::XLAAllocation p1;
693   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
694 
695   xrt::XLAComputation c;
696   auto config = c.mutable_config();
697   auto shapes = config->mutable_program_shape();
698   *shapes->add_parameters() =
699       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
700   *shapes->add_parameters() =
701       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
702   *shapes->mutable_result() =
703       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
704   StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
705 
706   xrt::XRTExecutionConfig e;
707   e.set_release_input_handles(true);
708   e.set_release_compilation_handle(true);
709 
710   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
711   auto e_config =
712       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
713   auto computation =
714       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
715   auto c_handle = ops::XRTCompile(root, computation);
716   auto p0_value =
717       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
718   auto p0_handle = ops::XRTAllocate(root, p0_value);
719   auto p1_value =
720       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
721   auto p1_handle = ops::XRTAllocate(root, p1_value);
722   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
723                                 {Output(p0_handle), Output(p1_handle)});
724   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
725   TF_ASSERT_OK(root.status());
726 
727   ClientSession session(root);
728   std::vector<Tensor> outputs;
729   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
730 
731   xla::LiteralProto response;
732   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
733 
734   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
735   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
736 
737   xla::ProgramShapeProto program_shape;
738   EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
739   EXPECT_EQ(program_shape.parameters_size(), 2);
740 }
741 
TEST(RawApiTest,CompileAndExecuteWithArgumentVector)742 TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
743   xrt::XLAAllocation p0;
744   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
745   xrt::XLAAllocation p1;
746   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
747 
748   xrt::XLAComputation c;
749   auto config = c.mutable_config();
750   auto shapes = config->mutable_program_shape();
751   *shapes->add_parameters() =
752       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
753   *shapes->add_parameters() =
754       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
755   *shapes->mutable_result() =
756       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
757   StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
758 
759   xrt::XRTExecutionConfig e;
760   e.set_release_input_handles(true);
761   e.set_release_compilation_handle(true);
762 
763   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
764   auto e_config =
765       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
766   auto computation =
767       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
768   auto c_handle = ops::XRTCompile(root, computation);
769   auto p0_value =
770       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
771   auto p0_handle = ops::XRTAllocate(root, p0_value);
772   auto p1_value =
773       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
774   auto p1_handle = ops::XRTAllocate(root, p1_value);
775   auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
776                                 {Output(p0_handle), Output(p1_handle)});
777   auto result =
778       ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
779   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
780   TF_ASSERT_OK(root.status());
781 
782   ClientSession session(root);
783   std::vector<Tensor> outputs;
784   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
785 
786   xla::LiteralProto response;
787   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
788 
789   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
790   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
791 
792   xla::ProgramShapeProto program_shape;
793   EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
794   EXPECT_EQ(program_shape.parameters_size(), 2);
795 }
796 
TEST(RawApiTest,CompileWithXlaReturnShapes)797 TEST(RawApiTest, CompileWithXlaReturnShapes) {
798   xla::XlaBuilder builder("XrtXlaShapes");
799   auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
800   auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
801   // Clear layouts to signal XLA we are ready to get whatever are coming out of
802   // the compilation process.
803   xla::LayoutUtil::ClearLayout(&input_shape);
804   xla::LayoutUtil::ClearLayout(&kernel_shape);
805   auto param_shape =
806       xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
807   auto param = xla::Parameter(&builder, 0, param_shape, "param");
808   auto input = xla::GetTupleElement(param, 0);
809   auto kernel = xla::GetTupleElement(param, 1);
810   xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
811   TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
812 
813   auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
814   // Clear the result shape layout to tell XLA we are accepting whatever are
815   // coming out of the compilation process.
816   xla::LayoutUtil::ClearLayout(&result_shape);
817 
818   xrt::XLAComputation c;
819   auto config = c.mutable_config();
820   auto shapes = config->mutable_program_shape();
821   *shapes->add_parameters() = param_shape.ToProto();
822   *shapes->mutable_result() = result_shape.ToProto();
823   StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
824 
825   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
826   auto computation =
827       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
828   auto c_handle = ops::XRTCompile(root, computation);
829   auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
830   TF_ASSERT_OK(root.status());
831 
832   ClientSession session(root);
833   std::vector<Tensor> outputs;
834   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
835                            {c_handle.program_shape}, {release}, &outputs));
836 
837   xla::ProgramShapeProto program_shape_proto;
838   EXPECT_TRUE(program_shape_proto.ParseFromString(outputs[0].vec<string>()(0)));
839   xla::ProgramShape program_shape(program_shape_proto);
840   EXPECT_EQ(program_shape.parameters_size(), 1);
841 
842   VLOG(2) << "Param: "
843           << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
844   VLOG(2) << "Result: "
845           << xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
846 
847   xla::ProgramShape xla_program_shape =
848       XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
849   EXPECT_TRUE(xla::LayoutUtil::Equal(
850       xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
851       xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
852           .layout()));
853   EXPECT_TRUE(xla::LayoutUtil::Equal(
854       xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
855       xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
856           .layout()));
857   EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(),
858                                      xla_program_shape.result().layout()));
859 }
860 
TEST(RawApiTest,DotGeneralWithLayoutTest)861 TEST(RawApiTest, DotGeneralWithLayoutTest) {
862   auto layout = xla::LayoutUtil::MakeLayout({0, 1});
863 
864   xrt::XLAAllocation p0;
865   *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
866   xrt::XLAAllocation p1;
867   *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
868 
869   xrt::XLAComputation c;
870   auto config = c.mutable_config();
871   auto shapes = config->mutable_program_shape();
872   *shapes->add_parameters() =
873       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
874   *shapes->add_parameters() =
875       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
876   *shapes->mutable_result() =
877       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
878   StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
879 
880   xrt::XRTExecutionConfig e;
881   e.set_release_input_handles(true);
882   e.set_release_compilation_handle(true);
883 
884   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
885   auto e_config =
886       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
887   auto computation =
888       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
889   auto c_handle = ops::XRTCompile(root, computation);
890   auto p0_value =
891       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
892   auto p0_handle = ops::XRTAllocate(root, p0_value);
893   auto p1_value =
894       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
895   auto p1_handle = ops::XRTAllocate(root, p1_value);
896   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
897                                 {Output(p0_handle), Output(p1_handle)});
898   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
899   TF_ASSERT_OK(root.status());
900 
901   ClientSession session(root);
902   std::vector<Tensor> outputs;
903   TF_EXPECT_OK(session.Run({read_back}, &outputs));
904 
905   xla::LiteralProto response;
906   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
907 
908   auto expected =
909       xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
910 
911   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
912 }
913 
TEST(RawApiTest,CompileAndExecuteZeroArg)914 TEST(RawApiTest, CompileAndExecuteZeroArg) {
915   xrt::XLAComputation c;
916   auto config = c.mutable_config();
917   auto shapes = config->mutable_program_shape();
918   *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
919 
920   xrt::XRTExecutionConfig e;
921   e.set_release_input_handles(true);
922   e.set_release_compilation_handle(true);
923   StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
924 
925   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
926   auto e_config =
927       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
928   auto computation =
929       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
930   auto c_handle = ops::XRTCompile(root, computation);
931   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
932                                 std::initializer_list<Input>({}));
933   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
934   TF_ASSERT_OK(root.status());
935 
936   ClientSession session(root);
937   std::vector<Tensor> outputs;
938   TF_EXPECT_OK(session.Run({read_back}, &outputs));
939 
940   xla::LiteralProto response;
941   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
942 
943   auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
944   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
945 }
946 
TEST(RawApiTest,CompileAndExecuteReturnTuple)947 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
948   xrt::XLAAllocation p0;
949   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
950   xrt::XLAAllocation p1;
951   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
952 
953   xrt::XLAComputation c;
954   auto config = c.mutable_config();
955   auto shapes = config->mutable_program_shape();
956   *shapes->add_parameters() =
957       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
958   *shapes->add_parameters() =
959       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
960   *shapes->mutable_result() =
961       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
962           .ToProto();
963   StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
964 
965   xrt::XRTExecutionConfig e;
966   e.set_release_input_handles(true);
967   e.set_release_compilation_handle(true);
968 
969   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
970   auto e_config =
971       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
972   auto computation =
973       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
974   auto c_handle = ops::XRTCompile(root, computation);
975   auto p0_value =
976       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
977   auto p0_handle = ops::XRTAllocate(root, p0_value);
978   auto p1_value =
979       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
980   auto p1_handle = ops::XRTAllocate(root, p1_value);
981   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
982                                 {Output(p0_handle), Output(p1_handle)});
983   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
984   TF_ASSERT_OK(root.status());
985 
986   ClientSession session(root);
987   std::vector<Tensor> outputs;
988   TF_EXPECT_OK(session.Run({read_back}, &outputs));
989 
990   xla::LiteralProto response;
991   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
992 
993   auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
994   auto expected = xla::LiteralUtil::MakeTuple({&sum});
995   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
996 }
997 
TEST(RawApiTest,CompileAndExecuteReturnExplodedTuple)998 TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) {
999   xrt::XLAAllocation p0;
1000   *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(12.0f).ToProto();
1001 
1002   xrt::XLAAllocation p1;
1003   *p1.mutable_value() = xla::LiteralUtil::CreateR0<float>(3.0f).ToProto();
1004 
1005   xrt::XLAComputation c;
1006   auto config = c.mutable_config();
1007   auto shapes = config->mutable_program_shape();
1008   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1009   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1010   *shapes->mutable_result() =
1011       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
1012                                       xla::ShapeUtil::MakeShape(xla::F32, {})})
1013           .ToProto();
1014   StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot());
1015 
1016   xrt::XRTExecutionConfig e;
1017   e.set_release_input_handles(true);
1018   e.set_release_compilation_handle(true);
1019   e.set_return_exploded_tuple(true);
1020 
1021   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1022   auto e_config =
1023       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1024   auto computation =
1025       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1026   auto c_handle = ops::XRTCompile(root, computation);
1027   auto p0_value =
1028       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1029   auto p0_handle = ops::XRTAllocate(root, p0_value);
1030   auto p1_value =
1031       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1032   auto p1_handle = ops::XRTAllocate(root, p1_value);
1033   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1034                                 {Output(p0_handle), Output(p1_handle)});
1035   TF_ASSERT_OK(root.status());
1036 
1037   ClientSession session(root);
1038   std::vector<Tensor> outputs;
1039   TF_EXPECT_OK(session.Run({result}, &outputs));
1040   EXPECT_EQ(outputs.size(), 1);
1041 
1042   auto handles_vec = outputs.front().vec<int64>();
1043   EXPECT_EQ(handles_vec.size(), 2);
1044 
1045   const float kResults[2] = {15.0f, 9.0f};
1046   for (int64 i = 0; i < handles_vec.size(); ++i) {
1047     auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i)));
1048     std::vector<Tensor> voutputs;
1049     TF_EXPECT_OK(session.Run({read_back}, &voutputs));
1050     EXPECT_EQ(voutputs.size(), 1);
1051 
1052     xla::LiteralProto response;
1053     EXPECT_TRUE(response.ParseFromString(voutputs[0].scalar<string>()()));
1054 
1055     auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
1056     EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1057   }
1058 }
1059 
TEST(RawApiTest,LeakCompilationReference)1060 TEST(RawApiTest, LeakCompilationReference) {
1061   xrt::XLAComputation c;
1062   auto config = c.mutable_config();
1063   auto shapes = config->mutable_program_shape();
1064   *shapes->add_parameters() =
1065       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1066   *shapes->add_parameters() =
1067       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1068   *shapes->mutable_result() =
1069       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1070           .ToProto();
1071   StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1072 
1073   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1074   auto computation =
1075       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1076   auto c_handle = ops::XRTCompile(root, computation);
1077   TF_ASSERT_OK(root.status());
1078 
1079   ClientSession session(root);
1080   std::vector<Tensor> outputs;
1081   TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
1082 }
1083 
TEST(RawApiTest,CompileAndExecuteWithReusedBuffers)1084 TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
1085   xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2});
1086   xla::Shape shape =
1087       xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1088   xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1089       {element_shape, element_shape, element_shape, element_shape});
1090   xla::XlaBuilder builder("ReuseBuffer");
1091   auto param = xla::Parameter(&builder, 0, shape, "param");
1092   auto p0 = xla::GetTupleElement(param, 0);
1093   auto p1 = xla::GetTupleElement(param, 1);
1094   auto add = xla::Add(p0, p1);
1095   auto sub = xla::Sub(p0, p1);
1096   xla::Tuple(&builder, {add, sub, p0, p1});
1097 
1098   // Flip the tuple literals in the input handle.
1099   builder.SetUpAlias({1}, 0, {0});
1100   builder.SetUpAlias({0}, 0, {1});
1101 
1102   auto computation = builder.Build().ValueOrDie();
1103 
1104   auto literal0 = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f});
1105   auto literal1 = xla::LiteralUtil::CreateR1<float>({5.0f, 9.0f});
1106   auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1107 
1108   xrt::XLAAllocation param_alloc;
1109   *param_alloc.mutable_value() = literal.ToProto();
1110 
1111   xrt::XLAComputation c;
1112   auto config = c.mutable_config();
1113   auto shapes = config->mutable_program_shape();
1114   *shapes->add_parameters() = shape.ToProto();
1115   *shapes->mutable_result() = return_shape.ToProto();
1116   StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1117 
1118   xrt::XRTExecutionConfig e;
1119   e.set_release_input_handles(false);
1120   e.set_release_compilation_handle(true);
1121 
1122   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1123   ClientSession session(root);
1124   auto e_config =
1125       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1126   auto c_data =
1127       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1128   auto c_handle = ops::XRTCompile(root, c_data);
1129   auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1130                                 param_alloc.SerializeAsString());
1131   auto param_handle = ops::XRTAllocate(root, param_value);
1132   TF_ASSERT_OK(root.status());
1133 
1134   std::vector<Tensor> outputs;
1135   TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1136 
1137   int64 alloc_handle = outputs[0].scalar<int64>()();
1138 
1139   // Note that we release the result handle immediately, but since we aliased
1140   // the output buffers onto the input allocation ones (held in alloc_handle),
1141   // we can fetch the result from there.
1142   auto result =
1143       ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1144   auto read_back = ops::XRTReadLiteral(root, result);
1145   auto release = ops::XRTReleaseAllocationHandle(
1146       root.WithControlDependencies(read_back), result);
1147   TF_ASSERT_OK(root.status());
1148 
1149   outputs.clear();
1150   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_back},
1151                            {release}, &outputs));
1152 
1153   xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1154   auto exec_literal_parts = exec_literal.DecomposeTuple();
1155   ASSERT_EQ(exec_literal_parts.size(), 4);
1156 
1157   EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1158   EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1159 
1160   // Now we read back the original input handle values, which at this point
1161   // should contain the result of the XLA computation.
1162   auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1163   TF_ASSERT_OK(root.status());
1164   auto release_handle = ops::XRTReleaseAllocationHandle(
1165       root.WithControlDependencies(read_handle), Input(alloc_handle));
1166   TF_ASSERT_OK(root.status());
1167 
1168   outputs.clear();
1169   TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {read_handle},
1170                            {release_handle}, &outputs));
1171 
1172   xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1173 
1174   auto expected_literal0 = xla::LiteralUtil::CreateR1<float>({6.0f, 11.0f});
1175   auto expected_literal1 = xla::LiteralUtil::CreateR1<float>({-4.0f, -7.0f});
1176   // The first element of the computation returned tuple would be the add
1177   // (expected_literal0), but since we flipped the buffers, the sub
1178   // (expected_literal1) should come first.
1179   auto expected_literal =
1180       xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1181 
1182   EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1183 }
1184 
TEST(RawApiTest,CompileAndExecuteWithS64Argument)1185 TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
1186   xrt::XLAAllocation p0;
1187   *p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
1188   xrt::XLAAllocation p1;
1189   *p1.mutable_value() = xla::LiteralUtil::CreateR0<int64>(4091934).ToProto();
1190 
1191   xrt::XLAComputation c;
1192   auto config = c.mutable_config();
1193   auto shapes = config->mutable_program_shape();
1194   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1195   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1196   *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1197   StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
1198 
1199   xrt::XRTExecutionConfig e;
1200   e.set_release_input_handles(true);
1201   e.set_release_compilation_handle(true);
1202   e.set_return_exploded_tuple(true);
1203 
1204   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1205   auto e_config =
1206       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1207   auto computation =
1208       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1209   auto c_handle = ops::XRTCompile(root, computation);
1210   auto p0_value =
1211       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1212   auto p0_handle = ops::XRTAllocate(root, p0_value);
1213   auto p1_value =
1214       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1215   auto p1_handle = ops::XRTAllocate(root, p1_value);
1216   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1217                                 {Output(p0_handle), Output(p1_handle)});
1218   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1219   TF_ASSERT_OK(root.status());
1220 
1221   ClientSession session(root);
1222   std::vector<Tensor> outputs;
1223   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1224 
1225   xla::LiteralProto response;
1226   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
1227 
1228   auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
1229   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1230 
1231   xla::ProgramShapeProto program_shape;
1232   EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
1233   EXPECT_EQ(program_shape.parameters_size(), 2);
1234   EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
1235       xla::Shape(program_shape.result()), xla::S64));
1236 }
1237 
1238 }  // namespace
1239 
1240 }  // namespace tensorflow
1241 
main(int argc,char ** argv)1242 int main(int argc, char** argv) {
1243   tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
1244   tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
1245   std::vector<tensorflow::Flag> flag_list = {
1246       tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
1247                        "Tensorflow device type to use for test, e.g., XLA_CPU"),
1248       tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
1249                        "The XLA platform to select for the device"),
1250   };
1251   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
1252   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
1253   if (!parse_result) {
1254     LOG(ERROR) << "\n" << usage;
1255     return 2;
1256   }
1257   testing::InitGoogleTest(&argc, argv);
1258   if (argc > 1) {
1259     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
1260     return 2;
1261   }
1262   return RUN_ALL_TESTS();
1263 }
1264