1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17 
18 #include <string>
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
26 #include "tensorflow/compiler/xla/window_util.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace xla {
31 namespace {
32 
33 namespace m = ::xla::match;
34 using absl::string_view;
35 
36 struct TestData {
37   string test_name;
38   string module_string;
39 };
40 
TestDataToString(const::testing::TestParamInfo<TestData> & data)41 string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
42   return data.param.test_name;
43 }
44 
45 // For each string below, we check that:
46 //  - we parse it to an HloModule successfully, and
47 //  - the stringification of the resulting HloModule is equal to our original
48 //    string.
CreateTestCases()49 std::vector<TestData> CreateTestCases() {
50   // clang-format off
51   return std::vector<TestData>({
52 // ax + y
53 {
54 "AxpyParam",
55 R"(HloModule axpy_module
56 
57 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
58   %alpha = f32[] parameter(0)
59   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
60   %x = f32[2,4]{1,0} parameter(1)
61   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
62   %y = f32[2,4]{1,0} parameter(2)
63   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
64 }
65 
66 )"
67 },
68 // parameter replication
69 {
70 "ParamReplication",
71 R"(HloModule param_replication_module
72 
73 ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
74   %a = f32[] parameter(0), parameter_replication={true}
75   %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
76   ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
77 }
78 
79 )"
80 },
81 // pred constant
82 {
83 "ConstantPred",
84 R"(HloModule constant_pred_module
85 
86 ENTRY %constant_pred () -> pred[] {
87   ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
88 }
89 
90 )"
91 },
92 // pred array constant
93 {
94 "ConstantPredArray",
95 R"(HloModule module
96 
97 ENTRY %constant_pred_array () -> pred[2,3] {
98   ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
99 }
100 
101 )"
102 },
103 
104 // s32 constant
105 {
106 "ConstantS32",
107 R"(HloModule constant_s32_module
108 
109 ENTRY %constant_s32 () -> s32[] {
110   ROOT %constant = s32[] constant(-42)
111 }
112 
113 )"
114 },
115 // f32 constant, but the value is not a decimal and there is a backend
116 // configuration
117 {
118 "ConstantF32",
119 R"(HloModule ConstantF32_module
120 
121 ENTRY %ConstantF32.v4 () -> f32[] {
122   ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
123 }
124 
125 )"
126 },
127 // f32 constant, rank 1 empty array.
128 {
129 "ConstantF32R1Empty",
130 R"(HloModule ConstantF32Empty_module
131 
132 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
133   ROOT %constant = f32[0]{0} constant({})
134 }
135 
136 )"
137 },
138 // f32 constant, rank 4 empty array.
139 {
140 "ConstantF32R4Empty",
141 R"(HloModule ConstantF32R4Empty_module
142 
143 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
144   ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
145 }
146 
147 )"
148 },
149 // constant 4D
150 {
151 "Constant4D",
152 R"(HloModule Small_3x2x1x1_module
153 
154 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
155   ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
156 }
157 
158 )"
159 },
160 // non-finite constants: nan, inf, -inf
161 {
162 "ConstantNonFinite",
163 R"(HloModule IsFiniteR1F32s_module
164 
165 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
166   %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
167   ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
168 }
169 
170 )"
171 },
172 // constant f16
173 {
174 "ConstantF16",
175 R"(HloModule ConstantF16_module
176 
177 ENTRY %ConstantF16.v4 () -> f16[] {
178   ROOT %constant = f16[] constant(500)
179 }
180 
181 )"
182 },
183 // bf16
184 {
185 "BF16",
186 R"(HloModule BF16
187 
188 ENTRY %BF16.v4 () -> bf16[] {
189   ROOT %constant = bf16[] constant(500)
190 }
191 
192 )"
193 },
194 // constant + constant
195 {
196 "AddConstants",
197 R"(HloModule add_constants_module
198 
199 ENTRY %add_constants () -> f32[] {
200   %constant = f32[] constant(3.14)
201   ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
202 }
203 
204 )"
205 },
206 // tuple constant
207 {
208 "TupleConstant",
209 R"(HloModule TupleConstant_module
210 
211 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
212   ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
213 }
214 
215 )"
216 },
217 // v1 > v2 ? v1 : v2
218 {
219 "SelectR1F32",
220 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
221 
222 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
223   %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
224   %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
225   %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
226   ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
227 }
228 
229 )"
230 },
231 // empty tuple
232 {
233 "EmptyTupleCreate",
234 R"(HloModule EmptyTupleCreate_module
235 
236 ENTRY %EmptyTupleCreate.v1 () -> () {
237   ROOT %tuple = () tuple()
238 }
239 
240 )"
241 },
242 // tuple
243 {
244 "TupleCreate",
245 R"(HloModule TupleCreate_module
246 
247 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
248   %v1 = f32[] parameter(0)
249   %v2 = f32[3]{0} parameter(1)
250   %v3 = f32[2,3]{1,0} parameter(2)
251   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
252 }
253 
254 )"
255 },
256 {
257 "ShardedTupleCreate",
258 R"(HloModule ShardedTupleCreate_module
259 
260 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
261   %v1 = f32[] parameter(0)
262   %v2 = f32[3]{0} parameter(1)
263   %v3 = f32[2,3]{1,0} parameter(2)
264   ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
265 }
266 
267 )"
268 },
269 {
270 "DomainParsing",
271 R"(HloModule DomainParsing_module
272 
273 ENTRY %DomainParsing (v1: f32[]) -> f32[] {
274   %v1 = f32[] parameter(0)
275   ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
276 }
277 
278 )"
279 },
280 // int32 result = 0;
281 // while (result < 5) { result = result + 1; }
282 {
283 "WhileWithScalarS32Result",
284 R"(HloModule WhileWithScalarS32Result_module
285 
286 %body.v3 (prev.1: s32[]) -> s32[] {
287   %constant = s32[] constant(1)
288   %prev.1 = s32[] parameter(0)
289   ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
290 }
291 
292 %condition.v3 (prev.2: s32[]) -> pred[] {
293   %constant.1 = s32[] constant(5)
294   %prev.2 = s32[] parameter(0)
295   ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
296 }
297 
298 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
299   %constant.2 = s32[] constant(0)
300   ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
301 }
302 
303 )"
304 },
305 // send and recv
306 {
307 "SendRecv",
308 R"(HloModule TwoSendRecvBothWayRecvFist_module
309 
310 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
311   %token0 = token[] after-all()
312   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
313   ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
314   %constant = f32[] constant(2.1), sharding={maximal device=0}
315   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
316   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
317 }
318 
319 )"
320 },
321 {
322 "SendRecvWithHostTransfer",
323 R"(HloModule HostTransferSendRecv_module
324 
325 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
326   %token0 = token[] after-all()
327   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
328   ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
329   %constant = f32[] constant(2.1), sharding={maximal device=0}
330   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
331   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
332 }
333 
334 )"
335 },
336 // get-tuple-element
337 {
338 "GetTupleElement",
339 R"(HloModule GetTupleElement_module
340 
341 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
342   %constant = f32[3]{0} constant({1, 2, 3})
343   %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
344   %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
345   ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
346 }
347 
348 )"
349 },
350 // call
351 {
352 "Call",
353 R"(HloModule CallR0F32IdentityScalar_module
354 
355 %Identity.v1 (x: f32[]) -> f32[] {
356   ROOT %x = f32[] parameter(0)
357 }
358 
359 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
360   %constant = f32[] constant(42)
361   ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
362 }
363 
364 )"
365 },
366 // reduce window
367 {
368 "ReduceWindow",
369 R"(HloModule R4UnitWindow_module
370 
371 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
372   %lhs = f32[] parameter(0)
373   %rhs = f32[] parameter(1)
374   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
375 }
376 
377 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
378   %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
379   %constant = f32[] constant(0)
380   ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
381 }
382 
383 )"
384 },
385 // reduce window on scalar
386 {
387 "ReduceWindowScalar",
388 R"(HloModule reduce_window_scalar
389 
390 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
391   %lhs = f32[] parameter(0)
392   %rhs = f32[] parameter(1)
393   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
394 }
395 
396 ENTRY %R4UnitWindowScalar () -> f32[] {
397   %constant = f32[] constant(42)
398   %constant.1 = f32[] constant(1)
399   ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
400 }
401 
402 )"
403 },
404 // convolution
405 {
406 "Convolution",
407 R"(HloModule Convolve1D1Window_0_module
408 
409 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
410   %input = f32[1,2,1]{2,1,0} parameter(0)
411   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
412   %filter = f32[1,1,1]{2,1,0} parameter(1)
413   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
414 }
415 
416 )"
417 },
418 // convolution rank 2
419 {
420 "ConvolutionR2",
421 R"(HloModule ConvolveR2_module
422 
423 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
424   %input = f32[1,2]{1,0} parameter(0)
425   %filter = f32[1,1]{1,0} parameter(1)
426   ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
427 }
428 
429 )"
430 },
431 // convolution backward
432 {
433 "ConvolutionBackward",
434 R"(HloModule ConvolveBackward_module
435 
436 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
437   %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
438   %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
439   ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
440 }
441 
442 )"
443 },
444 // reverse(constant)
445 {
446 "Reverse4D",
447 R"(HloModule Reverse4DFloatArrayOnDim01_module
448 
449 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
450   %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
451   ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
452 }
453 
454 )"
455 },
456 // concat
457 {
458 "Concat",
459 R"(HloModule Concat2x3With2x5_module
460 
461 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
462   %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
463   %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
464   ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
465 }
466 
467 )"
468 },
469 // select and scatter
470 {
471 "SelectAndScatter",
472 R"(HloModule R4F32OverlapSmall_module
473 
474 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
475   %lhs = f32[] parameter(0)
476   %rhs = f32[] parameter(1)
477   ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
478 }
479 
480 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
481   %lhs.1 = f32[] parameter(0)
482   %rhs.1 = f32[] parameter(1)
483   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
484 }
485 
486 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
487   %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
488   %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
489   %constant.2 = f32[] constant(0)
490   ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
491 }
492 
493 )"
494 },
495 // select and scatter on scalar
496 {
497 "SelectAndScatterScalar",
498 R"(HloModule select_and_scatter_scalar
499 
500 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
501   %lhs = f32[] parameter(0)
502   %rhs = f32[] parameter(1)
503   ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
504 }
505 
506 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
507   %lhs.1 = f32[] parameter(0)
508   %rhs.1 = f32[] parameter(1)
509   ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
510 }
511 
512 ENTRY %SelectAndScatterScalar () -> f32[] {
513   %constant = f32[] constant(42)
514   %constant.1 = f32[] constant(1)
515   %constant.2 = f32[] constant(2)
516   ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
517 }
518 
519 )"
520 },
521 // slice
522 {
523 "Slice",
524 R"(HloModule slice_module
525 
526 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
527   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
528   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
529 }
530 
531 )"
532 },
533 // slice, no stride
534 {
535 "SliceNoStride",
536 R"(HloModule Slice3x3x3_To_1x3x3_F32_module
537 
538 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
539   %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
540   ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
541 }
542 
543 )"
544 },
545 // slice R0
546 {
547 "SliceR0",
548 R"(HloModule SliceR0_module
549 
550 ENTRY %SliceR0.v2 () -> s32[] {
551   %constant = s32[] constant(1)
552   ROOT %slice = s32[] slice(s32[] %constant), slice={}
553 }
554 
555 )"
556 },
557 // transpose
558 {
559 "Transpose",
560 R"(HloModule Transpose_module
561 
562 ENTRY %Transpose.v2 () -> s32[1,2,3] {
563   %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
564   ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
565 }
566 
567 )"
568 },
569 {
570 "TransposeC128",
571 R"(HloModule TransposeC128_module
572 
573 ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
574   %input = c128[1,2,3]{2,1,0} parameter(0)
575   ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
576 }
577 
578 )"
579 },
580 // Triangular solve
581 {
582 "TriangularSolve",
583 R"(HloModule TriangularSolve_module
584 
585 ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
586   %a.1 = f32[4,4]{1,0} parameter(0)
587   %b.2 = f32[3,4]{1,0} parameter(1)
588   ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
589 }
590 
591 )"
592 },
593 // Dynamic slice
594 {
595 "DynamicSlice",
596 R"(HloModule DynamicSlice_module
597 
598 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
599   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
600   %constant = s32[1]{0} constant({0})
601   %start_index = s32[1]{0} parameter(1)
602   %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
603   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
604 }
605 
606 )"
607 },
608 // Dynamic slice with scalar indices
609 {
610 "DynamicSliceScalarIndices",
611 R"(HloModule DynamicSlice_module
612 
613 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
614   %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
615   %constant = s32[] constant(0)
616   %start_index = s32[] parameter(1)
617   ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
618 }
619 
620 )"
621 },
622 // Dynamic update slice
623 {
624 "DynamicUpdateSlice",
625 R"(HloModule DynamicSlice_module
626 
627 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
628   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
629   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
630   %start_indices = s32[4]{0} parameter(2)
631   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
632 }
633 
634 )"
635 },
636 // Dynamic update slice with scalar indices
637 {
638 "DynamicUpdateSliceScalarIndex",
639 R"(HloModule DynamicUpdateSlice_module
640 
641 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
642   %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
643   %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
644   %start_index.0 = s32[] parameter(2)
645   %start_index.1 = s32[] parameter(3)
646   %start_index.2 = s32[] parameter(4)
647   %start_index.3 = s32[] parameter(5)
648   ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
649 }
650 
651 )"
652 },
653 // batch norm training
654 {
655 "BatchNormTraining",
656 R"(HloModule BasicTraining_module
657 
658 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
659   %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
660   %constant.1 = f32[2]{0} constant({2, 3})
661   %constant.2 = f32[2]{0} constant({1, 2})
662   ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
663 }
664 
665 )"
666 },
667 // batch norm inference
668 {
669 "BatchNormInference",
670 R"(HloModule BatchNormInference_module
671 
672 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
673   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
674   %offset = f32[2]{0} parameter(1)
675   %scale = f32[2]{0} parameter(2)
676   %mean = f32[2]{0} parameter(3)
677   %variance = f32[2]{0} parameter(4)
678   ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
679 }
680 
681 )"
682 },
683 // batch norm grad
684 {
685 "BatchNormGrad",
686 R"(HloModule BatchNormGrad_module
687 
688 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
689   %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
690   %scale = f32[2]{0} parameter(1)
691   %mean = f32[2]{0} parameter(2)
692   %variance = f32[2]{0} parameter(3)
693   %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
694   ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
695 }
696 
697 )"
698 },
699 // fft
700 {
701 "Fft",
702 R"(HloModule Fft_module
703 
704 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
705   %input = c64[8,32]{1,0} parameter(0)
706   ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
707 }
708 
709 )"
710 },
711 // ifft
712 {
713 "Ifft2d",
714 R"(HloModule Ifft2d_module
715 
716 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
717   %input = c64[5,8,32]{2,1,0} parameter(0)
718   ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
719 }
720 
721 )"
722 },
723 // rfft2d
724 {
725 "Rfft2d",
726 R"(HloModule Rfft2d_module
727 
728 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
729   %input = f32[5,64,32]{2,1,0} parameter(0)
730   ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
731 }
732 
733 )"
734 },
735 // irfft3d
736 {
737 "Irfft3d",
738 R"(HloModule Irfft3d_module
739 
740 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
741   %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
742   ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
743 }
744 
745 )"
746 },
747 // pad
748 {
749 "Pad",
750 R"(HloModule Pad1DS3Array_module
751 
752 ENTRY %Pad1DS3Array.v3 () -> f32[8] {
753   %constant = f32[3]{0} constant({1, 2, 3})
754   %constant.1 = f32[] constant(0.1)
755   ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
756 }
757 
758 )"
759 },
760 // pad has interior
761 {
762 "PadHasInterior",
763 R"(HloModule PadHasInterior_module
764 
765 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
766   %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
767   %constant = f32[] constant(-5.123)
768   ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
769 }
770 
771 )"
772 },
773 // Negative padding
774 {
775 "PadHasNegativePadding",
776 R"(HloModule PadHasNegativePadding_module
777 
778 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] {
779   %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
780   %constant = f32[] constant(-5.123)
781   ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
782 }
783 
784 )"
785 },
786 // fusion
787 {
788 "Fusion",
789 R"(HloModule fusion_module
790 
791 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
792   %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
793   %constant.1.param_1 = f32[2]{0} parameter(1)
794   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
795   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
796 }
797 
798 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
799   %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
800   %constant.1 = f32[2]{0} constant({3.14, 4.25})
801   ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
802 }
803 
804 )"
805 },
806 {
807 "Sparse",
808 R"(HloModule sparse_f32
809 
810 ENTRY %sparse () -> f32[2,3,4] {
811   ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
812 }
813 
814 )"
815 },
816 {
817 "SparseC128",
818 R"(HloModule sparse_c128
819 
820 ENTRY %sparse () -> c128[2,3,4] {
821   ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
822 }
823 
824 )"
825 },
826 {
827 "SparseEmpty",
828 R"(HloModule sparse_f32_empty
829 
830 ENTRY %sparse_f32_empty () -> f32[2,3,4] {
831   ROOT %foo = f32[2,3,4]sparse{10} constant({})
832 }
833 
834 )"
835 },
836 {
837 "SparseR1",
838 R"(HloModule sparse_f32_r1
839 
840 ENTRY %sparse_f32_r1 () -> f32[9] {
841   ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
842 }
843 
844 )"
845 },
846 {
847 "gather",
848 R"(HloModule StringifyGather
849 
850 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
851   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
852   %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
853   ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
854 }
855 
856 )"
857 },
858 {
859 "scatter",
860 R"(HloModule StringifyScatter
861 
862 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
863   %lhs = f32[] parameter(0)
864   %rhs = f32[] parameter(1)
865   ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
866 }
867 
868 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
869   %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
870   %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
871   %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
872   ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
873 }
874 
875 )"
876 },
877 {
878   "ConstantUnsignedNoUnderflow",
879   R"(HloModule ConstantUnsignedNoUnderflow_module
880 
881 ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
882   ROOT %constant = u64[] constant(1)
883 }
884 
885 )"
886 },
887 
888 {
889   "ConstantUnsignedNoOverflow",
890   R"(HloModule ConstantUnsignedNoOverflow_module
891 
892 ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
893   ROOT %constant = u64[] constant(9223372036854775807)
894 }
895 
896 )"
897 },
898 // CustomCallWithLayoutConstraints
899 {
900 "CustomCallWithLayoutConstraints",
901 R"(HloModule CustomCallWithLayoutConstraints
902 
903 ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
904   %p0 = f32[42,2,3]{0,1,2} parameter(0)
905   %p1 = f32[123,4]{0,1} parameter(1)
906   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
907 }
908 
909 )"
910 },
911 // CustomCallWithLayoutConstraintsNoOperands
912 {
913 "CustomCallWithLayoutConstraintsNoOperands",
914 R"(HloModule CustomCallWithLayoutConstraintsNoOperands
915 
916 ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
917   ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
918 }
919 
920 )"
921 },
922 // CustomCallWithLayoutConstraintsTupleShapes
923 {
924 "CustomCallWithLayoutConstraintsTupleShapes",
925 R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
926 
927 ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
928   %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
929   %p1 = f32[123,4]{0,1} parameter(1)
930   ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
931 }
932 
933 )"
934 },
935 // Parse c64 literal
936 {
937 "ParseC64Literal",
938 R"(HloModule ParseC64Literal
939 
940 ENTRY %ParseC64Literal () -> c64[2] {
941   ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
942 }
943 
944 )"
945 },
946 // Parse c128 literal
947 {
948 "ParseC128Literal",
949 R"(HloModule ParseC128Literal
950 
951 ENTRY %ParseC128Literal () -> c128[2] {
952   ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
953 }
954 
955 )"
956 },
957 // Indexed Conditional
958 {
959 "IndexedConditional",
960 R"(HloModule indexed_conditional
961 
962 %Negate (x: f32[]) -> f32[] {
963   %x = f32[] parameter(0)
964   ROOT %negate = f32[] negate(f32[] %x)
965 }
966 
967 %Identity (y: f32[]) -> f32[] {
968   %y = f32[] parameter(0)
969   ROOT %copy = f32[] copy(f32[] %y)
970 }
971 
972 %Floor (z: f32[]) -> f32[] {
973   %z = f32[] parameter(0)
974   ROOT %floor = f32[] floor(f32[] %z)
975 }
976 
977 ENTRY %Parameters1.v4 () -> f32[] {
978   %constant = s32[] constant(1)
979   %constant.1 = f32[] constant(56)
980   %constant.2 = f32[] constant(12)
981   %constant.3 = f32[] constant(13)
982   ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
983 }
984 
985 )"
986 },
987   });
988   // clang-format on
989 }
990 
991 std::vector<TestData> CreateShortTestCases() {
992   // clang-format off
993   return std::vector<TestData>({
994 // map
995 {
996 "Map",
997 R"(HloModule MapBinaryAdder_module
998 
999 add_F32.v3 {
1000   lhs = f32[] parameter(0)
1001   rhs = f32[] parameter(1)
1002   ROOT add = f32[] add(lhs, rhs)
1003 }
1004 
1005 ENTRY MapBinaryAdder.v3 {
1006   param0 = f32[4]{0} parameter(0)
1007   param1 = f32[4]{0} parameter(1)
1008   ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
1009 }
1010 
1011 )"
1012 },
1013 // reduce
1014 {
1015 "Reduce",
1016 R"(HloModule ReduceR3ToR2_module
1017 
1018 add_F32.v3 {
1019   lhs = f32[] parameter(0)
1020   rhs = f32[] parameter(1)
1021   ROOT add = f32[] add(lhs, rhs)
1022 }
1023 
1024 ENTRY ReduceR3ToR2.v3 {
1025   input = f32[8,16,256]{2,1,0} parameter(0)
1026   constant = f32[] constant(0)
1027   ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
1028 }
1029 
1030 )"
1031 },
1032 // tuple reduce
1033 {
1034 "TupleReduce",
1035 R"(HloModule TupleReduce
1036 
1037 max_argmax {
1038   value = f32[] parameter(2)
1039   prev_max = f32[] parameter(0)
1040   is_next_larger = pred[] compare(value, prev_max), direction=GE
1041   max = f32[] select(is_next_larger, value, prev_max)
1042   index = s32[] parameter(3)
1043   prev_argmax = s32[] parameter(1)
1044   argmax = s32[] select(is_next_larger, index, prev_argmax)
1045   ROOT pair = (f32[], s32[]) tuple(max, argmax)
1046 }
1047 
1048 ENTRY reduce_entry {
1049   values = f32[1024]{0} parameter(0)
1050   indices = f32[1024]{0} parameter(1)
1051   init_value = f32[] constant(-inf)
1052   init_index = s32[] constant(-1)
1053   ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
1054 }
1055 
1056 )"
1057 },
1058 // infeed/outfeed
1059 {
1060 "InfeedOutfeed",
1061 R"(HloModule outfeed_module
1062 
1063 ENTRY InfeedToOutfeed {
1064   token0 = token[] after-all()
1065   infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1066   infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
1067   outfeed = token[] outfeed(infeed.data, token0)
1068   ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1069   infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
1070   infeed.1.token = token[] get-tuple-element(infeed.1), index=1
1071   outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
1072 }
1073 
1074 )"
1075 },
1076 // Rng
1077 {
1078 "Rng",
1079 R"(HloModule rng_module
1080 
1081 ENTRY Rng {
1082   constant = f32[] constant(0)
1083   constant.1 = f32[] constant(1)
1084   ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
1085 }
1086 
1087 )"
1088 },
1089 // Reduce precision
1090 {
1091 "ReducePrevison",
1092 R"(HloModule reduce_precision
1093 
1094 ENTRY ReducePrecision {
1095   constant = f32[1]{0} constant({3.14159})
1096   ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
1097 }
1098 
1099 )"
1100 },
1101 // Sort (Key)
1102 {
1103 "SortKey",
1104 R"(HloModule sort
1105 
1106 compare {
1107   p.0.lhs = f32[] parameter(0)
1108   p.0.rhs = f32[] parameter(1)
1109   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1110 }
1111 
1112 ENTRY Sort {
1113   x = f32[1024]{0} parameter(0)
1114   ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
1115 }
1116 
1117 )"
1118 },
1119 // Sort (Key, Value)
1120 {
1121 "SortKeyValue",
1122 R"(HloModule sort
1123 
1124 compare {
1125   p.1.lhs = s32[] parameter(2)
1126   p.1.rhs = s32[] parameter(3)
1127   p.0.lhs = f32[] parameter(0)
1128   p.0.rhs = f32[] parameter(1)
1129   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1130 }
1131 
1132 ENTRY Sort {
1133   keys = f32[1024]{0} parameter(0)
1134   values = s32[1024]{0} parameter(1)
1135   ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1136 }
1137 
1138 )"
1139 },
1140 // R2 Sort (Key)
1141 {
1142 "SortKeyR2",
1143 R"(HloModule sort
1144 
1145 compare {
1146   p.0.lhs = f32[] parameter(0)
1147   p.0.rhs = f32[] parameter(1)
1148   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1149 }
1150 
1151 ENTRY Sort {
1152   x = f32[1024,16]{0,1} parameter(0)
1153   ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
1154 }
1155 
1156 )"
1157 },
1158 // R2 Sort (Key, Value)
1159 {
1160 "SortKeyValueR2",
1161 R"(HloModule sort
1162 
1163 compare {
1164   p.1.lhs = s32[] parameter(2)
1165   p.1.rhs = s32[] parameter(3)
1166   p.0.lhs = f32[] parameter(0)
1167   p.0.rhs = f32[] parameter(1)
1168   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1169 }
1170 
1171 ENTRY Sort {
1172   keys = f32[1024,16]{0,1} parameter(0)
1173   values = s32[1024,16]{0,1} parameter(1)
1174   ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
1175 }
1176 
1177 )"
1178 },
1179 // Sort (Key, Value, Value, Value)
1180 {
1181 "SortManyValues",
1182 R"(HloModule sort
1183 
1184 compare {
1185   p.1.lhs = s32[] parameter(2)
1186   p.1.rhs = s32[] parameter(3)
1187   p.2.lhs = u32[] parameter(4)
1188   p.2.rhs = u32[] parameter(5)
1189   p.3.lhs = f32[] parameter(6)
1190   p.3.rhs = f32[] parameter(7)
1191   p.0.lhs = f32[] parameter(0)
1192   p.0.rhs = f32[] parameter(1)
1193   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1194 }
1195 
1196 ENTRY Sort {
1197   keys = f32[1024,16]{0,1} parameter(0)
1198   values.0 = s32[1024,16]{0,1} parameter(1)
1199   values.1 = u32[1024,16]{0,1} parameter(2)
1200   values.2 = f32[1024,16]{0,1} parameter(3)
1201   ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
1202 }
1203 
1204 )"
1205 },
1206 // Sort (Key) is_stable=true
1207 {
1208 "SortKeyStable",
1209 R"(HloModule sort
1210 
1211 compare {
1212   p.0.lhs = f32[] parameter(0)
1213   p.0.rhs = f32[] parameter(1)
1214   ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1215 }
1216 
1217 ENTRY Sort {
1218   x = f32[1024]{0} parameter(0)
1219   ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
1220 }
1221 
1222 )"
1223 },
1224 // Indexed Conditional
1225 {
1226 "IndexedConditional",
1227 R"(HloModule indexed_conditional
1228 
1229 Negate {
1230   x = f32[] parameter(0)
1231   ROOT negate = f32[] negate(x)
1232 }
1233 
1234 Identity {
1235   y = f32[] parameter(0)
1236   ROOT copy = f32[] copy(y)
1237 }
1238 
1239 Floor {
1240   z = f32[] parameter(0)
1241   ROOT floor = f32[] floor(z)
1242 }
1243 
1244 ENTRY Parameters1.v4 {
1245   constant = s32[] constant(1)
1246   constant.1 = f32[] constant(56)
1247   constant.2 = f32[] constant(12)
1248   constant.3 = f32[] constant(13)
1249   ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
1250 }
1251 
1252 )"
1253 },
1254 // Predicated Conditional
1255 {
1256 "PredicatedConditional",
1257 R"(HloModule pred_conditional
1258 
1259 Negate {
1260   x = f32[] parameter(0)
1261   ROOT negate = f32[] negate(x)
1262 }
1263 
1264 Identity {
1265   y = f32[] parameter(0)
1266   ROOT copy = f32[] copy(y)
1267 }
1268 
1269 ENTRY Parameters1.v4 {
1270   constant = pred[] constant(true)
1271   constant.1 = f32[] constant(56)
1272   constant.2 = f32[] constant(12)
1273   ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
1274 }
1275 
1276 )"
1277 },
1278 // CustomCall
1279 {
1280 "CustomCall",
1281 R"(HloModule custom_call
1282 
1283 ENTRY CustomCall {
1284   constant = f32[1]{0} constant({12345})
1285   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
1286 }
1287 
1288 )"
1289 },
1290 // CustomCall with opaque value.
1291 {
1292 "CustomCallWithOpaque",
1293 R"(HloModule custom_call
1294 
1295 ENTRY CustomCall {
1296   constant = f32[1]{0} constant({12345})
1297   ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque"
1298 }
1299 
1300 )"
1301 },
1302 // Variables with non-default names
1303 {
1304 "NonDefaultNames",
1305 R"(HloModule add_constants_module
1306 
1307 ENTRY add_constants {
1308   foo = f32[] constant(3.14)
1309   ROOT bar = f32[] add(foo, foo)
1310 }
1311 
1312 )"
1313 },
1314 {
1315 "Dot",
1316 R"(HloModule dot
1317 
1318 ENTRY dot {
1319   a = f32[2,10]{1,0} parameter(0)
1320   b = f32[10,3]{1,0} parameter(1)
1321   ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0}
1322 }
1323 
1324 )"
1325 },
1326 {
1327 "gather",
1328 R"(HloModule gather
1329 
1330 ENTRY Gather {
1331   input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1332   start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1333   ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
1334 }
1335 
1336 )"
1337 },
1338 // all-reduce
1339 {
1340 "AllReduce",
1341 R"(HloModule CRS
1342 
1343 add {
1344   lhs = f32[] parameter(0)
1345   rhs = f32[] parameter(1)
1346   ROOT add = f32[] add(lhs, rhs)
1347 }
1348 
1349 ENTRY CRS {
1350   input = f32[8]{0} parameter(0)
1351   ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
1352 }
1353 
1354 )"
1355 },
1356 // all-reduce with subgroups
1357 {
1358 "AllReduceWithSubgroups",
1359 R"(HloModule CRS_Subgroups
1360 
1361 add {
1362   lhs = f32[] parameter(0)
1363   rhs = f32[] parameter(1)
1364   ROOT add = f32[] add(lhs, rhs)
1365 }
1366 
1367 ENTRY AllReduceWithSubgroups {
1368   input = f32[128,32]{0,1} parameter(0)
1369   ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add
1370 }
1371 
1372 )"
1373 },
1374 // all-reduce with all-reduce-id
1375 {
1376 "AllReduceAllReduce",
1377 R"(HloModule CRS
1378 
1379 add {
1380   lhs = f32[] parameter(0)
1381   rhs = f32[] parameter(1)
1382   ROOT add = f32[] add(lhs, rhs)
1383 }
1384 
1385 ENTRY CRS {
1386   input = f32[8]{0} parameter(0)
1387   crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
1388   ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
1389 }
1390 
1391 )"
1392 },
1393 // all-to-all
1394 {
1395 "AllToAll",
1396 R"(HloModule AllToAll
1397 
1398 ENTRY AllToAll {
1399   input = f32[128,32]{0,1} parameter(0)
1400   ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={}
1401 }
1402 
1403 )"
1404 },
1405 // all-to-all with subgroups
1406 {
1407 "AllToAllWithSubgroups",
1408 R"(HloModule AllToAllWithSubgroups
1409 
1410 ENTRY AllToAllWithSubgroups {
1411   input = f32[128,32]{0,1} parameter(0)
1412   ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}
1413 }
1414 
1415 )"
1416 },
1417 // collective-permute
1418 {
1419 "CollectivePermute",
1420 R"(HloModule CollectivePermute
1421 
1422 ENTRY CollectivePermute {
1423   input = f32[128,32]{0,1} parameter(0)
1424   ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
1425 }
1426 
1427 )"
1428 },
1429 // replica-id
1430 {
1431 "ReplicaId",
1432 R"(HloModule replica-id
1433 
1434 ENTRY Replica-id {
1435   ROOT replica-id = u32[] replica-id()
1436 }
1437 
1438 )"
1439 },
1440 // Iota
1441 {
1442 "Iota",
1443 R"(HloModule iota
1444 
1445 ENTRY Iota {
1446   ROOT iota = f32[100]{0} iota(), iota_dimension=0
1447 }
1448 
1449 )"
1450 },
1451 // custom-call with window, dim_labels and feature_group_count
1452 {
1453 "CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
1454 R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
1455 
1456 ENTRY Computation {
1457   ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
1458 }
1459 
1460 )"
1461     },
1462 // is_scheduled=true attribute
1463 {
1464 "ScheduledModule",
1465 R"(HloModule scheduled_module, is_scheduled=true
1466 
1467 compare {
1468   p.1.lhs = s32[] parameter(2)
1469   p.1.rhs = s32[] parameter(3)
1470   p.0.lhs = f32[] parameter(0)
1471   p.0.rhs = f32[] parameter(1)
1472   ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1473 }
1474 
1475 ENTRY Sort {
1476   keys = f32[1024]{0} parameter(0)
1477   values = s32[1024]{0} parameter(1)
1478   ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1479 }
1480 
1481 )"
1482     },
1483 // AfterAll with multiple operands
1484 {
1485 "AfterAllWithMultipleOperands",
1486 R"(HloModule AfterAllWithMultipleOperands
1487 
1488 ENTRY AfterAllWithMultipleOperands {
1489   p0 = f32[] parameter(0)
1490   token0 = token[] after-all()
1491   token1 = token[] after-all()
1492   ROOT after-all = token[] after-all(p0, token0, token1)
1493 }
1494 
1495 )"
1496 },
1497 // AddDependency
1498 // A dependency chain is created from 'neg' to 'exp' using tokens.
1499 {
1500 "AddDependency",
1501 R"(HloModule AddDependency
1502 
1503 ENTRY AddDependency {
1504   p = f32[] parameter(0)
1505   neg = f32[] negate(p)
1506   token0 = token[] after-all(neg)
1507   p_after_token = f32[] add-dependency(p, token0)
1508   exp = f32[] exponential(p_after_token)
1509   ROOT sum = f32[] add(neg, exp)
1510 }
1511 
1512 )"
1513 },
1514 
1515 // A module containing constants equal to the min/max values of various data
1516 // types.
1517 {
1518 "MinMaxValues",
1519 R"(HloModule MinMaxValues
1520 
1521 ENTRY MinMaxValues {
1522   x.s8 = s8[2]{0} constant({-128, 127})
1523   x.s16 = s16[2]{0} constant({-32768, 32767})
1524   x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
1525   x.u8 = u8[2]{0} constant({0, 255})
1526   x.u16 = u16[2]{0} constant({0, 65535})
1527   x.u32 = u32[2]{0} constant({0, 4294967295})
1528   x.f16 = f16[2]{0} constant({-65504, 65504})
1529   x.bf16 = bf16[2]{0} constant({-3.38953e+38, 3.38953e+38})
1530   x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
1531   x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
1532   x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
1533   ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
1534 }
1535 
1536 )"
1537 },
1538 });
1539   // clang-format on
1540 }
1541 
1542 // The test class for those tests defined above which round-trip through the
1543 // parser and ToString is templatized on two bool parameters:
1544 //
1545 //  short_form : used for the "short" test cases which use the ShortParsable
1546 //    output form.
1547 //  proto_round_trip : whether the module should also be round-tripped through
1548 //    HloProto form. This provides much better coverage for the proto
1549 //    serialization/deserialization.
1550 //
1551 // The proto_round_trip=true case also technically covers the Parser->ToString
1552 // roundtrip as well, but separating out the Parser->ToString roundtrip as its
1553 // own test provides better isolation and could conceivably catch weirdo bugs
1554 // which are hidden by interaction between the textual and proto roundtripping.
1555 template <bool short_form, bool proto_round_trip>
1556 class HloParameterizedParserTest
1557     : public ::testing::Test,
1558       public ::testing::WithParamInterface<TestData> {
1559  protected:
1560   // Expects "ToString(ParseHloString(string)) == string", that is, parses the
1561   // string, asserts that it succeeded, stringifies the parsed module, and
1562   // checks that it equals the original string.
1563   void ExpectEqual() {
1564     const string& original = GetParam().module_string;
1565     TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1566                             ParseHloString(original));
1567     if (proto_round_trip) {
1568       TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
1569                                           module->ToProto(), module->config()));
1570     }
1571     if (short_form) {
1572       EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
1573     } else {
1574       EXPECT_EQ(
1575           original,
1576           module->ToString(HloPrintOptions().set_print_large_constants(true)));
1577     }
1578   }
1579 };
1580 
1581 // These using shenanigans are required because the TEST_P macro doesn't like
1582 // template instantiations which contain commas.
1583 using HloParserTestLong = HloParameterizedParserTest<false, false>;
1584 using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
1585 using HloParserTestShort = HloParameterizedParserTest<true, false>;
1586 using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
1587 
1588 TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
1589 TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
1590 TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
1591 TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
1592 
1593 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
1594                          ::testing::ValuesIn(CreateTestCases()),
1595                          TestDataToString);
1596 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1597                          HloParserTestLongProto,
1598                          ::testing::ValuesIn(CreateTestCases()),
1599                          TestDataToString);
1600 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
1601                          ::testing::ValuesIn(CreateShortTestCases()),
1602                          TestDataToString);
1603 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1604                          HloParserTestShortProto,
1605                          ::testing::ValuesIn(CreateShortTestCases()),
1606                          TestDataToString);
1607 
1608 class HloParserTest : public ::testing::Test {
1609  protected:
1610   static void ExpectHasSubstr(string_view s, string_view expected) {
1611     EXPECT_TRUE(absl::StrContains(s, expected))
1612         << "'" << s << "' does not contain '" << expected << "'";
1613   }
1614 };
1615 
1616 TEST_F(HloParserTest, Empty) {
1617   const string original = "";
1618   auto result = ParseHloString(original);
1619   EXPECT_NE(Status::OK(), result.status());
1620 }
1621 
1622 TEST_F(HloParserTest, Garbage) {
1623   const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
1624   auto result = ParseHloString(original);
1625   EXPECT_NE(Status::OK(), result.status());
1626 }
1627 
1628 TEST_F(HloParserTest, WrongOpcode) {
1629   const string original = R"(HloModule wrong_opcode:
1630 
1631 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
1632   %x = f32[]{} parameter(0)
1633   %y = f32[]{} parameter(1)
1634   %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
1635 }
1636 
1637 )";
1638   auto result = ParseHloString(original);
1639   EXPECT_NE(Status::OK(), result.status());
1640 }
1641 
1642 TEST_F(HloParserTest, WrongShape) {
1643   const string original = R"(HloModule wrong_opcode:
1644 
1645 ENTRY %blabla (x: g32[]) -> g32[] {
1646   %x = g32[]{} parameter(0)
1647 }
1648 
1649 )";
1650   auto result = ParseHloString(original);
1651   EXPECT_NE(Status::OK(), result.status());
1652 }
1653 
1654 TEST_F(HloParserTest, WrongOperandsSize) {
1655   const string original = R"(HloModule wrong_opcode:
1656 
1657 ENTRY %blabla (x: f32[]) -> pred[] {
1658   %x = f32[]{} parameter(0)
1659   %eq = pred[]{} compare(f32[]{} %x), direction=EQ
1660 }
1661 
1662 )";
1663   auto result = ParseHloString(original);
1664   EXPECT_NE(Status::OK(), result.status());
1665 }
1666 
1667 TEST_F(HloParserTest, OperandNotFound) {
1668   const string original = R"(HloModule operand_not_found:
1669 ENTRY %blabla (x: f32[]) -> pred[] {
1670   %x = f32[]{} parameter(0)
1671   %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
1672 }
1673 )";
1674   auto result = ParseHloString(original);
1675   EXPECT_NE(Status::OK(), result.status());
1676 }
1677 
1678 TEST_F(HloParserTest, MoreConstants) {
1679   const string original = R"(HloModule SelectScalarS32True_module
1680 
1681 ENTRY %SelectScalarS32True.v4 () -> s32[] {
1682   %constant.2 = pred[] constant(true)
1683   %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
1684   %constant = s32[] constant(42)
1685   %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
1686 }
1687 
1688 )";
1689   auto result = ParseHloString(original);
1690   TF_EXPECT_OK(result.status());
1691   // Constant instructions have no name. The string will be parsed successfully
1692   // but the constant names will not be exactly the same.
1693 }
1694 
1695 TEST_F(HloParserTest, ConfigurationField) {
1696   const string original = R"(HloModule AModule
1697 ENTRY %configuration_test() -> s32[] {
1698   %constant = s32[] constant(42), backend_config="foo bar"
1699 })";
1700   auto result = ParseHloString(original);
1701   TF_ASSERT_OK(result.status());
1702   EXPECT_EQ("foo bar", result.ValueOrDie()
1703                            ->entry_computation()
1704                            ->root_instruction()
1705                            ->raw_backend_config_string());
1706 }
1707 
1708 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
1709   const string original = R"(HloModule some_2_module
1710 
1711 ENTRY %some_2 () -> f32[2] {
1712   ROOT %constant = f32[2]{0} constant({1,{2}})
1713 }
1714 
1715 )";
1716   auto result = ParseHloString(original);
1717   EXPECT_NE(Status::OK(), result.status());
1718   ExpectHasSubstr(result.status().error_message(),
1719                   "expects nested array in rank 1, but sees larger");
1720 }
1721 
1722 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
1723   const string original = R"(HloModule some_2x3_module
1724 
1725 ENTRY %some_2x3 () -> f32[2,3] {
1726   ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
1727 }
1728 
1729 )";
1730   auto result = ParseHloString(original);
1731   EXPECT_NE(Status::OK(), result.status());
1732   ExpectHasSubstr(result.status().error_message(),
1733                   "expects nested array in rank 2, but sees 1");
1734 }
1735 
1736 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
1737   const string original = R"(HloModule some_2x3x2_module
1738 
1739 ENTRY %some_2x3x2 () -> f32[2,3,2] {
1740   ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
1741 }
1742 
1743 )";
1744   auto result = ParseHloString(original);
1745   EXPECT_NE(Status::OK(), result.status());
1746   ExpectHasSubstr(result.status().error_message(),
1747                   "expects 3 elements in the [0]th element");
1748 }
1749 
1750 TEST_F(HloParserTest, ConstantF16Overflow) {
1751   const string original =
1752       R"(HloModule ConstantF16Overflow_module
1753 
1754 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
1755   ROOT %constant = f16[] constant(-65505)
1756 }
1757 
1758 )";
1759   auto result = ParseHloString(original);
1760   EXPECT_NE(Status::OK(), result.status());
1761   ExpectHasSubstr(result.status().error_message(),
1762                   "is out of range for literal's primitive type F16");
1763 }
1764 
1765 TEST_F(HloParserTest, ConstantBf16NoOverflow) {
1766   // 65505 is in range for bf16.
1767   const string original = R"(
1768   HloModule test_module
1769   ENTRY test {
1770     ROOT c = bf16[] constant(-65505)
1771   })";
1772   EXPECT_EQ(Status::OK(), ParseHloString(original).status());
1773 }
1774 
1775 TEST_F(HloParserTest, ConstantBf16Overflow) {
1776   // 1e100 is out of range for bf16.
1777   const string original = R"(
1778   HloModule test_module
1779   ENTRY test {
1780     ROOT c = bf16[] constant(1e100)
1781   })";
1782   ExpectHasSubstr(ParseHloString(original).status().error_message(),
1783                   "out of range");
1784 }
1785 
1786 TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) {
1787   const string original = R"(
1788     HloModule test_module
1789     ENTRY test {
1790       ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65505})
1791     })";
1792   ExpectHasSubstr(ParseHloString(original).status().error_message(),
1793                   "is out of range for literal's primitive type F16");
1794 }
1795 
1796 TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
1797   const string original = R"(
1798       HloModule ConstantUnsignedUnderflow_module
1799       ENTRY %ConstantUnsignedUnderflow () -> u64[] {
1800         ROOT %constant = u64[] constant(-1)
1801       })";
1802   auto result = ParseHloString(original);
1803   EXPECT_NE(Status::OK(), result.status());
1804   ExpectHasSubstr(result.status().error_message(),
1805                   "is out of range for literal's primitive type U64");
1806 }
1807 
1808 TEST_F(HloParserTest, ConstantUnsignedOverflow) {
1809   const string original = R"(
1810       HloModule ConstantUnsignedOverflow_module
1811       ENTRY %ConstantUnsignedOverflow () -> u32[] {
1812         ROOT %constant = u32[] constant(4294967296)
1813       })";
1814   auto result = ParseHloString(original);
1815   EXPECT_NE(Status::OK(), result.status());
1816   ExpectHasSubstr(result.status().error_message(),
1817                   "is out of range for literal's primitive type U32");
1818 }
1819 
1820 TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
1821   const string original = R"(
1822       HloModule ConstantUnsignedOverflow_module
1823       ENTRY %ConstantUnsignedOverflow () -> u64[] {
1824         ROOT %constant = u64[] constant(9223372036854775808)
1825       })";
1826   auto result = ParseHloString(original);
1827   EXPECT_NE(Status::OK(), result.status());
1828 }
1829 
1830 TEST_F(HloParserTest, ConstantC64Overflow) {
1831   const string original = R"(
1832       HloModule test_module
1833       ENTRY test () -> c64[] {
1834         ROOT c = c64[] constant((1e100, 0))
1835       })";
1836   auto result = ParseHloString(original);
1837   EXPECT_NE(Status::OK(), result.status());
1838 }
1839 
1840 TEST_F(HloParserTest, ConstantC64Underflow) {
1841   const string original = R"(
1842       HloModule test_module
1843       ENTRY test () -> c64[] {
1844         ROOT c = c64[] constant((0, -1e100))
1845       })";
1846   auto result = ParseHloString(original);
1847   EXPECT_NE(Status::OK(), result.status());
1848 }
1849 
1850 TEST_F(HloParserTest, ConstantF64Overflow) {
1851   const string original = R"(
1852       HloModule test_module
1853       ENTRY test {
1854         ROOT c = f64[] constant(1.8e308)
1855       })";
1856   auto result = ParseHloString(original);
1857   EXPECT_NE(Status::OK(), result.status());
1858 }
1859 
1860 TEST_F(HloParserTest, ConstantF64Underflow) {
1861   const string original = R"(
1862       HloModule test_module
1863       ENTRY test {
1864         ROOT c = f64[] constant(-1.8e308)
1865       })";
1866   auto result = ParseHloString(original);
1867   EXPECT_NE(Status::OK(), result.status());
1868 }
1869 
1870 TEST_F(HloParserTest, ConstantWithExp) {
1871   const string original = R"(HloModule ConstantWithExp_module
1872 
1873 ENTRY %ConstantWithExp.v4 () -> f32[] {
1874   %constant.1 = f32[] constant(3e+2)
1875 }
1876 
1877 )";
1878   auto result = ParseHloString(original);
1879   TF_EXPECT_OK(result.status());
1880   // The string will be parsed successfully but the output strings are not
1881   // exactly the same, because "3e2" is parsed into value 300 and will be
1882   // printed as "300".
1883 }
1884 
1885 TEST_F(HloParserTest, ShortConstant) {
1886   const string original = R"(HloModule ShortCOnstant_module
1887 
1888 ENTRY %ShortConstant.v4 () -> f32[67,89] {
1889   ROOT %constant.1 = f32[67,89]{1,0} constant({...})
1890 }
1891 
1892 )";
1893   auto result = ParseHloString(original);
1894   TF_EXPECT_OK(result.status());
1895   EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
1896 }
1897 
1898 TEST_F(HloParserTest, AttibutesAnyOrder) {
1899   const string original = R"(HloModule any_order_module
1900 
1901 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
1902   %input = f32[1,2,1]{2,1,0} parameter(0)
1903   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
1904   %filter = f32[1,1,1]{2,1,0} parameter(1)
1905   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
1906 }
1907 
1908 )";
1909   TF_EXPECT_OK(ParseHloString(original).status());
1910 }
1911 
1912 TEST_F(HloParserTest, InvalidDimLabels) {
1913   string prefix = R"(HloModule invalid_dim_labels_module
1914 
1915 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
1916   %input = f32[1,2,1]{2,1,0} parameter(0)
1917   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
1918   %filter = f32[1,1,1]{2,1,0} parameter(1)
1919   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
1920   string suffix = R"(
1921 }
1922 
1923 )";
1924 
1925   ExpectHasSubstr(
1926       ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
1927           .status()
1928           .error_message(),
1929       "expects dim labels pattern");
1930 
1931   ExpectHasSubstr(
1932       ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
1933           .status()
1934           .error_message(),
1935       "must have the same rank");
1936 }
1937 
1938 TEST_F(HloParserTest, UnexpectedAttribute) {
1939   const string original = R"(HloModule unexpected_attr_module
1940 
1941 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
1942   %token0 = token[] after-all()
1943   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
1944   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
1945   ROOT %constant = f32[] constant(2.1)
1946   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
1947   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
1948 }
1949 
1950 )";
1951   ExpectHasSubstr(ParseHloString(original).status().error_message(),
1952                   "unexpected attribute \"calls\"");
1953 }
1954 
TEST_F(HloParserTest,MissingAttribute)1955 TEST_F(HloParserTest, MissingAttribute) {
1956   const string original = R"(HloModule missing_attr_module
1957 
1958 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
1959   %token0 = token[] after-all()
1960   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
1961   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
1962   ROOT %constant = f32[] constant(-2.1)
1963   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
1964   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
1965 }
1966 
1967 )";
1968   ExpectHasSubstr(ParseHloString(original).status().error_message(),
1969                   "attribute channel_id is expected but not seen");
1970 }
1971 
TEST_F(HloParserTest,PredecessorUndefined)1972 TEST_F(HloParserTest, PredecessorUndefined) {
1973   const string original = R"(HloModule pre_not_found_module
1974 
1975 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
1976   %token0 = token[] after-all()
1977   %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
1978   %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
1979   ROOT %constant = f32[] constant(2.1)
1980   %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
1981   %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
1982 }
1983 
1984 )";
1985   ExpectHasSubstr(ParseHloString(original).status().error_message(),
1986                   "'done' is not defined");
1987 }
1988 
TEST_F(HloParserTest,SliceAllowOmitStride1)1989 TEST_F(HloParserTest, SliceAllowOmitStride1) {
1990   const string original = R"(HloModule slice_module
1991 
1992 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
1993   %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
1994   ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
1995 }
1996 
1997 )";
1998   TF_EXPECT_OK(ParseHloString(original).status());
1999 }
2000 
TEST_F(HloParserTest,PaddingConfigIsNotWindowPad)2001 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
2002   const string original = R"(HloModule window_pad_module
2003 
2004 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2005   %input = f32[1,2,1]{2,1,0} parameter(0)
2006   %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2007   %filter = f32[1,1,1]{2,1,0} parameter(1)
2008   ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
2009 }
2010 
2011 )";
2012   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2013                   "expects padding_low and padding_high separated by '_'");
2014 }
2015 
TEST_F(HloParserTest,CommaBetweenSubAttributes)2016 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
2017   const string original = R"(HloModule test_comma_module
2018 
2019 ENTRY %test_comma.v4 () -> f32[] {
2020   ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
2021 }
2022 
2023 )";
2024   TF_EXPECT_OK(ParseHloString(original).status());
2025 }
2026 
TEST_F(HloParserTest,ComputationShapeDoesNotMatchRootShape)2027 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
2028   const string original = R"(HloModule custom_call:
2029 
2030 ENTRY %CustomCall () -> f32[1] {
2031   %constant = f32[1]{0} constant({12345})
2032   ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
2033 })";
2034   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2035                   "Shape of computation CustomCall, f32[1], is not compatible "
2036                   "with that of its root instruction foo, f32[1,2,3]");
2037 }
2038 
TEST_F(HloParserTest,EntryComputationWithLayout)2039 TEST_F(HloParserTest, EntryComputationWithLayout) {
2040   const string original = R"(HloModule layout:
2041 add_F32.v3 {
2042   lhs = f32[] parameter(0)
2043   rhs = f32[] parameter(1)
2044   ROOT add = f32[] add(lhs, rhs)
2045 }
2046 
2047 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
2048   input = f32[8,16,256]{0,1,2} parameter(0)
2049   constant = f32[] constant(0)
2050   ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
2051 })";
2052 
2053   auto module = ParseHloString(original);
2054   TF_ASSERT_OK(module.status());
2055   auto program_layout = module.ValueOrDie()->entry_computation_layout();
2056   ASSERT_EQ(program_layout.parameter_count(), 1);
2057   auto param_layout = program_layout.parameter_layout(0).layout();
2058   auto result_layout = program_layout.result_layout().layout();
2059   EXPECT_TRUE(
2060       LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
2061       << "actual layout of parameter(0) is "
2062       << LayoutUtil::HumanString(param_layout);
2063   EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
2064       << "actual layout of result is "
2065       << LayoutUtil::HumanString(result_layout);
2066 }
2067 
TEST_F(HloParserTest,NoEntry)2068 TEST_F(HloParserTest, NoEntry) {
2069   const string original = R"(HloModule no_entry:
2070 c1 {
2071   const1 = f32[1]{0} constant({12345})
2072 }
2073 c2 {
2074   const2 = f32[1]{0} constant({67890})
2075 })";
2076   auto module = ParseHloString(original);
2077   TF_ASSERT_OK(module.status());
2078   EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
2079 }
2080 
TEST_F(HloParserTest,NoRoot)2081 TEST_F(HloParserTest, NoRoot) {
2082   const string original = R"(HloModule no_root:
2083 ENTRY consts {
2084   first = f32[1]{0} constant({12345})
2085   last = f32[1]{0} constant({67890})
2086 })";
2087   auto module = ParseHloString(original);
2088   TF_ASSERT_OK(module.status());
2089   EXPECT_EQ(
2090       module.ValueOrDie()->entry_computation()->root_instruction()->name(),
2091       "last");
2092 }
2093 
TEST_F(HloParserTest,Comments)2094 TEST_F(HloParserTest, Comments) {
2095   const string original = R"(/* module description. */
2096 HloModule comments:
2097 
2098 ENTRY /*comment*/ c1 {
2099   /* blah */
2100   ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
2101   /* comment */
2102 }
2103 
2104 /* something else */
2105 
2106 )";
2107   auto module = ParseHloString(original);
2108   TF_ASSERT_OK(module.status());
2109 }
2110 
TEST_F(HloParserTest,MultilineComments)2111 TEST_F(HloParserTest, MultilineComments) {
2112   const string original = R"(HloModule multiline_comment:
2113 ENTRY c1 {
2114   /*
2115      ROOT foo = f32[1]{0} constant({12345})
2116   */
2117   ROOT const1 = f32[1]{0} constant({12345})
2118 /*
2119 a
2120 b
2121 c
2122 d
2123 
2124 */
2125 })";
2126   auto module = ParseHloString(original);
2127   TF_ASSERT_OK(module.status());
2128 }
2129 
TEST_F(HloParserTest,UnterminatedComment)2130 TEST_F(HloParserTest, UnterminatedComment) {
2131   const string original = R"(HloModule unterminated_comment:
2132 ENTRY c1 {
2133 /* unterminated
2134   ROOT const1 = f32[1]{0} constant({12345})
2135 })";
2136   // Verify that the error message points to the beginning of the unterminated
2137   // comment.
2138   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2139                   "/* unterminated\n^");
2140 }
2141 
2142 TEST_F(HloParserTest, SlashSlashComments) {
2143   const string original = R"(HloModule slash_slash_comment:
2144 // Garbage
2145 ENTRY c1 {
2146   // Foo bar
2147   ROOT const1 = f32[1]{0} constant({12345}) // Something else
2148 })";
2149   auto module = ParseHloString(original);
2150   TF_ASSERT_OK(module.status());
2151 }
2152 
2153 TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
2154   const string original =
2155       "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
2156       "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
2157   auto module = ParseHloString(original);
2158   TF_ASSERT_OK(module.status());
2159 }
2160 
2161 TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
2162   const string original =
2163       "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
2164       "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
2165   auto module = ParseHloString(original);
2166   TF_ASSERT_OK(module.status());
2167 }
2168 
2169 TEST_F(HloParserTest, MultipleEntries) {
2170   const string original = R"(HloModule multiple_entries:
2171 ENTRY c1 {
2172   const1 = f32[1]{0} constant({12345})
2173 }
2174 ENTRY c2 {
2175   const2 = f32[1]{0} constant({67890})
2176 })";
2177   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2178                   "expects only one ENTRY");
2179 }
2180 
2181 TEST_F(HloParserTest, MultipleRoots) {
2182   const string original = R"(HloModule multiple_roots:
2183 ENTRY consts {
2184   ROOT const1 = f32[1]{0} constant({12345})
2185   ROOT const2 = f32[1]{0} constant({12345})
2186 })";
2187   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2188                   "one computation should have only one ROOT");
2189 }
2190 
2191 TEST_F(HloParserTest, ComputationExists) {
2192   const string original = R"(HloModule comp_exists
2193 comp {
2194   const1 = f32[1]{0} constant({12345})
2195 }
2196 comp {
2197   const2 = f32[1]{0} constant({67890})
2198 })";
2199   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2200                   R"(was parsing 2:1: error: computation previously defined here
2201 comp {
2202 ^)");
2203 }
2204 
2205 TEST_F(HloParserTest, CrossComputationLookup) {
2206   const string original = R"(HloModule cross_computation_lookup:
2207 tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
2208   ROOT aparam = (s32[], s32[]) parameter(0)
2209 }
2210 
2211 tcallb (b: (s32[], s32[])) -> s32[] {
2212   rparam = (s32[], s32[]) parameter(0)
2213   ROOT gte0 = s32[] get-tuple-element(aparam), index=0
2214 }
2215 
2216 ENTRY entry {
2217   param = (s32[], s32[]) parameter(0)
2218   call0 = (s32[], s32[]) call(param), to_apply=tcalla
2219   ROOT call1 = s32[] call(param), to_apply=tcallb
2220 })";
2221   ExpectHasSubstr(
2222       ParseHloString(original).status().error_message(),
2223       "was parsing 8:39: error: instruction does not exist: aparam");
2224 }
2225 
2226 TEST_F(HloParserTest, SameNameDiffComputations) {
2227   const string original = R"(HloModule same_names:
2228 add {
2229   p0 = f32[] parameter(0)
2230   p1 = f32[] parameter(1)
2231   ROOT result = f32[] add(p0, p1)
2232 }
2233 
2234 ENTRY ReduceR3ToR2 {
2235   p0 = f32[8,16,256]{2,1,0} parameter(0)
2236   p1 = f32[] constant(0)
2237   ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
2238 }
2239 )";
2240   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
2241   ASSERT_NE(module->entry_computation(), nullptr);
2242   EXPECT_THAT(module->entry_computation()->root_instruction(),
2243               GmockMatch(m::Reduce()));
2244 }
2245 
2246 TEST_F(HloParserTest, ParseSharding) {
2247   const string original = "{maximal device=42}";
2248   TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
2249   EXPECT_EQ(sharding.ToString(), original);
2250 }
2251 
2252 TEST_F(HloParserTest, ParseWindow) {
2253   Window original = window_util::MakeWindow({1, 2, 3});
2254   TF_ASSERT_OK_AND_ASSIGN(Window parsed,
2255                           ParseWindow(window_util::ToString(original)))
2256   EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
2257 }
2258 
2259 TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
2260   const string original = "b0f_0io->b0f";
2261   TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
2262                           ParseConvolutionDimensionNumbers(original));
2263   EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
2264 }
2265 
2266 TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
2267   const string original = "0_1x2_3";
2268   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2269   EXPECT_EQ(original, PaddingConfigToString(dnums));
2270 }
2271 
2272 TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
2273   const string original = "0_1_0x2_3_4";
2274   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2275   EXPECT_EQ(original, PaddingConfigToString(dnums));
2276 }
2277 
2278 TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
2279   TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
2280   // The extra "_0" gets added to the canonical string because the other dim has
2281   // interior padding.
2282   EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
2283 }
2284 
2285 TEST_F(HloParserTest, NontupleInfeed) {
2286   const string original = R"(HloModule nontuple_infeed:
2287 ENTRY nontuple_infeed {
2288   token0 = token[] after-all()
2289   ROOT infeed = pred[] infeed(token0)
2290 })";
2291   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2292                   "infeed must have a non-empty tuple shape");
2293 }
2294 
2295 TEST(HloParserSingleOpTest, SingleOp) {
2296   const string text =
2297       "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
2298       "f32[2,4]{1,0} %x)";
2299   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2300   const HloComputation* computation = module->entry_computation();
2301   ASSERT_NE(computation, nullptr);
2302   EXPECT_THAT(computation->root_instruction(),
2303               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2304 }
2305 
2306 TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
2307   const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
2308   StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
2309   ASSERT_TRUE(!module.status().ok());
2310   LOG(INFO) << "Status: " << module.status();
2311   EXPECT_THAT(module.status().ToString(),
2312               ::testing::HasSubstr("expects '=' in instruction"));
2313 }
2314 
2315 TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
2316   const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
2317   StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
2318   ASSERT_TRUE(!module.status().ok());
2319   LOG(INFO) << "Status: " << module.status();
2320   EXPECT_THAT(module.status().ToString(),
2321               ::testing::HasSubstr("Operand had no shape in HLO text"));
2322 }
2323 
2324 TEST(HloParserSingleOpTest, SingleOpNoNames) {
2325   const string text =
2326       "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2327   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2328   const HloComputation* computation = module->entry_computation();
2329   ASSERT_NE(computation, nullptr);
2330   EXPECT_THAT(computation->root_instruction(),
2331               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2332 }
2333 
2334 TEST(HloParserSingleOpTest, CanonicalOp) {
2335   const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2336   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2337   const HloComputation* computation = module->entry_computation();
2338   ASSERT_NE(computation, nullptr);
2339   EXPECT_THAT(computation->root_instruction(),
2340               GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2341   EXPECT_EQ(
2342       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2343       text);
2344 }
2345 
2346 TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
2347   const string text =
2348       R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
2349 {
2350   tmp_0 = f32[5,10]{1,0} parameter(0)
2351   tmp_1 = f32[20,10]{1,0} parameter(1)
2352   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2353   {
2354     tmp_0 = f32[5,10]{1,0} parameter(0)
2355     tmp_1 = f32[20,10]{1,0} parameter(1)
2356     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2357     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2358   }
2359 }, body=
2360 {
2361   tmp_0 = f32[5,10]{1,0} parameter(0)
2362   tmp_1 = f32[20,10]{1,0} parameter(1)
2363   ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2364   {
2365     tmp_0 = f32[5,10]{1,0} parameter(0)
2366     tmp_1 = f32[20,10]{1,0} parameter(1)
2367     tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2368     ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2369   }
2370 })";
2371 
2372   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2373   const HloComputation* computation = module->entry_computation();
2374   ASSERT_NE(computation, nullptr);
2375   EXPECT_EQ(
2376       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2377       text);
2378 }
2379 
2380 TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
2381   const string text =
2382       R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
2383 {
2384   tmp_0 = f32[5,10]{1,0} parameter(0)
2385   ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
2386 },
2387 {
2388   tmp_0 = f32[5,10]{1,0} parameter(0)
2389   ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
2390 },
2391 {
2392   tmp_0 = f32[5,10]{1,0} parameter(0)
2393   ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
2394 }
2395 })";
2396 
2397   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2398   const HloComputation* computation = module->entry_computation();
2399   ASSERT_NE(computation, nullptr);
2400   EXPECT_EQ(
2401       computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2402       text);
2403 }
2404 
2405 TEST(HloParserSingleOpTest, SingleOpWithNested) {
2406   const string text =
2407       R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
2408 {
2409   %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
2410   %param_1 = f32[2]{0} parameter(1)
2411   %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
2412   ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
2413 })";
2414 
2415   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2416   const HloComputation* computation = module->entry_computation();
2417   ASSERT_NE(computation, nullptr);
2418   EXPECT_THAT(computation->root_instruction(),
2419               GmockMatch(m::Op()
2420                              .WithOpcode(HloOpcode::kFusion)
2421                              .WithNumOperands(2)
2422                              .WithOperand(0, m::Parameter(0))
2423                              .WithOperand(1, m::Parameter(1))));
2424 }
2425 
2426 TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
2427   const string text =
2428       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2429 {
2430   result = f32[] add(f32[] x, f32[] y)
2431 })";
2432   auto status = ParseHloString(text).status();
2433   ASSERT_FALSE(status.ok());
2434   EXPECT_THAT(status.error_message(),
2435               ::testing::HasSubstr("does not exist: x"));
2436 }
2437 
2438 TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
2439   const string text =
2440       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2441 {
2442   f32[] add(f32[] x, f32[] y)
2443 })";
2444   auto status = ParseHloString(text).status();
2445   ASSERT_FALSE(status.ok());
2446   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2447 }
2448 
2449 TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
2450   const string text =
2451       R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2452 {
2453   result = f32[] add(f32[], f32[])
2454 })";
2455   auto status = ParseHloString(text).status();
2456   ASSERT_FALSE(status.ok());
2457   EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2458 }
2459 
2460 TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
2461   const string text =
2462       R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
2463   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2464   const HloComputation* computation = module->entry_computation();
2465   ASSERT_NE(computation, nullptr);
2466   EXPECT_THAT(computation->root_instruction(),
2467               GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
2468   auto* convolution =
2469       Cast<HloConvolutionInstruction>(computation->root_instruction());
2470   EXPECT_EQ(convolution->feature_group_count(), 1);
2471 }
2472 
2473 TEST_F(HloParserTest, IsScheduledIsFalse) {
2474   const string text = R"(
2475 HloModule axpy_module, is_scheduled=false
2476 
2477 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2478   %alpha = f32[] parameter(0)
2479   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2480   %x = f32[2,4]{1,0} parameter(1)
2481   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2482   %y = f32[2,4]{1,0} parameter(2)
2483   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2484 }
2485 )";
2486   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2487                           ParseHloString(text));
2488   ASSERT_FALSE(module->has_schedule());
2489 }
2490 
2491 TEST_F(HloParserTest, IsScheduledNotPresent) {
2492   const string text = R"(
2493 HloModule axpy_module
2494 
2495 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2496   %alpha = f32[] parameter(0)
2497   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2498   %x = f32[2,4]{1,0} parameter(1)
2499   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2500   %y = f32[2,4]{1,0} parameter(2)
2501   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2502 }
2503 )";
2504   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2505                           ParseHloString(text));
2506   ASSERT_FALSE(module->has_schedule());
2507 }
2508 
2509 TEST_F(HloParserTest, IsScheduledIsTrue) {
2510   const string text = R"(
2511 HloModule axpy_module, is_scheduled=true
2512 
2513 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2514   %alpha = f32[] parameter(0)
2515   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2516   %x = f32[2,4]{1,0} parameter(1)
2517   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2518   %y = f32[2,4]{1,0} parameter(2)
2519   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2520 }
2521 )";
2522   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2523                           ParseHloString(text));
2524   ASSERT_TRUE(module->has_schedule());
2525   TF_ASSERT_OK(module->schedule().Verify());
2526   EXPECT_EQ(module->schedule().sequences().size(), 1);
2527   ASSERT_TRUE(
2528       module->schedule().is_computation_scheduled(module->entry_computation()));
2529   EXPECT_THAT(
2530       module->schedule().sequence(module->entry_computation()).instructions(),
2531       ::testing::ElementsAre(
2532           GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2533           GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
2534           GmockMatch(m::Parameter()), GmockMatch(m::Add())));
2535 }
2536 
2537 TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
2538   // As above but in with a different schedule order.
2539   const string text = R"(
2540 HloModule axpy_module, is_scheduled=true
2541 
2542 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2543   %alpha = f32[] parameter(0)
2544   %x = f32[2,4]{1,0} parameter(1)
2545   %y = f32[2,4]{1,0} parameter(2)
2546   %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2547   %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2548   ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2549 }
2550 )";
2551   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2552                           ParseHloString(text));
2553   ASSERT_TRUE(module->has_schedule());
2554   TF_ASSERT_OK(module->schedule().Verify());
2555   EXPECT_EQ(module->schedule().sequences().size(), 1);
2556   ASSERT_TRUE(
2557       module->schedule().is_computation_scheduled(module->entry_computation()));
2558   EXPECT_THAT(
2559       module->schedule().sequence(module->entry_computation()).instructions(),
2560       ::testing::ElementsAre(
2561           GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
2562           GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2563           GmockMatch(m::Multiply()), GmockMatch(m::Add())));
2564 }
2565 
2566 TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
2567   const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
2568 
2569 ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2570   %p0 = f32[42,2,3]{0,1,2} parameter(0)
2571   %p1 = f32[123,4]{0,1} parameter(1)
2572   ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
2573 }
2574 
2575 )";
2576   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2577                   "Expected 2 operand layout constraints, 1 given");
2578 }
2579 
2580 TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
2581   const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
2582 
2583 ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2584   %p0 = f32[42,2,3]{0,1,2} parameter(0)
2585   %p1 = f32[123,4]{0,1} parameter(1)
2586   ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
2587 }
2588 
2589 )";
2590   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2591                   "operand 1 is not compatible with operand shape");
2592 }
2593 
2594 TEST_F(HloParserTest, AllowShapeWhitespace) {
2595   const string text = R"(
2596 HloModule module
2597 
2598 ENTRY entry {
2599   ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
2600 }
2601 )";
2602   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2603                           ParseHloString(text));
2604 }
2605 
2606 TEST_F(HloParserTest, ShapeMismatchInOperand) {
2607   const string text = R"(
2608 HloModule foobar
2609 
2610 ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
2611   %p = f32[2,2] parameter(0)
2612   %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
2613   ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
2614 }
2615 )";
2616 
2617   ExpectHasSubstr(ParseHloString(text).status().error_message(),
2618                   "The declared operand shape f32[2,5]{1,0} is not compatible"
2619                   " with the shape of the operand instruction f32[2,2]{1,0}.");
2620 }
2621 
2622 TEST_F(HloParserTest, OutOfRangeSparseIndex) {
2623   const string original = R"(
2624     HloModule test_module
2625     ENTRY test {
2626       ROOT c = f16[5]sparse{10} constant({[100]: 0})
2627     })";
2628   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2629                   "Invalid sparse index");
2630 }
2631 
2632 TEST_F(HloParserTest, NegativeSparseIndex) {
2633   const string original = R"(
2634     HloModule test_module
2635     ENTRY test {
2636       ROOT c = f16[5]sparse{10} constant({-1: 0})
2637     })";
2638   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2639                   "Invalid sparse index");
2640 }
2641 
2642 TEST_F(HloParserTest, SparseIndexWithRankTooLarge) {
2643   const string original = R"(
2644     HloModule test_module
2645     ENTRY test {
2646       ROOT c = f16[5]sparse{10} constant({[0, 0]: 0})
2647     })";
2648   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2649                   "Invalid sparse index");
2650 }
2651 
2652 TEST_F(HloParserTest, SparseIndexWithRankTooSmall) {
2653   const string original = R"(
2654     HloModule test_module
2655     ENTRY test {
2656       ROOT c = f16[5, 5]sparse{10} constant({[0]: 0})
2657     })";
2658   ExpectHasSubstr(ParseHloString(original).status().error_message(),
2659                   "Invalid sparse index");
2660 }
2661 
2662 TEST_F(HloParserTest, ParseShapeStringR2F32) {
2663   string shape_string = "f32[123,456]";
2664   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2665   Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
2666   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2667       << "expected: " << ShapeUtil::HumanString(expected)
2668       << "actual:   " << ShapeUtil::HumanString(actual);
2669 }
2670 
2671 TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
2672   string shape_string = "(f32[1572864],s8[5120,1024])";
2673   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2674   Shape expected =
2675       ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
2676                                  ShapeUtil::MakeShape(S8, {5120, 1024})});
2677   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2678       << "expected: " << ShapeUtil::HumanString(expected)
2679       << "actual:   " << ShapeUtil::HumanString(actual);
2680 }
2681 
2682 TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
2683   string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
2684   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2685   Shape expected = ShapeUtil::MakeTupleShape({
2686       ShapeUtil::MakeShape(F32, {1}),
2687       ShapeUtil::MakeTupleShape(
2688           {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
2689       ShapeUtil::MakeOpaqueShape(),
2690       ShapeUtil::MakeShape(F32, {3}),
2691   });
2692   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2693       << "expected: " << ShapeUtil::HumanString(expected)
2694       << "actual:   " << ShapeUtil::HumanString(actual);
2695 }
2696 
2697 TEST_F(HloParserTest, ParseShapeStringWithLayout) {
2698   string shape_string = "f32[123,456]{0,1}";
2699   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2700   Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
2701   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2702       << "expected: " << ShapeUtil::HumanString(expected)
2703       << "actual:   " << ShapeUtil::HumanString(actual);
2704 }
2705 
2706 TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
2707   // One tile.
2708   string shape_string = "f32[123,456]{0,1:T(2,128)}";
2709   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2710   Shape expected =
2711       ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})});
2712   EXPECT_EQ(expected, actual)
2713       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2714       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2715 
2716   // Tile with negative dimension size for combining dimensions.
2717   shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
2718   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2719   expected =
2720       ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2},
2721                                      {Tile({2, Tile::kCombineDimension, 128})});
2722   EXPECT_EQ(expected, actual)
2723       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2724       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2725 
2726   // Two tiles.
2727   shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
2728   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2729   expected = ShapeUtil::MakeShapeWithLayout(
2730       BF16, {123, 456, 789}, {2, 1, 0},
2731       {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
2732   EXPECT_EQ(expected, actual)
2733       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2734       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2735 
2736   // Tile with element size in bits.
2737   shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
2738   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2739   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
2740                                             {Tile({2, 128})}, 1);
2741   EXPECT_EQ(expected, actual)
2742       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2743       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2744 
2745   // Element size in bits without tile.
2746   shape_string = "pred[123,456]{1,0:E(1)}";
2747   TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2748   expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1);
2749   EXPECT_EQ(expected, actual)
2750       << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2751       << "actual:   " << ShapeUtil::HumanStringWithLayout(actual);
2752 
2753   // Wrong minor_to_major.
2754   shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
2755   auto result = ParseShape(shape_string);
2756   ExpectHasSubstr(result.status().error_message(),
2757                   "Dimensions size is 3, but minor to major size is 1.");
2758 }
2759 
2760 TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) {
2761   string shape_string = "f32[123,456]sparse{10}";
2762   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2763   Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
2764   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2765       << "expected: " << ShapeUtil::HumanString(expected)
2766       << "actual: " << ShapeUtil::HumanString(actual);
2767 }
2768 
2769 TEST_F(HloParserTest, ParseOpaqueType) {
2770   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
2771   Shape expected = ShapeUtil::MakeOpaqueShape();
2772   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2773       << "expected: " << ShapeUtil::HumanString(expected)
2774       << "actual:   " << ShapeUtil::HumanString(actual);
2775 }
2776 
2777 TEST_F(HloParserTest, ParseTokenType) {
2778   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
2779   Shape expected = ShapeUtil::MakeTokenShape();
2780   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2781       << "expected: " << ShapeUtil::HumanString(expected)
2782       << "actual:   " << ShapeUtil::HumanString(actual);
2783 }
2784 
2785 TEST_F(HloParserTest, ParseInvalidShapeString) {
2786   string shape_strings[] = {
2787       "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
2788       "f32[123,456]dense{foo}",  "f32[123,456]sparse{foo}",
2789   };
2790   for (const string& shape_string : shape_strings) {
2791     StatusOr<Shape> result = ParseShape(shape_string);
2792     ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
2793   }
2794 }
2795 
2796 TEST_F(HloParserTest, ParseDynamicArray) {
2797   string shape_string = "f32[123,<=456]";
2798   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2799   Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
2800   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2801       << "expected: " << ShapeUtil::HumanString(expected)
2802       << "actual:   " << ShapeUtil::HumanString(actual);
2803 }
2804 
2805 TEST_F(HloParserTest, ParseDynamicTuple) {
2806   string shape_string = "(f32[42], u32[<=123,<=456])";
2807   TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2808   Shape expected = ShapeUtil::MakeTupleShape(
2809       {ShapeUtil::MakeShape(F32, {42}),
2810        ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
2811   ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2812       << "expected: " << ShapeUtil::HumanString(expected)
2813       << "actual:   " << ShapeUtil::HumanString(actual);
2814 }
2815 
2816 TEST_F(HloParserTest, NegativeParameterNumber) {
2817   const string hlo_string = "par0 = f32[3,5] parameter(-1)";
2818   auto result = ParseHloString(hlo_string);
2819   ASSERT_FALSE(result.status().ok());
2820   EXPECT_THAT(result.status().error_message(),
2821               ::testing::HasSubstr("parameter number must be >= 0"));
2822 }
2823 
2824 TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
2825   const string hlo_string =
2826       "par0 = (f32[3,5], f32[]) parameter(0), "
2827       "parameter_replication={true,false,true}";
2828   auto result = ParseHloString(hlo_string);
2829   ASSERT_FALSE(result.status().ok());
2830   EXPECT_THAT(result.status().error_message(),
2831               ::testing::HasSubstr("parameter has 2 leaf buffers, but "
2832                                    "parameter_replication has 3 elements"));
2833 }
2834 
2835 }  // namespace
2836 }  // namespace xla
2837