1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
17
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/lib/gtl/map_util.h"
23 #include "tensorflow/core/lib/strings/strcat.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25
26 namespace xla {
27 namespace tools {
28
29 namespace {
30
31 using tensorflow::StringPiece;
32 using tensorflow::gtl::optional;
33 using tensorflow::str_util::Split;
34 using tensorflow::str_util::SplitAndParseAsInts;
35 using tensorflow::strings::Printf;
36 using tensorflow::strings::StrAppend;
37 using tensorflow::strings::StrCat;
38
39 const double kF16max = 65504;
40
41 // Parser for the HloModule::ToString() format text.
42 class HloParser {
43 public:
44 using LocTy = HloLexer::LocTy;
45
HloParser(StringPiece str,const HloModuleConfig & config)46 explicit HloParser(StringPiece str, const HloModuleConfig& config)
47 : lexer_(str), config_(config) {}
48
49 // Runs the parser. Returns false if an error occurred.
50 bool Run();
51
52 // Returns the parsed HloModule.
ConsumeHloModule()53 std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
54
55 // Returns the error information.
GetError() const56 string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
57
58 private:
59 // ParseXXX returns false if an error occurred.
60 bool ParseHloModule();
61 bool ParseComputations();
62 bool ParseComputation(HloComputation** entry_computation);
63 bool ParseInstructionList(HloComputation::Builder* builder,
64 string* root_name);
65 bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
66 bool ParseControlPredecessors(HloInstruction* instruction);
67 bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
68 bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
69 bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
70 const Shape& shape);
71 bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
72 bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
73 const Shape& shape);
74 template <typename LiteralNativeT>
75 bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
76 const Shape& shape);
77
78 // Sets the sub-value of literal at the given index to the given value. The
79 // literal's shape must have the default layout.
80 bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal);
81 bool SetValueInLiteral(double value, int64 linear_index, Literal* literal);
82 bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal);
83 template <typename LiteralNativeT, typename ParsedElemT>
84 bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
85 Literal* literal);
86
87 bool ParseOperands(std::vector<HloInstruction*>* operands);
88 // Fills parsed operands into 'operands' and expects a certain number of
89 // operands.
90 bool ParseOperands(std::vector<HloInstruction*>* operands,
91 const int expected_size);
92
93 // Describes the start, limit, and stride on every dimension of the operand
94 // being sliced.
95 struct SliceRanges {
96 std::vector<int64> starts;
97 std::vector<int64> limits;
98 std::vector<int64> strides;
99 };
100
101 // Types of attributes.
102 enum class AttrTy {
103 kInt64,
104 kInt32,
105 kFloat,
106 kString,
107 kBracedInt64List,
108 kHloComputation,
109 kFftType,
110 kWindow,
111 kConvolutionDimensionNumbers,
112 kSharding,
113 kInstructionList,
114 kSliceRanges,
115 kPaddingConfig,
116 kMetadata,
117 kFusionKind,
118 kDistribution,
119 };
120
121 struct AttrConfig {
122 bool required; // whether it's required or optional
123 AttrTy attr_type; // what type it is
124 void* result; // where to store the parsed result.
125 };
126
127 // attributes ::= (',' attribute)*
128 //
129 // Parses attributes given names and configs of the attributes. Each parsed
130 // result is passed back through the result pointer in corresponding
131 // AttrConfig. Note that the result pointer must point to a optional<T> typed
132 // variable which outlives this function. Returns false on error. You should
133 // not use the any of the results if this function failed.
134 //
135 // Example usage:
136 //
137 // std::unordered_map<string, AttrConfig> attrs;
138 // optional<int64> foo;
139 // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
140 // optional<Window> bar;
141 // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
142 // if (!ParseAttributes(attrs)) {
143 // return false; // Do not use 'foo' 'bar' if failed.
144 // }
145 // // Do something with 'bar'.
146 // if (foo) { // If attr foo is seen, do something with 'foo'. }
147 //
148 bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
149
150 // sub_attributes ::= '{' (','? attribute)* '}'
151 //
152 // Usage is the same as ParseAttributes. See immediately above.
153 bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs);
154
155 // Parses one attribute. If it has already been seen, return error. Returns
156 // true and adds to seen_attrs on success.
157 //
158 // Do not call this except in ParseAttributes or ParseSubAttributes.
159 bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs,
160 std::unordered_set<string>* seen_attrs);
161
162 // Parses a name and finds the corresponding hlo computation.
163 bool ParseComputationName(HloComputation** value);
164 // Parses a list of names and finds the corresponding hlo instructions.
165 bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
166 bool ParseWindow(Window* window);
167 bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
168 bool ParsePaddingConfig(PaddingConfig* padding);
169 bool ParseMetadata(OpMetadata* metadata);
170 bool ParseSharding(OpSharding* sharding);
171 bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
172
173 // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
174 bool ParseDxD(const string& name, std::vector<int64>* result);
175 // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
176 bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
177
178 bool ParseSliceRanges(SliceRanges* result);
179 bool ParseInt64List(const TokKind start, const TokKind end,
180 const TokKind delim, std::vector<int64>* result);
181
182 bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
183 bool ParseParamList();
184 bool ParseName(string* result);
185 bool ParseAttributeName(string* result);
186 bool ParseString(string* result);
187 bool ParseShape(Shape* result);
188 bool ParseOpcode(HloOpcode* result);
189 bool ParseFftType(FftType* result);
190 bool ParseFusionKind(HloInstruction::FusionKind* result);
191 bool ParseRandomDistribution(RandomDistribution* result);
192 bool ParseInt64(int64* result);
193 bool ParseDouble(double* result);
194 bool ParseBool(bool* result);
195 bool ParseToken(TokKind kind, const string& msg);
196
197 // Returns true if the current token is the beginning of a shape.
198 bool CanBeShape();
199 // Returns true if the current token is the beginning of a
200 // param_list_to_shape.
201 bool CanBeParamListToShape();
202
203 // Logs the current parsing line and the given message. Always returns false.
204 bool TokenError(StringPiece msg);
205 bool Error(LocTy loc, StringPiece msg);
206
207 // If the current token is 'kind', eats it (i.e. lexes the next token) and
208 // returns true.
209 bool EatIfPresent(TokKind kind);
210 // Parses a shape, and returns true if the result is compatible with the given
211 // shape.
212 bool EatShapeAndCheckCompatible(const Shape& shape);
213
214 // Adds the instruction to the pool. Returns false and emits an error if the
215 // instruction already exists.
216 bool AddInstruction(const string& name, HloInstruction* instruction,
217 LocTy name_loc);
218 // Adds the computation to the pool. Returns false and emits an error if the
219 // computation already exists.
220 bool AddComputation(const string& name, HloComputation* computation,
221 LocTy name_loc);
222
223 // The map from the instruction/computation name to the
224 // instruction/computation itself and it's location. This does not own the
225 // pointers.
226 std::unordered_map<string, std::pair<HloInstruction*, LocTy>>
227 instruction_pool_;
228 std::unordered_map<string, std::pair<HloComputation*, LocTy>>
229 computation_pool_;
230
231 HloLexer lexer_;
232 std::unique_ptr<HloModule> module_;
233 std::vector<std::unique_ptr<HloComputation>> computations_;
234 const HloModuleConfig config_;
235 std::vector<string> error_;
236 };
237
Error(LocTy loc,StringPiece msg)238 bool HloParser::Error(LocTy loc, StringPiece msg) {
239 auto line_col = lexer_.GetLineAndColumn(loc);
240 const unsigned line = line_col.first;
241 const unsigned col = line_col.second;
242 std::vector<string> error_lines;
243 error_lines.push_back(
244 StrCat("was parsing ", line, ":", col, ": error: ", msg));
245 error_lines.push_back(lexer_.GetLine(loc).ToString());
246 error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
247
248 error_.push_back(tensorflow::str_util::Join(error_lines, "\n"));
249 VLOG(1) << "Error: " << error_.back();
250 return false;
251 }
252
TokenError(StringPiece msg)253 bool HloParser::TokenError(StringPiece msg) {
254 return Error(lexer_.GetLoc(), msg);
255 }
256
Run()257 bool HloParser::Run() {
258 lexer_.Lex();
259 return ParseHloModule();
260 }
261
262 // ::= 'HloModule' name computations
ParseHloModule()263 bool HloParser::ParseHloModule() {
264 if (lexer_.GetKind() != TokKind::kw_HloModule) {
265 return TokenError("expects HloModule");
266 }
267 // Eat 'HloModule'
268 lexer_.Lex();
269
270 string name;
271 if (!ParseName(&name)) {
272 return false;
273 }
274
275 module_ = MakeUnique<HloModule>(name, config_);
276
277 return ParseComputations();
278 }
279
280 // computations ::= (computation)+
ParseComputations()281 bool HloParser::ParseComputations() {
282 HloComputation* entry_computation = nullptr;
283 do {
284 if (!ParseComputation(&entry_computation)) {
285 return false;
286 }
287 } while (lexer_.GetKind() != TokKind::kEof);
288
289 for (int i = 0; i < computations_.size(); i++) {
290 // If entry_computation is not nullptr, it means the computation it pointed
291 // to is marked with "ENTRY"; otherwise, no computation is marked with
292 // "ENTRY", and we use the last computation as the entry computation. We
293 // add the non-entry computations as embedded computations to the module.
294 if ((entry_computation != nullptr &&
295 computations_[i].get() != entry_computation) ||
296 (entry_computation == nullptr && i != computations_.size() - 1)) {
297 module_->AddEmbeddedComputation(std::move(computations_[i]));
298 continue;
299 }
300 auto computation =
301 module_->AddEntryComputation(std::move(computations_[i]));
302 // The parameters and result layouts were set to default layout. Here we
303 // set the layouts to what the hlo text says.
304 for (int p = 0; p < computation->num_parameters(); p++) {
305 const Shape& param_shape = computation->parameter_instruction(p)->shape();
306 if (param_shape.has_layout()) {
307 module_->mutable_entry_computation_layout()
308 ->mutable_parameter_layout(p)
309 ->ResetLayout(param_shape.layout());
310 }
311 }
312 const Shape& result_shape = computation->root_instruction()->shape();
313 if (result_shape.has_layout()) {
314 module_->mutable_entry_computation_layout()
315 ->mutable_result_layout()
316 ->ResetLayout(result_shape.layout());
317 }
318 }
319
320 return true;
321 }
322
323 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)324 bool HloParser::ParseComputation(HloComputation** entry_computation) {
325 LocTy maybe_entry_loc = lexer_.GetLoc();
326 const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
327
328 string name;
329 LocTy name_loc = lexer_.GetLoc();
330 if (!ParseName(&name)) {
331 return false;
332 }
333 auto builder = MakeUnique<HloComputation::Builder>(name);
334
335 LocTy shape_loc = nullptr;
336 Shape shape;
337 if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
338 return false;
339 }
340
341 string root_name;
342 if (!ParseInstructionList(builder.get(), &root_name)) {
343 return false;
344 }
345
346 std::pair<HloInstruction*, LocTy>* root_node =
347 tensorflow::gtl::FindOrNull(instruction_pool_, root_name);
348 // This means some instruction was marked as ROOT but we didn't find it in the
349 // pool, which should not happen.
350 if (!root_name.empty() && root_node == nullptr) {
351 LOG(FATAL) << "instruction " << root_name
352 << " was marked as ROOT but the parser has not seen it before";
353 }
354
355 HloInstruction* root = root_node == nullptr ? nullptr : root_node->first;
356 // Now root can be either an existing instruction or a nullptr. If it's a
357 // nullptr, the implementation of Builder will set the last instruction as
358 // root instruction.
359 computations_.emplace_back(builder->Build(root));
360 HloComputation* computation = computations_.back().get();
361
362 if (!root) {
363 root = computation->root_instruction();
364 } else {
365 CHECK_EQ(root, computation->root_instruction());
366 }
367
368 // If param_list_to_shape was present, check compatibility.
369 if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
370 return Error(
371 shape_loc,
372 StrCat("Shape of computation ", name, ", ",
373 ShapeUtil::HumanString(shape),
374 ", is not compatible with that of its root instruction ",
375 root_name, ", ", ShapeUtil::HumanString(root->shape())));
376 }
377
378 if (is_entry_computation) {
379 if (*entry_computation != nullptr) {
380 return Error(maybe_entry_loc, "expects only one ENTRY");
381 }
382 *entry_computation = computation;
383 }
384
385 return AddComputation(name, computation, name_loc);
386 }
387
388 // instruction_list ::= '{' instruction_list1 '}'
389 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation::Builder * builder,string * root_name)390 bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
391 string* root_name) {
392 if (!ParseToken(TokKind::kLbrace,
393 "expects '{' at the beginning of instruction list.")) {
394 return false;
395 }
396 do {
397 if (!ParseInstruction(builder, root_name)) {
398 return false;
399 }
400 } while (lexer_.GetKind() != TokKind::kRbrace);
401 return ParseToken(TokKind::kRbrace,
402 "expects '}' at the end of instruction list.");
403 }
404
405 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,string * root_name)406 bool HloParser::ParseInstruction(HloComputation::Builder* builder,
407 string* root_name) {
408 string name;
409 Shape shape;
410 HloOpcode opcode;
411 std::vector<HloInstruction*> operands;
412
413 LocTy maybe_root_loc = lexer_.GetLoc();
414 bool is_root = EatIfPresent(TokKind::kw_ROOT);
415
416 const LocTy name_loc = lexer_.GetLoc();
417 if (!ParseName(&name) ||
418 !ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
419 !ParseShape(&shape) || !ParseOpcode(&opcode)) {
420 return false;
421 }
422
423 if (is_root) {
424 if (!root_name->empty()) {
425 return Error(maybe_root_loc, "one computation should have only one ROOT");
426 }
427 *root_name = name;
428 }
429
430 // Add optional attributes.
431 std::unordered_map<string, AttrConfig> attrs;
432 optional<OpSharding> sharding;
433 attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
434 optional<std::vector<HloInstruction*>> predecessors;
435 attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
436 &predecessors};
437 optional<OpMetadata> metadata;
438 attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
439
440 HloInstruction* instruction;
441 switch (opcode) {
442 case HloOpcode::kParameter: {
443 int64 parameter_number;
444 if (!ParseToken(TokKind::kLparen,
445 "expects '(' before parameter number") ||
446 !ParseInt64(¶meter_number) ||
447 !ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
448 !ParseAttributes(attrs)) {
449 return false;
450 }
451 instruction = builder->AddInstruction(
452 HloInstruction::CreateParameter(parameter_number, shape, name));
453 break;
454 }
455 case HloOpcode::kConstant: {
456 std::unique_ptr<Literal> literal;
457 if (!ParseToken(TokKind::kLparen,
458 "expects '(' before constant literal") ||
459 !ParseLiteral(&literal, shape) ||
460 !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
461 !ParseAttributes(attrs)) {
462 return false;
463 }
464 instruction = builder->AddInstruction(
465 HloInstruction::CreateConstant(std::move(literal)));
466 break;
467 }
468 // Unary ops.
469 case HloOpcode::kAbs:
470 case HloOpcode::kRoundNearestAfz:
471 case HloOpcode::kBitcast:
472 case HloOpcode::kCeil:
473 case HloOpcode::kCopy:
474 case HloOpcode::kCos:
475 case HloOpcode::kExp:
476 case HloOpcode::kImag:
477 case HloOpcode::kIsFinite:
478 case HloOpcode::kFloor:
479 case HloOpcode::kLog:
480 case HloOpcode::kNot:
481 case HloOpcode::kNegate:
482 case HloOpcode::kReal:
483 case HloOpcode::kSign:
484 case HloOpcode::kSin:
485 case HloOpcode::kSort:
486 case HloOpcode::kTanh: {
487 if (!ParseOperands(&operands, /*expected_size=*/1) ||
488 !ParseAttributes(attrs)) {
489 return false;
490 }
491 instruction = builder->AddInstruction(
492 HloInstruction::CreateUnary(shape, opcode, operands[0]));
493 break;
494 }
495 // Binary ops.
496 case HloOpcode::kAdd:
497 case HloOpcode::kDivide:
498 case HloOpcode::kMultiply:
499 case HloOpcode::kSubtract:
500 case HloOpcode::kAtan2:
501 case HloOpcode::kComplex:
502 case HloOpcode::kEq:
503 case HloOpcode::kGe:
504 case HloOpcode::kGt:
505 case HloOpcode::kLe:
506 case HloOpcode::kLt:
507 case HloOpcode::kNe:
508 case HloOpcode::kMaximum:
509 case HloOpcode::kMinimum:
510 case HloOpcode::kPower:
511 case HloOpcode::kRemainder:
512 case HloOpcode::kAnd:
513 case HloOpcode::kOr:
514 case HloOpcode::kShiftLeft:
515 case HloOpcode::kShiftRightArithmetic:
516 case HloOpcode::kShiftRightLogical: {
517 if (!ParseOperands(&operands, /*expected_size=*/2) ||
518 !ParseAttributes(attrs)) {
519 return false;
520 }
521 instruction = builder->AddInstruction(HloInstruction::CreateBinary(
522 shape, opcode, operands[0], operands[1]));
523 break;
524 }
525 // Ternary ops.
526 case HloOpcode::kClamp:
527 case HloOpcode::kSelect: {
528 if (!ParseOperands(&operands, /*expected_size=*/3) ||
529 !ParseAttributes(attrs)) {
530 return false;
531 }
532 instruction = builder->AddInstruction(HloInstruction::CreateTernary(
533 shape, opcode, operands[0], operands[1], operands[2]));
534 break;
535 }
536 // Other supported ops.
537 case HloOpcode::kConvert: {
538 if (!ParseOperands(&operands, /*expected_size=*/1) ||
539 !ParseAttributes(attrs)) {
540 return false;
541 }
542 instruction = builder->AddInstruction(
543 HloInstruction::CreateConvert(shape, operands[0]));
544 break;
545 }
546 case HloOpcode::kBitcastConvert: {
547 if (!ParseOperands(&operands, /*expected_size=*/1) ||
548 !ParseAttributes(attrs)) {
549 return false;
550 }
551 instruction = builder->AddInstruction(
552 HloInstruction::CreateBitcastConvert(shape, operands[0]));
553 break;
554 }
555 case HloOpcode::kCrossReplicaSum: {
556 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
557 return false;
558 }
559 instruction = builder->AddInstruction(
560 HloInstruction::CreateCrossReplicaSum(shape, operands));
561 break;
562 }
563 case HloOpcode::kReshape: {
564 if (!ParseOperands(&operands, /*expected_size=*/1) ||
565 !ParseAttributes(attrs)) {
566 return false;
567 }
568 instruction = builder->AddInstruction(
569 HloInstruction::CreateReshape(shape, operands[0]));
570 break;
571 }
572 case HloOpcode::kTuple: {
573 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
574 return false;
575 }
576 instruction =
577 builder->AddInstruction(HloInstruction::CreateTuple(operands));
578 break;
579 }
580 case HloOpcode::kWhile: {
581 optional<HloComputation*> condition;
582 optional<HloComputation*> body;
583 attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
584 &condition};
585 attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
586 if (!ParseOperands(&operands, /*expected_size=*/1) ||
587 !ParseAttributes(attrs)) {
588 return false;
589 }
590 instruction = builder->AddInstruction(HloInstruction::CreateWhile(
591 shape, *condition, *body, /*init=*/operands[0]));
592 break;
593 }
594 case HloOpcode::kRecv: {
595 optional<int64> channel_id;
596 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
597 if (!ParseOperands(&operands, /*expected_size=*/0) ||
598 !ParseAttributes(attrs)) {
599 return false;
600 }
601 instruction = builder->AddInstruction(
602 HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
603 break;
604 }
605 case HloOpcode::kRecvDone: {
606 optional<int64> channel_id;
607 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
608 if (!ParseOperands(&operands, /*expected_size=*/1) ||
609 !ParseAttributes(attrs)) {
610 return false;
611 }
612 if (channel_id != operands[0]->channel_id()) {
613 return false;
614 }
615 instruction =
616 builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
617 break;
618 }
619 case HloOpcode::kSend: {
620 optional<int64> channel_id;
621 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
622 if (!ParseOperands(&operands, /*expected_size=*/1) ||
623 !ParseAttributes(attrs)) {
624 return false;
625 }
626 instruction = builder->AddInstruction(
627 HloInstruction::CreateSend(operands[0], *channel_id));
628 break;
629 }
630 case HloOpcode::kSendDone: {
631 optional<int64> channel_id;
632 attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
633 if (!ParseOperands(&operands, /*expected_size=*/1) ||
634 !ParseAttributes(attrs)) {
635 return false;
636 }
637 if (channel_id != operands[0]->channel_id()) {
638 return false;
639 }
640 instruction =
641 builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
642 break;
643 }
644 case HloOpcode::kGetTupleElement: {
645 optional<int64> index;
646 attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
647 if (!ParseOperands(&operands, /*expected_size=*/1) ||
648 !ParseAttributes(attrs)) {
649 return false;
650 }
651 instruction = builder->AddInstruction(
652 HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
653 break;
654 }
655 case HloOpcode::kCall: {
656 optional<HloComputation*> to_apply;
657 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
658 &to_apply};
659 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
660 return false;
661 }
662 instruction = builder->AddInstruction(
663 HloInstruction::CreateCall(shape, operands, *to_apply));
664 break;
665 }
666 case HloOpcode::kReduceWindow: {
667 optional<HloComputation*> reduce_computation;
668 optional<Window> window;
669 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
670 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
671 &reduce_computation};
672 if (!ParseOperands(&operands, /*expected_size=*/2) ||
673 !ParseAttributes(attrs)) {
674 return false;
675 }
676 if (!window) {
677 window.emplace();
678 }
679 instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
680 shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
681 *reduce_computation));
682 break;
683 }
684 case HloOpcode::kConvolution: {
685 optional<Window> window;
686 optional<ConvolutionDimensionNumbers> dnums;
687 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
688 attrs["dim_labels"] = {/*required=*/true,
689 AttrTy::kConvolutionDimensionNumbers, &dnums};
690 if (!ParseOperands(&operands, /*expected_size=*/2) ||
691 !ParseAttributes(attrs)) {
692 return false;
693 }
694 if (!window) {
695 window.emplace();
696 }
697 instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
698 shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
699 break;
700 }
701 case HloOpcode::kFft: {
702 optional<FftType> fft_type;
703 optional<std::vector<int64>> fft_length;
704 attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
705 attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
706 &fft_length};
707 if (!ParseOperands(&operands, /*expected_size=*/1) ||
708 !ParseAttributes(attrs)) {
709 return false;
710 }
711 instruction = builder->AddInstruction(HloInstruction::CreateFft(
712 shape, operands[0], *fft_type, *fft_length));
713 break;
714 }
715 case HloOpcode::kBroadcast: {
716 optional<std::vector<int64>> broadcast_dimensions;
717 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
718 &broadcast_dimensions};
719 if (!ParseOperands(&operands, /*expected_size=*/1) ||
720 !ParseAttributes(attrs)) {
721 return false;
722 }
723 instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
724 shape, operands[0], *broadcast_dimensions));
725 break;
726 }
727 case HloOpcode::kConcatenate: {
728 optional<std::vector<int64>> dimensions;
729 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
730 &dimensions};
731 if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
732 dimensions->size() != 1) {
733 return false;
734 }
735 instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
736 shape, operands, dimensions->at(0)));
737 break;
738 }
739 case HloOpcode::kMap: {
740 optional<HloComputation*> to_apply;
741 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
742 &to_apply};
743 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
744 return false;
745 }
746 instruction = builder->AddInstruction(
747 HloInstruction::CreateMap(shape, operands, *to_apply));
748 break;
749 }
750 case HloOpcode::kReduce: {
751 optional<HloComputation*> reduce_computation;
752 attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
753 &reduce_computation};
754 optional<std::vector<int64>> dimensions_to_reduce;
755 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
756 &dimensions_to_reduce};
757 if (!ParseOperands(&operands, /*expected_size=*/2) ||
758 !ParseAttributes(attrs)) {
759 return false;
760 }
761 instruction = builder->AddInstruction(HloInstruction::CreateReduce(
762 shape, /*operand=*/operands[0], /*init_value=*/operands[1],
763 *dimensions_to_reduce, *reduce_computation));
764 break;
765 }
766 case HloOpcode::kReverse: {
767 optional<std::vector<int64>> dimensions;
768 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
769 &dimensions};
770 if (!ParseOperands(&operands, /*expected_size=*/1) ||
771 !ParseAttributes(attrs)) {
772 return false;
773 }
774 instruction = builder->AddInstruction(
775 HloInstruction::CreateReverse(shape, operands[0], *dimensions));
776 break;
777 }
778 case HloOpcode::kSelectAndScatter: {
779 optional<HloComputation*> select;
780 attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
781 optional<HloComputation*> scatter;
782 attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
783 optional<Window> window;
784 attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
785 if (!ParseOperands(&operands, /*expected_size=*/3) ||
786 !ParseAttributes(attrs)) {
787 return false;
788 }
789 if (!window) {
790 window.emplace();
791 }
792 instruction =
793 builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
794 shape, /*operand=*/operands[0], *select, *window,
795 /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
796 break;
797 }
798 case HloOpcode::kSlice: {
799 optional<SliceRanges> slice_ranges;
800 attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
801 if (!ParseOperands(&operands, /*expected_size=*/1) ||
802 !ParseAttributes(attrs)) {
803 return false;
804 }
805 instruction = builder->AddInstruction(HloInstruction::CreateSlice(
806 shape, operands[0], slice_ranges->starts, slice_ranges->limits,
807 slice_ranges->strides));
808 break;
809 }
810 case HloOpcode::kDynamicSlice: {
811 optional<std::vector<int64>> dynamic_slice_sizes;
812 attrs["dynamic_slice_sizes"] = {
813 /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
814 if (!ParseOperands(&operands, /*expected_size=*/2) ||
815 !ParseAttributes(attrs)) {
816 return false;
817 }
818 instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
819 shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
820 *dynamic_slice_sizes));
821 break;
822 }
823 case HloOpcode::kDynamicUpdateSlice: {
824 if (!ParseOperands(&operands, /*expected_size=*/3) ||
825 !ParseAttributes(attrs)) {
826 return false;
827 }
828 instruction =
829 builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
830 shape, /*operand=*/operands[0], /*update=*/operands[1],
831 /*start_indices=*/operands[2]));
832 break;
833 }
834 case HloOpcode::kTranspose: {
835 optional<std::vector<int64>> dimensions;
836 attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
837 &dimensions};
838 if (!ParseOperands(&operands, /*expected_size=*/1) ||
839 !ParseAttributes(attrs)) {
840 return false;
841 }
842 instruction = builder->AddInstruction(
843 HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
844 break;
845 }
846 case HloOpcode::kBatchNormTraining: {
847 optional<float> epsilon;
848 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
849 optional<int64> feature_index;
850 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
851 &feature_index};
852 if (!ParseOperands(&operands, /*expected_size=*/3) ||
853 !ParseAttributes(attrs)) {
854 return false;
855 }
856 instruction =
857 builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
858 shape, /*operand=*/operands[0], /*scale=*/operands[1],
859 /*offset=*/operands[2], *epsilon, *feature_index));
860 break;
861 }
862 case HloOpcode::kBatchNormInference: {
863 optional<float> epsilon;
864 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
865 optional<int64> feature_index;
866 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
867 &feature_index};
868 if (!ParseOperands(&operands, /*expected_size=*/5) ||
869 !ParseAttributes(attrs)) {
870 return false;
871 }
872 instruction =
873 builder->AddInstruction(HloInstruction::CreateBatchNormInference(
874 shape, /*operand=*/operands[0], /*scale=*/operands[1],
875 /*offset=*/operands[2], /*mean=*/operands[3],
876 /*variance=*/operands[4], *epsilon, *feature_index));
877 break;
878 }
879 case HloOpcode::kBatchNormGrad: {
880 optional<float> epsilon;
881 attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
882 optional<int64> feature_index;
883 attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
884 &feature_index};
885 if (!ParseOperands(&operands, /*expected_size=*/5) ||
886 !ParseAttributes(attrs)) {
887 return false;
888 }
889 instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
890 shape, /*operand=*/operands[0], /*scale=*/operands[1],
891 /*mean=*/operands[2], /*variance=*/operands[3],
892 /*grad_output=*/operands[4], *epsilon, *feature_index));
893 break;
894 }
895 case HloOpcode::kPad: {
896 optional<PaddingConfig> padding;
897 attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
898 if (!ParseOperands(&operands, /*expected_size=*/2) ||
899 !ParseAttributes(attrs)) {
900 return false;
901 }
902 instruction = builder->AddInstruction(HloInstruction::CreatePad(
903 shape, operands[0], /*padding_value=*/operands[1], *padding));
904 break;
905 }
906 case HloOpcode::kFusion: {
907 optional<HloComputation*> fusion_computation;
908 attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
909 &fusion_computation};
910 optional<HloInstruction::FusionKind> fusion_kind;
911 attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
912 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
913 return false;
914 }
915 instruction = builder->AddInstruction(HloInstruction::CreateFusion(
916 shape, *fusion_kind, operands, *fusion_computation));
917 break;
918 }
919 case HloOpcode::kInfeed: {
920 optional<string> config;
921 attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
922 if (!ParseOperands(&operands, /*expected_size=*/0) ||
923 !ParseAttributes(attrs)) {
924 return false;
925 }
926 instruction = builder->AddInstruction(
927 HloInstruction::CreateInfeed(shape, config ? *config : ""));
928 break;
929 }
930 case HloOpcode::kOutfeed: {
931 optional<string> config;
932 attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
933 if (!ParseOperands(&operands, /*expected_size=*/1) ||
934 !ParseAttributes(attrs)) {
935 return false;
936 }
937 instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
938 operands[0]->shape(), operands[0], config ? *config : ""));
939 break;
940 }
941 case HloOpcode::kRng: {
942 optional<RandomDistribution> distribution;
943 attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
944 &distribution};
945 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
946 return false;
947 }
948 instruction = builder->AddInstruction(
949 HloInstruction::CreateRng(shape, *distribution, operands));
950 break;
951 }
952 case HloOpcode::kReducePrecision: {
953 optional<int64> exponent_bits;
954 optional<int64> mantissa_bits;
955 attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
956 &exponent_bits};
957 attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
958 &mantissa_bits};
959 if (!ParseOperands(&operands, /*expected_size=*/1) ||
960 !ParseAttributes(attrs)) {
961 return false;
962 }
963 instruction =
964 builder->AddInstruction(HloInstruction::CreateReducePrecision(
965 shape, operands[0], static_cast<int>(*exponent_bits),
966 static_cast<int>(*mantissa_bits)));
967 break;
968 }
969 case HloOpcode::kConditional: {
970 optional<HloComputation*> true_computation;
971 optional<HloComputation*> false_computation;
972 attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
973 &true_computation};
974 attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation,
975 &false_computation};
976 if (!ParseOperands(&operands, /*expected_size=*/3) ||
977 !ParseAttributes(attrs)) {
978 return false;
979 }
980 instruction = builder->AddInstruction(HloInstruction::CreateConditional(
981 shape, /*pred=*/operands[0],
982 /*true_computation_arg=*/operands[1], *true_computation,
983 /*false_computation_arg=*/operands[2], *false_computation));
984 break;
985 }
986 case HloOpcode::kCustomCall: {
987 optional<string> custom_call_target;
988 attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
989 &custom_call_target};
990 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
991 return false;
992 }
993 instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
994 shape, operands, *custom_call_target));
995 break;
996 }
997 case HloOpcode::kHostCompute: {
998 optional<string> channel_name;
999 optional<int64> cost_estimate_ns;
1000 attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
1001 &channel_name};
1002 attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
1003 &cost_estimate_ns};
1004 if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1005 return false;
1006 }
1007 instruction = builder->AddInstruction(HloInstruction::CreateHostCompute(
1008 shape, operands, *channel_name, *cost_estimate_ns));
1009 break;
1010 }
1011 case HloOpcode::kDot: {
1012 optional<std::vector<int64>> lhs_contracting_dims;
1013 attrs["lhs_contracting_dims"] = {
1014 /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
1015 optional<std::vector<int64>> rhs_contracting_dims;
1016 attrs["rhs_contracting_dims"] = {
1017 /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
1018 optional<std::vector<int64>> lhs_batch_dims;
1019 attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1020 &lhs_batch_dims};
1021 optional<std::vector<int64>> rhs_batch_dims;
1022 attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1023 &rhs_batch_dims};
1024
1025 if (!ParseOperands(&operands, /*expected_size=*/2) ||
1026 !ParseAttributes(attrs)) {
1027 return false;
1028 }
1029
1030 DotDimensionNumbers dnum;
1031 if (lhs_contracting_dims) {
1032 *dnum.mutable_lhs_contracting_dimensions() = {
1033 lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
1034 }
1035 if (rhs_contracting_dims) {
1036 *dnum.mutable_rhs_contracting_dimensions() = {
1037 rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
1038 }
1039 if (lhs_batch_dims) {
1040 *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
1041 lhs_batch_dims->end()};
1042 }
1043 if (rhs_batch_dims) {
1044 *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
1045 rhs_batch_dims->end()};
1046 }
1047
1048 instruction = builder->AddInstruction(
1049 HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
1050 break;
1051 }
1052 case HloOpcode::kGather:
1053 // TODO(b/72710576): HLO parsing is not implemented for Gather.
1054 return TokenError("HLO parsing is not implemented for Gather");
1055 case HloOpcode::kTrace:
1056 return TokenError(StrCat("parsing not yet implemented for op: ",
1057 HloOpcodeString(opcode)));
1058 }
1059
1060 instruction->set_name(name);
1061
1062 // Add common attrs (sharding, control predecessors) to the instruction, if
1063 // they were seen.
1064 if (sharding) {
1065 instruction->set_sharding(
1066 HloSharding::FromProto(sharding.value()).ValueOrDie());
1067 }
1068 if (predecessors) {
1069 for (auto* pre : *predecessors) {
1070 Status status = pre->AddControlDependencyTo(instruction);
1071 if (!status.ok()) {
1072 return Error(name_loc, StrCat("error adding control dependency for: ",
1073 name, " status: ", status.ToString()));
1074 }
1075 }
1076 }
1077 if (metadata) {
1078 instruction->set_metadata(*metadata);
1079 }
1080 return AddInstruction(name, instruction, name_loc);
1081 } // NOLINT(readability/fn_size)
1082
1083 // ::= '{' (single_sharding | tuple_sharding) '}'
1084 //
1085 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)1086 bool HloParser::ParseSharding(OpSharding* sharding) {
1087 // A single sharding starts with '{' and is not followed by '{'.
1088 // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
1089 // an empty tuple.
1090 if (!ParseToken(TokKind::kLbrace,
1091 "expected '{' to start sharding attribute")) {
1092 return false;
1093 }
1094
1095 if (lexer_.GetKind() != TokKind::kLbrace &&
1096 lexer_.GetKind() != TokKind::kRbrace) {
1097 return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
1098 }
1099
1100 // Tuple sharding.
1101 // Allow empty tuple shardings.
1102 if (lexer_.GetKind() != TokKind::kRbrace) {
1103 do {
1104 if (!ParseSingleSharding(sharding->add_tuple_shardings(),
1105 /*lbrace_pre_lexed=*/false)) {
1106 return false;
1107 }
1108 } while (EatIfPresent(TokKind::kComma));
1109 }
1110 sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
1111
1112 return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
1113 }
1114
1115 // ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
1116 // ('devices=' ('[' dims ']')* device_list)? '}'
1117 // dims ::= int_list device_list ::= int_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)1118 bool HloParser::ParseSingleSharding(OpSharding* sharding,
1119 bool lbrace_pre_lexed) {
1120 if (!lbrace_pre_lexed &&
1121 !ParseToken(TokKind::kLbrace,
1122 "expected '{' to start sharding attribute")) {
1123 return false;
1124 }
1125
1126 LocTy loc = lexer_.GetLoc();
1127 bool maximal = false;
1128 bool replicated = false;
1129 std::vector<int64> devices;
1130 std::vector<int64> tile_assignment_dimensions;
1131 Shape tile_shape;
1132 while (lexer_.GetKind() != TokKind::kRbrace) {
1133 switch (lexer_.GetKind()) {
1134 case TokKind::kw_maximal:
1135 maximal = true;
1136 lexer_.Lex();
1137 break;
1138 case TokKind::kw_replicated:
1139 replicated = true;
1140 lexer_.Lex();
1141 break;
1142 case TokKind::kAttributeName: {
1143 if (lexer_.GetStrVal() == "device") {
1144 if (lexer_.Lex() != TokKind::kInt) {
1145 return TokenError("device= attribute must be an integer");
1146 }
1147 devices = {lexer_.GetInt64Val()};
1148 lexer_.Lex();
1149 } else if (lexer_.GetStrVal() == "devices") {
1150 lexer_.Lex();
1151 if (!ParseToken(TokKind::kLsquare,
1152 "expected '[' to start sharding devices shape")) {
1153 return false;
1154 }
1155
1156 do {
1157 int64 dim;
1158 if (!ParseInt64(&dim)) {
1159 return false;
1160 }
1161 tile_assignment_dimensions.push_back(dim);
1162 } while (EatIfPresent(TokKind::kComma));
1163
1164 if (!ParseToken(TokKind::kRsquare,
1165 "expected ']' to start sharding devices shape")) {
1166 return false;
1167 }
1168 do {
1169 int64 device;
1170 if (!ParseInt64(&device)) {
1171 return false;
1172 }
1173 devices.push_back(device);
1174 } while (EatIfPresent(TokKind::kComma));
1175 } else {
1176 return TokenError(
1177 "unknown attribute in sharding: expected device= or devices=");
1178 }
1179 break;
1180 }
1181 case TokKind::kShape:
1182 tile_shape = lexer_.GetShapeVal();
1183 lexer_.Lex();
1184 break;
1185 case TokKind::kRbrace:
1186 break;
1187 default:
1188 return TokenError("unexpected token");
1189 }
1190 }
1191
1192 if (replicated) {
1193 if (!devices.empty()) {
1194 return Error(loc,
1195 "replicated shardings should not have any devices assigned");
1196 }
1197 if (!ShapeUtil::Equal(tile_shape, Shape())) {
1198 return Error(loc,
1199 "replicated shardings should not have any tile shape set");
1200 }
1201 sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
1202 } else if (maximal) {
1203 if (devices.size() != 1) {
1204 return Error(loc,
1205 "maximal shardings should have exactly one device assigned");
1206 }
1207 if (!ShapeUtil::Equal(tile_shape, Shape())) {
1208 return Error(loc, "maximal shardings should not have any tile shape set");
1209 }
1210 sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
1211 sharding->add_tile_assignment_devices(devices[0]);
1212 } else {
1213 if (devices.size() <= 1) {
1214 return Error(
1215 loc, "non-maximal shardings must have more than one device assigned");
1216 }
1217 if (ShapeUtil::Equal(tile_shape, Shape())) {
1218 return Error(loc, "non-maximal shardings should have a tile shape set");
1219 }
1220 if (tile_assignment_dimensions.empty()) {
1221 return Error(
1222 loc,
1223 "non-maximal shardings must have a tile assignment list including "
1224 "dimensions");
1225 }
1226 sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
1227 *sharding->mutable_tile_shape() = tile_shape;
1228 for (int64 dim : tile_assignment_dimensions) {
1229 sharding->add_tile_assignment_dimensions(dim);
1230 }
1231 for (int64 device : devices) {
1232 sharding->add_tile_assignment_devices(device);
1233 }
1234 }
1235
1236 lexer_.Lex();
1237 return true;
1238 }
1239
1240 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)1241 bool HloParser::ParseInstructionNames(
1242 std::vector<HloInstruction*>* instructions) {
1243 if (!ParseToken(TokKind::kLbrace,
1244 "expects '{' at the beginning of instruction name list")) {
1245 return false;
1246 }
1247 LocTy loc = lexer_.GetLoc();
1248 do {
1249 string name;
1250 if (!ParseName(&name)) {
1251 return Error(loc, "expects a instruction name");
1252 }
1253 std::pair<HloInstruction*, LocTy>* instr =
1254 tensorflow::gtl::FindOrNull(instruction_pool_, name);
1255 if (!instr) {
1256 return TokenError(
1257 Printf("instruction '%s' is not defined", name.c_str()));
1258 }
1259 instructions->push_back(instr->first);
1260 } while (EatIfPresent(TokKind::kComma));
1261
1262 return ParseToken(TokKind::kRbrace,
1263 "expects '}' at the end of instruction name list");
1264 }
1265
SetValueInLiteral(int64 value,int64 linear_index,Literal * literal)1266 bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
1267 Literal* literal) {
1268 const Shape& shape = literal->shape();
1269 switch (shape.element_type()) {
1270 case S8:
1271 return SetValueInLiteralHelper<int8>(value, linear_index, literal);
1272 case S16:
1273 return SetValueInLiteralHelper<int16>(value, linear_index, literal);
1274 case S32:
1275 return SetValueInLiteralHelper<int32>(value, linear_index, literal);
1276 case S64:
1277 return SetValueInLiteralHelper<int64>(value, linear_index, literal);
1278 case U8:
1279 return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
1280 case U16:
1281 return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
1282 case U32:
1283 return SetValueInLiteralHelper<uint32>(value, linear_index, literal);
1284 case U64:
1285 return SetValueInLiteralHelper<uint64>(value, linear_index, literal);
1286 default:
1287 LOG(FATAL) << "unknown integral primitive type "
1288 << PrimitiveType_Name(shape.element_type());
1289 }
1290 }
1291
SetValueInLiteral(double value,int64 linear_index,Literal * literal)1292 bool HloParser::SetValueInLiteral(double value, int64 linear_index,
1293 Literal* literal) {
1294 const Shape& shape = literal->shape();
1295 switch (shape.element_type()) {
1296 case F16:
1297 return SetValueInLiteralHelper<half>(value, linear_index, literal);
1298 case BF16:
1299 return SetValueInLiteralHelper<bfloat16>(value, linear_index, literal);
1300 case F32:
1301 return SetValueInLiteralHelper<float>(value, linear_index, literal);
1302 case F64:
1303 return SetValueInLiteralHelper<double>(value, linear_index, literal);
1304 default:
1305 LOG(FATAL) << "unknown floating point primitive type "
1306 << PrimitiveType_Name(shape.element_type());
1307 }
1308 }
1309
SetValueInLiteral(bool value,int64 linear_index,Literal * literal)1310 bool HloParser::SetValueInLiteral(bool value, int64 linear_index,
1311 Literal* literal) {
1312 const Shape& shape = literal->shape();
1313 switch (shape.element_type()) {
1314 case PRED:
1315 return SetValueInLiteralHelper<bool>(value, linear_index, literal);
1316 default:
1317 LOG(FATAL) << PrimitiveType_Name(shape.element_type())
1318 << " is not PRED type";
1319 }
1320 }
1321
1322 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(ParsedElemT value,int64 linear_index,Literal * literal)1323 bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
1324 Literal* literal) {
1325 // Check that linear_index is in range.
1326 if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
1327 return TokenError(
1328 StrCat("trys to set value ", value, " to a literal in shape ",
1329 ShapeUtil::HumanString(literal->shape()), " at linear index ",
1330 linear_index, ", but the index is out of range"));
1331 }
1332
1333 if (std::isnan(value) ||
1334 (std::numeric_limits<ParsedElemT>::has_infinity &&
1335 (std::numeric_limits<ParsedElemT>::infinity() == value ||
1336 -std::numeric_limits<ParsedElemT>::infinity() == value))) {
1337 // Skip range checking for non-finite value.
1338 } else if (literal->shape().element_type() == F16 ||
1339 literal->shape().element_type() == BF16) {
1340 if (value > kF16max || value < -kF16max) {
1341 return TokenError(StrCat(
1342 "value ", value, " is out of range for literal's primitive type ",
1343 PrimitiveType_Name(literal->shape().element_type())));
1344 }
1345 } else if (value > static_cast<ParsedElemT>(
1346 std::numeric_limits<LiteralNativeT>::max()) ||
1347 value < static_cast<ParsedElemT>(
1348 std::numeric_limits<LiteralNativeT>::lowest())) {
1349 // Value is out of range for LiteralNativeT.
1350 return TokenError(StrCat(
1351 "value ", value, " is out of range for literal's primitive type ",
1352 PrimitiveType_Name(literal->shape().element_type())));
1353 }
1354
1355 literal->data<LiteralNativeT>().at(linear_index) =
1356 static_cast<LiteralNativeT>(value);
1357 return true;
1358 }
1359
EatShapeAndCheckCompatible(const Shape & shape)1360 bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
1361 Shape new_shape;
1362 if (!ParseShape(&new_shape)) {
1363 return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape)));
1364 }
1365 if (!ShapeUtil::Compatible(shape, new_shape)) {
1366 return TokenError(StrCat(
1367 "expects shape ", ShapeUtil::HumanString(shape),
1368 ", but sees a different shape: ", ShapeUtil::HumanString(new_shape)));
1369 }
1370 return true;
1371 }
1372
1373 // literal
1374 // ::= tuple
1375 // ::= non_tuple
ParseLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1376 bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
1377 const Shape& shape) {
1378 return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
1379 : ParseNonTupleLiteral(literal, shape);
1380 }
1381
1382 // tuple
1383 // ::= shape '(' literal_list ')'
1384 // literal_list
1385 // ::= /*empty*/
1386 // ::= literal (',' literal)*
ParseTupleLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1387 bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
1388 const Shape& shape) {
1389 if (!EatShapeAndCheckCompatible(shape)) {
1390 return TokenError(StrCat("expects tuple constant in shape ",
1391 ShapeUtil::HumanString(shape)));
1392 }
1393 if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
1394 return false;
1395 }
1396 std::vector<std::unique_ptr<Literal>> elements(
1397 ShapeUtil::TupleElementCount(shape));
1398
1399 if (lexer_.GetKind() == TokKind::kRparen) {
1400 // empty
1401 } else {
1402 // literal, (',' literal)*
1403 for (int i = 0; i < elements.size(); i++) {
1404 if (i > 0) {
1405 ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements");
1406 }
1407 if (!ParseLiteral(&elements[i],
1408 ShapeUtil::GetTupleElementShape(shape, i))) {
1409 return TokenError(StrCat("expects the ", i, "th element"));
1410 }
1411 }
1412 }
1413 *literal = Literal::MakeTupleOwned(std::move(elements));
1414 return ParseToken(TokKind::kRparen,
1415 StrCat("expects ')' at the end of the tuple with ",
1416 ShapeUtil::TupleElementCount(shape), "elements"));
1417 }
1418
1419 // non_tuple
1420 // ::= rank01
1421 // ::= rank2345
1422 // rank2345 ::= shape sparse_or_nested_array
ParseNonTupleLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1423 bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
1424 const Shape& shape) {
1425 if (LayoutUtil::IsSparseArray(shape)) {
1426 return ParseSparseLiteral(literal, shape);
1427 }
1428
1429 CHECK(LayoutUtil::IsDenseArray(shape));
1430 return ParseDenseLiteral(literal, shape);
1431 }
1432
ParseDenseLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1433 bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
1434 const Shape& shape) {
1435 const int64 rank = ShapeUtil::Rank(shape);
1436 if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
1437 return false;
1438 }
1439
1440 // Create a literal with the given shape in default layout.
1441 *literal = Literal::CreateFromDimensions(shape.element_type(),
1442 AsInt64Slice(shape.dimensions()));
1443 int64 nest_level = 0;
1444 int64 linear_index = 0;
1445 // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
1446 // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
1447 // when we are parsing the 2nd '{' (right before '1'), we are seeing a
1448 // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
1449 // the first '}' (right after '3'), it means the sub-array ends, and the
1450 // sub-array is supposed to contain exactly 3 elements, so check if
1451 // elems_seen_per_dim[1] is 3.
1452 std::vector<int64> elems_seen_per_dim(rank);
1453 auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
1454 std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
1455 elems_seen_per_dim.begin() + dim);
1456 return StrCat("[",
1457 tensorflow::str_util::Join(
1458 elems_seen_until_dim, ",",
1459 [](string* out, const int64& num_elems) {
1460 tensorflow::strings::StrAppend(out, num_elems - 1);
1461 }),
1462 "]");
1463 };
1464 do {
1465 switch (lexer_.GetKind()) {
1466 default:
1467 return TokenError("unexpected token type in a literal");
1468 case TokKind::kLbrace: {
1469 nest_level++;
1470 if (nest_level > rank) {
1471 return TokenError(Printf(
1472 "expects nested array in rank %lld, but sees larger", rank));
1473 }
1474 if (nest_level > 1) {
1475 elems_seen_per_dim[nest_level - 2]++;
1476 if (elems_seen_per_dim[nest_level - 2] >
1477 shape.dimensions(nest_level - 2)) {
1478 return TokenError(Printf(
1479 "expects %lld elements in the %sth element, but sees more",
1480 shape.dimensions(nest_level - 2),
1481 get_index_str(nest_level - 2).c_str()));
1482 }
1483 }
1484 lexer_.Lex();
1485 break;
1486 }
1487 case TokKind::kRbrace: {
1488 nest_level--;
1489 if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
1490 return TokenError(Printf(
1491 "expects %lld elements in the %sth element, but sees %lld",
1492 shape.dimensions(nest_level), get_index_str(nest_level).c_str(),
1493 elems_seen_per_dim[nest_level]));
1494 }
1495 elems_seen_per_dim[nest_level] = 0;
1496 lexer_.Lex();
1497 break;
1498 }
1499 case TokKind::kComma:
1500 case TokKind::kComment:
1501 // Skip.
1502 lexer_.Lex();
1503 break;
1504 case TokKind::kw_true:
1505 case TokKind::kw_false:
1506 case TokKind::kInt:
1507 case TokKind::kDecimal:
1508 case TokKind::kw_nan:
1509 case TokKind::kw_inf:
1510 case TokKind::kNegInf: {
1511 if (rank > 0) {
1512 if (nest_level != rank) {
1513 return TokenError(
1514 Printf("expects nested array in rank %lld, but sees %lld", rank,
1515 nest_level));
1516 }
1517 elems_seen_per_dim[rank - 1]++;
1518 if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
1519 return TokenError(
1520 Printf("expects %lld elements on the minor-most dimension, but "
1521 "sees more",
1522 shape.dimensions(rank - 1)));
1523 }
1524 }
1525 if (lexer_.GetKind() == TokKind::kw_true ||
1526 lexer_.GetKind() == TokKind::kw_false) {
1527 // TODO(congliu): bool type literals with rank >= 1 are actually
1528 // printed in a compact form instead of "true" or "false". Fix that.
1529 if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
1530 linear_index++, literal->get())) {
1531 return false;
1532 }
1533 lexer_.Lex();
1534 } else if (primitive_util::IsIntegralType(shape.element_type())) {
1535 LocTy loc = lexer_.GetLoc();
1536 int64 value;
1537 if (!ParseInt64(&value)) {
1538 return Error(loc, StrCat("expects integer for primitive type: ",
1539 PrimitiveType_Name(shape.element_type())));
1540 }
1541 if (!SetValueInLiteral(value, linear_index++, literal->get())) {
1542 return false;
1543 }
1544 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
1545 LocTy loc = lexer_.GetLoc();
1546 double value;
1547 if (!ParseDouble(&value)) {
1548 return Error(
1549 loc, StrCat("expect floating point value for primitive type: ",
1550 PrimitiveType_Name(shape.element_type())));
1551 }
1552 if (!SetValueInLiteral(value, linear_index++, literal->get())) {
1553 return false;
1554 }
1555 } else {
1556 return TokenError(StrCat("unsupported primitive type ",
1557 PrimitiveType_Name(shape.element_type())));
1558 }
1559 break;
1560 }
1561 } // end of switch
1562 } while (nest_level > 0);
1563
1564 *literal = (*literal)->Relayout(shape.layout());
1565 return true;
1566 }
1567
ParseSparseLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1568 bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
1569 const Shape& shape) {
1570 if (!EatShapeAndCheckCompatible(shape)) {
1571 return false;
1572 }
1573
1574 switch (shape.element_type()) {
1575 case PRED:
1576 return ParseSparseLiteralHelper<uint8>(literal, shape);
1577 case S8:
1578 return ParseSparseLiteralHelper<int8>(literal, shape);
1579 case S16:
1580 return ParseSparseLiteralHelper<int16>(literal, shape);
1581 case S32:
1582 return ParseSparseLiteralHelper<int32>(literal, shape);
1583 case S64:
1584 return ParseSparseLiteralHelper<int64>(literal, shape);
1585 case U8:
1586 return ParseSparseLiteralHelper<uint8>(literal, shape);
1587 case U16:
1588 return ParseSparseLiteralHelper<uint16>(literal, shape);
1589 case U32:
1590 return ParseSparseLiteralHelper<uint32>(literal, shape);
1591 case U64:
1592 return ParseSparseLiteralHelper<uint64>(literal, shape);
1593 case F16:
1594 return ParseSparseLiteralHelper<half>(literal, shape);
1595 case F32:
1596 return ParseSparseLiteralHelper<float>(literal, shape);
1597 case BF16:
1598 return ParseSparseLiteralHelper<bfloat16>(literal, shape);
1599 case F64:
1600 return ParseSparseLiteralHelper<double>(literal, shape);
1601 default:
1602 return Error(lexer_.GetLoc(),
1603 StrCat("invalid primitive type for sparse literal: ",
1604 PrimitiveType_Name(shape.element_type())));
1605 }
1606 }
1607
1608 template <typename LiteralNativeT>
ParseSparseLiteralHelper(std::unique_ptr<Literal> * literal,const Shape & shape)1609 bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
1610 const Shape& shape) {
1611 std::vector<int64> index;
1612
1613 int64 rank = ShapeUtil::Rank(shape);
1614
1615 *literal = MakeUnique<Literal>(shape);
1616
1617 if (!ParseToken(TokKind::kLbrace,
1618 "expects '{' at the beginning of a sparse literal")) {
1619 return false;
1620 }
1621
1622 for (;;) {
1623 if (lexer_.GetKind() == TokKind::kRbrace) {
1624 lexer_.Lex();
1625 break;
1626 }
1627
1628 LocTy index_loc = lexer_.GetLoc();
1629 index.clear();
1630 if (lexer_.GetKind() == TokKind::kInt) {
1631 int64 single_index = lexer_.GetInt64Val();
1632 lexer_.Lex();
1633 if (rank != 1) {
1634 return Error(
1635 index_loc,
1636 StrCat("invalid single-dimensional index for shape with rank ",
1637 rank, ": ", single_index));
1638 }
1639 index.push_back(single_index);
1640 } else {
1641 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
1642 &index)) {
1643 return false;
1644 }
1645 if (index.size() != rank) {
1646 return Error(
1647 index_loc,
1648 StrCat("invalid multi-dimension index for shape with rank ", rank,
1649 ": [", tensorflow::str_util::Join(index, ", "), "]"));
1650 }
1651 }
1652 if (!ParseToken(TokKind::kColon,
1653 "expects ':' after after the sparse array index and before "
1654 "the sparse array value")) {
1655 return false;
1656 }
1657 LocTy value_loc = lexer_.GetLoc();
1658 LiteralNativeT value;
1659 if (lexer_.GetKind() == TokKind::kw_true ||
1660 lexer_.GetKind() == TokKind::kw_false) {
1661 value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true);
1662 lexer_.Lex();
1663 } else if (primitive_util::IsIntegralType(shape.element_type())) {
1664 int64 value_s64;
1665 if (!ParseInt64(&value_s64)) {
1666 return Error(value_loc,
1667 StrCat("expects integer for primitive type: ",
1668 PrimitiveType_Name(shape.element_type())));
1669 }
1670 value = static_cast<LiteralNativeT>(value_s64);
1671 } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
1672 double value_f64;
1673 if (!ParseDouble(&value_f64)) {
1674 return Error(value_loc,
1675 StrCat("expects floating point value for primitive type: ",
1676 PrimitiveType_Name(shape.element_type())));
1677 }
1678 value = static_cast<LiteralNativeT>(value_f64);
1679 } else {
1680 LOG(FATAL) << "Unexpected element type: "
1681 << PrimitiveType_Name(shape.element_type());
1682 }
1683 if (lexer_.GetKind() != TokKind::kRbrace &&
1684 !ParseToken(TokKind::kComma,
1685 "expects ',' separator between sparse array elements")) {
1686 return false;
1687 }
1688
1689 if ((*literal)->sparse_element_count() + 1 ==
1690 LayoutUtil::MaxSparseElements(shape.layout())) {
1691 return Error(
1692 lexer_.GetLoc(),
1693 StrCat("number of sparse elements exceeds maximum for layout: ",
1694 ShapeUtil::HumanStringWithLayout(shape)));
1695 }
1696
1697 (*literal)->AppendSparseElement(index, value);
1698 }
1699
1700 (*literal)->SortSparseElements();
1701 return true;
1702 }
1703
1704 // operands ::= '(' operands1 ')'
1705 // operands1
1706 // ::= /*empty*/
1707 // ::= operand (, operand)*
1708 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)1709 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
1710 if (!ParseToken(TokKind::kLparen,
1711 "expects '(' at the beginning of operands")) {
1712 return false;
1713 }
1714 if (lexer_.GetKind() == TokKind::kRparen) {
1715 // empty
1716 } else {
1717 do {
1718 LocTy loc = lexer_.GetLoc();
1719 string name;
1720 if (CanBeShape()) {
1721 Shape shape;
1722 if (!ParseShape(&shape)) {
1723 return false;
1724 }
1725 }
1726 if (!ParseName(&name)) {
1727 return false;
1728 }
1729 std::pair<HloInstruction*, LocTy>* instruction =
1730 tensorflow::gtl::FindOrNull(instruction_pool_, name);
1731 if (!instruction) {
1732 return Error(loc, StrCat("instruction does not exist: ", name));
1733 }
1734 operands->push_back(instruction->first);
1735 } while (EatIfPresent(TokKind::kComma));
1736 }
1737 return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
1738 }
1739
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)1740 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
1741 const int expected_size) {
1742 LocTy loc = lexer_.GetLoc();
1743 if (!ParseOperands(operands)) {
1744 return false;
1745 }
1746 if (expected_size != operands->size()) {
1747 return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
1748 operands->size(), " operands"));
1749 }
1750 return true;
1751 }
1752
1753 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const std::unordered_map<string,AttrConfig> & attrs)1754 bool HloParser::ParseSubAttributes(
1755 const std::unordered_map<string, AttrConfig>& attrs) {
1756 LocTy loc = lexer_.GetLoc();
1757 if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
1758 return false;
1759 }
1760 std::unordered_set<string> seen_attrs;
1761 if (lexer_.GetKind() == TokKind::kRbrace) {
1762 // empty
1763 } else {
1764 do {
1765 EatIfPresent(TokKind::kComma);
1766 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
1767 return false;
1768 }
1769 } while (lexer_.GetKind() != TokKind::kRbrace);
1770 }
1771 // Check that all required attrs were seen.
1772 for (const auto& attr_it : attrs) {
1773 if (attr_it.second.required &&
1774 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
1775 return Error(loc, Printf("sub-attribute %s is expected but not seen",
1776 attr_it.first.c_str()));
1777 }
1778 }
1779 return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
1780 }
1781
1782 // attributes ::= (',' attribute)*
ParseAttributes(const std::unordered_map<string,AttrConfig> & attrs)1783 bool HloParser::ParseAttributes(
1784 const std::unordered_map<string, AttrConfig>& attrs) {
1785 LocTy loc = lexer_.GetLoc();
1786 std::unordered_set<string> seen_attrs;
1787 while (EatIfPresent(TokKind::kComma)) {
1788 if (!ParseAttributeHelper(attrs, &seen_attrs)) {
1789 return false;
1790 }
1791 }
1792 // Check that all required attrs were seen.
1793 for (const auto& attr_it : attrs) {
1794 if (attr_it.second.required &&
1795 seen_attrs.find(attr_it.first) == seen_attrs.end()) {
1796 return Error(loc, Printf("attribute %s is expected but not seen",
1797 attr_it.first.c_str()));
1798 }
1799 }
1800 return true;
1801 }
1802
ParseAttributeHelper(const std::unordered_map<string,AttrConfig> & attrs,std::unordered_set<string> * seen_attrs)1803 bool HloParser::ParseAttributeHelper(
1804 const std::unordered_map<string, AttrConfig>& attrs,
1805 std::unordered_set<string>* seen_attrs) {
1806 LocTy loc = lexer_.GetLoc();
1807 string name;
1808 if (!ParseAttributeName(&name)) {
1809 return Error(loc, "error parsing attributes");
1810 }
1811 VLOG(1) << "Parsing attribute " << name;
1812 if (!seen_attrs->insert(name).second) {
1813 return Error(loc, Printf("attribute %s already exists", name.c_str()));
1814 }
1815 auto attr_it = attrs.find(name);
1816 if (attr_it == attrs.end()) {
1817 return Error(loc, Printf("unexpected attribute %s", name.c_str()));
1818 }
1819 AttrTy attr_type = attr_it->second.attr_type;
1820 void* attr_out_ptr = attr_it->second.result;
1821 bool success = [&] {
1822 LocTy attr_loc = lexer_.GetLoc();
1823 switch (attr_type) {
1824 case AttrTy::kInt64: {
1825 int64 result;
1826 if (!ParseInt64(&result)) {
1827 return false;
1828 }
1829 static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
1830 return true;
1831 }
1832 case AttrTy::kInt32: {
1833 int64 result;
1834 if (!ParseInt64(&result)) {
1835 return false;
1836 }
1837 if (result != static_cast<int32>(result)) {
1838 return Error(attr_loc, "value out of range for int32");
1839 }
1840 static_cast<optional<int32>*>(attr_out_ptr)
1841 ->emplace(static_cast<int32>(result));
1842 return true;
1843 }
1844 case AttrTy::kFloat: {
1845 double result;
1846 if (!ParseDouble(&result)) {
1847 return false;
1848 }
1849 if (result > std::numeric_limits<float>::max() ||
1850 result < std::numeric_limits<float>::lowest()) {
1851 return Error(attr_loc, "value out of range for float");
1852 }
1853 static_cast<optional<float>*>(attr_out_ptr)
1854 ->emplace(static_cast<float>(result));
1855 return true;
1856 }
1857 case AttrTy::kHloComputation: {
1858 HloComputation* result;
1859 if (!ParseComputationName(&result)) {
1860 return false;
1861 }
1862 static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
1863 return true;
1864 }
1865 case AttrTy::kFftType: {
1866 FftType result;
1867 if (!ParseFftType(&result)) {
1868 return false;
1869 }
1870 static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
1871 return true;
1872 }
1873 case AttrTy::kWindow: {
1874 Window result;
1875 if (!ParseWindow(&result)) {
1876 return false;
1877 }
1878 static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
1879 return true;
1880 }
1881 case AttrTy::kConvolutionDimensionNumbers: {
1882 ConvolutionDimensionNumbers result;
1883 if (!ParseConvolutionDimensionNumbers(&result)) {
1884 return false;
1885 }
1886 static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
1887 ->emplace(result);
1888 return true;
1889 }
1890 case AttrTy::kSharding: {
1891 OpSharding sharding;
1892 if (!ParseSharding(&sharding)) {
1893 return false;
1894 }
1895 static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
1896 return true;
1897 }
1898 case AttrTy::kInstructionList: {
1899 std::vector<HloInstruction*> result;
1900 if (!ParseInstructionNames(&result)) {
1901 return false;
1902 }
1903 static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
1904 ->emplace(result);
1905 return true;
1906 }
1907 case AttrTy::kFusionKind: {
1908 HloInstruction::FusionKind result;
1909 if (!ParseFusionKind(&result)) {
1910 return false;
1911 }
1912 static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
1913 ->emplace(result);
1914 return true;
1915 }
1916 case AttrTy::kBracedInt64List: {
1917 std::vector<int64> result;
1918 if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
1919 &result)) {
1920 return false;
1921 }
1922 static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
1923 ->emplace(result);
1924 return true;
1925 }
1926 case AttrTy::kSliceRanges: {
1927 SliceRanges result;
1928 if (!ParseSliceRanges(&result)) {
1929 return false;
1930 }
1931 static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
1932 return true;
1933 }
1934 case AttrTy::kPaddingConfig: {
1935 PaddingConfig result;
1936 if (!ParsePaddingConfig(&result)) {
1937 return false;
1938 }
1939 static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
1940 return true;
1941 }
1942 case AttrTy::kString: {
1943 string result;
1944 if (!ParseString(&result)) {
1945 return false;
1946 }
1947 static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
1948 return true;
1949 }
1950 case AttrTy::kMetadata: {
1951 OpMetadata result;
1952 if (!ParseMetadata(&result)) {
1953 return false;
1954 }
1955 static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
1956 return true;
1957 }
1958 case AttrTy::kDistribution: {
1959 RandomDistribution result;
1960 if (!ParseRandomDistribution(&result)) {
1961 return false;
1962 }
1963 static_cast<optional<RandomDistribution>*>(attr_out_ptr)
1964 ->emplace(result);
1965 return true;
1966 }
1967 }
1968 }();
1969 if (!success) {
1970 return Error(loc, Printf("error parsing attribute %s", name.c_str()));
1971 }
1972 return true;
1973 }
1974
ParseComputationName(HloComputation ** value)1975 bool HloParser::ParseComputationName(HloComputation** value) {
1976 string name;
1977 LocTy loc = lexer_.GetLoc();
1978 if (!ParseName(&name)) {
1979 return Error(loc, "expects computation name");
1980 }
1981 std::pair<HloComputation*, LocTy>* computation =
1982 tensorflow::gtl::FindOrNull(computation_pool_, name);
1983 if (computation == nullptr) {
1984 return Error(loc, StrCat("computation does not exist: ", name));
1985 }
1986 *value = computation->first;
1987 return true;
1988 }
1989
1990 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
1991 // The subattributes can appear in any order. 'size=' is required, others are
1992 // optional.
ParseWindow(Window * window)1993 bool HloParser::ParseWindow(Window* window) {
1994 LocTy loc = lexer_.GetLoc();
1995 if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
1996 return false;
1997 }
1998
1999 std::vector<int64> size;
2000 std::vector<int64> stride;
2001 std::vector<std::vector<int64>> pad;
2002 std::vector<int64> lhs_dilate;
2003 std::vector<int64> rhs_dilate;
2004 std::vector<int64> rhs_reversal;
2005 while (lexer_.GetKind() != TokKind::kRbrace) {
2006 LocTy attr_loc = lexer_.GetLoc();
2007 string field_name;
2008 if (!ParseAttributeName(&field_name)) {
2009 return Error(attr_loc, "expects sub-attributes in window");
2010 }
2011 bool ok = [&] {
2012 if (field_name == "size") {
2013 return ParseDxD("size", &size);
2014 }
2015 if (field_name == "stride") {
2016 return ParseDxD("stride", &stride);
2017 }
2018 if (field_name == "lhs_dilate") {
2019 return ParseDxD("lhs_dilate", &lhs_dilate);
2020 }
2021 if (field_name == "rhs_dilate") {
2022 return ParseDxD("rls_dilate", &rhs_dilate);
2023 }
2024 if (field_name == "pad") {
2025 return ParseWindowPad(&pad);
2026 }
2027 if (field_name == "rhs_reversal") {
2028 return ParseDxD("rhs_reversal", &rhs_reversal);
2029 }
2030 return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
2031 }();
2032 if (!ok) {
2033 return false;
2034 }
2035 }
2036
2037 if (size.empty()) {
2038 return Error(loc,
2039 "sub-attribute 'size=' is required in the window attribute");
2040 }
2041 if (!stride.empty() && stride.size() != size.size()) {
2042 return Error(loc, "expects 'stride=' has the same size as 'size='");
2043 }
2044 if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
2045 return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
2046 }
2047 if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
2048 return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
2049 }
2050 if (!pad.empty() && pad.size() != size.size()) {
2051 return Error(loc, "expects 'pad=' has the same size as 'size='");
2052 }
2053
2054 for (int i = 0; i < size.size(); i++) {
2055 window->add_dimensions()->set_size(size[i]);
2056 if (!pad.empty()) {
2057 window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
2058 window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
2059 }
2060 // If some field is not present, it has the default value.
2061 window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
2062 window->mutable_dimensions(i)->set_base_dilation(
2063 lhs_dilate.empty() ? 1 : lhs_dilate[i]);
2064 window->mutable_dimensions(i)->set_window_dilation(
2065 rhs_dilate.empty() ? 1 : rhs_dilate[i]);
2066 window->mutable_dimensions(i)->set_window_reversal(
2067 rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
2068 }
2069 return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
2070 }
2071
2072 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
2073 // The string looks like "dim_labels=0bf_0io->0bf".
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)2074 bool HloParser::ParseConvolutionDimensionNumbers(
2075 ConvolutionDimensionNumbers* dnums) {
2076 if (lexer_.GetKind() != TokKind::kDimLabels) {
2077 return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
2078 }
2079 string str = lexer_.GetStrVal();
2080
2081 // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
2082 // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
2083 // So we replace the "->" with "_" and then split on "_".
2084 str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
2085 /*newsub=*/"_",
2086 /*replace_all=*/false);
2087 std::vector<string> lhs_rhs_out = Split(str, "_");
2088 if (lhs_rhs_out.size() != 3) {
2089 LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
2090 << str;
2091 }
2092
2093 const int64 rank = lhs_rhs_out[0].length();
2094 if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
2095 return TokenError(
2096 "convolution lhs, rhs, and output must have the same rank");
2097 }
2098 if (rank < 2) {
2099 return TokenError("convolution rank must >=2");
2100 }
2101
2102 auto is_unique = [](string str) -> bool {
2103 std::sort(str.begin(), str.end());
2104 return std::unique(str.begin(), str.end()) == str.end();
2105 };
2106
2107 // lhs
2108 {
2109 const string& lhs = lhs_rhs_out[0];
2110 if (!is_unique(lhs)) {
2111 return TokenError(
2112 StrCat("expects unique lhs dimension numbers, but sees ", lhs));
2113 }
2114 for (int i = 0; i < rank - 2; i++) {
2115 dnums->add_input_spatial_dimensions(-1);
2116 }
2117 for (int i = 0; i < rank; i++) {
2118 char c = lhs[i];
2119 if (c == 'b') {
2120 dnums->set_input_batch_dimension(i);
2121 } else if (c == 'f') {
2122 dnums->set_input_feature_dimension(i);
2123 } else if (c < '0' + rank && c >= '0') {
2124 dnums->set_input_spatial_dimensions(c - '0', i);
2125 } else {
2126 return TokenError(
2127 Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
2128 }
2129 }
2130 }
2131 // rhs
2132 {
2133 const string& rhs = lhs_rhs_out[1];
2134 if (!is_unique(rhs)) {
2135 return TokenError(
2136 StrCat("expects unique rhs dimension numbers, but sees ", rhs));
2137 }
2138 for (int i = 0; i < rank - 2; i++) {
2139 dnums->add_kernel_spatial_dimensions(-1);
2140 }
2141 for (int i = 0; i < rank; i++) {
2142 char c = rhs[i];
2143 if (c == 'i') {
2144 dnums->set_kernel_input_feature_dimension(i);
2145 } else if (c == 'o') {
2146 dnums->set_kernel_output_feature_dimension(i);
2147 } else if (c < '0' + rank && c >= '0') {
2148 dnums->set_kernel_spatial_dimensions(c - '0', i);
2149 } else {
2150 return TokenError(
2151 Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
2152 }
2153 }
2154 }
2155 // output
2156 {
2157 const string& out = lhs_rhs_out[2];
2158 if (!is_unique(out)) {
2159 return TokenError(
2160 StrCat("expects unique output dimension numbers, but sees ", out));
2161 }
2162 for (int i = 0; i < rank - 2; i++) {
2163 dnums->add_output_spatial_dimensions(-1);
2164 }
2165 for (int i = 0; i < rank; i++) {
2166 char c = out[i];
2167 if (c == 'b') {
2168 dnums->set_output_batch_dimension(i);
2169 } else if (c == 'f') {
2170 dnums->set_output_feature_dimension(i);
2171 } else if (c < '0' + rank && c >= '0') {
2172 dnums->set_output_spatial_dimensions(c - '0', i);
2173 } else {
2174 return TokenError(
2175 Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
2176 }
2177 }
2178 }
2179
2180 lexer_.Lex();
2181 return true;
2182 }
2183
2184 // ::= '{' ranges '}'
2185 // ::= /*empty*/
2186 // ::= range (',' range)*
2187 // range ::= '[' start ':' limit (':' stride)? ']'
2188 //
2189 // The slice ranges are printed as:
2190 //
2191 // {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
2192 //
2193 // This function extracts the starts, limits, and strides as 3 vectors to the
2194 // result. If stride is not present, stride is 1. For example, if the slice
2195 // ranges is printed as:
2196 //
2197 // {[2:3:4], [5:6:7], [8:9]}
2198 //
2199 // The parsed result will be:
2200 //
2201 // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
2202 //
ParseSliceRanges(SliceRanges * result)2203 bool HloParser::ParseSliceRanges(SliceRanges* result) {
2204 if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
2205 return false;
2206 }
2207 std::vector<std::vector<int64>> ranges;
2208 if (lexer_.GetKind() == TokKind::kRbrace) {
2209 // empty
2210 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
2211 }
2212 do {
2213 LocTy loc = lexer_.GetLoc();
2214 ranges.emplace_back();
2215 if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
2216 &ranges.back())) {
2217 return false;
2218 }
2219 const auto& range = ranges.back();
2220 if (range.size() != 2 && range.size() != 3) {
2221 return Error(loc, Printf("expects [start:limit:step] or [start:limit], "
2222 "but sees %ld elements.",
2223 range.size()));
2224 }
2225 } while (EatIfPresent(TokKind::kComma));
2226
2227 for (const auto& range : ranges) {
2228 result->starts.push_back(range[0]);
2229 result->limits.push_back(range[1]);
2230 result->strides.push_back(range.size() == 3 ? range[2] : 1);
2231 }
2232 return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
2233 }
2234
2235 // int64list ::= start int64_elements end
2236 // int64_elements
2237 // ::= /*empty*/
2238 // ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)2239 bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
2240 const TokKind delim,
2241 std::vector<int64>* result) {
2242 if (!ParseToken(start, StrCat("expects an int64 list starting with ",
2243 TokKindToString(start)))) {
2244 return false;
2245 }
2246 if (lexer_.GetKind() == end) {
2247 // empty
2248 } else {
2249 do {
2250 int64 i;
2251 if (!ParseInt64(&i)) {
2252 return false;
2253 }
2254 result->push_back(i);
2255 } while (EatIfPresent(delim));
2256 }
2257 return ParseToken(
2258 end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
2259 }
2260
2261 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)2262 bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
2263 if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
2264 return false;
2265 }
2266 *shape_loc = lexer_.GetLoc();
2267 return ParseShape(shape);
2268 }
2269
CanBeParamListToShape()2270 bool HloParser::CanBeParamListToShape() {
2271 return lexer_.GetKind() == TokKind::kLparen;
2272 }
2273
2274 // param_list ::= '(' param_list1 ')'
2275 // param_list1
2276 // ::= /*empty*/
2277 // ::= param (',' param)*
2278 // param ::= name shape
ParseParamList()2279 bool HloParser::ParseParamList() {
2280 if (!ParseToken(TokKind::kLparen,
2281 "expects '(' at the beginning of param list")) {
2282 return false;
2283 }
2284
2285 if (lexer_.GetKind() == TokKind::kRparen) {
2286 // empty
2287 } else {
2288 do {
2289 Shape shape;
2290 string name;
2291 if (!ParseName(&name) || !ParseShape(&shape)) {
2292 return false;
2293 }
2294 } while (EatIfPresent(TokKind::kComma));
2295 }
2296 return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
2297 }
2298
2299 // shape ::= shape_val_
2300 // shape ::= '(' tuple_elements ')'
2301 // tuple_elements
2302 // ::= /*empty*/
2303 // ::= shape (',' shape)*
ParseShape(Shape * result)2304 bool HloParser::ParseShape(Shape* result) {
2305 if (EatIfPresent(TokKind::kLparen)) { // Tuple
2306 std::vector<Shape> shapes;
2307 if (lexer_.GetKind() == TokKind::kRparen) {
2308 /*empty*/
2309 } else {
2310 // shape (',' shape)*
2311 do {
2312 shapes.emplace_back();
2313 if (!ParseShape(&shapes.back())) {
2314 return false;
2315 }
2316 } while (EatIfPresent(TokKind::kComma));
2317 }
2318 *result = ShapeUtil::MakeTupleShape(shapes);
2319 return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
2320 }
2321
2322 if (lexer_.GetKind() != TokKind::kShape) {
2323 return TokenError("expects shape");
2324 }
2325 *result = lexer_.GetShapeVal();
2326 lexer_.Lex();
2327 return true;
2328 }
2329
CanBeShape()2330 bool HloParser::CanBeShape() {
2331 // A non-tuple shape starts with a kShape token; a tuple shape starts with
2332 // '('.
2333 return lexer_.GetKind() == TokKind::kShape ||
2334 lexer_.GetKind() == TokKind::kLparen;
2335 }
2336
ParseName(string * result)2337 bool HloParser::ParseName(string* result) {
2338 VLOG(1) << "ParseName";
2339 if (lexer_.GetKind() != TokKind::kIdent &&
2340 lexer_.GetKind() != TokKind::kName) {
2341 return TokenError("expects name");
2342 }
2343 *result = lexer_.GetStrVal();
2344 lexer_.Lex();
2345 return true;
2346 }
2347
ParseAttributeName(string * result)2348 bool HloParser::ParseAttributeName(string* result) {
2349 if (lexer_.GetKind() != TokKind::kAttributeName) {
2350 return TokenError("expects attribute name");
2351 }
2352 *result = lexer_.GetStrVal();
2353 lexer_.Lex();
2354 return true;
2355 }
2356
ParseString(string * result)2357 bool HloParser::ParseString(string* result) {
2358 VLOG(1) << "ParseString";
2359 if (lexer_.GetKind() != TokKind::kString) {
2360 return TokenError("expects string");
2361 }
2362 *result = lexer_.GetStrVal();
2363 lexer_.Lex();
2364 return true;
2365 }
2366
ParseDxD(const string & name,std::vector<int64> * result)2367 bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
2368 LocTy loc = lexer_.GetLoc();
2369 if (!result->empty()) {
2370 return Error(loc,
2371 Printf("sub-attribute '%s=' already exists", name.c_str()));
2372 }
2373 // 1D
2374 if (lexer_.GetKind() == TokKind::kInt) {
2375 int64 number;
2376 if (!ParseInt64(&number)) {
2377 return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str()));
2378 }
2379 result->push_back(number);
2380 return true;
2381 }
2382 // 2D or higher.
2383 if (lexer_.GetKind() == TokKind::kDxD) {
2384 string str = lexer_.GetStrVal();
2385 if (!SplitAndParseAsInts(str, 'x', result)) {
2386 return Error(loc,
2387 Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
2388 }
2389 lexer_.Lex();
2390 return true;
2391 }
2392 return TokenError("expects token type kInt or kDxD");
2393 }
2394
ParseWindowPad(std::vector<std::vector<int64>> * pad)2395 bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
2396 LocTy loc = lexer_.GetLoc();
2397 if (!pad->empty()) {
2398 return Error(loc, "sub-attribute 'pad=' already exists");
2399 }
2400 if (lexer_.GetKind() != TokKind::kPad) {
2401 return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
2402 }
2403 string str = lexer_.GetStrVal();
2404 std::vector<string> padding_str = Split(str, 'x');
2405 for (int i = 0; i < padding_str.size(); i++) {
2406 std::vector<int64> low_high;
2407 if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
2408 low_high.size() != 2) {
2409 return Error(loc,
2410 "expects padding_low and padding_high separated by '_'");
2411 }
2412 pad->push_back(low_high);
2413 }
2414 lexer_.Lex();
2415 return true;
2416 }
2417
2418 // This is the inverse xla::ToString(PaddingConfig). The padding config string
2419 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
2420 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
2421 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)2422 bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
2423 if (lexer_.GetKind() != TokKind::kPad) {
2424 return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
2425 }
2426 LocTy loc = lexer_.GetLoc();
2427 string str = lexer_.GetStrVal();
2428 std::vector<string> padding_str = Split(str, 'x');
2429 for (const auto& padding_dim_str : padding_str) {
2430 std::vector<int64> padding_dim;
2431 if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
2432 (padding_dim.size() != 2 && padding_dim.size() != 3)) {
2433 return Error(loc,
2434 "expects padding config pattern like 'low_high_interior' or "
2435 "'low_high'");
2436 }
2437 auto* dim = padding->add_dimensions();
2438 dim->set_edge_padding_low(padding_dim[0]);
2439 dim->set_edge_padding_high(padding_dim[1]);
2440 dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
2441 }
2442 lexer_.Lex();
2443 return true;
2444 }
2445
2446 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)2447 bool HloParser::ParseMetadata(OpMetadata* metadata) {
2448 std::unordered_map<string, AttrConfig> attrs;
2449 optional<string> op_type;
2450 optional<string> op_name;
2451 optional<string> source_file;
2452 optional<int32> source_line;
2453 attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
2454 attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
2455 attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
2456 attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
2457 if (!ParseSubAttributes(attrs)) {
2458 return false;
2459 }
2460 if (op_type) {
2461 metadata->set_op_type(*op_type);
2462 }
2463 if (op_name) {
2464 metadata->set_op_name(*op_name);
2465 }
2466 if (source_file) {
2467 metadata->set_source_file(*source_file);
2468 }
2469 if (source_line) {
2470 metadata->set_source_line(*source_line);
2471 }
2472 return true;
2473 }
2474
ParseOpcode(HloOpcode * result)2475 bool HloParser::ParseOpcode(HloOpcode* result) {
2476 VLOG(1) << "ParseOpcode";
2477 if (lexer_.GetKind() != TokKind::kIdent) {
2478 return TokenError("expects opcode");
2479 }
2480 string val = lexer_.GetStrVal();
2481 auto status_or_result = StringToHloOpcode(val);
2482 if (!status_or_result.ok()) {
2483 return TokenError(
2484 Printf("expects opcode but sees: %s, error: %s", val.c_str(),
2485 status_or_result.status().error_message().c_str()));
2486 }
2487 *result = status_or_result.ValueOrDie();
2488 lexer_.Lex();
2489 return true;
2490 }
2491
ParseFftType(FftType * result)2492 bool HloParser::ParseFftType(FftType* result) {
2493 VLOG(1) << "ParseFftType";
2494 if (lexer_.GetKind() != TokKind::kIdent) {
2495 return TokenError("expects fft type");
2496 }
2497 string val = lexer_.GetStrVal();
2498 if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
2499 return TokenError(Printf("expects fft type but sees: %s", val.c_str()));
2500 }
2501 lexer_.Lex();
2502 return true;
2503 }
2504
ParseFusionKind(HloInstruction::FusionKind * result)2505 bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
2506 VLOG(1) << "ParseFusionKind";
2507 if (lexer_.GetKind() != TokKind::kIdent) {
2508 return TokenError("expects fusion kind");
2509 }
2510 string val = lexer_.GetStrVal();
2511 auto status_or_result = StringToFusionKind(val);
2512 if (!status_or_result.ok()) {
2513 return TokenError(
2514 Printf("expects fusion kind but sees: %s, error: %s", val.c_str(),
2515 status_or_result.status().error_message().c_str()));
2516 }
2517 *result = status_or_result.ValueOrDie();
2518 lexer_.Lex();
2519 return true;
2520 }
2521
ParseRandomDistribution(RandomDistribution * result)2522 bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
2523 VLOG(1) << "ParseRandomDistribution";
2524 if (lexer_.GetKind() != TokKind::kIdent) {
2525 return TokenError("expects random distribution");
2526 }
2527 string val = lexer_.GetStrVal();
2528 auto status_or_result = StringToRandomDistribution(val);
2529 if (!status_or_result.ok()) {
2530 return TokenError(
2531 Printf("expects random distribution but sees: %s, error: %s",
2532 val.c_str(), status_or_result.status().error_message().c_str()));
2533 }
2534 *result = status_or_result.ValueOrDie();
2535 lexer_.Lex();
2536 return true;
2537 }
2538
ParseInt64(int64 * result)2539 bool HloParser::ParseInt64(int64* result) {
2540 VLOG(1) << "ParseInt64";
2541 if (lexer_.GetKind() != TokKind::kInt) {
2542 return TokenError("expects integer");
2543 }
2544 *result = lexer_.GetInt64Val();
2545 lexer_.Lex();
2546 return true;
2547 }
2548
ParseDouble(double * result)2549 bool HloParser::ParseDouble(double* result) {
2550 switch (lexer_.GetKind()) {
2551 case TokKind::kDecimal:
2552 *result = lexer_.GetDecimalVal();
2553 break;
2554 case TokKind::kInt:
2555 *result = static_cast<double>(lexer_.GetInt64Val());
2556 break;
2557 case TokKind::kw_nan:
2558 *result = std::numeric_limits<double>::quiet_NaN();
2559 break;
2560 case TokKind::kw_inf:
2561 *result = std::numeric_limits<double>::infinity();
2562 break;
2563 case TokKind::kNegInf:
2564 *result = -std::numeric_limits<double>::infinity();
2565 break;
2566 default:
2567 return TokenError("expects decimal or integer");
2568 }
2569 lexer_.Lex();
2570 return true;
2571 }
2572
ParseBool(bool * result)2573 bool HloParser::ParseBool(bool* result) {
2574 if (lexer_.GetKind() != TokKind::kw_true &&
2575 lexer_.GetKind() != TokKind::kw_false) {
2576 return TokenError("expects true or false");
2577 }
2578 *result = lexer_.GetKind() == TokKind::kw_true;
2579 lexer_.Lex();
2580 return true;
2581 }
2582
ParseToken(TokKind kind,const string & msg)2583 bool HloParser::ParseToken(TokKind kind, const string& msg) {
2584 VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg;
2585 if (lexer_.GetKind() != kind) {
2586 return TokenError(msg);
2587 }
2588 lexer_.Lex();
2589 return true;
2590 }
2591
EatIfPresent(TokKind kind)2592 bool HloParser::EatIfPresent(TokKind kind) {
2593 if (lexer_.GetKind() != kind) {
2594 return false;
2595 }
2596 lexer_.Lex();
2597 return true;
2598 }
2599
AddInstruction(const string & name,HloInstruction * instruction,LocTy name_loc)2600 bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
2601 LocTy name_loc) {
2602 auto result = instruction_pool_.insert({name, {instruction, name_loc}});
2603 if (!result.second) {
2604 Error(name_loc, StrCat("instruction already exists: ", name));
2605 return Error(/*loc=*/result.first->second.second,
2606 "instruction previously defined here");
2607 }
2608 return true;
2609 }
2610
AddComputation(const string & name,HloComputation * computation,LocTy name_loc)2611 bool HloParser::AddComputation(const string& name, HloComputation* computation,
2612 LocTy name_loc) {
2613 auto result = computation_pool_.insert({name, {computation, name_loc}});
2614 if (!result.second) {
2615 Error(name_loc, StrCat("computation already exists: ", name));
2616 return Error(/*loc=*/result.first->second.second,
2617 "computation previously defined here");
2618 }
2619 return true;
2620 }
2621
2622 } // namespace
2623
Parse(StringPiece str,const HloModuleConfig & config)2624 StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
2625 const HloModuleConfig& config) {
2626 HloParser parser(str, config);
2627 if (!parser.Run()) {
2628 return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
2629 }
2630 return parser.ConsumeHloModule();
2631 }
2632
Parse(StringPiece str)2633 StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
2634 HloModuleConfig config;
2635 return Parse(str, config);
2636 }
2637
2638 } // namespace tools
2639 } // namespace xla
2640