1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <vector>
19 
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/cc/client/client_session.h"
22 #include "tensorflow/cc/framework/ops.h"
23 #include "tensorflow/cc/framework/scope.h"
24 #include "tensorflow/cc/ops/standard_ops.h"
25 #include "tensorflow/compiler/tf2xla/literal_util.h"
26 #include "tensorflow/compiler/tf2xla/shape_util.h"
27 #include "tensorflow/compiler/xla/client/client_library.h"
28 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
29 #include "tensorflow/compiler/xla/client/lib/constants.h"
30 #include "tensorflow/compiler/xla/client/local_client.h"
31 #include "tensorflow/compiler/xla/client/xla_builder.h"
32 #include "tensorflow/compiler/xla/client/xla_computation.h"
33 #include "tensorflow/compiler/xla/literal.h"
34 #include "tensorflow/compiler/xla/literal_util.h"
35 #include "tensorflow/compiler/xla/service/platform_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/xla_data.pb.h"
38 #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
39 #include "tensorflow/compiler/xrt/cc/ops/xrt_execute_op.h"
40 #include "tensorflow/compiler/xrt/cc/ops/xrt_state_ops.h"
41 #include "tensorflow/compiler/xrt/xrt.pb.h"
42 #include "tensorflow/core/framework/tensor.h"
43 #include "tensorflow/core/lib/core/status.h"
44 #include "tensorflow/core/lib/core/status_test_util.h"
45 #include "tensorflow/core/lib/gtl/array_slice.h"
46 #include "tensorflow/core/platform/types.h"
47 #include "tensorflow/core/util/command_line_flags.h"
48 
49 namespace tensorflow {
50 namespace {
51 
ReturnDynamicR1()52 xla::XlaComputation ReturnDynamicR1() {
53   xla::XlaBuilder builder("ReturnDynamicR1");
54   auto p0 = xla::Parameter(&builder, 0,
55                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
56   auto p1 = xla::Parameter(&builder, 1,
57                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
58   auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
59                            "P2");
60   auto sum = xla::Add(p0, p1);
61   auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
62   return builder.Build(pad_sum).ValueOrDie();
63 }
64 
ReturnDynamicR2()65 xla::XlaComputation ReturnDynamicR2() {
66   xla::XlaBuilder builder("ReturnDynamicR2");
67   auto p0 = xla::Parameter(&builder, 0,
68                            xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P0");
69   auto p1 = xla::Parameter(&builder, 1,
70                            xla::ShapeUtil::MakeShape(xla::F32, {2, 4}), "P1");
71   auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
72                            "P2");
73   auto sum = xla::Add(p0, p1);
74   auto pad_sum_dim0 = xla::SetDimensionSize(sum, p2, 0);
75   auto pad_sum_dim1 = xla::SetDimensionSize(pad_sum_dim0, p2, 1);
76   return builder.Build(pad_sum_dim1).ValueOrDie();
77 }
78 
AcceptDynamicR1()79 xla::XlaComputation AcceptDynamicR1() {
80   xla::XlaBuilder builder("AcceptDynamicR1");
81   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
82   dyn_shape.set_dynamic_dimension(0, true);
83   auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
84   auto p1 = xla::Parameter(&builder, 1, dyn_shape, "P1");
85   auto sum = xla::Add(p0, p1);
86   return builder.Build(sum).ValueOrDie();
87 }
88 
AcceptDynamicR2()89 xla::XlaComputation AcceptDynamicR2() {
90   xla::XlaBuilder builder("AcceptDynamicR2");
91   xla::Shape dyn_shape;
92   dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
93   dyn_shape.set_dynamic_dimension(1, true);
94   auto p0 = xla::Parameter(&builder, 0, dyn_shape, "P0");
95   auto negate = xla::Neg(p0);
96   return builder.Build(negate).ValueOrDie();
97 }
98 
ReturnDynamicR1Tuple()99 xla::XlaComputation ReturnDynamicR1Tuple() {
100   xla::XlaBuilder builder("ReturnDynamicR1Tuple");
101   auto p0 = xla::Parameter(&builder, 0,
102                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P0");
103   auto p1 = xla::Parameter(&builder, 1,
104                            xla::ShapeUtil::MakeShape(xla::F32, {4}), "P1");
105   auto p2 = xla::Parameter(&builder, 2, xla::ShapeUtil::MakeShape(xla::S32, {}),
106                            "P2");
107   auto sum = xla::Add(p0, p1);
108   auto sub = xla::Sub(p0, p1);
109   auto one = xla::One(&builder, xla::S32);
110   auto pad_sum = xla::SetDimensionSize(sum, p2, 0);
111   auto pad_sub = xla::SetDimensionSize(sub, p2 + one, 0);
112   auto tuple = xla::Tuple(&builder, {pad_sum, sum, pad_sub});
113   return builder.Build(tuple, /*remove_dynamic_dimensions=*/true).ValueOrDie();
114 }
115 
AcceptDynamicR1Tuple()116 xla::XlaComputation AcceptDynamicR1Tuple() {
117   xla::XlaBuilder builder("AcceptDynamicR1");
118   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
119   dyn_shape.set_dynamic_dimension(0, true);
120   xla::Shape tuple_shape =
121       xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
122   xla::Shape nest_tuple_shape =
123       xla::ShapeUtil::MakeTupleShape({dyn_shape, dyn_shape});
124   auto p = xla::Parameter(&builder, 0, tuple_shape, "P0");
125   auto p0 = xla::GetTupleElement(p, 0);
126   auto p1 = xla::GetTupleElement(p, 1);
127   auto sum = xla::Add(p0, p1);
128   return builder.Build(sum).ValueOrDie();
129 }
130 
131 template <typename T>
CreateR0(T v)132 xla::LiteralProto CreateR0(T v) {
133   auto array = xla::LiteralUtil::CreateR0<T>(v);
134   return array.ToProto();
135 }
136 
137 class XrtClientSession : public ClientSession {
138  public:
XrtClientSession(const Scope & scope)139   explicit XrtClientSession(const Scope& scope) : ClientSession(scope) {
140     auto clear_all = ops::XRTReleaseAllAllocations(scope);
141     std::vector<Tensor> outputs;
142     TF_CHECK_OK(Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
143   }
144 };
145 
146 string* xla_test_device_ptr;  // initial value set in main()
147 string* xla_platform_ptr;     // initial value set in main()
148 
DeviceFromFlag()149 string DeviceFromFlag() {
150   string xla_test_device = *xla_test_device_ptr;
151   return absl::StrCat("/device:", xla_test_device, ":0");
152 }
153 
GetAttrLayout(absl::Span<const int64> minor_to_mayor)154 std::vector<int> GetAttrLayout(absl::Span<const int64> minor_to_mayor) {
155   std::vector<int> layout;
156   for (auto dim : minor_to_mayor) {
157     layout.push_back(static_cast<int>(dim));
158   }
159   return layout;
160 }
161 
TwoElementTuple()162 xla::LiteralProto TwoElementTuple() {
163   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
164   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
165   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
166   return tuple.ToProto();
167 }
168 
BasedTwoElementTuple(float base)169 xla::LiteralProto BasedTwoElementTuple(float base) {
170   auto array = xla::LiteralUtil::CreateR1<float>({base, base + 1});
171   auto matrix = xla::LiteralUtil::CreateR2<float>(
172       {{base + 2, base + 3}, {base + 4, base + 5}});
173   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
174   return tuple.ToProto();
175 }
176 
ScalarLiteral()177 xla::LiteralProto ScalarLiteral() {
178   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
179   return scalar.ToProto();
180 }
181 
NestedTuple()182 xla::LiteralProto NestedTuple() {
183   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
184   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
185   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
186   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
187   auto nested = xla::LiteralUtil::MakeTuple({&tuple, &scalar});
188   return nested.ToProto();
189 }
190 
MakeTuple0()191 xla::LiteralProto MakeTuple0() {
192   auto scalar = xla::LiteralUtil::CreateR0<float>(12.0f);
193   auto array = xla::LiteralUtil::CreateR1<float>({1.0f, 3.0f});
194   auto matrix = xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}});
195   auto tuple = xla::LiteralUtil::MakeTuple({&array, &matrix});
196   auto nested0 = xla::LiteralUtil::MakeTuple({&scalar, &tuple});
197   auto nested1 = xla::LiteralUtil::MakeTuple({&scalar, &nested0});
198   return nested1.ToProto();
199 }
200 
FloatVector(absl::Span<const float> v)201 xla::LiteralProto FloatVector(absl::Span<const float> v) {
202   auto array = xla::LiteralUtil::CreateR1<float>(v);
203   return array.ToProto();
204 }
205 
FloatMatrix(std::initializer_list<std::initializer_list<float>> v,const xla::Layout & layout)206 xla::LiteralProto FloatMatrix(
207     std::initializer_list<std::initializer_list<float>> v,
208     const xla::Layout& layout) {
209   auto array = xla::LiteralUtil::CreateR2WithLayout<float>(v, layout);
210   return array.ToProto();
211 }
212 
ReadOutputLiteral(const std::vector<Tensor> & outputs,size_t idx)213 xla::Literal ReadOutputLiteral(const std::vector<Tensor>& outputs, size_t idx) {
214   xla::LiteralProto response;
215   CHECK(ParseFromTString(outputs[idx].scalar<tstring>()(), &response));
216   return xla::Literal::CreateFromProto(response).ValueOrDie();
217 }
218 
CompareLiteralProtos(const xla::LiteralProto & a,const xla::LiteralProto & b)219 bool CompareLiteralProtos(const xla::LiteralProto& a,
220                           const xla::LiteralProto& b) {
221   auto l_a = xla::Literal::CreateFromProto(a).ValueOrDie();
222   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
223   bool equal = l_a == l_b;
224   if (!equal) {
225     LOG(INFO) << "LiteralProtos don't match:\n"
226               << a.DebugString() << "\n!=\n"
227               << b.DebugString();
228   }
229   return equal;
230 }
231 
CompareLiteralToLiteralProto(const xla::Literal & a,const xla::LiteralProto & b)232 bool CompareLiteralToLiteralProto(const xla::Literal& a,
233                                   const xla::LiteralProto& b) {
234   auto l_b = xla::Literal::CreateFromProto(b).ValueOrDie();
235   bool equal = a == l_b;
236   if (!equal) {
237     LOG(INFO) << "Literal and LiteralProto don't match:\n"
238               << a.ToProto().DebugString() << "\n!=\n"
239               << b.DebugString();
240   }
241   return equal;
242 }
243 
CompareLiterals(const xla::Literal & a,const xla::Literal & b)244 bool CompareLiterals(const xla::Literal& a, const xla::Literal& b) {
245   bool equal = a == b;
246   if (!equal) {
247     LOG(INFO) << "Literals don't match:\n"
248               << a.ToProto().DebugString() << "\n!=\n"
249               << b.ToProto().DebugString();
250   }
251   return equal;
252 }
253 
OnePlusTwo()254 xla::XlaComputation OnePlusTwo() {
255   xla::XlaBuilder builder("OnePlusTwo");
256   auto c0 = xla::ConstantR0(&builder, 1.0f);
257   auto c1 = xla::ConstantR0(&builder, 2.0f);
258   xla::Add(c0, c1);
259   return builder.Build().ValueOrDie();
260 }
261 
AddAndScale()262 xla::XlaComputation AddAndScale() {
263   xla::XlaBuilder builder("AddAndScale");
264   auto p0 = xla::Parameter(&builder, 0,
265                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
266   auto p1 = xla::Parameter(&builder, 1,
267                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
268   auto sum = xla::Add(p0, p1);
269   auto c = xla::ConstantR0<float>(&builder, 3.0f);
270   xla::Mul(sum, c);
271   return builder.Build().ValueOrDie();
272 }
273 
SubAndScale()274 xla::XlaComputation SubAndScale() {
275   xla::XlaBuilder builder("SubAndScale");
276   auto p0 = xla::Parameter(&builder, 0,
277                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
278   auto p1 = xla::Parameter(&builder, 1,
279                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
280   auto sum = xla::Sub(p0, p1);
281   auto c = xla::ConstantR0<float>(&builder, 11.0f);
282   xla::Mul(sum, c);
283   return builder.Build().ValueOrDie();
284 }
285 
Dot()286 xla::XlaComputation Dot() {
287   xla::XlaBuilder builder("Dot");
288   auto p0 = xla::Parameter(
289       &builder, 0,
290       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}), "P0");
291   auto p1 = xla::Parameter(
292       &builder, 1,
293       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}), "P1");
294   xla::DotDimensionNumbers ddn;
295   ddn.add_lhs_contracting_dimensions(1);
296   ddn.add_rhs_contracting_dimensions(0);
297   xla::DotGeneral(p0, p1, ddn);
298   return builder.Build().ValueOrDie();
299 }
300 
AddS64()301 xla::XlaComputation AddS64() {
302   xla::XlaBuilder builder("AddS64");
303   auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::S64, {}),
304                            "P0");
305   auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::S64, {}),
306                            "P1");
307   xla::Add(p0, p1);
308   return builder.Build().ValueOrDie();
309 }
310 
AddAndTuple()311 xla::XlaComputation AddAndTuple() {
312   xla::XlaBuilder builder("AddAndTuple");
313   auto p0 = xla::Parameter(&builder, 0,
314                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
315   auto p1 = xla::Parameter(&builder, 1,
316                            xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
317   auto sum = xla::Add(p0, p1);
318   xla::Tuple(&builder, {sum});
319   return builder.Build().ValueOrDie();
320 }
321 
AddAndSubTuple()322 xla::XlaComputation AddAndSubTuple() {
323   xla::XlaBuilder builder("AddAndSubTuple");
324   auto p0 = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(xla::F32, {}),
325                            "P0");
326   auto p1 = xla::Parameter(&builder, 1, xla::ShapeUtil::MakeShape(xla::F32, {}),
327                            "P1");
328   auto sum = xla::Add(p0, p1);
329   auto sub = xla::Sub(p0, p1);
330   xla::Tuple(&builder, {sum, sub});
331   return builder.Build().ValueOrDie();
332 }
333 
BroadcastComputation(const xla::Shape & shape,absl::Span<const xla::int64> dimensions)334 xla::XlaComputation BroadcastComputation(
335     const xla::Shape& shape, absl::Span<const xla::int64> dimensions) {
336   xla::XlaBuilder builder("BroadcastComputation");
337   auto p0 = xla::Parameter(&builder, 0, shape, "P0");
338   xla::Broadcast(p0, dimensions);
339   return builder.Build().ValueOrDie();
340 }
341 
IsEqualComputation(const xla::Shape & shape)342 xla::XlaComputation IsEqualComputation(const xla::Shape& shape) {
343   xla::XlaBuilder builder("IsEqualComputation");
344   auto p0 = xla::Parameter(&builder, 0, shape, "P0");
345   auto p1 = xla::Parameter(&builder, 1, shape, "P1");
346   auto cmp =
347       xla::Ne(xla::Sub(p0, p1), xla::Zero(&builder, shape.element_type()));
348   auto icmp = xla::ConvertElementType(cmp, xla::S32);
349   xla::ReduceAll(icmp, xla::Zero(&builder, xla::S32),
350                  xla::CreateScalarAddComputation(xla::S32, &builder));
351   return builder.Build().ValueOrDie();
352 }
353 
StoreComputationSnapshot(const xla::XlaComputation & computation,xla::HloSnapshot * dst)354 void StoreComputationSnapshot(const xla::XlaComputation& computation,
355                               xla::HloSnapshot* dst) {
356   auto snapshot = computation.Snapshot().ValueOrDie();
357   *dst = *snapshot;
358 }
359 
XlaCompiledProgramShape(const xla::XlaComputation & computation,const xla::ProgramShape & input_program_shape)360 xla::ProgramShape XlaCompiledProgramShape(
361     const xla::XlaComputation& computation,
362     const xla::ProgramShape& input_program_shape) {
363   se::Platform* platform =
364       xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
365   xla::LocalClient* client =
366       xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
367   xla::ExecutableBuildOptions exec_options;
368   exec_options.set_result_layout(input_program_shape.result());
369   std::vector<const xla::Shape*> parameters_shapes;
370   for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
371     parameters_shapes.push_back(&input_program_shape.parameters(i));
372   }
373   std::vector<std::unique_ptr<xla::LocalExecutable>> local_executables =
374       client->Compile(computation, parameters_shapes, exec_options)
375           .ConsumeValueOrDie();
376   EXPECT_EQ(local_executables.size(), 1);
377   std::unique_ptr<xla::LocalExecutable> local_executable =
378       std::move(local_executables[0]);
379   return local_executable->executable()
380       ->module()
381       .entry_computation()
382       ->ComputeProgramShape();
383 }
384 
TEST(RawApiTest,AllocFromTensor)385 TEST(RawApiTest, AllocFromTensor) {
386   xla::Literal literal =
387       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
388   Tensor tensor;
389   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
390 
391   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
392   std::vector<int> layout =
393       GetAttrLayout(literal.shape().layout().minor_to_major());
394   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
395       ops::XRTAllocateFromTensor::Layouts(layout);
396   auto handle =
397       ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
398   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
399   TF_ASSERT_OK(root.status());
400 
401   XrtClientSession session(root);
402   std::vector<Tensor> outputs;
403   TF_EXPECT_OK(session.Run({read_back}, &outputs));
404   EXPECT_EQ(outputs.size(), 1);
405 
406   xla::LiteralProto response;
407   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
408   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
409 }
410 
TEST(RawApiTest,AllocUninitialized)411 TEST(RawApiTest, AllocUninitialized) {
412   xla::Literal literal =
413       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
414   Tensor tensor;
415   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
416 
417   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
418   std::vector<int> layout =
419       GetAttrLayout(literal.shape().layout().minor_to_major());
420 
421   auto allocate_op =
422       ops::XRTAllocateUninitialized(root, DT_FLOAT, tensor.shape());
423 
424   Tensor handle;
425   std::vector<Tensor> outputs;
426   XrtClientSession session(root);
427   // Allocate the tensor
428   {
429     TF_EXPECT_OK(session.Run({allocate_op}, &outputs));
430     handle = outputs[0];
431   }
432 
433   // Make sure it has the expected shape
434   {
435     auto read_back_op = ops::XRTReadLiteral(root, handle);
436     TF_ASSERT_OK(root.status());
437 
438     TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
439     EXPECT_EQ(outputs.size(), 1);
440     xla::LiteralProto read_back_literal;
441     EXPECT_TRUE(
442         ParseFromTString(outputs[0].scalar<tstring>()(), &read_back_literal));
443     Tensor read_back_tensor;
444     TF_ASSERT_OK(LiteralToHostTensor(
445         xla::Literal::CreateFromProto(read_back_literal).ValueOrDie(), DT_FLOAT,
446         &read_back_tensor));
447 
448     // The shape should be the same as 'tensor', but we don't have any
449     // expectation about the value of the tensors yet since it is uninitialized
450     EXPECT_EQ(tensor.shape(), read_back_tensor.shape());
451   }
452 
453   // Make sure we can write to it
454   xla::LiteralProto new_literal =
455       xla::LiteralUtil::CreateR2({{9.0f, 2.0f}, {4.0f, 1.0f}}).ToProto();
456   {
457     auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
458                                 new_literal.SerializeAsString());
459     auto write_op = ops::XRTWriteLiteral(root, Input(handle), new_value);
460     TF_ASSERT_OK(root.status());
461     TF_EXPECT_OK(session.Run({write_op}, &outputs));
462   }
463 
464   // Now read it back
465   {
466     auto read_back_op = ops::XRTReadLiteralAndRelease(root, handle);
467     TF_ASSERT_OK(root.status());
468     TF_EXPECT_OK(session.Run({read_back_op}, &outputs));
469     EXPECT_EQ(outputs.size(), 1);
470 
471     xla::LiteralProto response;
472     EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
473     EXPECT_TRUE(CompareLiteralProtos(response, new_literal));
474   }
475 }
476 
TEST(RawApiTest,AllocFromTensorTuple)477 TEST(RawApiTest, AllocFromTensorTuple) {
478   xla::Literal literal0 =
479       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
480   xla::Literal literal1 =
481       xla::LiteralUtil::CreateR2<float>({{14.0f, -5.0f}, {16.0f, 17.0f}});
482   xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
483   Tensor tensor0;
484   TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
485   Tensor tensor1;
486   TF_ASSERT_OK(LiteralToHostTensor(literal1, DT_FLOAT, &tensor1));
487 
488   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
489   std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
490   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
491       ops::XRTAllocateFromTensor::Layouts(layout);
492   auto handle = ops::XRTAllocateFromTensor(root, {tensor0, tensor1},
493                                            {tensor0.shape(), tensor1.shape()},
494                                            alloc_attrs);
495   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
496   TF_ASSERT_OK(root.status());
497 
498   XrtClientSession session(root);
499   std::vector<Tensor> outputs;
500   TF_EXPECT_OK(session.Run({read_back}, &outputs));
501   EXPECT_EQ(outputs.size(), 1);
502 
503   xla::LiteralProto response;
504   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
505   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
506 }
507 
TEST(RawApiTest,AllocFromTensorTupleSingle)508 TEST(RawApiTest, AllocFromTensorTupleSingle) {
509   xla::Literal literal0 =
510       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
511   xla::Literal literal = xla::LiteralUtil::MakeTuple({&literal0});
512   Tensor tensor0;
513   TF_ASSERT_OK(LiteralToHostTensor(literal0, DT_FLOAT, &tensor0));
514 
515   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
516   std::vector<int> layout = GetShapeLayoutVector(literal.shape()).ValueOrDie();
517   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
518       ops::XRTAllocateFromTensor::Layouts(layout).MakeTuple(true);
519   auto handle = ops::XRTAllocateFromTensor(root, {tensor0}, {tensor0.shape()},
520                                            alloc_attrs);
521   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
522   TF_ASSERT_OK(root.status());
523 
524   XrtClientSession session(root);
525   std::vector<Tensor> outputs;
526   TF_EXPECT_OK(session.Run({read_back}, &outputs));
527   EXPECT_EQ(outputs.size(), 1);
528 
529   xla::LiteralProto response;
530   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
531   EXPECT_TRUE(CompareLiteralToLiteralProto(literal, response));
532 }
533 
TEST(RawApiTest,AllocFromTensorRelayout)534 TEST(RawApiTest, AllocFromTensorRelayout) {
535   xla::Literal literal =
536       xla::LiteralUtil::CreateR2<float>({{4.0f, 5.0f}, {6.0f, 7.0f}});
537   Tensor tensor;
538   TF_ASSERT_OK(LiteralToHostTensor(literal, DT_FLOAT, &tensor));
539 
540   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
541   // Use inverse array layout with the tensor data above.
542   std::vector<int> layout({0, 1});
543   ops::XRTAllocateFromTensor::Attrs alloc_attrs =
544       ops::XRTAllocateFromTensor::Layouts(layout);
545   auto handle =
546       ops::XRTAllocateFromTensor(root, {tensor}, {tensor.shape()}, alloc_attrs);
547   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
548   TF_ASSERT_OK(root.status());
549 
550   XrtClientSession session(root);
551   std::vector<Tensor> outputs;
552   TF_EXPECT_OK(session.Run({read_back}, &outputs));
553   EXPECT_EQ(outputs.size(), 1);
554 
555   xla::LiteralProto response;
556   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
557   // We have sent literal's data (in array layout) with a attribute layout
558   // {0,1}, so the expected literal read from device needs to be changed
559   // accordingly.
560   xla::Literal expected_literal =
561       xla::LiteralUtil::CreateR2<float>({{4.0f, 6.0f}, {5.0f, 7.0f}});
562   EXPECT_TRUE(CompareLiteralToLiteralProto(expected_literal, response));
563 }
564 
TEST(RawApiTest,AllocAndRewrite)565 TEST(RawApiTest, AllocAndRewrite) {
566   xrt::XLAAllocation alloc;
567   *alloc.mutable_value() =
568       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
569 
570   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
571   auto value =
572       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
573   auto handle = ops::XRTAllocate(root, value);
574   auto read_back = ops::XRTReadLiteral(root, handle);
575   TF_ASSERT_OK(root.status());
576 
577   XrtClientSession session(root);
578   std::vector<Tensor> outputs;
579   TF_EXPECT_OK(session.Run({read_back, handle}, &outputs));
580   EXPECT_EQ(outputs.size(), 2);
581 
582   int64 allocation_handle = outputs[1].scalar<int64>()();
583   xla::LiteralProto response;
584   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
585   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
586 
587   xla::LiteralProto new_literal =
588       xla::LiteralUtil::CreateR2({{9, 2}, {4, 1}}).ToProto();
589   auto new_value = ops::Const(root.WithDevice("/device:CPU:0"),
590                               new_literal.SerializeAsString());
591   auto write_op =
592       ops::XRTWriteLiteral(root, Input(allocation_handle), new_value);
593   TF_ASSERT_OK(root.status());
594   TF_EXPECT_OK(session.Run({write_op}, &outputs));
595   EXPECT_EQ(outputs.size(), 1);
596   EXPECT_EQ(allocation_handle, outputs[0].scalar<int64>()());
597 
598   auto read_after_write = ops::XRTReadLiteral(root, Input(allocation_handle));
599   TF_EXPECT_OK(session.Run({read_after_write}, &outputs));
600   EXPECT_EQ(outputs.size(), 1);
601 
602   xla::LiteralProto new_response;
603   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &new_response));
604   EXPECT_TRUE(CompareLiteralProtos(new_literal, new_response));
605 
606   Tensor release_tensor(DT_INT64, TensorShape({1}));
607   release_tensor.flat<int64>()(0) = allocation_handle;
608 
609   auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
610   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
611 }
612 
TEST(RawApiTest,AllocReleaseMany)613 TEST(RawApiTest, AllocReleaseMany) {
614   xrt::XLAAllocation alloc1;
615   *alloc1.mutable_value() =
616       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
617   xrt::XLAAllocation alloc2;
618   *alloc2.mutable_value() =
619       xla::LiteralUtil::CreateR2({{6, 7}, {4, 5}}).ToProto();
620 
621   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
622   auto value1 =
623       ops::Const(root.WithDevice("/device:CPU:0"), alloc1.SerializeAsString());
624   auto value2 =
625       ops::Const(root.WithDevice("/device:CPU:0"), alloc2.SerializeAsString());
626   auto handle1 = ops::XRTAllocate(root, value1);
627   auto handle2 = ops::XRTAllocate(root, value2);
628   TF_ASSERT_OK(root.status());
629 
630   XrtClientSession session(root);
631   std::vector<Tensor> outputs;
632   TF_EXPECT_OK(session.Run({handle1, handle2}, &outputs));
633   EXPECT_EQ(outputs.size(), 2);
634 
635   int64 allocation_handle1 = outputs[0].scalar<int64>()();
636   int64 allocation_handle2 = outputs[1].scalar<int64>()();
637 
638   Tensor release_tensor(DT_INT64, TensorShape({2}));
639   release_tensor.flat<int64>()(0) = allocation_handle1;
640   release_tensor.flat<int64>()(1) = allocation_handle2;
641 
642   auto release = ops::XRTReleaseAllocationHandle(root, release_tensor);
643   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
644 }
645 
TEST(RawApiTest,CompileAndReleaseMany)646 TEST(RawApiTest, CompileAndReleaseMany) {
647   xrt::XLAComputation c1;
648   auto config1 = c1.mutable_config();
649   auto shapes1 = config1->mutable_program_shape();
650   *shapes1->add_parameters() =
651       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
652   *shapes1->add_parameters() =
653       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
654   *shapes1->mutable_result() =
655       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
656   StoreComputationSnapshot(AddAndScale(), c1.mutable_hlo_snapshot());
657 
658   xrt::XLAComputation c2;
659   auto config2 = c2.mutable_config();
660   auto shapes2 = config2->mutable_program_shape();
661   *shapes2->add_parameters() =
662       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
663   *shapes2->add_parameters() =
664       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
665   *shapes2->mutable_result() =
666       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
667           .ToProto();
668   StoreComputationSnapshot(AddAndTuple(), c2.mutable_hlo_snapshot());
669 
670   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
671   auto computation1 =
672       ops::Const(root.WithDevice("/device:CPU:0"), c1.SerializeAsString());
673   auto c_handle1 = ops::XRTCompile(root, computation1);
674   auto computation2 =
675       ops::Const(root.WithDevice("/device:CPU:0"), c2.SerializeAsString());
676   auto c_handle2 = ops::XRTCompile(root, computation2);
677   TF_ASSERT_OK(root.status());
678 
679   XrtClientSession session(root);
680   std::vector<Tensor> outputs;
681   TF_EXPECT_OK(session.Run({c_handle1.handle, c_handle2.handle}, &outputs));
682   EXPECT_EQ(outputs.size(), 2);
683 
684   int64 compilation_handle1 = outputs[0].scalar<int64>()();
685   int64 compilation_handle2 = outputs[1].scalar<int64>()();
686 
687   Tensor release_tensor(DT_INT64, TensorShape({2}));
688   release_tensor.flat<int64>()(0) = compilation_handle1;
689   release_tensor.flat<int64>()(1) = compilation_handle2;
690 
691   auto release = ops::XRTReleaseCompilationHandle(root, release_tensor);
692   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {}, {release}, &outputs));
693 }
694 
TEST(RawApiTest,AllocAndClearAll)695 TEST(RawApiTest, AllocAndClearAll) {
696   xrt::XLAAllocation alloc;
697   *alloc.mutable_value() =
698       xla::LiteralUtil::CreateR2({{4, 5}, {6, 7}}).ToProto();
699 
700   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
701   auto value =
702       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
703   auto handle = ops::XRTAllocate(root, value);
704   TF_ASSERT_OK(root.status());
705 
706   XrtClientSession session(root);
707   std::vector<Tensor> outputs;
708   TF_EXPECT_OK(session.Run({handle}, &outputs));
709   EXPECT_EQ(outputs.size(), 1);
710 
711   int64 allocation_handle = outputs[0].scalar<int64>()();
712 
713   auto clear_all = ops::XRTReleaseAllAllocations(root);
714 
715   TF_EXPECT_OK(
716       session.Run(ClientSession::FeedType(), {}, {clear_all}, &outputs));
717   EXPECT_EQ(outputs.size(), 0);
718 
719   auto read_after_clear = ops::XRTReadLiteral(root, Input(allocation_handle));
720   EXPECT_EQ(session.Run({read_after_clear}, &outputs).code(),
721             error::Code::NOT_FOUND);
722 }
723 
TEST(RawApiTest,ReadAndWriteState)724 TEST(RawApiTest, ReadAndWriteState) {
725   xrt::XLAAllocation alloc;
726   *alloc.mutable_value() = TwoElementTuple();
727 
728   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
729   auto value =
730       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
731   auto handle = ops::XRTAllocate(root, value);
732   auto read_back = ops::XRTReadLiteral(root, handle);
733   auto release = ops::XRTReleaseAllocationHandle(
734       root.WithControlDependencies(read_back), handle);
735   TF_ASSERT_OK(root.status());
736 
737   XrtClientSession session(root);
738   std::vector<Tensor> outputs;
739   TF_EXPECT_OK(
740       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
741 
742   xla::LiteralProto response;
743   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
744 
745   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
746 }
747 
TEST(RawApiTest,ReadAndWriteStateAutoFree)748 TEST(RawApiTest, ReadAndWriteStateAutoFree) {
749   xrt::XLAAllocation alloc;
750   *alloc.mutable_value() = TwoElementTuple();
751 
752   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
753   auto value =
754       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
755   auto handle = ops::XRTAllocate(root, value);
756   auto read_back = ops::XRTReadLiteralAndRelease(root, handle);
757   TF_ASSERT_OK(root.status());
758 
759   XrtClientSession session(root);
760   std::vector<Tensor> outputs;
761   TF_EXPECT_OK(session.Run({read_back}, &outputs));
762 
763   xla::LiteralProto response;
764   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
765   EXPECT_TRUE(CompareLiteralProtos(alloc.value(), response));
766 }
767 
TEST(RawApiTest,SubBuffer)768 TEST(RawApiTest, SubBuffer) {
769   xrt::XLAAllocation alloc;
770   *alloc.mutable_value() = NestedTuple();
771 
772   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
773   auto value =
774       ops::Const(root.WithDevice("/device:CPU:0"), alloc.SerializeAsString());
775   auto base_handle = ops::XRTAllocate(root, value);
776   auto index_0 = ops::Const(root.WithDevice("/device:CPU:0"), {0});
777   auto index_1 = ops::Const(root.WithDevice("/device:CPU:0"), {1});
778   auto index_00 = ops::Const(root.WithDevice("/device:CPU:0"), {0, 0});
779   auto sub_0 = ops::XRTSubTuple(root, base_handle, index_0);
780   auto sub_1 = ops::XRTSubTuple(root, base_handle, index_1);
781   auto sub_00 = ops::XRTSubTupleAndRelease(
782       root.WithControlDependencies(
783           {sub_0.output_handle.op(), sub_1.output_handle.op()}),
784       base_handle, index_00);
785   auto value_0 = ops::XRTReadLiteralAndRelease(root, sub_0);
786   auto value_1 = ops::XRTReadLiteralAndRelease(root, sub_1);
787   auto value_00 = ops::XRTReadLiteralAndRelease(root, sub_00);
788   TF_ASSERT_OK(root.status());
789 
790   XrtClientSession session(root);
791   std::vector<Tensor> outputs;
792   TF_EXPECT_OK(session.Run({value_0, value_1, value_00}, &outputs));
793 
794   auto base_literal = xla::Literal::CreateFromProto(alloc.value()).ValueOrDie();
795   auto base_elements = base_literal.DecomposeTuple();
796   auto nested_0_elements = base_elements[0].Clone().DecomposeTuple();
797   xla::LiteralProto response_0;
798   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
799   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[0], response_0));
800   xla::LiteralProto response_1;
801   EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
802   EXPECT_TRUE(CompareLiteralToLiteralProto(base_elements[1], response_1));
803   xla::LiteralProto response_00;
804   EXPECT_TRUE(ParseFromTString(outputs[2].scalar<tstring>()(), &response_00));
805   EXPECT_TRUE(CompareLiteralToLiteralProto(nested_0_elements[0], response_00));
806 }
807 
TEST(RawApiTest,MakeTuple)808 TEST(RawApiTest, MakeTuple) {
809   xrt::XLAAllocation alloc_0;
810   *alloc_0.mutable_value() = TwoElementTuple();
811   xrt::XLAAllocation alloc_1;
812   *alloc_1.mutable_value() = ScalarLiteral();
813 
814   // The trivial tuple that just forwards its input and releases it.
815   xrt::XLATupleNode desc_0;
816   desc_0.set_input_index(0);
817   desc_0.set_release_input_handle(true);
818 
819   xrt::XLATupleNode desc_1;
820   auto subdesc_10 = desc_1.add_tuples();
821   auto subdesc_11 = desc_1.add_tuples();
822   subdesc_10->set_input_index(0);
823   auto subdesc_110 = subdesc_11->add_tuples();
824   subdesc_110->set_input_index(0);
825   auto subdesc_111 = subdesc_11->add_tuples();
826   subdesc_111->set_input_index(1);
827 
828   xrt::XLATupleNode desc_2;
829   auto subdesc_20 = desc_2.add_tuples();
830   auto subdesc_21 = desc_2.add_tuples();
831   subdesc_20->set_input_index(1);
832   subdesc_20->set_release_input_handle(true);
833   subdesc_21->set_input_index(0);
834   subdesc_21->set_release_input_handle(true);
835 
836   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
837   auto value_0 =
838       ops::Const(root.WithDevice("/device:CPU:0"), alloc_0.SerializeAsString());
839   auto handle_0 = ops::XRTAllocate(root, value_0);
840   auto value_1 =
841       ops::Const(root.WithDevice("/device:CPU:0"), alloc_1.SerializeAsString());
842   auto handle_1 = ops::XRTAllocate(root, value_1);
843   auto tuple_0 =
844       ops::Const(root.WithDevice("/device:CPU:0"), desc_0.SerializeAsString());
845   auto handle_2 =
846       ops::XRTMakeTuple(root, tuple_0, {static_cast<Output>(handle_0)});
847   // handle_0 has now been released.
848   auto tuple_1 =
849       ops::Const(root.WithDevice("/device:CPU:0"), desc_1.SerializeAsString());
850   auto handle_3 = ops::XRTMakeTuple(
851       root, tuple_1,
852       {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
853   auto tuple_2 =
854       ops::Const(root.WithDevice("/device:CPU:0"), desc_2.SerializeAsString());
855   // Make sure this runs after handle_3 has completed, since it will free
856   // handle_1 and handle_2.
857   auto handle_4 = ops::XRTMakeTuple(
858       root.WithControlDependencies(handle_3), tuple_2,
859       {static_cast<Output>(handle_1), static_cast<Output>(handle_2)});
860   // handle_1 and handle_2 have now been released.
861 
862   auto res_0 = ops::XRTReadLiteralAndRelease(root, handle_3);
863   auto res_1 = ops::XRTReadLiteralAndRelease(root, handle_4);
864   TF_ASSERT_OK(root.status());
865 
866   XrtClientSession session(root);
867   std::vector<Tensor> outputs;
868   TF_EXPECT_OK(session.Run({res_0, res_1}, &outputs));
869   xla::LiteralProto response_0;
870   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response_0));
871   xla::LiteralProto response_1;
872   EXPECT_TRUE(ParseFromTString(outputs[1].scalar<tstring>()(), &response_1));
873 
874   auto expected_0 = MakeTuple0();
875   EXPECT_TRUE(CompareLiteralProtos(response_0, expected_0));
876   auto expected_1 = NestedTuple();
877   EXPECT_TRUE(CompareLiteralProtos(response_1, expected_1));
878 }
879 
TEST(RawApiTest,ExecuteChainedOpByOp)880 TEST(RawApiTest, ExecuteChainedOpByOp) {
881   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
882 
883   auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
884     xrt::XLAComputation c;
885     auto config = c.mutable_config();
886     auto shapes = config->mutable_program_shape();
887     *shapes->add_parameters() =
888         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
889     *shapes->add_parameters() =
890         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
891     *shapes->mutable_result() =
892         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
893     StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
894     return c.SerializeAsString();
895   };
896 
897   auto c_add_scale = make_computation(AddAndScale);
898   auto c_sub_scale = make_computation(SubAndScale);
899 
900   auto c_add_scale_op = ops::XRTCompile(
901       root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
902   auto c_sub_scale_op = ops::XRTCompile(
903       root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
904   TF_ASSERT_OK(root.status());
905 
906   XrtClientSession session(root);
907   std::vector<Tensor> outputs;
908   TF_EXPECT_OK(
909       session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
910   EXPECT_EQ(outputs.size(), 2);
911 
912   int64 c_add_scale_handle = outputs[0].scalar<int64>()();
913   int64 c_sub_scale_handle = outputs[1].scalar<int64>()();
914 
915   xrt::XLAAllocation p0;
916   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
917   xrt::XLAAllocation p1;
918   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
919 
920   auto p0_handle = ops::XRTAllocate(
921       root,
922       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
923   auto p1_handle = ops::XRTAllocate(
924       root,
925       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
926 
927   xrt::XRTExecutionConfig e;
928   e.set_release_input_handles(false);
929   e.set_release_compilation_handle(false);
930   auto e_config =
931       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
932   auto result0 = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
933                                  {Output(p0_handle), Output(p1_handle)});
934   auto result1 = ops::XRTExecute(root, Input(c_sub_scale_handle), e_config,
935                                  {Output(p0_handle), Output(p1_handle)});
936   auto result = ops::XRTExecute(root, Input(c_add_scale_handle), e_config,
937                                 {result0.output_handle, result1.output_handle});
938   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
939   TF_ASSERT_OK(root.status());
940 
941   TF_EXPECT_OK(session.Run({read_back}, &outputs));
942 
943   xla::LiteralProto response;
944   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
945 
946   auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
947   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
948 }
949 
TEST(RawApiTest,ExecuteChained)950 TEST(RawApiTest, ExecuteChained) {
951   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
952 
953   auto make_computation = [](const std::function<xla::XlaComputation()>& fn) {
954     xrt::XLAComputation c;
955     auto config = c.mutable_config();
956     auto shapes = config->mutable_program_shape();
957     *shapes->add_parameters() =
958         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
959     *shapes->add_parameters() =
960         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
961     *shapes->mutable_result() =
962         xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
963     StoreComputationSnapshot(fn(), c.mutable_hlo_snapshot());
964     return c.SerializeAsString();
965   };
966 
967   auto c_add_scale = make_computation(AddAndScale);
968   auto c_sub_scale = make_computation(SubAndScale);
969 
970   auto c_add_scale_op = ops::XRTCompile(
971       root, ops::Const(root.WithDevice("/device:CPU:0"), c_add_scale));
972   auto c_sub_scale_op = ops::XRTCompile(
973       root, ops::Const(root.WithDevice("/device:CPU:0"), c_sub_scale));
974   TF_ASSERT_OK(root.status());
975 
976   XrtClientSession session(root);
977   std::vector<Tensor> outputs;
978   TF_EXPECT_OK(
979       session.Run({c_add_scale_op.handle, c_sub_scale_op.handle}, &outputs));
980   EXPECT_EQ(outputs.size(), 2);
981 
982   int64 c_add_scale_handle = outputs[0].scalar<int64>()();
983   int64 c_sub_scale_handle = outputs[1].scalar<int64>()();
984 
985   xrt::XLAAllocation p0;
986   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
987   xrt::XLAAllocation p1;
988   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
989 
990   auto p0_handle_op = ops::XRTAllocate(
991       root,
992       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString()));
993   auto p1_handle_op = ops::XRTAllocate(
994       root,
995       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()));
996 
997   TF_EXPECT_OK(session.Run({p0_handle_op, p1_handle_op}, &outputs));
998   EXPECT_EQ(outputs.size(), 2);
999 
1000   int64 p0_handle = outputs[0].scalar<int64>()();
1001   int64 p1_handle = outputs[1].scalar<int64>()();
1002 
1003   xrt::XRTChainedExecuteConfig config;
1004   auto config_const =
1005       ops::Const(root.WithDevice("/device:CPU:0"), config.SerializeAsString());
1006 
1007   xrt::XRTChainedExecutePlan plan;
1008   xrt::XRTChainedExecuteOp* op;
1009   xrt::XRTChainedExecuteOp::Input* input;
1010   xrt::XRTChainedExecuteOp::Output* output;
1011 
1012   // Index 0
1013   op = plan.add_ops();
1014   op->set_data_handle(p0_handle);
1015 
1016   // Index 1
1017   op = plan.add_ops();
1018   op->set_data_handle(p1_handle);
1019 
1020   // Index 2
1021   op = plan.add_ops();
1022   op->set_computation_handle(c_add_scale_handle);
1023   input = op->add_inputs();
1024   input->set_op_index(0);
1025   input = op->add_inputs();
1026   input->set_op_index(1);
1027 
1028   // Index 3
1029   op = plan.add_ops();
1030   op->set_computation_handle(c_sub_scale_handle);
1031   input = op->add_inputs();
1032   input->set_op_index(0);
1033   input = op->add_inputs();
1034   input->set_op_index(1);
1035 
1036   // Index 4
1037   op = plan.add_ops();
1038   op->set_computation_handle(c_add_scale_handle);
1039   input = op->add_inputs();
1040   input->set_op_index(2);
1041   input = op->add_inputs();
1042   input->set_op_index(3);
1043   output = op->add_outputs();
1044   output->set_result_index(0);
1045 
1046   auto plan_const =
1047       ops::Const(root.WithDevice("/device:CPU:0"), plan.SerializeAsString());
1048   auto result = ops::XRTExecuteChained(root, plan_const, config_const);
1049   TF_ASSERT_OK(root.status());
1050 
1051   TF_EXPECT_OK(session.Run({result}, &outputs));
1052   EXPECT_EQ(outputs.size(), 1);
1053 
1054   auto handles_vec = outputs[0].vec<int64>();
1055   EXPECT_EQ(handles_vec.size(), 1);
1056 
1057   auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(0)));
1058   TF_ASSERT_OK(root.status());
1059 
1060   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1061   EXPECT_EQ(outputs.size(), 1);
1062 
1063   xla::LiteralProto response;
1064   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1065 
1066   auto expected = xla::LiteralUtil::CreateR1<float>({-150.0f, -36.0f});
1067   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1068 }
1069 
TEST(RawApiTest,CompileAndExecute)1070 TEST(RawApiTest, CompileAndExecute) {
1071   xrt::XLAAllocation p0;
1072   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1073   xrt::XLAAllocation p1;
1074   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1075 
1076   xrt::XLAComputation c;
1077   auto config = c.mutable_config();
1078   auto shapes = config->mutable_program_shape();
1079   *shapes->add_parameters() =
1080       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1081   *shapes->add_parameters() =
1082       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1083   *shapes->mutable_result() =
1084       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1085   StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1086 
1087   xrt::XRTExecutionConfig e;
1088   e.set_release_input_handles(true);
1089   e.set_release_compilation_handle(true);
1090 
1091   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1092   auto e_config =
1093       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1094   auto computation =
1095       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1096   auto c_handle = ops::XRTCompile(root, computation);
1097   auto p0_value =
1098       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1099   auto p0_handle = ops::XRTAllocate(root, p0_value);
1100   auto p1_value =
1101       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1102   auto p1_handle = ops::XRTAllocate(root, p1_value);
1103   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1104                                 {Output(p0_handle), Output(p1_handle)});
1105   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1106   TF_ASSERT_OK(root.status());
1107 
1108   XrtClientSession session(root);
1109   std::vector<Tensor> outputs;
1110   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1111 
1112   xla::LiteralProto response;
1113   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1114 
1115   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1116   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1117 
1118   xla::ProgramShapeProto program_shape;
1119   EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1120   EXPECT_EQ(program_shape.parameters_size(), 2);
1121 }
1122 
TEST(RawApiTest,DynamicR1Test)1123 TEST(RawApiTest, DynamicR1Test) {
1124   xrt::XLAAllocation p0;
1125   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1126   xrt::XLAAllocation p1;
1127   *p1.mutable_value() = FloatVector({1.0f, -1.0f, 2.5f, 1.17f});
1128   xrt::XLAAllocation p2;
1129   *p2.mutable_value() = CreateR0<xla::int32>(2);
1130 
1131   xrt::XLAComputation c;
1132   auto config = c.mutable_config();
1133   auto shapes = config->mutable_program_shape();
1134   *shapes->add_parameters() =
1135       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1136   *shapes->add_parameters() =
1137       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1138   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1139   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1140   dyn_shape.set_dynamic_dimension(0, true);
1141   *shapes->mutable_result() = dyn_shape.ToProto();
1142   StoreComputationSnapshot(ReturnDynamicR1(), c.mutable_hlo_snapshot());
1143 
1144   xrt::XRTExecutionConfig e;
1145   e.set_release_input_handles(true);
1146   e.set_release_compilation_handle(true);
1147 
1148   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1149   Scope cpu_root = root.WithDevice("/device:CPU:0");
1150   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1151   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1152   auto c_handle = ops::XRTCompile(root, computation);
1153   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1154   auto p0_handle = ops::XRTAllocate(root, p0_value);
1155   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1156   auto p1_handle = ops::XRTAllocate(root, p1_value);
1157   auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1158   auto p2_handle = ops::XRTAllocate(root, p2_value);
1159   auto result = ops::XRTExecute(
1160       root, c_handle.handle, e_config,
1161       {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1162   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1163   TF_ASSERT_OK(root.status());
1164 
1165   XrtClientSession session(root);
1166   std::vector<Tensor> outputs;
1167   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1168 
1169   xla::LiteralProto response;
1170   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1171   auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1172   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1173 }
1174 
TEST(RawApiTest,DynamicR2Test)1175 TEST(RawApiTest, DynamicR2Test) {
1176   xrt::XLAAllocation p0;
1177   *p0.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, 2.0f, 0.5f, -1.0f},
1178                                                     {1.5f, 2.5f, 3.0f, -2.0f}})
1179                             .ToProto();
1180   xrt::XLAAllocation p1;
1181   *p1.mutable_value() = xla::LiteralUtil::CreateR2({{1.0f, -1.0f, 2.5f, 1.17f},
1182                                                     {1.2f, -1.6f, 2.8f, 1.24f}})
1183                             .ToProto();
1184   xrt::XLAAllocation p2;
1185   *p2.mutable_value() = CreateR0<xla::int32>(2);
1186 
1187   xrt::XLAComputation c;
1188   auto config = c.mutable_config();
1189   auto shapes = config->mutable_program_shape();
1190   *shapes->add_parameters() =
1191       xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1192   *shapes->add_parameters() =
1193       xla::ShapeUtil::MakeShape(xla::F32, {2, 4}).ToProto();
1194   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1195   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1196   dyn_shape.set_dynamic_dimension(0, true);
1197   dyn_shape.set_dynamic_dimension(1, true);
1198   *shapes->mutable_result() = dyn_shape.ToProto();
1199   StoreComputationSnapshot(ReturnDynamicR2(), c.mutable_hlo_snapshot());
1200 
1201   xrt::XRTExecutionConfig e;
1202   e.set_release_input_handles(true);
1203   e.set_release_compilation_handle(true);
1204 
1205   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1206   Scope cpu_root = root.WithDevice("/device:CPU:0");
1207   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1208   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1209   auto c_handle = ops::XRTCompile(root, computation);
1210   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1211   auto p0_handle = ops::XRTAllocate(root, p0_value);
1212   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1213   auto p1_handle = ops::XRTAllocate(root, p1_value);
1214   auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1215   auto p2_handle = ops::XRTAllocate(root, p2_value);
1216   auto result = ops::XRTExecute(
1217       root, c_handle.handle, e_config,
1218       {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1219   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1220   TF_ASSERT_OK(root.status());
1221 
1222   XrtClientSession session(root);
1223   std::vector<Tensor> outputs;
1224   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1225 
1226   xla::LiteralProto response;
1227   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1228   auto expected = xla::LiteralUtil::CreateR2<float>({{2.0f, 1.0f}, {2.7, 0.9}});
1229   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1230 }
1231 
TEST(RawApiTest,DynamicR1TupleTest)1232 TEST(RawApiTest, DynamicR1TupleTest) {
1233   xrt::XLAAllocation p0;
1234   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f, -1.0f});
1235   xrt::XLAAllocation p1;
1236   *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f, 1.0f});
1237   xrt::XLAAllocation p2;
1238   *p2.mutable_value() = CreateR0<xla::int32>(2);
1239 
1240   xrt::XLAComputation c;
1241   auto config = c.mutable_config();
1242   auto shapes = config->mutable_program_shape();
1243   *shapes->add_parameters() =
1244       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1245   *shapes->add_parameters() =
1246       xla::ShapeUtil::MakeShape(xla::F32, {4}).ToProto();
1247   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
1248   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1249   dyn_shape.set_dynamic_dimension(0, true);
1250   *shapes->mutable_result() =
1251       xla::ShapeUtil::MakeTupleShape(
1252           {dyn_shape, xla::ShapeUtil::MakeShape(xla::F32, {4}), dyn_shape})
1253           .ToProto();
1254   StoreComputationSnapshot(ReturnDynamicR1Tuple(), c.mutable_hlo_snapshot());
1255 
1256   xrt::XRTExecutionConfig e;
1257   e.set_release_input_handles(true);
1258   e.set_release_compilation_handle(true);
1259 
1260   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1261   Scope cpu_root = root.WithDevice("/device:CPU:0");
1262   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1263   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1264   auto c_handle = ops::XRTCompile(root, computation);
1265   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1266   auto p0_handle = ops::XRTAllocate(root, p0_value);
1267   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1268   auto p1_handle = ops::XRTAllocate(root, p1_value);
1269   auto p2_value = ops::Const(cpu_root, p2.SerializeAsString());
1270   auto p2_handle = ops::XRTAllocate(root, p2_value);
1271   auto result = ops::XRTExecute(
1272       root, c_handle.handle, e_config,
1273       {Output(p0_handle), Output(p1_handle), Output(p2_handle)});
1274   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1275   TF_ASSERT_OK(root.status());
1276 
1277   XrtClientSession session(root);
1278   std::vector<Tensor> outputs;
1279   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1280 
1281   xla::LiteralProto response;
1282   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1283 
1284   auto expected0 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f});
1285   auto expected1 = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f, 0.0f});
1286   auto expected2 = xla::LiteralUtil::CreateR1<float>({0.0f, 3.0f, 1.0f});
1287   auto expected =
1288       xla::LiteralUtil::MakeTuple({&expected0, &expected1, &expected2});
1289   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1290 }
1291 
TEST(RawApiTest,AcceptDynamicR1TupleTest)1292 TEST(RawApiTest, AcceptDynamicR1TupleTest) {
1293   xrt::XLAAllocation p0;
1294   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1295   xrt::XLAAllocation p1;
1296   *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1297 
1298   xrt::XLATupleNode tuple_desc;
1299   auto subdesc_10 = tuple_desc.add_tuples();
1300   auto subdesc_11 = tuple_desc.add_tuples();
1301   subdesc_10->set_input_index(0);
1302   subdesc_10->set_release_input_handle(true);
1303   subdesc_11->set_input_index(1);
1304   subdesc_11->set_release_input_handle(true);
1305 
1306   xrt::XLAComputation c;
1307   auto config = c.mutable_config();
1308   auto shapes = config->mutable_program_shape();
1309   xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1310   dyn_input_shape.set_dynamic_dimension(0, true);
1311   xla::Shape dyn_tuple_shape =
1312       xla::ShapeUtil::MakeTupleShape({dyn_input_shape, dyn_input_shape});
1313   *shapes->add_parameters() = dyn_tuple_shape.ToProto();
1314   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1315   dyn_shape.set_dynamic_dimension(0, true);
1316   *shapes->mutable_result() = dyn_shape.ToProto();
1317   StoreComputationSnapshot(AcceptDynamicR1Tuple(), c.mutable_hlo_snapshot());
1318 
1319   xrt::XRTExecutionConfig e;
1320   e.set_release_input_handles(true);
1321   e.set_release_compilation_handle(true);
1322 
1323   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1324   Scope cpu_root = root.WithDevice("/device:CPU:0");
1325   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1326   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1327   auto c_handle = ops::XRTCompile(root, computation);
1328   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1329   auto p0_handle = ops::XRTAllocate(root, p0_value);
1330   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1331   auto p1_handle = ops::XRTAllocate(root, p1_value);
1332 
1333   auto tuple_0 = ops::Const(root.WithDevice("/device:CPU:0"),
1334                             tuple_desc.SerializeAsString());
1335   auto t0_handle = ops::XRTMakeTuple(
1336       root, tuple_0,
1337       {static_cast<Output>(p0_handle), static_cast<Output>(p1_handle)});
1338   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1339                                 {static_cast<Output>(t0_handle)});
1340   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1341   TF_ASSERT_OK(root.status());
1342 
1343   XrtClientSession session(root);
1344   std::vector<Tensor> outputs;
1345   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1346 
1347   xla::LiteralProto response;
1348   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1349 
1350   auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1351   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1352 }
1353 
TEST(RawApiTest,AcceptDynamicR1Test)1354 TEST(RawApiTest, AcceptDynamicR1Test) {
1355   xrt::XLAAllocation p0;
1356   *p0.mutable_value() = FloatVector({1.0f, 2.0f, 0.5f});
1357   xrt::XLAAllocation p1;
1358   *p1.mutable_value() = FloatVector({1.0f, -1.0f, -0.5f});
1359 
1360   xrt::XLAComputation c;
1361   auto config = c.mutable_config();
1362   auto shapes = config->mutable_program_shape();
1363   xla::Shape dyn_input_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1364   dyn_input_shape.set_dynamic_dimension(0, true);
1365   *shapes->add_parameters() = dyn_input_shape.ToProto();
1366   *shapes->add_parameters() = dyn_input_shape.ToProto();
1367   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {4});
1368   dyn_shape.set_dynamic_dimension(0, true);
1369   *shapes->mutable_result() = dyn_shape.ToProto();
1370   StoreComputationSnapshot(AcceptDynamicR1(), c.mutable_hlo_snapshot());
1371 
1372   xrt::XRTExecutionConfig e;
1373   e.set_release_input_handles(true);
1374   e.set_release_compilation_handle(true);
1375 
1376   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1377   Scope cpu_root = root.WithDevice("/device:CPU:0");
1378   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1379   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1380   auto c_handle = ops::XRTCompile(root, computation);
1381   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1382   auto allocate_op_0 = ops::XRTAllocate(root, p0_value);
1383   auto p1_value = ops::Const(cpu_root, p1.SerializeAsString());
1384   auto allocate_op_1 = ops::XRTAllocate(root, p1_value);
1385   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1386                                 {Output(allocate_op_0), Output(allocate_op_1)});
1387   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1388   TF_ASSERT_OK(root.status());
1389 
1390   XrtClientSession session(root);
1391   std::vector<Tensor> outputs;
1392   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1393 
1394   xla::LiteralProto response;
1395   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1396 
1397   auto expected = xla::LiteralUtil::CreateR1<float>({2.0f, 1.0f, 0.0f});
1398   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1399 }
1400 
TEST(RawApiTest,AcceptDynamicR2Test)1401 TEST(RawApiTest, AcceptDynamicR2Test) {
1402   xrt::XLAAllocation p0;
1403   *p0.mutable_value() =
1404       xla::LiteralUtil::CreateR2({{-1.0f, 2.0f, 3.0f}, {-4.0f, -5.0f, 6.0f}})
1405           .ToProto();
1406 
1407   xrt::XLAComputation c;
1408   auto config = c.mutable_config();
1409   auto shapes = config->mutable_program_shape();
1410   // Compile time expects ascending layout.
1411   xla::Shape dyn_shape = xla::ShapeUtil::MakeShape(xla::F32, {2, 4});
1412   dyn_shape.set_dynamic_dimension(1, true);
1413   *shapes->add_parameters() = dyn_shape.ToProto();
1414 
1415   *shapes->mutable_result() = dyn_shape.ToProto();
1416   StoreComputationSnapshot(AcceptDynamicR2(), c.mutable_hlo_snapshot());
1417 
1418   xrt::XRTExecutionConfig e;
1419   e.set_release_input_handles(true);
1420   e.set_release_compilation_handle(true);
1421 
1422   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1423   Scope cpu_root = root.WithDevice("/device:CPU:0");
1424   auto e_config = ops::Const(cpu_root, e.SerializeAsString());
1425   auto computation = ops::Const(cpu_root, c.SerializeAsString());
1426   auto c_handle = ops::XRTCompile(root, computation);
1427   auto p0_value = ops::Const(cpu_root, p0.SerializeAsString());
1428   auto p0_handle = ops::XRTAllocate(root, p0_value);
1429   auto result =
1430       ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle)});
1431   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1432   TF_ASSERT_OK(root.status());
1433 
1434   XrtClientSession session(root);
1435   std::vector<Tensor> outputs;
1436   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1437 
1438   xla::LiteralProto response;
1439   EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<tstring>()()));
1440 
1441   auto expected = xla::LiteralUtil::CreateR2<float>(
1442       {{1.0f, -2.0f, -3.0f}, {4.0f, 5.0f, -6.0f}});
1443   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1444 }
1445 
TEST(RawApiTest,CompileAndExecuteWithArgumentVector)1446 TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
1447   xrt::XLAAllocation p0;
1448   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1449   xrt::XLAAllocation p1;
1450   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1451 
1452   xrt::XLAComputation c;
1453   auto config = c.mutable_config();
1454   auto shapes = config->mutable_program_shape();
1455   *shapes->add_parameters() =
1456       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1457   *shapes->add_parameters() =
1458       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1459   *shapes->mutable_result() =
1460       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1461   StoreComputationSnapshot(AddAndScale(), c.mutable_hlo_snapshot());
1462 
1463   xrt::XRTExecutionConfig e;
1464   e.set_release_input_handles(true);
1465   e.set_release_compilation_handle(true);
1466 
1467   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1468   auto e_config =
1469       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1470   auto computation =
1471       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1472   auto c_handle = ops::XRTCompile(root, computation);
1473   auto p0_value =
1474       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1475   auto p0_handle = ops::XRTAllocate(root, p0_value);
1476   auto p1_value =
1477       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1478   auto p1_handle = ops::XRTAllocate(root, p1_value);
1479   auto packed_args = ops::Stack(root.WithDevice("/device:CPU:0"),
1480                                 {Output(p0_handle), Output(p1_handle)});
1481   auto result =
1482       ops::XRTExecute(root, c_handle.handle, e_config, {Output(packed_args)});
1483   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1484   TF_ASSERT_OK(root.status());
1485 
1486   XrtClientSession session(root);
1487   std::vector<Tensor> outputs;
1488   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
1489 
1490   xla::LiteralProto response;
1491   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1492 
1493   auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
1494   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1495 
1496   xla::ProgramShapeProto program_shape;
1497   EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
1498   EXPECT_EQ(program_shape.parameters_size(), 2);
1499 }
1500 
TEST(RawApiTest,CompileWithXlaReturnShapes)1501 TEST(RawApiTest, CompileWithXlaReturnShapes) {
1502   xla::XlaBuilder builder("XrtXlaShapes");
1503   auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
1504   auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
1505   // Clear layouts to signal XLA we are ready to get whatever are coming out of
1506   // the compilation process.
1507   xla::LayoutUtil::ClearLayout(&input_shape);
1508   xla::LayoutUtil::ClearLayout(&kernel_shape);
1509   auto param_shape =
1510       xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
1511   auto param = xla::Parameter(&builder, 0, param_shape, "param");
1512   auto input = xla::GetTupleElement(param, 0);
1513   auto kernel = xla::GetTupleElement(param, 1);
1514   xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
1515   TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
1516 
1517   auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
1518   // Clear the result shape layout to tell XLA we are accepting whatever are
1519   // coming out of the compilation process.
1520   xla::LayoutUtil::ClearLayout(&result_shape);
1521 
1522   xrt::XLAComputation c;
1523   auto config = c.mutable_config();
1524   auto shapes = config->mutable_program_shape();
1525   *shapes->add_parameters() = param_shape.ToProto();
1526   *shapes->mutable_result() = result_shape.ToProto();
1527   StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
1528 
1529   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1530   auto computation =
1531       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1532   auto c_handle = ops::XRTCompile(root, computation);
1533   auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
1534   TF_ASSERT_OK(root.status());
1535 
1536   XrtClientSession session(root);
1537   std::vector<Tensor> outputs;
1538   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {c_handle.program_shape},
1539                            {release}, &outputs));
1540 
1541   xla::ProgramShapeProto program_shape_proto;
1542   EXPECT_TRUE(
1543       ParseFromTString(outputs[0].vec<tstring>()(0), &program_shape_proto));
1544   xla::ProgramShape program_shape(program_shape_proto);
1545   EXPECT_EQ(program_shape.parameters_size(), 1);
1546 
1547   VLOG(2) << "Param: "
1548           << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
1549   VLOG(2) << "Result: "
1550           << xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
1551 
1552   xla::ProgramShape xla_program_shape =
1553       XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
1554   EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1555       xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
1556       xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
1557           .layout()));
1558   EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1559       xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
1560       xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
1561           .layout()));
1562   EXPECT_TRUE(xla::Layout::Equal().MinorToMajorOnly()(
1563       program_shape.result().layout(), xla_program_shape.result().layout()));
1564 }
1565 
TEST(RawApiTest,DotGeneralWithLayoutTest)1566 TEST(RawApiTest, DotGeneralWithLayoutTest) {
1567   auto layout = xla::LayoutUtil::MakeLayout({0, 1});
1568 
1569   xrt::XLAAllocation p0;
1570   *p0.mutable_value() = FloatMatrix({{1.0f, 2.0f}, {3.0f, 4.0f}}, layout);
1571   xrt::XLAAllocation p1;
1572   *p1.mutable_value() = FloatMatrix({{8.0f}, {5.0f}}, layout);
1573 
1574   xrt::XLAComputation c;
1575   auto config = c.mutable_config();
1576   auto shapes = config->mutable_program_shape();
1577   *shapes->add_parameters() =
1578       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 2}, {0, 1}).ToProto();
1579   *shapes->add_parameters() =
1580       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1581   *shapes->mutable_result() =
1582       xla::ShapeUtil::MakeShapeWithLayout(xla::F32, {2, 1}, {0, 1}).ToProto();
1583   StoreComputationSnapshot(Dot(), c.mutable_hlo_snapshot());
1584 
1585   xrt::XRTExecutionConfig e;
1586   e.set_release_input_handles(true);
1587   e.set_release_compilation_handle(true);
1588 
1589   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1590   auto e_config =
1591       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1592   auto computation =
1593       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1594   auto c_handle = ops::XRTCompile(root, computation);
1595   auto p0_value =
1596       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1597   auto p0_handle = ops::XRTAllocate(root, p0_value);
1598   auto p1_value =
1599       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1600   auto p1_handle = ops::XRTAllocate(root, p1_value);
1601   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1602                                 {Output(p0_handle), Output(p1_handle)});
1603   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1604   TF_ASSERT_OK(root.status());
1605 
1606   XrtClientSession session(root);
1607   std::vector<Tensor> outputs;
1608   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1609 
1610   xla::LiteralProto response;
1611   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1612 
1613   auto expected =
1614       xla::LiteralUtil::CreateR2WithLayout<float>({{18.0f}, {44.0f}}, layout);
1615 
1616   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1617 }
1618 
TEST(RawApiTest,CompileAndExecuteZeroArg)1619 TEST(RawApiTest, CompileAndExecuteZeroArg) {
1620   xrt::XLAComputation c;
1621   auto config = c.mutable_config();
1622   auto shapes = config->mutable_program_shape();
1623   *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1624 
1625   xrt::XRTExecutionConfig e;
1626   e.set_release_input_handles(true);
1627   e.set_release_compilation_handle(true);
1628   StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
1629 
1630   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1631   auto e_config =
1632       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1633   auto computation =
1634       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1635   auto c_handle = ops::XRTCompile(root, computation);
1636   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1637                                 std::initializer_list<Input>({}));
1638   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1639   TF_ASSERT_OK(root.status());
1640 
1641   XrtClientSession session(root);
1642   std::vector<Tensor> outputs;
1643   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1644 
1645   xla::LiteralProto response;
1646   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1647 
1648   auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
1649   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1650 }
1651 
TEST(RawApiTest,CompileAndExecuteReturnTuple)1652 TEST(RawApiTest, CompileAndExecuteReturnTuple) {
1653   xrt::XLAAllocation p0;
1654   *p0.mutable_value() = FloatVector({1.0f, 2.0f});
1655   xrt::XLAAllocation p1;
1656   *p1.mutable_value() = FloatVector({8.0f, 5.0f});
1657 
1658   xrt::XLAComputation c;
1659   auto config = c.mutable_config();
1660   auto shapes = config->mutable_program_shape();
1661   *shapes->add_parameters() =
1662       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1663   *shapes->add_parameters() =
1664       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1665   *shapes->mutable_result() =
1666       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1667           .ToProto();
1668   StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1669 
1670   xrt::XRTExecutionConfig e;
1671   e.set_release_input_handles(true);
1672   e.set_release_compilation_handle(true);
1673 
1674   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1675   auto e_config =
1676       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1677   auto computation =
1678       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1679   auto c_handle = ops::XRTCompile(root, computation);
1680   auto p0_value =
1681       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1682   auto p0_handle = ops::XRTAllocate(root, p0_value);
1683   auto p1_value =
1684       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1685   auto p1_handle = ops::XRTAllocate(root, p1_value);
1686   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1687                                 {Output(p0_handle), Output(p1_handle)});
1688   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
1689   TF_ASSERT_OK(root.status());
1690 
1691   XrtClientSession session(root);
1692   std::vector<Tensor> outputs;
1693   TF_EXPECT_OK(session.Run({read_back}, &outputs));
1694 
1695   xla::LiteralProto response;
1696   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
1697 
1698   auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
1699   auto expected = xla::LiteralUtil::MakeTuple({&sum});
1700   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1701 }
1702 
TEST(RawApiTest,CompileAndExecuteReturnExplodedTuple)1703 TEST(RawApiTest, CompileAndExecuteReturnExplodedTuple) {
1704   xrt::XLAAllocation p0;
1705   *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(12.0f).ToProto();
1706 
1707   xrt::XLAAllocation p1;
1708   *p1.mutable_value() = xla::LiteralUtil::CreateR0<float>(3.0f).ToProto();
1709 
1710   xrt::XLAComputation c;
1711   auto config = c.mutable_config();
1712   auto shapes = config->mutable_program_shape();
1713   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1714   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::F32, {}).ToProto();
1715   *shapes->mutable_result() =
1716       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {}),
1717                                       xla::ShapeUtil::MakeShape(xla::F32, {})})
1718           .ToProto();
1719   StoreComputationSnapshot(AddAndSubTuple(), c.mutable_hlo_snapshot());
1720 
1721   xrt::XRTExecutionConfig e;
1722   e.set_release_input_handles(true);
1723   e.set_release_compilation_handle(true);
1724   e.set_return_exploded_tuple(true);
1725 
1726   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1727   auto e_config =
1728       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1729   auto computation =
1730       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1731   auto c_handle = ops::XRTCompile(root, computation);
1732   auto p0_value =
1733       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
1734   auto p0_handle = ops::XRTAllocate(root, p0_value);
1735   auto p1_value =
1736       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
1737   auto p1_handle = ops::XRTAllocate(root, p1_value);
1738   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
1739                                 {Output(p0_handle), Output(p1_handle)});
1740   TF_ASSERT_OK(root.status());
1741 
1742   XrtClientSession session(root);
1743   std::vector<Tensor> outputs;
1744   TF_EXPECT_OK(session.Run({result}, &outputs));
1745   EXPECT_EQ(outputs.size(), 1);
1746 
1747   auto handles_vec = outputs.front().vec<int64>();
1748   EXPECT_EQ(handles_vec.size(), 2);
1749 
1750   const float kResults[2] = {15.0f, 9.0f};
1751   for (int64 i = 0; i < handles_vec.size(); ++i) {
1752     auto read_back = ops::XRTReadLiteralAndRelease(root, Input(handles_vec(i)));
1753     std::vector<Tensor> voutputs;
1754     TF_EXPECT_OK(session.Run({read_back}, &voutputs));
1755     EXPECT_EQ(voutputs.size(), 1);
1756 
1757     xla::LiteralProto response;
1758     EXPECT_TRUE(ParseFromTString(voutputs[0].scalar<tstring>()(), &response));
1759 
1760     auto expected = xla::LiteralUtil::CreateR0<float>(kResults[i]);
1761     EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
1762   }
1763 }
1764 
TEST(RawApiTest,LeakCompilationReference)1765 TEST(RawApiTest, LeakCompilationReference) {
1766   xrt::XLAComputation c;
1767   auto config = c.mutable_config();
1768   auto shapes = config->mutable_program_shape();
1769   *shapes->add_parameters() =
1770       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1771   *shapes->add_parameters() =
1772       xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
1773   *shapes->mutable_result() =
1774       xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
1775           .ToProto();
1776   StoreComputationSnapshot(AddAndTuple(), c.mutable_hlo_snapshot());
1777 
1778   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1779   auto computation =
1780       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1781   auto c_handle = ops::XRTCompile(root, computation);
1782   TF_ASSERT_OK(root.status());
1783 
1784   XrtClientSession session(root);
1785   std::vector<Tensor> outputs;
1786   TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
1787 }
1788 
TEST(RawApiTest,CompileAndExecuteWithReusedBuffers)1789 TEST(RawApiTest, CompileAndExecuteWithReusedBuffers) {
1790   xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::F32, {2});
1791   xla::Shape shape =
1792       xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1793   xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1794       {element_shape, element_shape, element_shape, element_shape});
1795   xla::XlaBuilder builder("ReuseBuffer");
1796   auto param = xla::Parameter(&builder, 0, shape, "param");
1797   auto p0 = xla::GetTupleElement(param, 0);
1798   auto p1 = xla::GetTupleElement(param, 1);
1799   auto add = xla::Add(p0, p1);
1800   auto sub = xla::Sub(p0, p1);
1801   xla::Tuple(&builder, {add, sub, p0, p1});
1802 
1803   // Flip the tuple literals in the input handle.
1804   builder.SetUpAlias({1}, 0, {0});
1805   builder.SetUpAlias({0}, 0, {1});
1806 
1807   auto computation = builder.Build().ValueOrDie();
1808 
1809   auto literal0 = xla::LiteralUtil::CreateR1<float>({1.0f, 2.0f});
1810   auto literal1 = xla::LiteralUtil::CreateR1<float>({5.0f, 9.0f});
1811   auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1812 
1813   xrt::XLAAllocation param_alloc;
1814   *param_alloc.mutable_value() = literal.ToProto();
1815 
1816   xrt::XLAComputation c;
1817   auto config = c.mutable_config();
1818   auto shapes = config->mutable_program_shape();
1819   *shapes->add_parameters() = shape.ToProto();
1820   *shapes->mutable_result() = return_shape.ToProto();
1821   StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1822 
1823   xrt::XRTExecutionConfig e;
1824   e.set_release_input_handles(false);
1825   e.set_release_compilation_handle(true);
1826 
1827   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1828   XrtClientSession session(root);
1829   auto e_config =
1830       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1831   auto c_data =
1832       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1833   auto c_handle = ops::XRTCompile(root, c_data);
1834   auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1835                                 param_alloc.SerializeAsString());
1836   auto param_handle = ops::XRTAllocate(root, param_value);
1837   TF_ASSERT_OK(root.status());
1838 
1839   std::vector<Tensor> outputs;
1840   TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1841 
1842   int64 alloc_handle = outputs[0].scalar<int64>()();
1843 
1844   // Note that we release the result handle immediately, but since we aliased
1845   // the output buffers onto the input allocation ones (held in alloc_handle),
1846   // we can fetch the result from there.
1847   auto result =
1848       ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1849   auto read_back = ops::XRTReadLiteral(root, result);
1850   auto release = ops::XRTReleaseAllocationHandle(
1851       root.WithControlDependencies(read_back), result);
1852   TF_ASSERT_OK(root.status());
1853 
1854   TF_EXPECT_OK(
1855       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1856 
1857   xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1858   auto exec_literal_parts = exec_literal.DecomposeTuple();
1859   ASSERT_EQ(exec_literal_parts.size(), 4);
1860 
1861   EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1862   EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1863 
1864   // Now we read back the original input handle values, which at this point
1865   // should contain the result of the XLA computation.
1866   auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1867   TF_ASSERT_OK(root.status());
1868   auto release_handle = ops::XRTReleaseAllocationHandle(
1869       root.WithControlDependencies(read_handle), Input(alloc_handle));
1870   TF_ASSERT_OK(root.status());
1871 
1872   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1873                            {release_handle}, &outputs));
1874 
1875   xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1876 
1877   auto expected_literal0 = xla::LiteralUtil::CreateR1<float>({6.0f, 11.0f});
1878   auto expected_literal1 = xla::LiteralUtil::CreateR1<float>({-4.0f, -7.0f});
1879   // The first element of the computation returned tuple would be the add
1880   // (expected_literal0), but since we flipped the buffers, the sub
1881   // (expected_literal1) should come first.
1882   auto expected_literal =
1883       xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1884 
1885   EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1886 }
1887 
TEST(RawApiTest,CompileAndExecuteWithReusedBuffersS64)1888 TEST(RawApiTest, CompileAndExecuteWithReusedBuffersS64) {
1889   xla::Shape element_shape = xla::ShapeUtil::MakeShape(xla::S64, {2});
1890   xla::Shape shape =
1891       xla::ShapeUtil::MakeTupleShape({element_shape, element_shape});
1892   xla::Shape return_shape = xla::ShapeUtil::MakeTupleShape(
1893       {element_shape, element_shape, element_shape, element_shape});
1894   xla::XlaBuilder builder("ReuseBuffer");
1895   auto param = xla::Parameter(&builder, 0, shape, "param");
1896   auto p0 = xla::GetTupleElement(param, 0);
1897   auto p1 = xla::GetTupleElement(param, 1);
1898   auto add = xla::Add(p0, p1);
1899   auto sub = xla::Sub(p0, p1);
1900   xla::Tuple(&builder, {add, sub, p0, p1});
1901 
1902   // Flip the tuple literals in the input handle.
1903   builder.SetUpAlias({1}, 0, {0});
1904   builder.SetUpAlias({0}, 0, {1});
1905 
1906   auto computation = builder.Build().ValueOrDie();
1907 
1908   auto literal0 = xla::LiteralUtil::CreateR1<int64>({1, 2});
1909   auto literal1 = xla::LiteralUtil::CreateR1<int64>({5, 9});
1910   auto literal = xla::LiteralUtil::MakeTuple({&literal0, &literal1});
1911 
1912   xrt::XLAAllocation param_alloc;
1913   *param_alloc.mutable_value() = literal.ToProto();
1914 
1915   xrt::XLAComputation c;
1916   auto config = c.mutable_config();
1917   auto shapes = config->mutable_program_shape();
1918   *shapes->add_parameters() = shape.ToProto();
1919   *shapes->mutable_result() = return_shape.ToProto();
1920   StoreComputationSnapshot(computation, c.mutable_hlo_snapshot());
1921 
1922   xrt::XRTExecutionConfig e;
1923   e.set_release_input_handles(false);
1924   e.set_release_compilation_handle(true);
1925 
1926   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
1927   XrtClientSession session(root);
1928   auto e_config =
1929       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
1930   auto c_data =
1931       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
1932   auto c_handle = ops::XRTCompile(root, c_data);
1933   auto param_value = ops::Const(root.WithDevice("/device:CPU:0"),
1934                                 param_alloc.SerializeAsString());
1935   auto param_handle = ops::XRTAllocate(root, param_value);
1936   TF_ASSERT_OK(root.status());
1937 
1938   std::vector<Tensor> outputs;
1939   TF_EXPECT_OK(session.Run({param_handle}, &outputs));
1940 
1941   int64 alloc_handle = outputs[0].scalar<int64>()();
1942 
1943   // Note that we release the result handle immediately, but since we aliased
1944   // the output buffers onto the input allocation ones (held in alloc_handle),
1945   // we can fetch the result from there.
1946   auto result =
1947       ops::XRTExecute(root, c_handle.handle, e_config, {Input(alloc_handle)});
1948   auto read_back = ops::XRTReadLiteral(root, result);
1949   auto release = ops::XRTReleaseAllocationHandle(
1950       root.WithControlDependencies(read_back), result);
1951   TF_ASSERT_OK(root.status());
1952 
1953   TF_EXPECT_OK(
1954       session.Run(ClientSession::FeedType(), {read_back}, {release}, &outputs));
1955 
1956   xla::Literal exec_literal = ReadOutputLiteral(outputs, 0);
1957   auto exec_literal_parts = exec_literal.DecomposeTuple();
1958   ASSERT_EQ(exec_literal_parts.size(), 4);
1959 
1960   EXPECT_TRUE(CompareLiterals(exec_literal_parts[2], literal0));
1961   EXPECT_TRUE(CompareLiterals(exec_literal_parts[3], literal1));
1962 
1963   // Now we read back the original input handle values, which at this point
1964   // should contain the result of the XLA computation.
1965   auto read_handle = ops::XRTReadLiteral(root, Input(alloc_handle));
1966   TF_ASSERT_OK(root.status());
1967   auto release_handle = ops::XRTReleaseAllocationHandle(
1968       root.WithControlDependencies(read_handle), Input(alloc_handle));
1969   TF_ASSERT_OK(root.status());
1970 
1971   TF_EXPECT_OK(session.Run(ClientSession::FeedType(), {read_handle},
1972                            {release_handle}, &outputs));
1973 
1974   xla::Literal return_literal = ReadOutputLiteral(outputs, 0);
1975 
1976   auto expected_literal0 = xla::LiteralUtil::CreateR1<int64>({6, 11});
1977   auto expected_literal1 = xla::LiteralUtil::CreateR1<int64>({-4, -7});
1978   // The first element of the computation returned tuple would be the add
1979   // (expected_literal0), but since we flipped the buffers, the sub
1980   // (expected_literal1) should come first.
1981   auto expected_literal =
1982       xla::LiteralUtil::MakeTuple({&expected_literal1, &expected_literal0});
1983 
1984   EXPECT_TRUE(CompareLiterals(return_literal, expected_literal));
1985 }
1986 
TEST(RawApiTest,CompileAndExecuteWithS64Argument)1987 TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
1988   xrt::XLAAllocation p0;
1989   *p0.mutable_value() = xla::LiteralUtil::CreateR0<int64>(11031965).ToProto();
1990   xrt::XLAAllocation p1;
1991   *p1.mutable_value() = xla::LiteralUtil::CreateR0<int64>(4091934).ToProto();
1992 
1993   xrt::XLAComputation c;
1994   auto config = c.mutable_config();
1995   auto shapes = config->mutable_program_shape();
1996   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1997   *shapes->add_parameters() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1998   *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::S64, {}).ToProto();
1999   StoreComputationSnapshot(AddS64(), c.mutable_hlo_snapshot());
2000 
2001   xrt::XRTExecutionConfig e;
2002   e.set_release_input_handles(true);
2003   e.set_release_compilation_handle(true);
2004   e.set_return_exploded_tuple(true);
2005 
2006   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2007   auto e_config =
2008       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2009   auto computation =
2010       ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
2011   auto c_handle = ops::XRTCompile(root, computation);
2012   auto p0_value =
2013       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2014   auto p0_handle = ops::XRTAllocate(root, p0_value);
2015   auto p1_value =
2016       ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
2017   auto p1_handle = ops::XRTAllocate(root, p1_value);
2018   auto result = ops::XRTExecute(root, c_handle.handle, e_config,
2019                                 {Output(p0_handle), Output(p1_handle)});
2020   auto read_back = ops::XRTReadLiteralAndRelease(root, result);
2021   TF_ASSERT_OK(root.status());
2022 
2023   XrtClientSession session(root);
2024   std::vector<Tensor> outputs;
2025   TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
2026 
2027   xla::LiteralProto response;
2028   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2029 
2030   auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
2031   EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
2032 
2033   xla::ProgramShapeProto program_shape;
2034   EXPECT_TRUE(ParseFromTString(outputs[1].vec<tstring>()(0), &program_shape));
2035   EXPECT_EQ(program_shape.parameters_size(), 2);
2036   EXPECT_TRUE(xla::ShapeUtil::HasPrimitiveType(
2037       xla::Shape(program_shape.result()), xla::S64));
2038 }
2039 
2040 // Tests the XRT device memory compaction API (XRTCompactAllocations).
TEST(RawApiTest,TestDeviceMemoryCompaction)2041 TEST(RawApiTest, TestDeviceMemoryCompaction) {
2042   static const int kNumAllocs = 32;
2043   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2044 
2045   std::vector<xrt::XLAAllocation> allocs(kNumAllocs);
2046   std::vector<Output> handle_outputs;
2047   for (int i = 0; i < kNumAllocs; ++i) {
2048     *allocs[i].mutable_value() = BasedTwoElementTuple(i * 4.0f);
2049     auto value = ops::Const(root.WithDevice("/device:CPU:0"),
2050                             allocs[i].SerializeAsString());
2051     handle_outputs.push_back(ops::XRTAllocate(root, value));
2052   }
2053   TF_ASSERT_OK(root.status());
2054 
2055   XrtClientSession session(root);
2056   std::vector<Tensor> outputs;
2057   TF_EXPECT_OK(session.Run(handle_outputs, &outputs));
2058   EXPECT_EQ(outputs.size(), handle_outputs.size());
2059 
2060   std::vector<int64> handles;
2061   for (auto& output : outputs) {
2062     handles.push_back(output.scalar<int64>()());
2063   }
2064   // Create holes by releasing even allocations.
2065   std::vector<Operation> handle_releases;
2066   for (size_t i = 0; i < handles.size(); i += 2) {
2067     handle_releases.push_back(
2068         ops::XRTReleaseAllocationHandle(root, Input(handles[i])));
2069   }
2070   TF_ASSERT_OK(root.status());
2071 
2072   TF_EXPECT_OK(
2073       session.Run(ClientSession::FeedType(), {}, handle_releases, &outputs));
2074 
2075   // Run the compaction API.
2076   auto compact_op = ops::XRTCompactAllocations(root);
2077   TF_EXPECT_OK(
2078       session.Run(ClientSession::FeedType(), {}, {compact_op}, &outputs));
2079 
2080   // Read back the allocation left at odd indices.
2081   std::vector<Output> read_outputs;
2082   for (size_t i = 1; i < handles.size(); i += 2) {
2083     read_outputs.push_back(ops::XRTReadLiteral(root, Input(handles[i])));
2084   }
2085   TF_ASSERT_OK(root.status());
2086 
2087   TF_EXPECT_OK(session.Run(read_outputs, &outputs));
2088   EXPECT_EQ(outputs.size(), read_outputs.size());
2089 
2090   // Verify that everything got moved correctly and the device data matches what
2091   // we have on record.
2092   for (size_t i = 1, j = 0; i < handles.size(); i += 2, ++j) {
2093     xla::LiteralProto response;
2094     EXPECT_TRUE(ParseFromTString(outputs[j].scalar<tstring>()(), &response));
2095     EXPECT_TRUE(CompareLiteralProtos(allocs[i].value(), response));
2096   }
2097 }
2098 
TEST(RawApiTest,TestDeviceMemorySwap)2099 TEST(RawApiTest, TestDeviceMemorySwap) {
2100   const xla::Shape scalar_shape = xla::ShapeUtil::MakeShape(xla::F32, {});
2101   // 100MB F32 tensor.
2102   const xla::Shape shape = xla::ShapeUtil::MakeShape(xla::F32, {5000, 5000});
2103   const xla::int64 tensor_size = xla::ShapeUtil::ByteSizeOf(shape);
2104   // On CPU we cannot trigger OOM/swap. For TPU and GPU we select 16GB as
2105   // maximum memory.
2106   xla::int64 device_memory_size = 8LL * 1024 * 1024 * 1024;
2107   if (*xla_test_device_ptr == "TPU" || *xla_test_device_ptr == "XLA_GPU") {
2108     device_memory_size = 16LL * 1024 * 1024 * 1024;
2109   }
2110 
2111   xrt::XLAAllocation p0;
2112   *p0.mutable_value() = xla::LiteralUtil::CreateR0<float>(0.90434).ToProto();
2113 
2114   // Create a computation which broadcasts a scalar to a big tensor.
2115   xrt::XLAComputation c_bcast;
2116   {
2117     auto shapes = c_bcast.mutable_config()->mutable_program_shape();
2118     *shapes->add_parameters() = scalar_shape.ToProto();
2119     *shapes->mutable_result() = shape.ToProto();
2120     StoreComputationSnapshot(
2121         BroadcastComputation(scalar_shape, shape.dimensions()),
2122         c_bcast.mutable_hlo_snapshot());
2123   }
2124 
2125   // Create a computation which compares two tensors.
2126   xrt::XLAComputation c_equal;
2127   {
2128     auto shapes = c_equal.mutable_config()->mutable_program_shape();
2129     *shapes->add_parameters() = shape.ToProto();
2130     *shapes->add_parameters() = shape.ToProto();
2131     *shapes->mutable_result() =
2132         xla::ShapeUtil::MakeShape(xla::S32, {}).ToProto();
2133     StoreComputationSnapshot(IsEqualComputation(shape),
2134                              c_equal.mutable_hlo_snapshot());
2135   }
2136 
2137   xrt::XRTExecutionConfig e;
2138   e.set_release_input_handles(false);
2139   e.set_release_compilation_handle(false);
2140 
2141   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2142   XrtClientSession session(root);
2143   auto e_config =
2144       ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
2145   auto bcast_computation =
2146       ops::Const(root.WithDevice("/device:CPU:0"), c_bcast.SerializeAsString());
2147   auto c_bcast_handle = ops::XRTCompile(root, bcast_computation);
2148   auto equal_computation =
2149       ops::Const(root.WithDevice("/device:CPU:0"), c_equal.SerializeAsString());
2150   auto c_equal_handle = ops::XRTCompile(root, equal_computation);
2151   auto p0_value =
2152       ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
2153   auto p0_handle = ops::XRTAllocate(root, p0_value);
2154   std::vector<Tensor> outputs;
2155   std::vector<xla::int64> device_handles;
2156 
2157   // Create more data the device can take using the broadcast computation.
2158   xla::int64 num_tensors = 8 + device_memory_size / tensor_size;
2159   for (xla::int64 i = 0; i < num_tensors; ++i) {
2160     auto result = ops::XRTExecute(root, c_bcast_handle.handle, e_config,
2161                                   {Output(p0_handle)});
2162     TF_ASSERT_OK(root.status());
2163     TF_ASSERT_OK(session.Run({result}, &outputs));
2164     EXPECT_EQ(outputs.size(), 1);
2165     device_handles.push_back(outputs[0].scalar<int64>()());
2166   }
2167 
2168   // Trigger computations on XRT handles to verify the swap-out/swap-in logic,
2169   // by comparing sequential couple of tensors.
2170   auto zero_literal = xla::LiteralUtil::CreateR0<xla::int32>(0);
2171   for (size_t i = 0; i + 1 < device_handles.size(); ++i) {
2172     auto exec_op = ops::XRTExecute(
2173         root, c_equal_handle.handle, e_config,
2174         {Input(device_handles[i]), Input(device_handles[i + 1])});
2175     auto read_back = ops::XRTReadLiteral(root, exec_op);
2176 
2177     TF_ASSERT_OK(root.status());
2178     TF_ASSERT_OK(session.Run({read_back}, &outputs));
2179     EXPECT_EQ(outputs.size(), 1);
2180 
2181     xla::LiteralProto response;
2182     EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &response));
2183     auto literal = xla::Literal::CreateFromProto(response).ValueOrDie();
2184     EXPECT_EQ(literal, zero_literal);
2185   }
2186 }
2187 
TEST(RawApiTest,TestMetricsFetch)2188 TEST(RawApiTest, TestMetricsFetch) {
2189   xrt::XRTMetricsCollect metrics;
2190   metrics.add_metrics_regex("/tensorflow/xrt/.*");
2191 
2192   Scope root = Scope::NewRootScope().WithDevice("/device:CPU:0");
2193   auto metrics_value = ops::Const(root, metrics.SerializeAsString());
2194   Output result = ops::XRTMetricsCollect(root, metrics_value);
2195   TF_ASSERT_OK(root.status());
2196 
2197   ClientSession session(root);
2198   std::vector<Tensor> outputs;
2199   TF_EXPECT_OK(session.Run({result}, &outputs));
2200   ASSERT_EQ(outputs.size(), 1);
2201 
2202   xrt::MetricsReport report;
2203   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &report));
2204   for (auto& metric : report.metrics()) {
2205     EXPECT_EQ(metric.name().compare(0, 16, "/tensorflow/xrt/"), 0);
2206   }
2207 }
2208 
TEST(RawApiTest,TestMemoryInfo)2209 TEST(RawApiTest, TestMemoryInfo) {
2210   Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
2211   Output result = ops::XRTMemoryInfo(root);
2212   TF_ASSERT_OK(root.status());
2213 
2214   ClientSession session(root);
2215   std::vector<Tensor> outputs;
2216   TF_EXPECT_OK(session.Run({result}, &outputs));
2217   ASSERT_EQ(outputs.size(), 1);
2218 
2219   xrt::MemoryInfo mem_info;
2220   EXPECT_TRUE(ParseFromTString(outputs[0].scalar<tstring>()(), &mem_info));
2221   EXPECT_GT(mem_info.kb_total(), 0);
2222   EXPECT_GT(mem_info.kb_free(), 0);
2223 }
2224 
2225 }  // namespace
2226 
2227 }  // namespace tensorflow
2228 
main(int argc,char ** argv)2229 int main(int argc, char** argv) {
2230   tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
2231   tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
2232   std::vector<tensorflow::Flag> flag_list = {
2233       tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
2234                        "Tensorflow device type to use for test, e.g., XLA_CPU"),
2235       tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
2236                        "The XLA platform to select for the device"),
2237   };
2238   tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
2239   const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
2240   if (!parse_result) {
2241     LOG(ERROR) << "\n" << usage;
2242     return 2;
2243   }
2244   testing::InitGoogleTest(&argc, argv);
2245   if (argc > 1) {
2246     LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
2247     return 2;
2248   }
2249   return RUN_ALL_TESTS();
2250 }
2251