1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_parser.h"
17
18 #include <string>
19 #include "absl/strings/match.h"
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
25 #include "tensorflow/compiler/xla/service/pattern_matcher_gmock.h"
26 #include "tensorflow/compiler/xla/window_util.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace xla {
31 namespace {
32
33 namespace m = ::xla::match;
34 using absl::string_view;
35
36 struct TestData {
37 string test_name;
38 string module_string;
39 };
40
TestDataToString(const::testing::TestParamInfo<TestData> & data)41 string TestDataToString(const ::testing::TestParamInfo<TestData>& data) {
42 return data.param.test_name;
43 }
44
45 // For each string below, we check that:
46 // - we parse it to an HloModule successfully, and
47 // - the stringification of the resulting HloModule is equal to our original
48 // string.
CreateTestCases()49 std::vector<TestData> CreateTestCases() {
50 // clang-format off
51 return std::vector<TestData>({
52 // ax + y
53 {
54 "AxpyParam",
55 R"(HloModule axpy_module
56
57 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
58 %alpha = f32[] parameter(0)
59 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
60 %x = f32[2,4]{1,0} parameter(1)
61 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
62 %y = f32[2,4]{1,0} parameter(2)
63 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
64 }
65
66 )"
67 },
68 // parameter replication
69 {
70 "ParamReplication",
71 R"(HloModule param_replication_module
72
73 ENTRY %param_replication (a: f32[], b: (f32[2,4], (f32[2,4]))) -> (f32[], (f32[2,4], (f32[2,4]))) {
74 %a = f32[] parameter(0), parameter_replication={true}
75 %b = (f32[2,4]{1,0}, (f32[2,4]{1,0})) parameter(1), parameter_replication={false,true}
76 ROOT %tuple = (f32[], (f32[2,4]{1,0}, (f32[2,4]{1,0}))) tuple(f32[] %a, (f32[2,4]{1,0}, (f32[2,4]{1,0})) %b)
77 }
78
79 )"
80 },
81 // pred constant
82 {
83 "ConstantPred",
84 R"(HloModule constant_pred_module
85
86 ENTRY %constant_pred () -> pred[] {
87 ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68}, backend_config="foo\" bar"
88 }
89
90 )"
91 },
92 // pred array constant
93 {
94 "ConstantPredArray",
95 R"(HloModule module
96
97 ENTRY %constant_pred_array () -> pred[2,3] {
98 ROOT %constant = pred[2,3]{1,0} constant({ { 0, 1, 0 }, { 1, 0, 1 } })
99 }
100
101 )"
102 },
103
104 // s32 constant
105 {
106 "ConstantS32",
107 R"(HloModule constant_s32_module
108
109 ENTRY %constant_s32 () -> s32[] {
110 ROOT %constant = s32[] constant(-42)
111 }
112
113 )"
114 },
115 // f32 constant, but the value is not a decimal and there is a backend
116 // configuration
117 {
118 "ConstantF32",
119 R"(HloModule ConstantF32_module
120
121 ENTRY %ConstantF32.v4 () -> f32[] {
122 ROOT %constant = f32[] constant(42), backend_config="this is a configuration"
123 }
124
125 )"
126 },
127 // f32 constant, rank 1 empty array.
128 {
129 "ConstantF32R1Empty",
130 R"(HloModule ConstantF32Empty_module
131
132 ENTRY %ConstantF32Empty.v4 () -> f32[0] {
133 ROOT %constant = f32[0]{0} constant({})
134 }
135
136 )"
137 },
138 // f32 constant, rank 4 empty array.
139 {
140 "ConstantF32R4Empty",
141 R"(HloModule ConstantF32R4Empty_module
142
143 ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] {
144 ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant({ { /*i0=0*/ }, { /*i0=1*/ } })
145 }
146
147 )"
148 },
149 // constant 4D
150 {
151 "Constant4D",
152 R"(HloModule Small_3x2x1x1_module
153
154 ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] {
155 ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
156 }
157
158 )"
159 },
160 // non-finite constants: nan, inf, -inf
161 {
162 "ConstantNonFinite",
163 R"(HloModule IsFiniteR1F32s_module
164
165 ENTRY %IsFiniteR1F32s.v2 () -> pred[6] {
166 %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf})
167 ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant)
168 }
169
170 )"
171 },
172 // constant f16
173 {
174 "ConstantF16",
175 R"(HloModule ConstantF16_module
176
177 ENTRY %ConstantF16.v4 () -> f16[] {
178 ROOT %constant = f16[] constant(500)
179 }
180
181 )"
182 },
183 // bf16
184 {
185 "BF16",
186 R"(HloModule BF16
187
188 ENTRY %BF16.v4 () -> bf16[] {
189 ROOT %constant = bf16[] constant(500)
190 }
191
192 )"
193 },
194 // constant + constant
195 {
196 "AddConstants",
197 R"(HloModule add_constants_module
198
199 ENTRY %add_constants () -> f32[] {
200 %constant = f32[] constant(3.14)
201 ROOT %add = f32[] add(f32[] %constant, f32[] %constant)
202 }
203
204 )"
205 },
206 // tuple constant
207 {
208 "TupleConstant",
209 R"(HloModule TupleConstant_module
210
211 ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) {
212 ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant(( { {1}, {2} }, {2, 42} ))
213 }
214
215 )"
216 },
217 // v1 > v2 ? v1 : v2
218 {
219 "SelectR1F32",
220 R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module
221
222 ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] {
223 %v1 = f32[4]{0} parameter(0), sharding={maximal device=1}
224 %v2 = f32[4]{0} parameter(1), sharding={maximal device=1}
225 %greater-than = pred[4]{0} compare(f32[4]{0} %v1, f32[4]{0} %v2), direction=GT, sharding={replicated}
226 ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2), sharding={}
227 }
228
229 )"
230 },
231 // empty tuple
232 {
233 "EmptyTupleCreate",
234 R"(HloModule EmptyTupleCreate_module
235
236 ENTRY %EmptyTupleCreate.v1 () -> () {
237 ROOT %tuple = () tuple()
238 }
239
240 )"
241 },
242 // tuple
243 {
244 "TupleCreate",
245 R"(HloModule TupleCreate_module
246
247 ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
248 %v1 = f32[] parameter(0)
249 %v2 = f32[3]{0} parameter(1)
250 %v3 = f32[2,3]{1,0} parameter(2)
251 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3)
252 }
253
254 )"
255 },
256 {
257 "ShardedTupleCreate",
258 R"(HloModule ShardedTupleCreate_module
259
260 ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) {
261 %v1 = f32[] parameter(0)
262 %v2 = f32[3]{0} parameter(1)
263 %v3 = f32[2,3]{1,0} parameter(2)
264 ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3), sharding={{replicated}, {maximal device=0}, {replicated}}
265 }
266
267 )"
268 },
269 {
270 "DomainParsing",
271 R"(HloModule DomainParsing_module
272
273 ENTRY %DomainParsing (v1: f32[]) -> f32[] {
274 %v1 = f32[] parameter(0)
275 ROOT %dom = f32[] domain(f32[] %v1), domain={kind="sharding", entry={maximal device=0}, exit={maximal device=1}}
276 }
277
278 )"
279 },
280 // int32 result = 0;
281 // while (result < 5) { result = result + 1; }
282 {
283 "WhileWithScalarS32Result",
284 R"(HloModule WhileWithScalarS32Result_module
285
286 %body.v3 (prev.1: s32[]) -> s32[] {
287 %constant = s32[] constant(1)
288 %prev.1 = s32[] parameter(0)
289 ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1)
290 }
291
292 %condition.v3 (prev.2: s32[]) -> pred[] {
293 %constant.1 = s32[] constant(5)
294 %prev.2 = s32[] parameter(0)
295 ROOT %greater-than = pred[] compare(s32[] %constant.1, s32[] %prev.2), direction=GT
296 }
297
298 ENTRY %WhileWithScalarS32Result.v2 () -> s32[] {
299 %constant.2 = s32[] constant(0)
300 ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3
301 }
302
303 )"
304 },
305 // send and recv
306 {
307 "SendRecv",
308 R"(HloModule TwoSendRecvBothWayRecvFist_module
309
310 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
311 %token0 = token[] after-all()
312 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, sharding={maximal device=1}
313 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, sharding={maximal device=1}
314 %constant = f32[] constant(2.1), sharding={maximal device=0}
315 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv}
316 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, sharding={maximal device=0}
317 }
318
319 )"
320 },
321 {
322 "SendRecvWithHostTransfer",
323 R"(HloModule HostTransferSendRecv_module
324
325 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> (f32[], token[]) {
326 %token0 = token[] after-all()
327 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15, is_host_transfer=true
328 ROOT %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15, is_host_transfer=true
329 %constant = f32[] constant(2.1), sharding={maximal device=0}
330 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, is_host_transfer=true
331 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16, is_host_transfer=true
332 }
333
334 )"
335 },
336 // get-tuple-element
337 {
338 "GetTupleElement",
339 R"(HloModule GetTupleElement_module
340
341 ENTRY %GetTupleElement.v4 () -> s32[2,3] {
342 %constant = f32[3]{0} constant({1, 2, 3})
343 %constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
344 %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1)
345 ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0}
346 }
347
348 )"
349 },
350 // call
351 {
352 "Call",
353 R"(HloModule CallR0F32IdentityScalar_module
354
355 %Identity.v1 (x: f32[]) -> f32[] {
356 ROOT %x = f32[] parameter(0)
357 }
358
359 ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] {
360 %constant = f32[] constant(42)
361 ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1
362 }
363
364 )"
365 },
366 // reduce window
367 {
368 "ReduceWindow",
369 R"(HloModule R4UnitWindow_module
370
371 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
372 %lhs = f32[] parameter(0)
373 %rhs = f32[] parameter(1)
374 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
375 }
376
377 ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] {
378 %operand = f32[13,12,8,15]{0,3,2,1} parameter(0)
379 %constant = f32[] constant(0)
380 ROOT %reduce-window = f32[13,3,8,15]{0,3,2,1} reduce-window(f32[13,12,8,15]{0,3,2,1} %operand, f32[] %constant), window={size=1x1x7x1 stride=1x4x1x1 pad=0_0x0_0x3_3x0_0}, to_apply=%add_F32.v3
381 }
382
383 )"
384 },
385 // reduce window on scalar
386 {
387 "ReduceWindowScalar",
388 R"(HloModule reduce_window_scalar
389
390 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
391 %lhs = f32[] parameter(0)
392 %rhs = f32[] parameter(1)
393 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
394 }
395
396 ENTRY %R4UnitWindowScalar () -> f32[] {
397 %constant = f32[] constant(42)
398 %constant.1 = f32[] constant(1)
399 ROOT %reduce-window = f32[] reduce-window(f32[] %constant, f32[] %constant.1), to_apply=%add_F32.v3
400 }
401
402 )"
403 },
404 // convolution
405 {
406 "Convolution",
407 R"(HloModule Convolve1D1Window_0_module
408
409 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
410 %input = f32[1,2,1]{2,1,0} parameter(0)
411 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
412 %filter = f32[1,1,1]{2,1,0} parameter(1)
413 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f, operand_precision={high,default}
414 }
415
416 )"
417 },
418 // convolution rank 2
419 {
420 "ConvolutionR2",
421 R"(HloModule ConvolveR2_module
422
423 ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] {
424 %input = f32[1,2]{1,0} parameter(0)
425 %filter = f32[1,1]{1,0} parameter(1)
426 ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf
427 }
428
429 )"
430 },
431 // convolution backward
432 {
433 "ConvolutionBackward",
434 R"(HloModule ConvolveBackward_module
435
436 ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] {
437 %input = f32[128,7,7,512]{0,3,2,1} parameter(0)
438 %filter = f32[3,3,512,512]{3,2,1,0} parameter(1)
439 ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f
440 }
441
442 )"
443 },
444 // reverse(constant)
445 {
446 "Reverse4D",
447 R"(HloModule Reverse4DFloatArrayOnDim01_module
448
449 ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] {
450 %constant = f32[4,3,2,1]{0,1,2,3} constant({ { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } })
451 ROOT %reverse = f32[4,3,2,1]{0,1,2,3} reverse(f32[4,3,2,1]{0,1,2,3} %constant), dimensions={0,1}
452 }
453
454 )"
455 },
456 // concat
457 {
458 "Concat",
459 R"(HloModule Concat2x3With2x5_module
460
461 ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] {
462 %constant = f32[2,3]{1,0} constant({ { 0, 1, 2 }, { 1000, 1001, 1002 } })
463 %constant.1 = f32[2,5]{1,0} constant({ { 64, 65, 66, 67, 68 }, { 1064, 1065, 1066, 1067, 1068 } })
464 ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1}
465 }
466
467 )"
468 },
469 // select and scatter
470 {
471 "SelectAndScatter",
472 R"(HloModule R4F32OverlapSmall_module
473
474 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
475 %lhs = f32[] parameter(0)
476 %rhs = f32[] parameter(1)
477 ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
478 }
479
480 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
481 %lhs.1 = f32[] parameter(0)
482 %rhs.1 = f32[] parameter(1)
483 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
484 }
485
486 ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] {
487 %constant = f32[4,5,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {7} }, { /*i1=1*/ {2} }, { /*i1=2*/ {5} }, { /*i1=3*/ {3} }, { /*i1=4*/ {8} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {8} }, { /*i1=2*/ {9} }, { /*i1=3*/ {3} }, { /*i1=4*/ {4} } }, { /*i0=2*/ { /*i1=0*/ {1} }, { /*i1=1*/ {5} }, { /*i1=2*/ {7} }, { /*i1=3*/ {5} }, { /*i1=4*/ {6} } }, { /*i0=3*/ { /*i1=0*/ {0} }, { /*i1=1*/ {6} }, { /*i1=2*/ {2} }, { /*i1=3*/ {10} }, { /*i1=4*/ {2} } } })
488 %constant.1 = f32[2,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {2} }, { /*i1=1*/ {6} } }, { /*i0=1*/ { /*i1=0*/ {3} }, { /*i1=1*/ {1} } } })
489 %constant.2 = f32[] constant(0)
490 ROOT %select-and-scatter = f32[4,5,1,1]{3,2,1,0} select-and-scatter(f32[4,5,1,1]{3,2,1,0} %constant, f32[2,2,1,1]{3,2,1,0} %constant.1, f32[] %constant.2), window={size=2x3x1x1 stride=2x2x1x1}, select=%ge_F32.v3, scatter=%add_F32.v3
491 }
492
493 )"
494 },
495 // select and scatter on scalar
496 {
497 "SelectAndScatterScalar",
498 R"(HloModule select_and_scatter_scalar
499
500 %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] {
501 %lhs = f32[] parameter(0)
502 %rhs = f32[] parameter(1)
503 ROOT %greater-than-or-equal-to = pred[] compare(f32[] %lhs, f32[] %rhs), direction=GE
504 }
505
506 %add_F32.v3 (lhs.1: f32[], rhs.1: f32[]) -> f32[] {
507 %lhs.1 = f32[] parameter(0)
508 %rhs.1 = f32[] parameter(1)
509 ROOT %add = f32[] add(f32[] %lhs.1, f32[] %rhs.1)
510 }
511
512 ENTRY %SelectAndScatterScalar () -> f32[] {
513 %constant = f32[] constant(42)
514 %constant.1 = f32[] constant(1)
515 %constant.2 = f32[] constant(2)
516 ROOT %select-and-scatter = f32[] select-and-scatter(f32[] %constant, f32[] %constant.1, f32[] %constant.2), select=%ge_F32.v3, scatter=%add_F32.v3
517 }
518
519 )"
520 },
521 // slice
522 {
523 "Slice",
524 R"(HloModule slice_module
525
526 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
527 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
528 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3:1], [0:3:1], [0:4:2], [0:4:1]}
529 }
530
531 )"
532 },
533 // slice, no stride
534 {
535 "SliceNoStride",
536 R"(HloModule Slice3x3x3_To_1x3x3_F32_module
537
538 ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] {
539 %constant = f32[3,3,3]{2,1,0} constant({ { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } })
540 ROOT %slice = f32[1,3,3]{2,1,0} slice(f32[3,3,3]{2,1,0} %constant), slice={[0:1], [0:3], [0:3]}
541 }
542
543 )"
544 },
545 // slice R0
546 {
547 "SliceR0",
548 R"(HloModule SliceR0_module
549
550 ENTRY %SliceR0.v2 () -> s32[] {
551 %constant = s32[] constant(1)
552 ROOT %slice = s32[] slice(s32[] %constant), slice={}
553 }
554
555 )"
556 },
557 // transpose
558 {
559 "Transpose",
560 R"(HloModule Transpose_module
561
562 ENTRY %Transpose.v2 () -> s32[1,2,3] {
563 %constant = s32[1,2,3]{2,1,0} constant({ { { 1, 2, 3 }, { 4, 5, 6 } } })
564 ROOT %transpose = s32[1,2,3]{2,1,0} transpose(s32[1,2,3]{2,1,0} %constant), dimensions={0,1,2}
565 }
566
567 )"
568 },
569 {
570 "TransposeC128",
571 R"(HloModule TransposeC128_module
572
573 ENTRY %Transpose.v3 (input: c128[1,2,3]) -> c128[1,2,3] {
574 %input = c128[1,2,3]{2,1,0} parameter(0)
575 ROOT %transpose = c128[1,2,3]{2,1,0} transpose(c128[1,2,3]{2,1,0} %input), dimensions={0,1,2}
576 }
577
578 )"
579 },
580 // Triangular solve
581 {
582 "TriangularSolve",
583 R"(HloModule TriangularSolve_module
584
585 ENTRY %SimpleRightLowerNotranspose.4 (a.1: f32[4,4], b.2: f32[3,4]) -> f32[3,4] {
586 %a.1 = f32[4,4]{1,0} parameter(0)
587 %b.2 = f32[3,4]{1,0} parameter(1)
588 ROOT %triangular-solve.3 = f32[3,4]{1,0} triangular-solve(f32[4,4]{1,0} %a.1, f32[3,4]{1,0} %b.2), lower=true, transpose_a=NO_TRANSPOSE
589 }
590
591 )"
592 },
593 // Dynamic slice
594 {
595 "DynamicSlice",
596 R"(HloModule DynamicSlice_module
597
598 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] {
599 %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
600 %constant = s32[1]{0} constant({0})
601 %start_index = s32[1]{0} parameter(1)
602 %concatenate = s32[3]{0} concatenate(s32[1]{0} %constant, s32[1]{0} %constant, s32[1]{0} %start_index), dimensions={0}
603 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[3]{0} %concatenate), dynamic_slice_sizes={2,2,258}
604 }
605
606 )"
607 },
608 // Dynamic slice with scalar indices
609 {
610 "DynamicSliceScalarIndices",
611 R"(HloModule DynamicSlice_module
612
613 ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[]) -> s32[2,2,258] {
614 %original_parameter = s32[2,2,258]{2,1,0} parameter(0)
615 %constant = s32[] constant(0)
616 %start_index = s32[] parameter(1)
617 ROOT %dynamic-slice = s32[2,2,258]{2,1,0} dynamic-slice(s32[2,2,258]{2,1,0} %original_parameter, s32[] %constant, s32[] %constant, s32[] %start_index), dynamic_slice_sizes={2,2,258}
618 }
619
620 )"
621 },
622 // Dynamic update slice
623 {
624 "DynamicUpdateSlice",
625 R"(HloModule DynamicSlice_module
626
627 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] {
628 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
629 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
630 %start_indices = s32[4]{0} parameter(2)
631 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[4]{0} %start_indices)
632 }
633
634 )"
635 },
636 // Dynamic update slice with scalar indices
637 {
638 "DynamicUpdateSliceScalarIndex",
639 R"(HloModule DynamicUpdateSlice_module
640
641 ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_index.0: s32[], start_index.1: s32[], start_index.2: s32[], start_index.3: s32[]) -> s32[1,1,25,1] {
642 %input = s32[1,1,25,1]{3,2,1,0} parameter(0)
643 %update = s32[1,1,2,1]{3,2,1,0} parameter(1)
644 %start_index.0 = s32[] parameter(2)
645 %start_index.1 = s32[] parameter(3)
646 %start_index.2 = s32[] parameter(4)
647 %start_index.3 = s32[] parameter(5)
648 ROOT %dynamic-update-slice = s32[1,1,25,1]{3,2,1,0} dynamic-update-slice(s32[1,1,25,1]{3,2,1,0} %input, s32[1,1,2,1]{3,2,1,0} %update, s32[] %start_index.0, s32[] %start_index.1, s32[] %start_index.2, s32[] %start_index.3)
649 }
650
651 )"
652 },
653 // batch norm training
654 {
655 "BatchNormTraining",
656 R"(HloModule BasicTraining_module
657
658 ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) {
659 %constant = f32[2,2,1,2]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ { 1, 2 } }, { /*i1=1*/ { 3, 4 } } }, { /*i0=1*/ { /*i1=0*/ { 5, 6 } }, { /*i1=1*/ { 7, 8 } } } })
660 %constant.1 = f32[2]{0} constant({2, 3})
661 %constant.2 = f32[2]{0} constant({1, 2})
662 ROOT %batch-norm-training = (f32[2,2,1,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-training(f32[2,2,1,2]{3,2,1,0} %constant, f32[2]{0} %constant.1, f32[2]{0} %constant.2), epsilon=0.001, feature_index=3
663 }
664
665 )"
666 },
667 // batch norm inference
668 {
669 "BatchNormInference",
670 R"(HloModule BatchNormInference_module
671
672 ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] {
673 %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
674 %offset = f32[2]{0} parameter(1)
675 %scale = f32[2]{0} parameter(2)
676 %mean = f32[2]{0} parameter(3)
677 %variance = f32[2]{0} parameter(4)
678 ROOT %batch-norm-inference = f32[2,2,2,2]{3,2,1,0} batch-norm-inference(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %offset, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance), epsilon=0.001, feature_index=0
679 }
680
681 )"
682 },
683 // batch norm grad
684 {
685 "BatchNormGrad",
686 R"(HloModule BatchNormGrad_module
687
688 ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) {
689 %input = f32[2,2,2,2]{3,2,1,0} parameter(0)
690 %scale = f32[2]{0} parameter(1)
691 %mean = f32[2]{0} parameter(2)
692 %variance = f32[2]{0} parameter(3)
693 %grad_output = f32[2,2,2,2]{3,2,1,0} parameter(4)
694 ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0
695 }
696
697 )"
698 },
699 // fft
700 {
701 "Fft",
702 R"(HloModule Fft_module
703
704 ENTRY %Fft (input: c64[8,32]) -> c64[8,32] {
705 %input = c64[8,32]{1,0} parameter(0)
706 ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32}
707 }
708
709 )"
710 },
711 // ifft
712 {
713 "Ifft2d",
714 R"(HloModule Ifft2d_module
715
716 ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] {
717 %input = c64[5,8,32]{2,1,0} parameter(0)
718 ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32}
719 }
720
721 )"
722 },
723 // rfft2d
724 {
725 "Rfft2d",
726 R"(HloModule Rfft2d_module
727
728 ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] {
729 %input = f32[5,64,32]{2,1,0} parameter(0)
730 ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32}
731 }
732
733 )"
734 },
735 // irfft3d
736 {
737 "Irfft3d",
738 R"(HloModule Irfft3d_module
739
740 ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] {
741 %input = c64[5,64,128,33]{3,2,1,0} parameter(0)
742 ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64}
743 }
744
745 )"
746 },
747 // pad
748 {
749 "Pad",
750 R"(HloModule Pad1DS3Array_module
751
752 ENTRY %Pad1DS3Array.v3 () -> f32[8] {
753 %constant = f32[3]{0} constant({1, 2, 3})
754 %constant.1 = f32[] constant(0.1)
755 ROOT %pad = f32[8]{0} pad(f32[3]{0} %constant, f32[] %constant.1), padding=3_1
756 }
757
758 )"
759 },
760 // pad has interior
761 {
762 "PadHasInterior",
763 R"(HloModule PadHasInterior_module
764
765 ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] {
766 %input = f32[1,25,7,7]{3,2,1,0} parameter(0)
767 %constant = f32[] constant(-5.123)
768 ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0
769 }
770
771 )"
772 },
773 // Negative padding
774 {
775 "PadHasNegativePadding",
776 R"(HloModule PadHasNegativePadding_module
777
778 ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] {
779 %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0)
780 %constant = f32[] constant(-5.123)
781 ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3
782 }
783
784 )"
785 },
786 // fusion
787 {
788 "Fusion",
789 R"(HloModule fusion_module
790
791 %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] {
792 %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
793 %constant.1.param_1 = f32[2]{0} parameter(1)
794 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %constant.1.param_1), dimensions={1}
795 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %constant.param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
796 }
797
798 ENTRY %fusion.v3 () -> f32[3,2,1,1] {
799 %constant = f32[3,2,1,1]{3,2,1,0} constant({ { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } })
800 %constant.1 = f32[2]{0} constant({3.14, 4.25})
801 ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation
802 }
803
804 )"
805 },
806 {
807 "Sparse",
808 R"(HloModule sparse_f32
809
810 ENTRY %sparse () -> f32[2,3,4] {
811 ROOT %foo = f32[2,3,4]sparse{10} constant({[0, 1, 2]: 1, [1, 2, 2]: 2, [1, 2, 3]: 3})
812 }
813
814 )"
815 },
816 {
817 "SparseC128",
818 R"(HloModule sparse_c128
819
820 ENTRY %sparse () -> c128[2,3,4] {
821 ROOT %foo = c128[2,3,4]sparse{10} constant({[0, 1, 2]: (1, 0), [1, 2, 2]: (2, 5), [1, 2, 3]: (3, 10)})
822 }
823
824 )"
825 },
826 {
827 "SparseEmpty",
828 R"(HloModule sparse_f32_empty
829
830 ENTRY %sparse_f32_empty () -> f32[2,3,4] {
831 ROOT %foo = f32[2,3,4]sparse{10} constant({})
832 }
833
834 )"
835 },
836 {
837 "SparseR1",
838 R"(HloModule sparse_f32_r1
839
840 ENTRY %sparse_f32_r1 () -> f32[9] {
841 ROOT %foo = f32[9]sparse{10} constant({1: 2, 3: 4, 5: 6})
842 }
843
844 )"
845 },
846 {
847 "gather",
848 R"(HloModule StringifyGather
849
850 ENTRY %Gather (input_tensor: f32[50,49,48,47,46], start_indices: s64[10,9,8,7,5]) -> f32[10,9,8,7,30,29,28,27,26] {
851 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
852 %start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
853 ROOT %gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
854 }
855
856 )"
857 },
858 {
859 "scatter",
860 R"(HloModule StringifyScatter
861
862 %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] {
863 %lhs = f32[] parameter(0)
864 %rhs = f32[] parameter(1)
865 ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
866 }
867
868 ENTRY %Scatter (input_tensor: f32[50,49,48,47,46], scatter_indices: s64[10,9,8,7,5], updates: f32[10,9,8,7,30,29,28,27,26]) -> f32[50,49,48,47,46] {
869 %input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
870 %scatter_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
871 %updates = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} parameter(2)
872 ROOT %scatter = f32[50,49,48,47,46]{4,3,2,1,0} scatter(f32[50,49,48,47,46]{4,3,2,1,0} %input_tensor, s64[10,9,8,7,5]{4,3,2,1,0} %scatter_indices, f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} %updates), update_window_dims={4,5,6,7,8}, inserted_window_dims={}, scatter_dims_to_operand_dims={0,1,2,3,4}, index_vector_dim=4, to_apply=%add_F32.v3
873 }
874
875 )"
876 },
877 {
878 "ConstantUnsignedNoUnderflow",
879 R"(HloModule ConstantUnsignedNoUnderflow_module
880
881 ENTRY %ConstantUnsignedNoUnderflow () -> u64[] {
882 ROOT %constant = u64[] constant(1)
883 }
884
885 )"
886 },
887
888 {
889 "ConstantUnsignedNoOverflow",
890 R"(HloModule ConstantUnsignedNoOverflow_module
891
892 ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
893 ROOT %constant = u64[] constant(9223372036854775807)
894 }
895
896 )"
897 },
898 // CustomCallWithLayoutConstraints
899 {
900 "CustomCallWithLayoutConstraints",
901 R"(HloModule CustomCallWithLayoutConstraints
902
903 ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
904 %p0 = f32[42,2,3]{0,1,2} parameter(0)
905 %p1 = f32[123,4]{0,1} parameter(1)
906 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
907 }
908
909 )"
910 },
911 // CustomCallWithLayoutConstraintsNoOperands
912 {
913 "CustomCallWithLayoutConstraintsNoOperands",
914 R"(HloModule CustomCallWithLayoutConstraintsNoOperands
915
916 ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
917 ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
918 }
919
920 )"
921 },
922 // CustomCallWithLayoutConstraintsTupleShapes
923 {
924 "CustomCallWithLayoutConstraintsTupleShapes",
925 R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
926
927 ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
928 %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
929 %p1 = f32[123,4]{0,1} parameter(1)
930 ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
931 }
932
933 )"
934 },
935 // Parse c64 literal
936 {
937 "ParseC64Literal",
938 R"(HloModule ParseC64Literal
939
940 ENTRY %ParseC64Literal () -> c64[2] {
941 ROOT %c = c64[2]{0} constant({(1, 2), (-inf, nan)})
942 }
943
944 )"
945 },
946 // Parse c128 literal
947 {
948 "ParseC128Literal",
949 R"(HloModule ParseC128Literal
950
951 ENTRY %ParseC128Literal () -> c128[2] {
952 ROOT %c = c128[2]{0} constant({(1, 2), (-inf, nan)})
953 }
954
955 )"
956 },
957 // Indexed Conditional
958 {
959 "IndexedConditional",
960 R"(HloModule indexed_conditional
961
962 %Negate (x: f32[]) -> f32[] {
963 %x = f32[] parameter(0)
964 ROOT %negate = f32[] negate(f32[] %x)
965 }
966
967 %Identity (y: f32[]) -> f32[] {
968 %y = f32[] parameter(0)
969 ROOT %copy = f32[] copy(f32[] %y)
970 }
971
972 %Floor (z: f32[]) -> f32[] {
973 %z = f32[] parameter(0)
974 ROOT %floor = f32[] floor(f32[] %z)
975 }
976
977 ENTRY %Parameters1.v4 () -> f32[] {
978 %constant = s32[] constant(1)
979 %constant.1 = f32[] constant(56)
980 %constant.2 = f32[] constant(12)
981 %constant.3 = f32[] constant(13)
982 ROOT %conditional = f32[] conditional(s32[] %constant, f32[] %constant.1, f32[] %constant.2, f32[] %constant.3), branch_computations={%Negate, %Identity, %Floor}
983 }
984
985 )"
986 },
987 });
988 // clang-format on
989 }
990
991 std::vector<TestData> CreateShortTestCases() {
992 // clang-format off
993 return std::vector<TestData>({
994 // map
995 {
996 "Map",
997 R"(HloModule MapBinaryAdder_module
998
999 add_F32.v3 {
1000 lhs = f32[] parameter(0)
1001 rhs = f32[] parameter(1)
1002 ROOT add = f32[] add(lhs, rhs)
1003 }
1004
1005 ENTRY MapBinaryAdder.v3 {
1006 param0 = f32[4]{0} parameter(0)
1007 param1 = f32[4]{0} parameter(1)
1008 ROOT map = f32[4]{0} map(param0, param1), dimensions={0}, to_apply=add_F32.v3
1009 }
1010
1011 )"
1012 },
1013 // reduce
1014 {
1015 "Reduce",
1016 R"(HloModule ReduceR3ToR2_module
1017
1018 add_F32.v3 {
1019 lhs = f32[] parameter(0)
1020 rhs = f32[] parameter(1)
1021 ROOT add = f32[] add(lhs, rhs)
1022 }
1023
1024 ENTRY ReduceR3ToR2.v3 {
1025 input = f32[8,16,256]{2,1,0} parameter(0)
1026 constant = f32[] constant(0)
1027 ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
1028 }
1029
1030 )"
1031 },
1032 // tuple reduce
1033 {
1034 "TupleReduce",
1035 R"(HloModule TupleReduce
1036
1037 max_argmax {
1038 value = f32[] parameter(2)
1039 prev_max = f32[] parameter(0)
1040 is_next_larger = pred[] compare(value, prev_max), direction=GE
1041 max = f32[] select(is_next_larger, value, prev_max)
1042 index = s32[] parameter(3)
1043 prev_argmax = s32[] parameter(1)
1044 argmax = s32[] select(is_next_larger, index, prev_argmax)
1045 ROOT pair = (f32[], s32[]) tuple(max, argmax)
1046 }
1047
1048 ENTRY reduce_entry {
1049 values = f32[1024]{0} parameter(0)
1050 indices = f32[1024]{0} parameter(1)
1051 init_value = f32[] constant(-inf)
1052 init_index = s32[] constant(-1)
1053 ROOT result = (f32[], s32[]) reduce(values, indices, init_value, init_index), dimensions={0}, to_apply=max_argmax
1054 }
1055
1056 )"
1057 },
1058 // infeed/outfeed
1059 {
1060 "InfeedOutfeed",
1061 R"(HloModule outfeed_module
1062
1063 ENTRY InfeedToOutfeed {
1064 token0 = token[] after-all()
1065 infeed = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1066 infeed.data = (u32[3]{0}, pred[]) get-tuple-element(infeed), index=0
1067 outfeed = token[] outfeed(infeed.data, token0)
1068 ROOT infeed.1 = ((u32[3]{0}, pred[]), token[]) infeed(token0)
1069 infeed.1.data = (u32[3]{0}, pred[]) get-tuple-element(infeed.1), index=0
1070 infeed.1.token = token[] get-tuple-element(infeed.1), index=1
1071 outfeed.1 = token[] outfeed(infeed.1.data, infeed.1.token)
1072 }
1073
1074 )"
1075 },
1076 // Rng
1077 {
1078 "Rng",
1079 R"(HloModule rng_module
1080
1081 ENTRY Rng {
1082 constant = f32[] constant(0)
1083 constant.1 = f32[] constant(1)
1084 ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform
1085 }
1086
1087 )"
1088 },
1089 // Reduce precision
1090 {
1091 "ReducePrevison",
1092 R"(HloModule reduce_precision
1093
1094 ENTRY ReducePrecision {
1095 constant = f32[1]{0} constant({3.14159})
1096 ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10
1097 }
1098
1099 )"
1100 },
1101 // Sort (Key)
1102 {
1103 "SortKey",
1104 R"(HloModule sort
1105
1106 compare {
1107 p.0.lhs = f32[] parameter(0)
1108 p.0.rhs = f32[] parameter(1)
1109 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1110 }
1111
1112 ENTRY Sort {
1113 x = f32[1024]{0} parameter(0)
1114 ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, to_apply=compare
1115 }
1116
1117 )"
1118 },
1119 // Sort (Key, Value)
1120 {
1121 "SortKeyValue",
1122 R"(HloModule sort
1123
1124 compare {
1125 p.1.lhs = s32[] parameter(2)
1126 p.1.rhs = s32[] parameter(3)
1127 p.0.lhs = f32[] parameter(0)
1128 p.0.rhs = f32[] parameter(1)
1129 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1130 }
1131
1132 ENTRY Sort {
1133 keys = f32[1024]{0} parameter(0)
1134 values = s32[1024]{0} parameter(1)
1135 ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1136 }
1137
1138 )"
1139 },
1140 // R2 Sort (Key)
1141 {
1142 "SortKeyR2",
1143 R"(HloModule sort
1144
1145 compare {
1146 p.0.lhs = f32[] parameter(0)
1147 p.0.rhs = f32[] parameter(1)
1148 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1149 }
1150
1151 ENTRY Sort {
1152 x = f32[1024,16]{0,1} parameter(0)
1153 ROOT sorted = f32[1024,16]{0,1} sort(x), dimensions={0}, to_apply=compare
1154 }
1155
1156 )"
1157 },
1158 // R2 Sort (Key, Value)
1159 {
1160 "SortKeyValueR2",
1161 R"(HloModule sort
1162
1163 compare {
1164 p.1.lhs = s32[] parameter(2)
1165 p.1.rhs = s32[] parameter(3)
1166 p.0.lhs = f32[] parameter(0)
1167 p.0.rhs = f32[] parameter(1)
1168 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1169 }
1170
1171 ENTRY Sort {
1172 keys = f32[1024,16]{0,1} parameter(0)
1173 values = s32[1024,16]{0,1} parameter(1)
1174 ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}) sort(keys, values), dimensions={0}, to_apply=compare
1175 }
1176
1177 )"
1178 },
1179 // Sort (Key, Value, Value, Value)
1180 {
1181 "SortManyValues",
1182 R"(HloModule sort
1183
1184 compare {
1185 p.1.lhs = s32[] parameter(2)
1186 p.1.rhs = s32[] parameter(3)
1187 p.2.lhs = u32[] parameter(4)
1188 p.2.rhs = u32[] parameter(5)
1189 p.3.lhs = f32[] parameter(6)
1190 p.3.rhs = f32[] parameter(7)
1191 p.0.lhs = f32[] parameter(0)
1192 p.0.rhs = f32[] parameter(1)
1193 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1194 }
1195
1196 ENTRY Sort {
1197 keys = f32[1024,16]{0,1} parameter(0)
1198 values.0 = s32[1024,16]{0,1} parameter(1)
1199 values.1 = u32[1024,16]{0,1} parameter(2)
1200 values.2 = f32[1024,16]{0,1} parameter(3)
1201 ROOT sorted = (f32[1024,16]{0,1}, s32[1024,16]{0,1}, u32[1024,16]{0,1}, f32[1024,16]{0,1}) sort(keys, values.0, values.1, values.2), dimensions={0}, to_apply=compare
1202 }
1203
1204 )"
1205 },
1206 // Sort (Key) is_stable=true
1207 {
1208 "SortKeyStable",
1209 R"(HloModule sort
1210
1211 compare {
1212 p.0.lhs = f32[] parameter(0)
1213 p.0.rhs = f32[] parameter(1)
1214 ROOT lt = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1215 }
1216
1217 ENTRY Sort {
1218 x = f32[1024]{0} parameter(0)
1219 ROOT sorted = f32[1024]{0} sort(x), dimensions={0}, is_stable=true, to_apply=compare
1220 }
1221
1222 )"
1223 },
1224 // Indexed Conditional
1225 {
1226 "IndexedConditional",
1227 R"(HloModule indexed_conditional
1228
1229 Negate {
1230 x = f32[] parameter(0)
1231 ROOT negate = f32[] negate(x)
1232 }
1233
1234 Identity {
1235 y = f32[] parameter(0)
1236 ROOT copy = f32[] copy(y)
1237 }
1238
1239 Floor {
1240 z = f32[] parameter(0)
1241 ROOT floor = f32[] floor(z)
1242 }
1243
1244 ENTRY Parameters1.v4 {
1245 constant = s32[] constant(1)
1246 constant.1 = f32[] constant(56)
1247 constant.2 = f32[] constant(12)
1248 constant.3 = f32[] constant(13)
1249 ROOT conditional = f32[] conditional(constant, constant.1, constant.2, constant.3), branch_computations={Negate, Identity, Floor}
1250 }
1251
1252 )"
1253 },
1254 // Predicated Conditional
1255 {
1256 "PredicatedConditional",
1257 R"(HloModule pred_conditional
1258
1259 Negate {
1260 x = f32[] parameter(0)
1261 ROOT negate = f32[] negate(x)
1262 }
1263
1264 Identity {
1265 y = f32[] parameter(0)
1266 ROOT copy = f32[] copy(y)
1267 }
1268
1269 ENTRY Parameters1.v4 {
1270 constant = pred[] constant(true)
1271 constant.1 = f32[] constant(56)
1272 constant.2 = f32[] constant(12)
1273 ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity
1274 }
1275
1276 )"
1277 },
1278 // CustomCall
1279 {
1280 "CustomCall",
1281 R"(HloModule custom_call
1282
1283 ENTRY CustomCall {
1284 constant = f32[1]{0} constant({12345})
1285 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar"
1286 }
1287
1288 )"
1289 },
1290 // CustomCall with opaque value.
1291 {
1292 "CustomCallWithOpaque",
1293 R"(HloModule custom_call
1294
1295 ENTRY CustomCall {
1296 constant = f32[1]{0} constant({12345})
1297 ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque"
1298 }
1299
1300 )"
1301 },
1302 // Variables with non-default names
1303 {
1304 "NonDefaultNames",
1305 R"(HloModule add_constants_module
1306
1307 ENTRY add_constants {
1308 foo = f32[] constant(3.14)
1309 ROOT bar = f32[] add(foo, foo)
1310 }
1311
1312 )"
1313 },
1314 {
1315 "Dot",
1316 R"(HloModule dot
1317
1318 ENTRY dot {
1319 a = f32[2,10]{1,0} parameter(0)
1320 b = f32[10,3]{1,0} parameter(1)
1321 ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0}
1322 }
1323
1324 )"
1325 },
1326 {
1327 "gather",
1328 R"(HloModule gather
1329
1330 ENTRY Gather {
1331 input_tensor = f32[50,49,48,47,46]{4,3,2,1,0} parameter(0)
1332 start_indices = s64[10,9,8,7,5]{4,3,2,1,0} parameter(1)
1333 ROOT gather = f32[10,9,8,7,30,29,28,27,26]{8,7,6,5,4,3,2,1,0} gather(input_tensor, start_indices), offset_dims={4,5,6,7,8}, collapsed_slice_dims={}, start_index_map={0,1,2,3,4}, index_vector_dim=4, slice_sizes={30,29,28,27,26}
1334 }
1335
1336 )"
1337 },
1338 // all-reduce
1339 {
1340 "AllReduce",
1341 R"(HloModule CRS
1342
1343 add {
1344 lhs = f32[] parameter(0)
1345 rhs = f32[] parameter(1)
1346 ROOT add = f32[] add(lhs, rhs)
1347 }
1348
1349 ENTRY CRS {
1350 input = f32[8]{0} parameter(0)
1351 ROOT crs = f32[8]{0} all-reduce(input), replica_groups={}, to_apply=add
1352 }
1353
1354 )"
1355 },
1356 // all-reduce with subgroups
1357 {
1358 "AllReduceWithSubgroups",
1359 R"(HloModule CRS_Subgroups
1360
1361 add {
1362 lhs = f32[] parameter(0)
1363 rhs = f32[] parameter(1)
1364 ROOT add = f32[] add(lhs, rhs)
1365 }
1366
1367 ENTRY AllReduceWithSubgroups {
1368 input = f32[128,32]{0,1} parameter(0)
1369 ROOT all-reduce = f32[128,32]{0,1} all-reduce(input), replica_groups={{0,1},{2,3}}, barrier="abc", to_apply=add
1370 }
1371
1372 )"
1373 },
1374 // all-reduce with all-reduce-id
1375 {
1376 "AllReduceAllReduce",
1377 R"(HloModule CRS
1378
1379 add {
1380 lhs = f32[] parameter(0)
1381 rhs = f32[] parameter(1)
1382 ROOT add = f32[] add(lhs, rhs)
1383 }
1384
1385 ENTRY CRS {
1386 input = f32[8]{0} parameter(0)
1387 crs.1 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
1388 ROOT crs.0 = f32[8]{0} all-reduce(input), replica_groups={{0}}, all_reduce_id=1, to_apply=add
1389 }
1390
1391 )"
1392 },
1393 // all-to-all
1394 {
1395 "AllToAll",
1396 R"(HloModule AllToAll
1397
1398 ENTRY AllToAll {
1399 input = f32[128,32]{0,1} parameter(0)
1400 ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={}
1401 }
1402
1403 )"
1404 },
1405 // all-to-all with subgroups
1406 {
1407 "AllToAllWithSubgroups",
1408 R"(HloModule AllToAllWithSubgroups
1409
1410 ENTRY AllToAllWithSubgroups {
1411 input = f32[128,32]{0,1} parameter(0)
1412 ROOT a2a = f32[128,32]{0,1} all-to-all(input), replica_groups={{1,2},{3,0}}
1413 }
1414
1415 )"
1416 },
1417 // collective-permute
1418 {
1419 "CollectivePermute",
1420 R"(HloModule CollectivePermute
1421
1422 ENTRY CollectivePermute {
1423 input = f32[128,32]{0,1} parameter(0)
1424 ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}}
1425 }
1426
1427 )"
1428 },
1429 // replica-id
1430 {
1431 "ReplicaId",
1432 R"(HloModule replica-id
1433
1434 ENTRY Replica-id {
1435 ROOT replica-id = u32[] replica-id()
1436 }
1437
1438 )"
1439 },
1440 // Iota
1441 {
1442 "Iota",
1443 R"(HloModule iota
1444
1445 ENTRY Iota {
1446 ROOT iota = f32[100]{0} iota(), iota_dimension=0
1447 }
1448
1449 )"
1450 },
1451 // custom-call with window, dim_labels and feature_group_count
1452 {
1453 "CustomCallWithWindowAndDimLabelsAndFeatureGroupCount",
1454 R"(HloModule CustomCallWithWindowAndDimLabelsAndFeatureGroupCount
1455
1456 ENTRY Computation {
1457 ROOT r = f32[100]{0} custom-call(), window={size=2x2}, dim_labels=b01f_01io->b01f, feature_group_count=2, custom_call_target="target"
1458 }
1459
1460 )"
1461 },
1462 // is_scheduled=true attribute
1463 {
1464 "ScheduledModule",
1465 R"(HloModule scheduled_module, is_scheduled=true
1466
1467 compare {
1468 p.1.lhs = s32[] parameter(2)
1469 p.1.rhs = s32[] parameter(3)
1470 p.0.lhs = f32[] parameter(0)
1471 p.0.rhs = f32[] parameter(1)
1472 ROOT lhs = pred[] compare(p.0.lhs, p.0.rhs), direction=LT
1473 }
1474
1475 ENTRY Sort {
1476 keys = f32[1024]{0} parameter(0)
1477 values = s32[1024]{0} parameter(1)
1478 ROOT sorted = (f32[1024]{0}, s32[1024]{0}) sort(keys, values), dimensions={0}, to_apply=compare
1479 }
1480
1481 )"
1482 },
1483 // AfterAll with multiple operands
1484 {
1485 "AfterAllWithMultipleOperands",
1486 R"(HloModule AfterAllWithMultipleOperands
1487
1488 ENTRY AfterAllWithMultipleOperands {
1489 p0 = f32[] parameter(0)
1490 token0 = token[] after-all()
1491 token1 = token[] after-all()
1492 ROOT after-all = token[] after-all(p0, token0, token1)
1493 }
1494
1495 )"
1496 },
1497 // AddDependency
1498 // A dependency chain is created from 'neg' to 'exp' using tokens.
1499 {
1500 "AddDependency",
1501 R"(HloModule AddDependency
1502
1503 ENTRY AddDependency {
1504 p = f32[] parameter(0)
1505 neg = f32[] negate(p)
1506 token0 = token[] after-all(neg)
1507 p_after_token = f32[] add-dependency(p, token0)
1508 exp = f32[] exponential(p_after_token)
1509 ROOT sum = f32[] add(neg, exp)
1510 }
1511
1512 )"
1513 },
1514
1515 // A module containing constants equal to the min/max values of various data
1516 // types.
1517 {
1518 "MinMaxValues",
1519 R"(HloModule MinMaxValues
1520
1521 ENTRY MinMaxValues {
1522 x.s8 = s8[2]{0} constant({-128, 127})
1523 x.s16 = s16[2]{0} constant({-32768, 32767})
1524 x.s32 = s32[2]{0} constant({-2147483648, 2147483647})
1525 x.u8 = u8[2]{0} constant({0, 255})
1526 x.u16 = u16[2]{0} constant({0, 65535})
1527 x.u32 = u32[2]{0} constant({0, 4294967295})
1528 x.f16 = f16[2]{0} constant({-65504, 65504})
1529 x.bf16 = bf16[2]{0} constant({-3.38953e+38, 3.38953e+38})
1530 x.f32 = f32[2]{0} constant({-3.40282e+38, 3.40282e+38})
1531 x.f64 = f64[2]{0} constant({-1.79769e+308, 1.79769e+308})
1532 x.c64 = c64[2]{0} constant({(-3.40282e+38, 3.40282e+38), (3.40282e+38, -3.40282e+38)})
1533 ROOT c.c128 = c128[2]{0} constant({(-1.79769e+308, 1.79769e+308), (1.79769e+308, -1.79769e+308)})
1534 }
1535
1536 )"
1537 },
1538 });
1539 // clang-format on
1540 }
1541
1542 // The test class for those tests defined above which round-trip through the
1543 // parser and ToString is templatized on two bool parameters:
1544 //
1545 // short_form : used for the "short" test cases which use the ShortParsable
1546 // output form.
1547 // proto_round_trip : whether the module should also be round-tripped through
1548 // HloProto form. This provides much better coverage for the proto
1549 // serialization/deserialization.
1550 //
1551 // The proto_round_trip=true case also technically covers the Parser->ToString
1552 // roundtrip as well, but separating out the Parser->ToString roundtrip as its
1553 // own test provides better isolation and could conceivably catch weirdo bugs
1554 // which are hidden by interaction between the textual and proto roundtripping.
1555 template <bool short_form, bool proto_round_trip>
1556 class HloParameterizedParserTest
1557 : public ::testing::Test,
1558 public ::testing::WithParamInterface<TestData> {
1559 protected:
1560 // Expects "ToString(ParseHloString(string)) == string", that is, parses the
1561 // string, asserts that it succeeded, stringifies the parsed module, and
1562 // checks that it equals the original string.
1563 void ExpectEqual() {
1564 const string& original = GetParam().module_string;
1565 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
1566 ParseHloString(original));
1567 if (proto_round_trip) {
1568 TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
1569 module->ToProto(), module->config()));
1570 }
1571 if (short_form) {
1572 EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
1573 } else {
1574 EXPECT_EQ(
1575 original,
1576 module->ToString(HloPrintOptions().set_print_large_constants(true)));
1577 }
1578 }
1579 };
1580
1581 // These using shenanigans are required because the TEST_P macro doesn't like
1582 // template instantiations which contain commas.
1583 using HloParserTestLong = HloParameterizedParserTest<false, false>;
1584 using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
1585 using HloParserTestShort = HloParameterizedParserTest<true, false>;
1586 using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
1587
1588 TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
1589 TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
1590 TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
1591 TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
1592
1593 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
1594 ::testing::ValuesIn(CreateTestCases()),
1595 TestDataToString);
1596 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1597 HloParserTestLongProto,
1598 ::testing::ValuesIn(CreateTestCases()),
1599 TestDataToString);
1600 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
1601 ::testing::ValuesIn(CreateShortTestCases()),
1602 TestDataToString);
1603 INSTANTIATE_TEST_SUITE_P(HloParserTestSuccessInstantiation,
1604 HloParserTestShortProto,
1605 ::testing::ValuesIn(CreateShortTestCases()),
1606 TestDataToString);
1607
1608 class HloParserTest : public ::testing::Test {
1609 protected:
1610 static void ExpectHasSubstr(string_view s, string_view expected) {
1611 EXPECT_TRUE(absl::StrContains(s, expected))
1612 << "'" << s << "' does not contain '" << expected << "'";
1613 }
1614 };
1615
1616 TEST_F(HloParserTest, Empty) {
1617 const string original = "";
1618 auto result = ParseHloString(original);
1619 EXPECT_NE(Status::OK(), result.status());
1620 }
1621
1622 TEST_F(HloParserTest, Garbage) {
1623 const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$";
1624 auto result = ParseHloString(original);
1625 EXPECT_NE(Status::OK(), result.status());
1626 }
1627
1628 TEST_F(HloParserTest, WrongOpcode) {
1629 const string original = R"(HloModule wrong_opcode:
1630
1631 ENTRY %blabla (x: f32[], y: f32[]) -> f32[] {
1632 %x = f32[]{} parameter(0)
1633 %y = f32[]{} parameter(1)
1634 %le = pred[]{} le(f32[]{} %x, f32[]{} %y)
1635 }
1636
1637 )";
1638 auto result = ParseHloString(original);
1639 EXPECT_NE(Status::OK(), result.status());
1640 }
1641
1642 TEST_F(HloParserTest, WrongShape) {
1643 const string original = R"(HloModule wrong_opcode:
1644
1645 ENTRY %blabla (x: g32[]) -> g32[] {
1646 %x = g32[]{} parameter(0)
1647 }
1648
1649 )";
1650 auto result = ParseHloString(original);
1651 EXPECT_NE(Status::OK(), result.status());
1652 }
1653
1654 TEST_F(HloParserTest, WrongOperandsSize) {
1655 const string original = R"(HloModule wrong_opcode:
1656
1657 ENTRY %blabla (x: f32[]) -> pred[] {
1658 %x = f32[]{} parameter(0)
1659 %eq = pred[]{} compare(f32[]{} %x), direction=EQ
1660 }
1661
1662 )";
1663 auto result = ParseHloString(original);
1664 EXPECT_NE(Status::OK(), result.status());
1665 }
1666
1667 TEST_F(HloParserTest, OperandNotFound) {
1668 const string original = R"(HloModule operand_not_found:
1669 ENTRY %blabla (x: f32[]) -> pred[] {
1670 %x = f32[]{} parameter(0)
1671 %eq = pred[]{} compare(f32[]{} %x, f32[]{} %y), direction=EQ
1672 }
1673 )";
1674 auto result = ParseHloString(original);
1675 EXPECT_NE(Status::OK(), result.status());
1676 }
1677
1678 TEST_F(HloParserTest, MoreConstants) {
1679 const string original = R"(HloModule SelectScalarS32True_module
1680
1681 ENTRY %SelectScalarS32True.v4 () -> s32[] {
1682 %constant.2 = pred[] constant(true)
1683 %constant.1 = s32[] constant(-42), sharding={devices=[2,2]1,2,3,4}
1684 %constant = s32[] constant(42)
1685 %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant)
1686 }
1687
1688 )";
1689 auto result = ParseHloString(original);
1690 TF_EXPECT_OK(result.status());
1691 // Constant instructions have no name. The string will be parsed successfully
1692 // but the constant names will not be exactly the same.
1693 }
1694
1695 TEST_F(HloParserTest, ConfigurationField) {
1696 const string original = R"(HloModule AModule
1697 ENTRY %configuration_test() -> s32[] {
1698 %constant = s32[] constant(42), backend_config="foo bar"
1699 })";
1700 auto result = ParseHloString(original);
1701 TF_ASSERT_OK(result.status());
1702 EXPECT_EQ("foo bar", result.ValueOrDie()
1703 ->entry_computation()
1704 ->root_instruction()
1705 ->raw_backend_config_string());
1706 }
1707
1708 TEST_F(HloParserTest, LiteralDimensionsMismatch_1) {
1709 const string original = R"(HloModule some_2_module
1710
1711 ENTRY %some_2 () -> f32[2] {
1712 ROOT %constant = f32[2]{0} constant({1,{2}})
1713 }
1714
1715 )";
1716 auto result = ParseHloString(original);
1717 EXPECT_NE(Status::OK(), result.status());
1718 ExpectHasSubstr(result.status().error_message(),
1719 "expects nested array in rank 1, but sees larger");
1720 }
1721
1722 TEST_F(HloParserTest, LiteralDimensionsMismatch_2) {
1723 const string original = R"(HloModule some_2x3_module
1724
1725 ENTRY %some_2x3 () -> f32[2,3] {
1726 ROOT %constant = f32[2,3]{1,0} constant({1, 2, 3, 4, 5, 6})
1727 }
1728
1729 )";
1730 auto result = ParseHloString(original);
1731 EXPECT_NE(Status::OK(), result.status());
1732 ExpectHasSubstr(result.status().error_message(),
1733 "expects nested array in rank 2, but sees 1");
1734 }
1735
1736 TEST_F(HloParserTest, LiteralDimensionsMismatch_3) {
1737 const string original = R"(HloModule some_2x3x2_module
1738
1739 ENTRY %some_2x3x2 () -> f32[2,3,2] {
1740 ROOT %constant = f32[2,3,2]{2,1,0} constant({{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}})
1741 }
1742
1743 )";
1744 auto result = ParseHloString(original);
1745 EXPECT_NE(Status::OK(), result.status());
1746 ExpectHasSubstr(result.status().error_message(),
1747 "expects 3 elements in the [0]th element");
1748 }
1749
1750 TEST_F(HloParserTest, ConstantF16Overflow) {
1751 const string original =
1752 R"(HloModule ConstantF16Overflow_module
1753
1754 ENTRY %ConstantF16Overflow.v4 () -> f16[] {
1755 ROOT %constant = f16[] constant(-65505)
1756 }
1757
1758 )";
1759 auto result = ParseHloString(original);
1760 EXPECT_NE(Status::OK(), result.status());
1761 ExpectHasSubstr(result.status().error_message(),
1762 "is out of range for literal's primitive type F16");
1763 }
1764
1765 TEST_F(HloParserTest, ConstantBf16NoOverflow) {
1766 // 65505 is in range for bf16.
1767 const string original = R"(
1768 HloModule test_module
1769 ENTRY test {
1770 ROOT c = bf16[] constant(-65505)
1771 })";
1772 EXPECT_EQ(Status::OK(), ParseHloString(original).status());
1773 }
1774
1775 TEST_F(HloParserTest, ConstantBf16Overflow) {
1776 // 1e100 is out of range for bf16.
1777 const string original = R"(
1778 HloModule test_module
1779 ENTRY test {
1780 ROOT c = bf16[] constant(1e100)
1781 })";
1782 ExpectHasSubstr(ParseHloString(original).status().error_message(),
1783 "out of range");
1784 }
1785
1786 TEST_F(HloParserTest, ConstantF16OverflowInSparseArray) {
1787 const string original = R"(
1788 HloModule test_module
1789 ENTRY test {
1790 ROOT c = f16[5]sparse{10} constant({[0]: 0, [1]: -65505})
1791 })";
1792 ExpectHasSubstr(ParseHloString(original).status().error_message(),
1793 "is out of range for literal's primitive type F16");
1794 }
1795
1796 TEST_F(HloParserTest, ConstantUnsignedUnderflow) {
1797 const string original = R"(
1798 HloModule ConstantUnsignedUnderflow_module
1799 ENTRY %ConstantUnsignedUnderflow () -> u64[] {
1800 ROOT %constant = u64[] constant(-1)
1801 })";
1802 auto result = ParseHloString(original);
1803 EXPECT_NE(Status::OK(), result.status());
1804 ExpectHasSubstr(result.status().error_message(),
1805 "is out of range for literal's primitive type U64");
1806 }
1807
1808 TEST_F(HloParserTest, ConstantUnsignedOverflow) {
1809 const string original = R"(
1810 HloModule ConstantUnsignedOverflow_module
1811 ENTRY %ConstantUnsignedOverflow () -> u32[] {
1812 ROOT %constant = u32[] constant(4294967296)
1813 })";
1814 auto result = ParseHloString(original);
1815 EXPECT_NE(Status::OK(), result.status());
1816 ExpectHasSubstr(result.status().error_message(),
1817 "is out of range for literal's primitive type U32");
1818 }
1819
1820 TEST_F(HloParserTest, ConstantUnsignedInt64Overflow) {
1821 const string original = R"(
1822 HloModule ConstantUnsignedOverflow_module
1823 ENTRY %ConstantUnsignedOverflow () -> u64[] {
1824 ROOT %constant = u64[] constant(9223372036854775808)
1825 })";
1826 auto result = ParseHloString(original);
1827 EXPECT_NE(Status::OK(), result.status());
1828 }
1829
1830 TEST_F(HloParserTest, ConstantC64Overflow) {
1831 const string original = R"(
1832 HloModule test_module
1833 ENTRY test () -> c64[] {
1834 ROOT c = c64[] constant((1e100, 0))
1835 })";
1836 auto result = ParseHloString(original);
1837 EXPECT_NE(Status::OK(), result.status());
1838 }
1839
1840 TEST_F(HloParserTest, ConstantC64Underflow) {
1841 const string original = R"(
1842 HloModule test_module
1843 ENTRY test () -> c64[] {
1844 ROOT c = c64[] constant((0, -1e100))
1845 })";
1846 auto result = ParseHloString(original);
1847 EXPECT_NE(Status::OK(), result.status());
1848 }
1849
1850 TEST_F(HloParserTest, ConstantF64Overflow) {
1851 const string original = R"(
1852 HloModule test_module
1853 ENTRY test {
1854 ROOT c = f64[] constant(1.8e308)
1855 })";
1856 auto result = ParseHloString(original);
1857 EXPECT_NE(Status::OK(), result.status());
1858 }
1859
1860 TEST_F(HloParserTest, ConstantF64Underflow) {
1861 const string original = R"(
1862 HloModule test_module
1863 ENTRY test {
1864 ROOT c = f64[] constant(-1.8e308)
1865 })";
1866 auto result = ParseHloString(original);
1867 EXPECT_NE(Status::OK(), result.status());
1868 }
1869
1870 TEST_F(HloParserTest, ConstantWithExp) {
1871 const string original = R"(HloModule ConstantWithExp_module
1872
1873 ENTRY %ConstantWithExp.v4 () -> f32[] {
1874 %constant.1 = f32[] constant(3e+2)
1875 }
1876
1877 )";
1878 auto result = ParseHloString(original);
1879 TF_EXPECT_OK(result.status());
1880 // The string will be parsed successfully but the output strings are not
1881 // exactly the same, because "3e2" is parsed into value 300 and will be
1882 // printed as "300".
1883 }
1884
1885 TEST_F(HloParserTest, ShortConstant) {
1886 const string original = R"(HloModule ShortCOnstant_module
1887
1888 ENTRY %ShortConstant.v4 () -> f32[67,89] {
1889 ROOT %constant.1 = f32[67,89]{1,0} constant({...})
1890 }
1891
1892 )";
1893 auto result = ParseHloString(original);
1894 TF_EXPECT_OK(result.status());
1895 EXPECT_EQ(result.ValueOrDie()->ToString(HloPrintOptions()), original);
1896 }
1897
1898 TEST_F(HloParserTest, AttibutesAnyOrder) {
1899 const string original = R"(HloModule any_order_module
1900
1901 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
1902 %input = f32[1,2,1]{2,1,0} parameter(0)
1903 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
1904 %filter = f32[1,1,1]{2,1,0} parameter(1)
1905 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), feature_group_count=1, sharding={maximal device=1}, backend_config="foo", dim_labels=b0f_0io->b0f, window={pad=1_1 size=2}
1906 }
1907
1908 )";
1909 TF_EXPECT_OK(ParseHloString(original).status());
1910 }
1911
1912 TEST_F(HloParserTest, InvalidDimLabels) {
1913 string prefix = R"(HloModule invalid_dim_labels_module
1914
1915 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
1916 %input = f32[1,2,1]{2,1,0} parameter(0)
1917 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
1918 %filter = f32[1,1,1]{2,1,0} parameter(1)
1919 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1} )";
1920 string suffix = R"(
1921 }
1922
1923 )";
1924
1925 ExpectHasSubstr(
1926 ParseHloString(absl::StrCat(prefix, ",dim_labels=00_01_10", suffix))
1927 .status()
1928 .error_message(),
1929 "expects dim labels pattern");
1930
1931 ExpectHasSubstr(
1932 ParseHloString(absl::StrCat(prefix, ",dim_labels=010_1100->010", suffix))
1933 .status()
1934 .error_message(),
1935 "must have the same rank");
1936 }
1937
1938 TEST_F(HloParserTest, UnexpectedAttribute) {
1939 const string original = R"(HloModule unexpected_attr_module
1940
1941 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
1942 %token0 = token[] after-all()
1943 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
1944 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
1945 ROOT %constant = f32[] constant(2.1)
1946 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, calls=%recv
1947 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
1948 }
1949
1950 )";
1951 ExpectHasSubstr(ParseHloString(original).status().error_message(),
1952 "unexpected attribute \"calls\"");
1953 }
1954
TEST_F(HloParserTest,MissingAttribute)1955 TEST_F(HloParserTest, MissingAttribute) {
1956 const string original = R"(HloModule missing_attr_module
1957
1958 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
1959 %token0 = token[] after-all()
1960 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
1961 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
1962 ROOT %constant = f32[] constant(-2.1)
1963 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0)
1964 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
1965 }
1966
1967 )";
1968 ExpectHasSubstr(ParseHloString(original).status().error_message(),
1969 "attribute channel_id is expected but not seen");
1970 }
1971
TEST_F(HloParserTest,PredecessorUndefined)1972 TEST_F(HloParserTest, PredecessorUndefined) {
1973 const string original = R"(HloModule pre_not_found_module
1974
1975 ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] {
1976 %token0 = token[] after-all()
1977 %recv = (f32[], u32[], token[]) recv(token[] %token0), channel_id=15
1978 %recv-done = (f32[], token[]) recv-done((f32[], u32[], token[]) %recv), channel_id=15
1979 ROOT %constant = f32[] constant(2.1)
1980 %send = (f32[], u32[], token[]) send(f32[] %constant, token[] %token0), channel_id=16, control-predecessors={%done}
1981 %send-done = token[] send-done((f32[], u32[], token[]) %send), channel_id=16
1982 }
1983
1984 )";
1985 ExpectHasSubstr(ParseHloString(original).status().error_message(),
1986 "'done' is not defined");
1987 }
1988
TEST_F(HloParserTest,SliceAllowOmitStride1)1989 TEST_F(HloParserTest, SliceAllowOmitStride1) {
1990 const string original = R"(HloModule slice_module
1991
1992 ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] {
1993 %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0)
1994 ROOT %slice = f32[3,3,2,4]{3,2,1,0} slice(f32[3,3,4,4]{3,2,1,0} %p0), slice={[0:3], [0:3], [0:4:2], [0:4]}
1995 }
1996
1997 )";
1998 TF_EXPECT_OK(ParseHloString(original).status());
1999 }
2000
TEST_F(HloParserTest,PaddingConfigIsNotWindowPad)2001 TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) {
2002 const string original = R"(HloModule window_pad_module
2003
2004 ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] {
2005 %input = f32[1,2,1]{2,1,0} parameter(0)
2006 %copy = f32[1,2,1]{2,0,1} copy(f32[1,2,1]{2,1,0} %input)
2007 %filter = f32[1,1,1]{2,1,0} parameter(1)
2008 ROOT %convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), dim_labels=b0f_0io->b0f, window={pad=1_1_0 size=1}
2009 }
2010
2011 )";
2012 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2013 "expects padding_low and padding_high separated by '_'");
2014 }
2015
TEST_F(HloParserTest,CommaBetweenSubAttributes)2016 TEST_F(HloParserTest, CommaBetweenSubAttributes) {
2017 const string original = R"(HloModule test_comma_module
2018
2019 ENTRY %test_comma.v4 () -> f32[] {
2020 ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"}
2021 }
2022
2023 )";
2024 TF_EXPECT_OK(ParseHloString(original).status());
2025 }
2026
TEST_F(HloParserTest,ComputationShapeDoesNotMatchRootShape)2027 TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) {
2028 const string original = R"(HloModule custom_call:
2029
2030 ENTRY %CustomCall () -> f32[1] {
2031 %constant = f32[1]{0} constant({12345})
2032 ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar"
2033 })";
2034 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2035 "Shape of computation CustomCall, f32[1], is not compatible "
2036 "with that of its root instruction foo, f32[1,2,3]");
2037 }
2038
TEST_F(HloParserTest,EntryComputationWithLayout)2039 TEST_F(HloParserTest, EntryComputationWithLayout) {
2040 const string original = R"(HloModule layout:
2041 add_F32.v3 {
2042 lhs = f32[] parameter(0)
2043 rhs = f32[] parameter(1)
2044 ROOT add = f32[] add(lhs, rhs)
2045 }
2046
2047 ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] {
2048 input = f32[8,16,256]{0,1,2} parameter(0)
2049 constant = f32[] constant(0)
2050 ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3
2051 })";
2052
2053 auto module = ParseHloString(original);
2054 TF_ASSERT_OK(module.status());
2055 auto program_layout = module.ValueOrDie()->entry_computation_layout();
2056 ASSERT_EQ(program_layout.parameter_count(), 1);
2057 auto param_layout = program_layout.parameter_layout(0).layout();
2058 auto result_layout = program_layout.result_layout().layout();
2059 EXPECT_TRUE(
2060 LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout))
2061 << "actual layout of parameter(0) is "
2062 << LayoutUtil::HumanString(param_layout);
2063 EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout))
2064 << "actual layout of result is "
2065 << LayoutUtil::HumanString(result_layout);
2066 }
2067
TEST_F(HloParserTest,NoEntry)2068 TEST_F(HloParserTest, NoEntry) {
2069 const string original = R"(HloModule no_entry:
2070 c1 {
2071 const1 = f32[1]{0} constant({12345})
2072 }
2073 c2 {
2074 const2 = f32[1]{0} constant({67890})
2075 })";
2076 auto module = ParseHloString(original);
2077 TF_ASSERT_OK(module.status());
2078 EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2");
2079 }
2080
TEST_F(HloParserTest,NoRoot)2081 TEST_F(HloParserTest, NoRoot) {
2082 const string original = R"(HloModule no_root:
2083 ENTRY consts {
2084 first = f32[1]{0} constant({12345})
2085 last = f32[1]{0} constant({67890})
2086 })";
2087 auto module = ParseHloString(original);
2088 TF_ASSERT_OK(module.status());
2089 EXPECT_EQ(
2090 module.ValueOrDie()->entry_computation()->root_instruction()->name(),
2091 "last");
2092 }
2093
TEST_F(HloParserTest,Comments)2094 TEST_F(HloParserTest, Comments) {
2095 const string original = R"(/* module description. */
2096 HloModule comments:
2097
2098 ENTRY /*comment*/ c1 {
2099 /* blah */
2100 ROOT const1 = /*foo*/f32[1]{0} constant({12345 /*bar*/})
2101 /* comment */
2102 }
2103
2104 /* something else */
2105
2106 )";
2107 auto module = ParseHloString(original);
2108 TF_ASSERT_OK(module.status());
2109 }
2110
TEST_F(HloParserTest,MultilineComments)2111 TEST_F(HloParserTest, MultilineComments) {
2112 const string original = R"(HloModule multiline_comment:
2113 ENTRY c1 {
2114 /*
2115 ROOT foo = f32[1]{0} constant({12345})
2116 */
2117 ROOT const1 = f32[1]{0} constant({12345})
2118 /*
2119 a
2120 b
2121 c
2122 d
2123
2124 */
2125 })";
2126 auto module = ParseHloString(original);
2127 TF_ASSERT_OK(module.status());
2128 }
2129
TEST_F(HloParserTest,UnterminatedComment)2130 TEST_F(HloParserTest, UnterminatedComment) {
2131 const string original = R"(HloModule unterminated_comment:
2132 ENTRY c1 {
2133 /* unterminated
2134 ROOT const1 = f32[1]{0} constant({12345})
2135 })";
2136 // Verify that the error message points to the beginning of the unterminated
2137 // comment.
2138 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2139 "/* unterminated\n^");
2140 }
2141
2142 TEST_F(HloParserTest, SlashSlashComments) {
2143 const string original = R"(HloModule slash_slash_comment:
2144 // Garbage
2145 ENTRY c1 {
2146 // Foo bar
2147 ROOT const1 = f32[1]{0} constant({12345}) // Something else
2148 })";
2149 auto module = ParseHloString(original);
2150 TF_ASSERT_OK(module.status());
2151 }
2152
2153 TEST_F(HloParserTest, SlashSlashCommentMsDosEolFormat) {
2154 const string original =
2155 "HloModule slash_slash_comment:\r\n// Garbage\r\nENTRY c1 {\r\n// Foo "
2156 "bar\r\nROOT const1 = f32[1]{0} constant({12345}) // Something else\r\n}";
2157 auto module = ParseHloString(original);
2158 TF_ASSERT_OK(module.status());
2159 }
2160
2161 TEST_F(HloParserTest, SlashSlashCommentMacEolFormat) {
2162 const string original =
2163 "HloModule slash_slash_comment:\r// Garbage\rENTRY c1 {\r// Foo "
2164 "bar\rROOT const1 = f32[1]{0} constant({12345}) // Something else\r}";
2165 auto module = ParseHloString(original);
2166 TF_ASSERT_OK(module.status());
2167 }
2168
2169 TEST_F(HloParserTest, MultipleEntries) {
2170 const string original = R"(HloModule multiple_entries:
2171 ENTRY c1 {
2172 const1 = f32[1]{0} constant({12345})
2173 }
2174 ENTRY c2 {
2175 const2 = f32[1]{0} constant({67890})
2176 })";
2177 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2178 "expects only one ENTRY");
2179 }
2180
2181 TEST_F(HloParserTest, MultipleRoots) {
2182 const string original = R"(HloModule multiple_roots:
2183 ENTRY consts {
2184 ROOT const1 = f32[1]{0} constant({12345})
2185 ROOT const2 = f32[1]{0} constant({12345})
2186 })";
2187 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2188 "one computation should have only one ROOT");
2189 }
2190
2191 TEST_F(HloParserTest, ComputationExists) {
2192 const string original = R"(HloModule comp_exists
2193 comp {
2194 const1 = f32[1]{0} constant({12345})
2195 }
2196 comp {
2197 const2 = f32[1]{0} constant({67890})
2198 })";
2199 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2200 R"(was parsing 2:1: error: computation previously defined here
2201 comp {
2202 ^)");
2203 }
2204
2205 TEST_F(HloParserTest, CrossComputationLookup) {
2206 const string original = R"(HloModule cross_computation_lookup:
2207 tcalla (a: (s32[], s32[])) -> (s32[], s32[]) {
2208 ROOT aparam = (s32[], s32[]) parameter(0)
2209 }
2210
2211 tcallb (b: (s32[], s32[])) -> s32[] {
2212 rparam = (s32[], s32[]) parameter(0)
2213 ROOT gte0 = s32[] get-tuple-element(aparam), index=0
2214 }
2215
2216 ENTRY entry {
2217 param = (s32[], s32[]) parameter(0)
2218 call0 = (s32[], s32[]) call(param), to_apply=tcalla
2219 ROOT call1 = s32[] call(param), to_apply=tcallb
2220 })";
2221 ExpectHasSubstr(
2222 ParseHloString(original).status().error_message(),
2223 "was parsing 8:39: error: instruction does not exist: aparam");
2224 }
2225
2226 TEST_F(HloParserTest, SameNameDiffComputations) {
2227 const string original = R"(HloModule same_names:
2228 add {
2229 p0 = f32[] parameter(0)
2230 p1 = f32[] parameter(1)
2231 ROOT result = f32[] add(p0, p1)
2232 }
2233
2234 ENTRY ReduceR3ToR2 {
2235 p0 = f32[8,16,256]{2,1,0} parameter(0)
2236 p1 = f32[] constant(0)
2237 ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
2238 }
2239 )";
2240 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
2241 ASSERT_NE(module->entry_computation(), nullptr);
2242 EXPECT_THAT(module->entry_computation()->root_instruction(),
2243 GmockMatch(m::Reduce()));
2244 }
2245
2246 TEST_F(HloParserTest, ParseSharding) {
2247 const string original = "{maximal device=42}";
2248 TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
2249 EXPECT_EQ(sharding.ToString(), original);
2250 }
2251
2252 TEST_F(HloParserTest, ParseWindow) {
2253 Window original = window_util::MakeWindow({1, 2, 3});
2254 TF_ASSERT_OK_AND_ASSIGN(Window parsed,
2255 ParseWindow(window_util::ToString(original)))
2256 EXPECT_EQ(window_util::ToString(original), window_util::ToString(parsed));
2257 }
2258
2259 TEST_F(HloParserTest, ParseConvolutionDimensionNumbers) {
2260 const string original = "b0f_0io->b0f";
2261 TF_ASSERT_OK_AND_ASSIGN(ConvolutionDimensionNumbers dnums,
2262 ParseConvolutionDimensionNumbers(original));
2263 EXPECT_EQ(original, ConvolutionDimensionNumbersToString(dnums));
2264 }
2265
2266 TEST_F(HloParserTest, ParsePaddingConfigNoInteriorPadding) {
2267 const string original = "0_1x2_3";
2268 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2269 EXPECT_EQ(original, PaddingConfigToString(dnums));
2270 }
2271
2272 TEST_F(HloParserTest, ParsePaddingConfigInteriorPadding) {
2273 const string original = "0_1_0x2_3_4";
2274 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig(original));
2275 EXPECT_EQ(original, PaddingConfigToString(dnums));
2276 }
2277
2278 TEST_F(HloParserTest, ParsePaddingConfigInteriorPaddingImplicitZeroDim) {
2279 TF_ASSERT_OK_AND_ASSIGN(PaddingConfig dnums, ParsePaddingConfig("0_1x2_3_4"));
2280 // The extra "_0" gets added to the canonical string because the other dim has
2281 // interior padding.
2282 EXPECT_EQ("0_1_0x2_3_4", PaddingConfigToString(dnums));
2283 }
2284
2285 TEST_F(HloParserTest, NontupleInfeed) {
2286 const string original = R"(HloModule nontuple_infeed:
2287 ENTRY nontuple_infeed {
2288 token0 = token[] after-all()
2289 ROOT infeed = pred[] infeed(token0)
2290 })";
2291 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2292 "infeed must have a non-empty tuple shape");
2293 }
2294
2295 TEST(HloParserSingleOpTest, SingleOp) {
2296 const string text =
2297 "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
2298 "f32[2,4]{1,0} %x)";
2299 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2300 const HloComputation* computation = module->entry_computation();
2301 ASSERT_NE(computation, nullptr);
2302 EXPECT_THAT(computation->root_instruction(),
2303 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2304 }
2305
2306 TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
2307 const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
2308 StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
2309 ASSERT_TRUE(!module.status().ok());
2310 LOG(INFO) << "Status: " << module.status();
2311 EXPECT_THAT(module.status().ToString(),
2312 ::testing::HasSubstr("expects '=' in instruction"));
2313 }
2314
2315 TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
2316 const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
2317 StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
2318 ASSERT_TRUE(!module.status().ok());
2319 LOG(INFO) << "Status: " << module.status();
2320 EXPECT_THAT(module.status().ToString(),
2321 ::testing::HasSubstr("Operand had no shape in HLO text"));
2322 }
2323
2324 TEST(HloParserSingleOpTest, SingleOpNoNames) {
2325 const string text =
2326 "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2327 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2328 const HloComputation* computation = module->entry_computation();
2329 ASSERT_NE(computation, nullptr);
2330 EXPECT_THAT(computation->root_instruction(),
2331 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2332 }
2333
2334 TEST(HloParserSingleOpTest, CanonicalOp) {
2335 const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
2336 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2337 const HloComputation* computation = module->entry_computation();
2338 ASSERT_NE(computation, nullptr);
2339 EXPECT_THAT(computation->root_instruction(),
2340 GmockMatch(m::Multiply(m::Parameter(0), m::Parameter(1))));
2341 EXPECT_EQ(
2342 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2343 text);
2344 }
2345
2346 TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
2347 const string text =
2348 R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
2349 {
2350 tmp_0 = f32[5,10]{1,0} parameter(0)
2351 tmp_1 = f32[20,10]{1,0} parameter(1)
2352 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2353 {
2354 tmp_0 = f32[5,10]{1,0} parameter(0)
2355 tmp_1 = f32[20,10]{1,0} parameter(1)
2356 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2357 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2358 }
2359 }, body=
2360 {
2361 tmp_0 = f32[5,10]{1,0} parameter(0)
2362 tmp_1 = f32[20,10]{1,0} parameter(1)
2363 ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
2364 {
2365 tmp_0 = f32[5,10]{1,0} parameter(0)
2366 tmp_1 = f32[20,10]{1,0} parameter(1)
2367 tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
2368 ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
2369 }
2370 })";
2371
2372 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2373 const HloComputation* computation = module->entry_computation();
2374 ASSERT_NE(computation, nullptr);
2375 EXPECT_EQ(
2376 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2377 text);
2378 }
2379
2380 TEST(HloParserSingleOpTest, CanonicalOpIndexedConditionalInlinedBranches) {
2381 const string text =
2382 R"(f32[5,10]{1,0} conditional(s32[], f32[5,10]{1,0}, f32[5,10]{1,0}, f32[5,10]{1,0}), branch_computations={
2383 {
2384 tmp_0 = f32[5,10]{1,0} parameter(0)
2385 ROOT tmp_1 = f32[5,10]{1,0} ceil(f32[5,10]{1,0} tmp_0)
2386 },
2387 {
2388 tmp_0 = f32[5,10]{1,0} parameter(0)
2389 ROOT tmp_1 = f32[5,10]{1,0} floor(f32[5,10]{1,0} tmp_0)
2390 },
2391 {
2392 tmp_0 = f32[5,10]{1,0} parameter(0)
2393 ROOT tmp_1 = f32[5,10]{1,0} copy(f32[5,10]{1,0} tmp_0)
2394 }
2395 })";
2396
2397 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2398 const HloComputation* computation = module->entry_computation();
2399 ASSERT_NE(computation, nullptr);
2400 EXPECT_EQ(
2401 computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
2402 text);
2403 }
2404
2405 TEST(HloParserSingleOpTest, SingleOpWithNested) {
2406 const string text =
2407 R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
2408 {
2409 %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
2410 %param_1 = f32[2]{0} parameter(1)
2411 %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
2412 ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
2413 })";
2414
2415 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2416 const HloComputation* computation = module->entry_computation();
2417 ASSERT_NE(computation, nullptr);
2418 EXPECT_THAT(computation->root_instruction(),
2419 GmockMatch(m::Op()
2420 .WithOpcode(HloOpcode::kFusion)
2421 .WithNumOperands(2)
2422 .WithOperand(0, m::Parameter(0))
2423 .WithOperand(1, m::Parameter(1))));
2424 }
2425
2426 TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
2427 const string text =
2428 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2429 {
2430 result = f32[] add(f32[] x, f32[] y)
2431 })";
2432 auto status = ParseHloString(text).status();
2433 ASSERT_FALSE(status.ok());
2434 EXPECT_THAT(status.error_message(),
2435 ::testing::HasSubstr("does not exist: x"));
2436 }
2437
2438 TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
2439 const string text =
2440 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2441 {
2442 f32[] add(f32[] x, f32[] y)
2443 })";
2444 auto status = ParseHloString(text).status();
2445 ASSERT_FALSE(status.ok());
2446 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2447 }
2448
2449 TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
2450 const string text =
2451 R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
2452 {
2453 result = f32[] add(f32[], f32[])
2454 })";
2455 auto status = ParseHloString(text).status();
2456 ASSERT_FALSE(status.ok());
2457 EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
2458 }
2459
2460 TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
2461 const string text =
2462 R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
2463 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
2464 const HloComputation* computation = module->entry_computation();
2465 ASSERT_NE(computation, nullptr);
2466 EXPECT_THAT(computation->root_instruction(),
2467 GmockMatch(m::Convolution(m::Parameter(0), m::Parameter(1))));
2468 auto* convolution =
2469 Cast<HloConvolutionInstruction>(computation->root_instruction());
2470 EXPECT_EQ(convolution->feature_group_count(), 1);
2471 }
2472
2473 TEST_F(HloParserTest, IsScheduledIsFalse) {
2474 const string text = R"(
2475 HloModule axpy_module, is_scheduled=false
2476
2477 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2478 %alpha = f32[] parameter(0)
2479 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2480 %x = f32[2,4]{1,0} parameter(1)
2481 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2482 %y = f32[2,4]{1,0} parameter(2)
2483 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2484 }
2485 )";
2486 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2487 ParseHloString(text));
2488 ASSERT_FALSE(module->has_schedule());
2489 }
2490
2491 TEST_F(HloParserTest, IsScheduledNotPresent) {
2492 const string text = R"(
2493 HloModule axpy_module
2494
2495 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2496 %alpha = f32[] parameter(0)
2497 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2498 %x = f32[2,4]{1,0} parameter(1)
2499 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2500 %y = f32[2,4]{1,0} parameter(2)
2501 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2502 }
2503 )";
2504 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2505 ParseHloString(text));
2506 ASSERT_FALSE(module->has_schedule());
2507 }
2508
2509 TEST_F(HloParserTest, IsScheduledIsTrue) {
2510 const string text = R"(
2511 HloModule axpy_module, is_scheduled=true
2512
2513 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2514 %alpha = f32[] parameter(0)
2515 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2516 %x = f32[2,4]{1,0} parameter(1)
2517 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2518 %y = f32[2,4]{1,0} parameter(2)
2519 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2520 }
2521 )";
2522 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2523 ParseHloString(text));
2524 ASSERT_TRUE(module->has_schedule());
2525 TF_ASSERT_OK(module->schedule().Verify());
2526 EXPECT_EQ(module->schedule().sequences().size(), 1);
2527 ASSERT_TRUE(
2528 module->schedule().is_computation_scheduled(module->entry_computation()));
2529 EXPECT_THAT(
2530 module->schedule().sequence(module->entry_computation()).instructions(),
2531 ::testing::ElementsAre(
2532 GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2533 GmockMatch(m::Parameter()), GmockMatch(m::Multiply()),
2534 GmockMatch(m::Parameter()), GmockMatch(m::Add())));
2535 }
2536
2537 TEST_F(HloParserTest, IsScheduledIsTrueDifferentOrder) {
2538 // As above but in with a different schedule order.
2539 const string text = R"(
2540 HloModule axpy_module, is_scheduled=true
2541
2542 ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
2543 %alpha = f32[] parameter(0)
2544 %x = f32[2,4]{1,0} parameter(1)
2545 %y = f32[2,4]{1,0} parameter(2)
2546 %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={}
2547 %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)
2548 ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y)
2549 }
2550 )";
2551 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2552 ParseHloString(text));
2553 ASSERT_TRUE(module->has_schedule());
2554 TF_ASSERT_OK(module->schedule().Verify());
2555 EXPECT_EQ(module->schedule().sequences().size(), 1);
2556 ASSERT_TRUE(
2557 module->schedule().is_computation_scheduled(module->entry_computation()));
2558 EXPECT_THAT(
2559 module->schedule().sequence(module->entry_computation()).instructions(),
2560 ::testing::ElementsAre(
2561 GmockMatch(m::Parameter()), GmockMatch(m::Parameter()),
2562 GmockMatch(m::Parameter()), GmockMatch(m::Broadcast()),
2563 GmockMatch(m::Multiply()), GmockMatch(m::Add())));
2564 }
2565
2566 TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
2567 const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
2568
2569 ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2570 %p0 = f32[42,2,3]{0,1,2} parameter(0)
2571 %p1 = f32[123,4]{0,1} parameter(1)
2572 ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
2573 }
2574
2575 )";
2576 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2577 "Expected 2 operand layout constraints, 1 given");
2578 }
2579
2580 TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
2581 const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
2582
2583 ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
2584 %p0 = f32[42,2,3]{0,1,2} parameter(0)
2585 %p1 = f32[123,4]{0,1} parameter(1)
2586 ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
2587 }
2588
2589 )";
2590 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2591 "operand 1 is not compatible with operand shape");
2592 }
2593
2594 TEST_F(HloParserTest, AllowShapeWhitespace) {
2595 const string text = R"(
2596 HloModule module
2597
2598 ENTRY entry {
2599 ROOT root = f32[ 1, 2,3, 4, 5]{0, 1, 2,3, 4 } parameter(0)
2600 }
2601 )";
2602 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
2603 ParseHloString(text));
2604 }
2605
2606 TEST_F(HloParserTest, ShapeMismatchInOperand) {
2607 const string text = R"(
2608 HloModule foobar
2609
2610 ENTRY %entrycomp (p: f32[2,2]) -> f32[2,2] {
2611 %p = f32[2,2] parameter(0)
2612 %constant.1 = f32[2,2] constant({{1, 2}, {3, 4}})
2613 ROOT %add.1 = f32[2,2] add(f32[2,2] %p, f32[2,5] %constant.1)
2614 }
2615 )";
2616
2617 ExpectHasSubstr(ParseHloString(text).status().error_message(),
2618 "The declared operand shape f32[2,5]{1,0} is not compatible"
2619 " with the shape of the operand instruction f32[2,2]{1,0}.");
2620 }
2621
2622 TEST_F(HloParserTest, OutOfRangeSparseIndex) {
2623 const string original = R"(
2624 HloModule test_module
2625 ENTRY test {
2626 ROOT c = f16[5]sparse{10} constant({[100]: 0})
2627 })";
2628 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2629 "Invalid sparse index");
2630 }
2631
2632 TEST_F(HloParserTest, NegativeSparseIndex) {
2633 const string original = R"(
2634 HloModule test_module
2635 ENTRY test {
2636 ROOT c = f16[5]sparse{10} constant({-1: 0})
2637 })";
2638 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2639 "Invalid sparse index");
2640 }
2641
2642 TEST_F(HloParserTest, SparseIndexWithRankTooLarge) {
2643 const string original = R"(
2644 HloModule test_module
2645 ENTRY test {
2646 ROOT c = f16[5]sparse{10} constant({[0, 0]: 0})
2647 })";
2648 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2649 "Invalid sparse index");
2650 }
2651
2652 TEST_F(HloParserTest, SparseIndexWithRankTooSmall) {
2653 const string original = R"(
2654 HloModule test_module
2655 ENTRY test {
2656 ROOT c = f16[5, 5]sparse{10} constant({[0]: 0})
2657 })";
2658 ExpectHasSubstr(ParseHloString(original).status().error_message(),
2659 "Invalid sparse index");
2660 }
2661
2662 TEST_F(HloParserTest, ParseShapeStringR2F32) {
2663 string shape_string = "f32[123,456]";
2664 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2665 Shape expected = ShapeUtil::MakeShape(F32, {123, 456});
2666 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2667 << "expected: " << ShapeUtil::HumanString(expected)
2668 << "actual: " << ShapeUtil::HumanString(actual);
2669 }
2670
2671 TEST_F(HloParserTest, ParseShapeStringTupleOfArrays) {
2672 string shape_string = "(f32[1572864],s8[5120,1024])";
2673 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2674 Shape expected =
2675 ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}),
2676 ShapeUtil::MakeShape(S8, {5120, 1024})});
2677 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2678 << "expected: " << ShapeUtil::HumanString(expected)
2679 << "actual: " << ShapeUtil::HumanString(actual);
2680 }
2681
2682 TEST_F(HloParserTest, ParseShapeStringNestedTuple) {
2683 string shape_string = "(f32[1],(f32[2], token[]), opaque[], f32[3])";
2684 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2685 Shape expected = ShapeUtil::MakeTupleShape({
2686 ShapeUtil::MakeShape(F32, {1}),
2687 ShapeUtil::MakeTupleShape(
2688 {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeTokenShape()}),
2689 ShapeUtil::MakeOpaqueShape(),
2690 ShapeUtil::MakeShape(F32, {3}),
2691 });
2692 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2693 << "expected: " << ShapeUtil::HumanString(expected)
2694 << "actual: " << ShapeUtil::HumanString(actual);
2695 }
2696
2697 TEST_F(HloParserTest, ParseShapeStringWithLayout) {
2698 string shape_string = "f32[123,456]{0,1}";
2699 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2700 Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1});
2701 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2702 << "expected: " << ShapeUtil::HumanString(expected)
2703 << "actual: " << ShapeUtil::HumanString(actual);
2704 }
2705
2706 TEST_F(HloParserTest, ParseShapeStringWithTilingLayout) {
2707 // One tile.
2708 string shape_string = "f32[123,456]{0,1:T(2,128)}";
2709 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2710 Shape expected =
2711 ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}, {Tile({2, 128})});
2712 EXPECT_EQ(expected, actual)
2713 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2714 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2715
2716 // Tile with negative dimension size for combining dimensions.
2717 shape_string = "f32[123,456,789]{0,1,2:T(2, * , 128)}";
2718 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2719 expected =
2720 ShapeUtil::MakeShapeWithLayout(F32, {123, 456, 789}, {0, 1, 2},
2721 {Tile({2, Tile::kCombineDimension, 128})});
2722 EXPECT_EQ(expected, actual)
2723 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2724 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2725
2726 // Two tiles.
2727 shape_string = "bf16[123,456,789]{2,1,0:T(2,*,128)(2,1)}";
2728 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2729 expected = ShapeUtil::MakeShapeWithLayout(
2730 BF16, {123, 456, 789}, {2, 1, 0},
2731 {Tile({2, Tile::kCombineDimension, 128}), Tile({2, 1})});
2732 EXPECT_EQ(expected, actual)
2733 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2734 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2735
2736 // Tile with element size in bits.
2737 shape_string = "pred[123,456]{1,0:T(2,128)E(1)}";
2738 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2739 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0},
2740 {Tile({2, 128})}, 1);
2741 EXPECT_EQ(expected, actual)
2742 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2743 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2744
2745 // Element size in bits without tile.
2746 shape_string = "pred[123,456]{1,0:E(1)}";
2747 TF_ASSERT_OK_AND_ASSIGN(actual, ParseShape(shape_string));
2748 expected = ShapeUtil::MakeShapeWithLayout(PRED, {123, 456}, {1, 0}, {}, 1);
2749 EXPECT_EQ(expected, actual)
2750 << "expected: " << ShapeUtil::HumanStringWithLayout(expected)
2751 << "actual: " << ShapeUtil::HumanStringWithLayout(actual);
2752
2753 // Wrong minor_to_major.
2754 shape_string = "f32[123,456,789]{1:T(2, * , 128)}";
2755 auto result = ParseShape(shape_string);
2756 ExpectHasSubstr(result.status().error_message(),
2757 "Dimensions size is 3, but minor to major size is 1.");
2758 }
2759
2760 TEST_F(HloParserTest, ParseShapeStringWithSparseLayout) {
2761 string shape_string = "f32[123,456]sparse{10}";
2762 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2763 Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10);
2764 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2765 << "expected: " << ShapeUtil::HumanString(expected)
2766 << "actual: " << ShapeUtil::HumanString(actual);
2767 }
2768
2769 TEST_F(HloParserTest, ParseOpaqueType) {
2770 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("opaque[]"));
2771 Shape expected = ShapeUtil::MakeOpaqueShape();
2772 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2773 << "expected: " << ShapeUtil::HumanString(expected)
2774 << "actual: " << ShapeUtil::HumanString(actual);
2775 }
2776
2777 TEST_F(HloParserTest, ParseTokenType) {
2778 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape("token[]"));
2779 Shape expected = ShapeUtil::MakeTokenShape();
2780 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2781 << "expected: " << ShapeUtil::HumanString(expected)
2782 << "actual: " << ShapeUtil::HumanString(actual);
2783 }
2784
2785 TEST_F(HloParserTest, ParseInvalidShapeString) {
2786 string shape_strings[] = {
2787 "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}",
2788 "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}",
2789 };
2790 for (const string& shape_string : shape_strings) {
2791 StatusOr<Shape> result = ParseShape(shape_string);
2792 ASSERT_FALSE(result.ok()) << "shape: " << shape_string;
2793 }
2794 }
2795
2796 TEST_F(HloParserTest, ParseDynamicArray) {
2797 string shape_string = "f32[123,<=456]";
2798 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2799 Shape expected = ShapeUtil::MakeShape(F32, {123, 456}, {false, true});
2800 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2801 << "expected: " << ShapeUtil::HumanString(expected)
2802 << "actual: " << ShapeUtil::HumanString(actual);
2803 }
2804
2805 TEST_F(HloParserTest, ParseDynamicTuple) {
2806 string shape_string = "(f32[42], u32[<=123,<=456])";
2807 TF_ASSERT_OK_AND_ASSIGN(Shape actual, ParseShape(shape_string));
2808 Shape expected = ShapeUtil::MakeTupleShape(
2809 {ShapeUtil::MakeShape(F32, {42}),
2810 ShapeUtil::MakeShape(U32, {123, 456}, {true, true})});
2811 ASSERT_TRUE(ShapeUtil::Equal(expected, actual))
2812 << "expected: " << ShapeUtil::HumanString(expected)
2813 << "actual: " << ShapeUtil::HumanString(actual);
2814 }
2815
2816 TEST_F(HloParserTest, NegativeParameterNumber) {
2817 const string hlo_string = "par0 = f32[3,5] parameter(-1)";
2818 auto result = ParseHloString(hlo_string);
2819 ASSERT_FALSE(result.status().ok());
2820 EXPECT_THAT(result.status().error_message(),
2821 ::testing::HasSubstr("parameter number must be >= 0"));
2822 }
2823
2824 TEST_F(HloParserTest, WrongNumberOfParameterLeafBuffersInReplication) {
2825 const string hlo_string =
2826 "par0 = (f32[3,5], f32[]) parameter(0), "
2827 "parameter_replication={true,false,true}";
2828 auto result = ParseHloString(hlo_string);
2829 ASSERT_FALSE(result.status().ok());
2830 EXPECT_THAT(result.status().error_message(),
2831 ::testing::HasSubstr("parameter has 2 leaf buffers, but "
2832 "parameter_replication has 3 elements"));
2833 }
2834
2835 } // namespace
2836 } // namespace xla
2837