1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/lib/gtl/map_util.h"
23 #include "tensorflow/core/lib/strings/strcat.h"
24 #include "tensorflow/core/lib/strings/stringprintf.h"
25 
26 namespace xla {
27 namespace tools {
28 
29 namespace {
30 
31 using tensorflow::StringPiece;
32 using tensorflow::gtl::optional;
33 using tensorflow::str_util::Split;
34 using tensorflow::str_util::SplitAndParseAsInts;
35 using tensorflow::strings::Printf;
36 using tensorflow::strings::StrAppend;
37 using tensorflow::strings::StrCat;
38 
39 const double kF16max = 65504;
40 
41 // Parser for the HloModule::ToString() format text.
42 class HloParser {
43  public:
44   using LocTy = HloLexer::LocTy;
45 
HloParser(StringPiece str,const HloModuleConfig & config)46   explicit HloParser(StringPiece str, const HloModuleConfig& config)
47       : lexer_(str), config_(config) {}
48 
49   // Runs the parser. Returns false if an error occurred.
50   bool Run();
51 
52   // Returns the parsed HloModule.
ConsumeHloModule()53   std::unique_ptr<HloModule> ConsumeHloModule() { return std::move(module_); }
54 
55   // Returns the error information.
GetError() const56   string GetError() const { return tensorflow::str_util::Join(error_, "\n"); }
57 
58  private:
59   // ParseXXX returns false if an error occurred.
60   bool ParseHloModule();
61   bool ParseComputations();
62   bool ParseComputation(HloComputation** entry_computation);
63   bool ParseInstructionList(HloComputation::Builder* builder,
64                             string* root_name);
65   bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
66   bool ParseControlPredecessors(HloInstruction* instruction);
67   bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
68   bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
69   bool ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
70                             const Shape& shape);
71   bool ParseDenseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
72   bool ParseSparseLiteral(std::unique_ptr<Literal>* literal,
73                           const Shape& shape);
74   template <typename LiteralNativeT>
75   bool ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
76                                 const Shape& shape);
77 
78   // Sets the sub-value of literal at the given index to the given value. The
79   // literal's shape must have the default layout.
80   bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal);
81   bool SetValueInLiteral(double value, int64 linear_index, Literal* literal);
82   bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal);
83   template <typename LiteralNativeT, typename ParsedElemT>
84   bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
85                                Literal* literal);
86 
87   bool ParseOperands(std::vector<HloInstruction*>* operands);
88   // Fills parsed operands into 'operands' and expects a certain number of
89   // operands.
90   bool ParseOperands(std::vector<HloInstruction*>* operands,
91                      const int expected_size);
92 
93   // Describes the start, limit, and stride on every dimension of the operand
94   // being sliced.
95   struct SliceRanges {
96     std::vector<int64> starts;
97     std::vector<int64> limits;
98     std::vector<int64> strides;
99   };
100 
101   // Types of attributes.
102   enum class AttrTy {
103     kInt64,
104     kInt32,
105     kFloat,
106     kString,
107     kBracedInt64List,
108     kHloComputation,
109     kFftType,
110     kWindow,
111     kConvolutionDimensionNumbers,
112     kSharding,
113     kInstructionList,
114     kSliceRanges,
115     kPaddingConfig,
116     kMetadata,
117     kFusionKind,
118     kDistribution,
119   };
120 
121   struct AttrConfig {
122     bool required;     // whether it's required or optional
123     AttrTy attr_type;  // what type it is
124     void* result;      // where to store the parsed result.
125   };
126 
127   // attributes ::= (',' attribute)*
128   //
129   // Parses attributes given names and configs of the attributes. Each parsed
130   // result is passed back through the result pointer in corresponding
131   // AttrConfig. Note that the result pointer must point to a optional<T> typed
132   // variable which outlives this function. Returns false on error. You should
133   // not use the any of the results if this function failed.
134   //
135   // Example usage:
136   //
137   //  std::unordered_map<string, AttrConfig> attrs;
138   //  optional<int64> foo;
139   //  attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
140   //  optional<Window> bar;
141   //  attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
142   //  if (!ParseAttributes(attrs)) {
143   //    return false; // Do not use 'foo' 'bar' if failed.
144   //  }
145   //  // Do something with 'bar'.
146   //  if (foo) { // If attr foo is seen, do something with 'foo'. }
147   //
148   bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
149 
150   // sub_attributes ::= '{' (','? attribute)* '}'
151   //
152   // Usage is the same as ParseAttributes. See immediately above.
153   bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs);
154 
155   // Parses one attribute. If it has already been seen, return error. Returns
156   // true and adds to seen_attrs on success.
157   //
158   // Do not call this except in ParseAttributes or ParseSubAttributes.
159   bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs,
160                             std::unordered_set<string>* seen_attrs);
161 
162   // Parses a name and finds the corresponding hlo computation.
163   bool ParseComputationName(HloComputation** value);
164   // Parses a list of names and finds the corresponding hlo instructions.
165   bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
166   bool ParseWindow(Window* window);
167   bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
168   bool ParsePaddingConfig(PaddingConfig* padding);
169   bool ParseMetadata(OpMetadata* metadata);
170   bool ParseSharding(OpSharding* sharding);
171   bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
172 
173   // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
174   bool ParseDxD(const string& name, std::vector<int64>* result);
175   // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
176   bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
177 
178   bool ParseSliceRanges(SliceRanges* result);
179   bool ParseInt64List(const TokKind start, const TokKind end,
180                       const TokKind delim, std::vector<int64>* result);
181 
182   bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
183   bool ParseParamList();
184   bool ParseName(string* result);
185   bool ParseAttributeName(string* result);
186   bool ParseString(string* result);
187   bool ParseShape(Shape* result);
188   bool ParseOpcode(HloOpcode* result);
189   bool ParseFftType(FftType* result);
190   bool ParseFusionKind(HloInstruction::FusionKind* result);
191   bool ParseRandomDistribution(RandomDistribution* result);
192   bool ParseInt64(int64* result);
193   bool ParseDouble(double* result);
194   bool ParseBool(bool* result);
195   bool ParseToken(TokKind kind, const string& msg);
196 
197   // Returns true if the current token is the beginning of a shape.
198   bool CanBeShape();
199   // Returns true if the current token is the beginning of a
200   // param_list_to_shape.
201   bool CanBeParamListToShape();
202 
203   // Logs the current parsing line and the given message. Always returns false.
204   bool TokenError(StringPiece msg);
205   bool Error(LocTy loc, StringPiece msg);
206 
207   // If the current token is 'kind', eats it (i.e. lexes the next token) and
208   // returns true.
209   bool EatIfPresent(TokKind kind);
210   // Parses a shape, and returns true if the result is compatible with the given
211   // shape.
212   bool EatShapeAndCheckCompatible(const Shape& shape);
213 
214   // Adds the instruction to the pool. Returns false and emits an error if the
215   // instruction already exists.
216   bool AddInstruction(const string& name, HloInstruction* instruction,
217                       LocTy name_loc);
218   // Adds the computation to the pool. Returns false and emits an error if the
219   // computation already exists.
220   bool AddComputation(const string& name, HloComputation* computation,
221                       LocTy name_loc);
222 
223   // The map from the instruction/computation name to the
224   // instruction/computation itself and it's location. This does not own the
225   // pointers.
226   std::unordered_map<string, std::pair<HloInstruction*, LocTy>>
227       instruction_pool_;
228   std::unordered_map<string, std::pair<HloComputation*, LocTy>>
229       computation_pool_;
230 
231   HloLexer lexer_;
232   std::unique_ptr<HloModule> module_;
233   std::vector<std::unique_ptr<HloComputation>> computations_;
234   const HloModuleConfig config_;
235   std::vector<string> error_;
236 };
237 
Error(LocTy loc,StringPiece msg)238 bool HloParser::Error(LocTy loc, StringPiece msg) {
239   auto line_col = lexer_.GetLineAndColumn(loc);
240   const unsigned line = line_col.first;
241   const unsigned col = line_col.second;
242   std::vector<string> error_lines;
243   error_lines.push_back(
244       StrCat("was parsing ", line, ":", col, ": error: ", msg));
245   error_lines.push_back(lexer_.GetLine(loc).ToString());
246   error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^"));
247 
248   error_.push_back(tensorflow::str_util::Join(error_lines, "\n"));
249   VLOG(1) << "Error: " << error_.back();
250   return false;
251 }
252 
TokenError(StringPiece msg)253 bool HloParser::TokenError(StringPiece msg) {
254   return Error(lexer_.GetLoc(), msg);
255 }
256 
Run()257 bool HloParser::Run() {
258   lexer_.Lex();
259   return ParseHloModule();
260 }
261 
262 // ::= 'HloModule' name computations
ParseHloModule()263 bool HloParser::ParseHloModule() {
264   if (lexer_.GetKind() != TokKind::kw_HloModule) {
265     return TokenError("expects HloModule");
266   }
267   // Eat 'HloModule'
268   lexer_.Lex();
269 
270   string name;
271   if (!ParseName(&name)) {
272     return false;
273   }
274 
275   module_ = MakeUnique<HloModule>(name, config_);
276 
277   return ParseComputations();
278 }
279 
280 // computations ::= (computation)+
ParseComputations()281 bool HloParser::ParseComputations() {
282   HloComputation* entry_computation = nullptr;
283   do {
284     if (!ParseComputation(&entry_computation)) {
285       return false;
286     }
287   } while (lexer_.GetKind() != TokKind::kEof);
288 
289   for (int i = 0; i < computations_.size(); i++) {
290     // If entry_computation is not nullptr, it means the computation it pointed
291     // to is marked with "ENTRY"; otherwise, no computation is marked with
292     // "ENTRY", and we use the last computation as the entry computation. We
293     // add the non-entry computations as embedded computations to the module.
294     if ((entry_computation != nullptr &&
295          computations_[i].get() != entry_computation) ||
296         (entry_computation == nullptr && i != computations_.size() - 1)) {
297       module_->AddEmbeddedComputation(std::move(computations_[i]));
298       continue;
299     }
300     auto computation =
301         module_->AddEntryComputation(std::move(computations_[i]));
302     // The parameters and result layouts were set to default layout. Here we
303     // set the layouts to what the hlo text says.
304     for (int p = 0; p < computation->num_parameters(); p++) {
305       const Shape& param_shape = computation->parameter_instruction(p)->shape();
306       if (param_shape.has_layout()) {
307         module_->mutable_entry_computation_layout()
308             ->mutable_parameter_layout(p)
309             ->ResetLayout(param_shape.layout());
310       }
311     }
312     const Shape& result_shape = computation->root_instruction()->shape();
313     if (result_shape.has_layout()) {
314       module_->mutable_entry_computation_layout()
315           ->mutable_result_layout()
316           ->ResetLayout(result_shape.layout());
317     }
318   }
319 
320   return true;
321 }
322 
323 // computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list
ParseComputation(HloComputation ** entry_computation)324 bool HloParser::ParseComputation(HloComputation** entry_computation) {
325   LocTy maybe_entry_loc = lexer_.GetLoc();
326   const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY);
327 
328   string name;
329   LocTy name_loc = lexer_.GetLoc();
330   if (!ParseName(&name)) {
331     return false;
332   }
333   auto builder = MakeUnique<HloComputation::Builder>(name);
334 
335   LocTy shape_loc = nullptr;
336   Shape shape;
337   if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) {
338     return false;
339   }
340 
341   string root_name;
342   if (!ParseInstructionList(builder.get(), &root_name)) {
343     return false;
344   }
345 
346   std::pair<HloInstruction*, LocTy>* root_node =
347       tensorflow::gtl::FindOrNull(instruction_pool_, root_name);
348   // This means some instruction was marked as ROOT but we didn't find it in the
349   // pool, which should not happen.
350   if (!root_name.empty() && root_node == nullptr) {
351     LOG(FATAL) << "instruction " << root_name
352                << " was marked as ROOT but the parser has not seen it before";
353   }
354 
355   HloInstruction* root = root_node == nullptr ? nullptr : root_node->first;
356   // Now root can be either an existing instruction or a nullptr. If it's a
357   // nullptr, the implementation of Builder will set the last instruction as
358   // root instruction.
359   computations_.emplace_back(builder->Build(root));
360   HloComputation* computation = computations_.back().get();
361 
362   if (!root) {
363     root = computation->root_instruction();
364   } else {
365     CHECK_EQ(root, computation->root_instruction());
366   }
367 
368   // If param_list_to_shape was present, check compatibility.
369   if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
370     return Error(
371         shape_loc,
372         StrCat("Shape of computation ", name, ", ",
373                ShapeUtil::HumanString(shape),
374                ", is not compatible with that of its root instruction ",
375                root_name, ", ", ShapeUtil::HumanString(root->shape())));
376   }
377 
378   if (is_entry_computation) {
379     if (*entry_computation != nullptr) {
380       return Error(maybe_entry_loc, "expects only one ENTRY");
381     }
382     *entry_computation = computation;
383   }
384 
385   return AddComputation(name, computation, name_loc);
386 }
387 
388 // instruction_list ::= '{' instruction_list1 '}'
389 // instruction_list1 ::= (instruction)+
ParseInstructionList(HloComputation::Builder * builder,string * root_name)390 bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
391                                      string* root_name) {
392   if (!ParseToken(TokKind::kLbrace,
393                   "expects '{' at the beginning of instruction list.")) {
394     return false;
395   }
396   do {
397     if (!ParseInstruction(builder, root_name)) {
398       return false;
399     }
400   } while (lexer_.GetKind() != TokKind::kRbrace);
401   return ParseToken(TokKind::kRbrace,
402                     "expects '}' at the end of instruction list.");
403 }
404 
405 // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
ParseInstruction(HloComputation::Builder * builder,string * root_name)406 bool HloParser::ParseInstruction(HloComputation::Builder* builder,
407                                  string* root_name) {
408   string name;
409   Shape shape;
410   HloOpcode opcode;
411   std::vector<HloInstruction*> operands;
412 
413   LocTy maybe_root_loc = lexer_.GetLoc();
414   bool is_root = EatIfPresent(TokKind::kw_ROOT);
415 
416   const LocTy name_loc = lexer_.GetLoc();
417   if (!ParseName(&name) ||
418       !ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
419       !ParseShape(&shape) || !ParseOpcode(&opcode)) {
420     return false;
421   }
422 
423   if (is_root) {
424     if (!root_name->empty()) {
425       return Error(maybe_root_loc, "one computation should have only one ROOT");
426     }
427     *root_name = name;
428   }
429 
430   // Add optional attributes.
431   std::unordered_map<string, AttrConfig> attrs;
432   optional<OpSharding> sharding;
433   attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
434   optional<std::vector<HloInstruction*>> predecessors;
435   attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
436                                    &predecessors};
437   optional<OpMetadata> metadata;
438   attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata};
439 
440   HloInstruction* instruction;
441   switch (opcode) {
442     case HloOpcode::kParameter: {
443       int64 parameter_number;
444       if (!ParseToken(TokKind::kLparen,
445                       "expects '(' before parameter number") ||
446           !ParseInt64(&parameter_number) ||
447           !ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
448           !ParseAttributes(attrs)) {
449         return false;
450       }
451       instruction = builder->AddInstruction(
452           HloInstruction::CreateParameter(parameter_number, shape, name));
453       break;
454     }
455     case HloOpcode::kConstant: {
456       std::unique_ptr<Literal> literal;
457       if (!ParseToken(TokKind::kLparen,
458                       "expects '(' before constant literal") ||
459           !ParseLiteral(&literal, shape) ||
460           !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
461           !ParseAttributes(attrs)) {
462         return false;
463       }
464       instruction = builder->AddInstruction(
465           HloInstruction::CreateConstant(std::move(literal)));
466       break;
467     }
468     // Unary ops.
469     case HloOpcode::kAbs:
470     case HloOpcode::kRoundNearestAfz:
471     case HloOpcode::kBitcast:
472     case HloOpcode::kCeil:
473     case HloOpcode::kCopy:
474     case HloOpcode::kCos:
475     case HloOpcode::kExp:
476     case HloOpcode::kImag:
477     case HloOpcode::kIsFinite:
478     case HloOpcode::kFloor:
479     case HloOpcode::kLog:
480     case HloOpcode::kNot:
481     case HloOpcode::kNegate:
482     case HloOpcode::kReal:
483     case HloOpcode::kSign:
484     case HloOpcode::kSin:
485     case HloOpcode::kSort:
486     case HloOpcode::kTanh: {
487       if (!ParseOperands(&operands, /*expected_size=*/1) ||
488           !ParseAttributes(attrs)) {
489         return false;
490       }
491       instruction = builder->AddInstruction(
492           HloInstruction::CreateUnary(shape, opcode, operands[0]));
493       break;
494     }
495     // Binary ops.
496     case HloOpcode::kAdd:
497     case HloOpcode::kDivide:
498     case HloOpcode::kMultiply:
499     case HloOpcode::kSubtract:
500     case HloOpcode::kAtan2:
501     case HloOpcode::kComplex:
502     case HloOpcode::kEq:
503     case HloOpcode::kGe:
504     case HloOpcode::kGt:
505     case HloOpcode::kLe:
506     case HloOpcode::kLt:
507     case HloOpcode::kNe:
508     case HloOpcode::kMaximum:
509     case HloOpcode::kMinimum:
510     case HloOpcode::kPower:
511     case HloOpcode::kRemainder:
512     case HloOpcode::kAnd:
513     case HloOpcode::kOr:
514     case HloOpcode::kShiftLeft:
515     case HloOpcode::kShiftRightArithmetic:
516     case HloOpcode::kShiftRightLogical: {
517       if (!ParseOperands(&operands, /*expected_size=*/2) ||
518           !ParseAttributes(attrs)) {
519         return false;
520       }
521       instruction = builder->AddInstruction(HloInstruction::CreateBinary(
522           shape, opcode, operands[0], operands[1]));
523       break;
524     }
525     // Ternary ops.
526     case HloOpcode::kClamp:
527     case HloOpcode::kSelect: {
528       if (!ParseOperands(&operands, /*expected_size=*/3) ||
529           !ParseAttributes(attrs)) {
530         return false;
531       }
532       instruction = builder->AddInstruction(HloInstruction::CreateTernary(
533           shape, opcode, operands[0], operands[1], operands[2]));
534       break;
535     }
536     // Other supported ops.
537     case HloOpcode::kConvert: {
538       if (!ParseOperands(&operands, /*expected_size=*/1) ||
539           !ParseAttributes(attrs)) {
540         return false;
541       }
542       instruction = builder->AddInstruction(
543           HloInstruction::CreateConvert(shape, operands[0]));
544       break;
545     }
546     case HloOpcode::kBitcastConvert: {
547       if (!ParseOperands(&operands, /*expected_size=*/1) ||
548           !ParseAttributes(attrs)) {
549         return false;
550       }
551       instruction = builder->AddInstruction(
552           HloInstruction::CreateBitcastConvert(shape, operands[0]));
553       break;
554     }
555     case HloOpcode::kCrossReplicaSum: {
556       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
557         return false;
558       }
559       instruction = builder->AddInstruction(
560           HloInstruction::CreateCrossReplicaSum(shape, operands));
561       break;
562     }
563     case HloOpcode::kReshape: {
564       if (!ParseOperands(&operands, /*expected_size=*/1) ||
565           !ParseAttributes(attrs)) {
566         return false;
567       }
568       instruction = builder->AddInstruction(
569           HloInstruction::CreateReshape(shape, operands[0]));
570       break;
571     }
572     case HloOpcode::kTuple: {
573       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
574         return false;
575       }
576       instruction =
577           builder->AddInstruction(HloInstruction::CreateTuple(operands));
578       break;
579     }
580     case HloOpcode::kWhile: {
581       optional<HloComputation*> condition;
582       optional<HloComputation*> body;
583       attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
584                             &condition};
585       attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
586       if (!ParseOperands(&operands, /*expected_size=*/1) ||
587           !ParseAttributes(attrs)) {
588         return false;
589       }
590       instruction = builder->AddInstruction(HloInstruction::CreateWhile(
591           shape, *condition, *body, /*init=*/operands[0]));
592       break;
593     }
594     case HloOpcode::kRecv: {
595       optional<int64> channel_id;
596       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
597       if (!ParseOperands(&operands, /*expected_size=*/0) ||
598           !ParseAttributes(attrs)) {
599         return false;
600       }
601       instruction = builder->AddInstruction(
602           HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
603       break;
604     }
605     case HloOpcode::kRecvDone: {
606       optional<int64> channel_id;
607       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
608       if (!ParseOperands(&operands, /*expected_size=*/1) ||
609           !ParseAttributes(attrs)) {
610         return false;
611       }
612       if (channel_id != operands[0]->channel_id()) {
613         return false;
614       }
615       instruction =
616           builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
617       break;
618     }
619     case HloOpcode::kSend: {
620       optional<int64> channel_id;
621       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
622       if (!ParseOperands(&operands, /*expected_size=*/1) ||
623           !ParseAttributes(attrs)) {
624         return false;
625       }
626       instruction = builder->AddInstruction(
627           HloInstruction::CreateSend(operands[0], *channel_id));
628       break;
629     }
630     case HloOpcode::kSendDone: {
631       optional<int64> channel_id;
632       attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
633       if (!ParseOperands(&operands, /*expected_size=*/1) ||
634           !ParseAttributes(attrs)) {
635         return false;
636       }
637       if (channel_id != operands[0]->channel_id()) {
638         return false;
639       }
640       instruction =
641           builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
642       break;
643     }
644     case HloOpcode::kGetTupleElement: {
645       optional<int64> index;
646       attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
647       if (!ParseOperands(&operands, /*expected_size=*/1) ||
648           !ParseAttributes(attrs)) {
649         return false;
650       }
651       instruction = builder->AddInstruction(
652           HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
653       break;
654     }
655     case HloOpcode::kCall: {
656       optional<HloComputation*> to_apply;
657       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
658                            &to_apply};
659       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
660         return false;
661       }
662       instruction = builder->AddInstruction(
663           HloInstruction::CreateCall(shape, operands, *to_apply));
664       break;
665     }
666     case HloOpcode::kReduceWindow: {
667       optional<HloComputation*> reduce_computation;
668       optional<Window> window;
669       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
670       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
671                            &reduce_computation};
672       if (!ParseOperands(&operands, /*expected_size=*/2) ||
673           !ParseAttributes(attrs)) {
674         return false;
675       }
676       if (!window) {
677         window.emplace();
678       }
679       instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
680           shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
681           *reduce_computation));
682       break;
683     }
684     case HloOpcode::kConvolution: {
685       optional<Window> window;
686       optional<ConvolutionDimensionNumbers> dnums;
687       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
688       attrs["dim_labels"] = {/*required=*/true,
689                              AttrTy::kConvolutionDimensionNumbers, &dnums};
690       if (!ParseOperands(&operands, /*expected_size=*/2) ||
691           !ParseAttributes(attrs)) {
692         return false;
693       }
694       if (!window) {
695         window.emplace();
696       }
697       instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
698           shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
699       break;
700     }
701     case HloOpcode::kFft: {
702       optional<FftType> fft_type;
703       optional<std::vector<int64>> fft_length;
704       attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type};
705       attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List,
706                              &fft_length};
707       if (!ParseOperands(&operands, /*expected_size=*/1) ||
708           !ParseAttributes(attrs)) {
709         return false;
710       }
711       instruction = builder->AddInstruction(HloInstruction::CreateFft(
712           shape, operands[0], *fft_type, *fft_length));
713       break;
714     }
715     case HloOpcode::kBroadcast: {
716       optional<std::vector<int64>> broadcast_dimensions;
717       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
718                              &broadcast_dimensions};
719       if (!ParseOperands(&operands, /*expected_size=*/1) ||
720           !ParseAttributes(attrs)) {
721         return false;
722       }
723       instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
724           shape, operands[0], *broadcast_dimensions));
725       break;
726     }
727     case HloOpcode::kConcatenate: {
728       optional<std::vector<int64>> dimensions;
729       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
730                              &dimensions};
731       if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
732           dimensions->size() != 1) {
733         return false;
734       }
735       instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
736           shape, operands, dimensions->at(0)));
737       break;
738     }
739     case HloOpcode::kMap: {
740       optional<HloComputation*> to_apply;
741       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
742                            &to_apply};
743       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
744         return false;
745       }
746       instruction = builder->AddInstruction(
747           HloInstruction::CreateMap(shape, operands, *to_apply));
748       break;
749     }
750     case HloOpcode::kReduce: {
751       optional<HloComputation*> reduce_computation;
752       attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
753                            &reduce_computation};
754       optional<std::vector<int64>> dimensions_to_reduce;
755       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
756                              &dimensions_to_reduce};
757       if (!ParseOperands(&operands, /*expected_size=*/2) ||
758           !ParseAttributes(attrs)) {
759         return false;
760       }
761       instruction = builder->AddInstruction(HloInstruction::CreateReduce(
762           shape, /*operand=*/operands[0], /*init_value=*/operands[1],
763           *dimensions_to_reduce, *reduce_computation));
764       break;
765     }
766     case HloOpcode::kReverse: {
767       optional<std::vector<int64>> dimensions;
768       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
769                              &dimensions};
770       if (!ParseOperands(&operands, /*expected_size=*/1) ||
771           !ParseAttributes(attrs)) {
772         return false;
773       }
774       instruction = builder->AddInstruction(
775           HloInstruction::CreateReverse(shape, operands[0], *dimensions));
776       break;
777     }
778     case HloOpcode::kSelectAndScatter: {
779       optional<HloComputation*> select;
780       attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
781       optional<HloComputation*> scatter;
782       attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
783       optional<Window> window;
784       attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
785       if (!ParseOperands(&operands, /*expected_size=*/3) ||
786           !ParseAttributes(attrs)) {
787         return false;
788       }
789       if (!window) {
790         window.emplace();
791       }
792       instruction =
793           builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
794               shape, /*operand=*/operands[0], *select, *window,
795               /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
796       break;
797     }
798     case HloOpcode::kSlice: {
799       optional<SliceRanges> slice_ranges;
800       attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
801       if (!ParseOperands(&operands, /*expected_size=*/1) ||
802           !ParseAttributes(attrs)) {
803         return false;
804       }
805       instruction = builder->AddInstruction(HloInstruction::CreateSlice(
806           shape, operands[0], slice_ranges->starts, slice_ranges->limits,
807           slice_ranges->strides));
808       break;
809     }
810     case HloOpcode::kDynamicSlice: {
811       optional<std::vector<int64>> dynamic_slice_sizes;
812       attrs["dynamic_slice_sizes"] = {
813           /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
814       if (!ParseOperands(&operands, /*expected_size=*/2) ||
815           !ParseAttributes(attrs)) {
816         return false;
817       }
818       instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
819           shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
820           *dynamic_slice_sizes));
821       break;
822     }
823     case HloOpcode::kDynamicUpdateSlice: {
824       if (!ParseOperands(&operands, /*expected_size=*/3) ||
825           !ParseAttributes(attrs)) {
826         return false;
827       }
828       instruction =
829           builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
830               shape, /*operand=*/operands[0], /*update=*/operands[1],
831               /*start_indices=*/operands[2]));
832       break;
833     }
834     case HloOpcode::kTranspose: {
835       optional<std::vector<int64>> dimensions;
836       attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
837                              &dimensions};
838       if (!ParseOperands(&operands, /*expected_size=*/1) ||
839           !ParseAttributes(attrs)) {
840         return false;
841       }
842       instruction = builder->AddInstruction(
843           HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
844       break;
845     }
846     case HloOpcode::kBatchNormTraining: {
847       optional<float> epsilon;
848       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
849       optional<int64> feature_index;
850       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
851                                 &feature_index};
852       if (!ParseOperands(&operands, /*expected_size=*/3) ||
853           !ParseAttributes(attrs)) {
854         return false;
855       }
856       instruction =
857           builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
858               shape, /*operand=*/operands[0], /*scale=*/operands[1],
859               /*offset=*/operands[2], *epsilon, *feature_index));
860       break;
861     }
862     case HloOpcode::kBatchNormInference: {
863       optional<float> epsilon;
864       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
865       optional<int64> feature_index;
866       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
867                                 &feature_index};
868       if (!ParseOperands(&operands, /*expected_size=*/5) ||
869           !ParseAttributes(attrs)) {
870         return false;
871       }
872       instruction =
873           builder->AddInstruction(HloInstruction::CreateBatchNormInference(
874               shape, /*operand=*/operands[0], /*scale=*/operands[1],
875               /*offset=*/operands[2], /*mean=*/operands[3],
876               /*variance=*/operands[4], *epsilon, *feature_index));
877       break;
878     }
879     case HloOpcode::kBatchNormGrad: {
880       optional<float> epsilon;
881       attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
882       optional<int64> feature_index;
883       attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
884                                 &feature_index};
885       if (!ParseOperands(&operands, /*expected_size=*/5) ||
886           !ParseAttributes(attrs)) {
887         return false;
888       }
889       instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
890           shape, /*operand=*/operands[0], /*scale=*/operands[1],
891           /*mean=*/operands[2], /*variance=*/operands[3],
892           /*grad_output=*/operands[4], *epsilon, *feature_index));
893       break;
894     }
895     case HloOpcode::kPad: {
896       optional<PaddingConfig> padding;
897       attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
898       if (!ParseOperands(&operands, /*expected_size=*/2) ||
899           !ParseAttributes(attrs)) {
900         return false;
901       }
902       instruction = builder->AddInstruction(HloInstruction::CreatePad(
903           shape, operands[0], /*padding_value=*/operands[1], *padding));
904       break;
905     }
906     case HloOpcode::kFusion: {
907       optional<HloComputation*> fusion_computation;
908       attrs["calls"] = {/*required=*/true, AttrTy::kHloComputation,
909                         &fusion_computation};
910       optional<HloInstruction::FusionKind> fusion_kind;
911       attrs["kind"] = {/*required=*/true, AttrTy::kFusionKind, &fusion_kind};
912       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
913         return false;
914       }
915       instruction = builder->AddInstruction(HloInstruction::CreateFusion(
916           shape, *fusion_kind, operands, *fusion_computation));
917       break;
918     }
919     case HloOpcode::kInfeed: {
920       optional<string> config;
921       attrs["infeed_config"] = {/*required=*/false, AttrTy::kString, &config};
922       if (!ParseOperands(&operands, /*expected_size=*/0) ||
923           !ParseAttributes(attrs)) {
924         return false;
925       }
926       instruction = builder->AddInstruction(
927           HloInstruction::CreateInfeed(shape, config ? *config : ""));
928       break;
929     }
930     case HloOpcode::kOutfeed: {
931       optional<string> config;
932       attrs["outfeed_config"] = {/*required=*/false, AttrTy::kString, &config};
933       if (!ParseOperands(&operands, /*expected_size=*/1) ||
934           !ParseAttributes(attrs)) {
935         return false;
936       }
937       instruction = builder->AddInstruction(HloInstruction::CreateOutfeed(
938           operands[0]->shape(), operands[0], config ? *config : ""));
939       break;
940     }
941     case HloOpcode::kRng: {
942       optional<RandomDistribution> distribution;
943       attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution,
944                                &distribution};
945       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
946         return false;
947       }
948       instruction = builder->AddInstruction(
949           HloInstruction::CreateRng(shape, *distribution, operands));
950       break;
951     }
952     case HloOpcode::kReducePrecision: {
953       optional<int64> exponent_bits;
954       optional<int64> mantissa_bits;
955       attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64,
956                                 &exponent_bits};
957       attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64,
958                                 &mantissa_bits};
959       if (!ParseOperands(&operands, /*expected_size=*/1) ||
960           !ParseAttributes(attrs)) {
961         return false;
962       }
963       instruction =
964           builder->AddInstruction(HloInstruction::CreateReducePrecision(
965               shape, operands[0], static_cast<int>(*exponent_bits),
966               static_cast<int>(*mantissa_bits)));
967       break;
968     }
969     case HloOpcode::kConditional: {
970       optional<HloComputation*> true_computation;
971       optional<HloComputation*> false_computation;
972       attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation,
973                                    &true_computation};
974       attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation,
975                                     &false_computation};
976       if (!ParseOperands(&operands, /*expected_size=*/3) ||
977           !ParseAttributes(attrs)) {
978         return false;
979       }
980       instruction = builder->AddInstruction(HloInstruction::CreateConditional(
981           shape, /*pred=*/operands[0],
982           /*true_computation_arg=*/operands[1], *true_computation,
983           /*false_computation_arg=*/operands[2], *false_computation));
984       break;
985     }
986     case HloOpcode::kCustomCall: {
987       optional<string> custom_call_target;
988       attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
989                                      &custom_call_target};
990       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
991         return false;
992       }
993       instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
994           shape, operands, *custom_call_target));
995       break;
996     }
997     case HloOpcode::kHostCompute: {
998       optional<string> channel_name;
999       optional<int64> cost_estimate_ns;
1000       attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
1001                                &channel_name};
1002       attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
1003                                    &cost_estimate_ns};
1004       if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
1005         return false;
1006       }
1007       instruction = builder->AddInstruction(HloInstruction::CreateHostCompute(
1008           shape, operands, *channel_name, *cost_estimate_ns));
1009       break;
1010     }
1011     case HloOpcode::kDot: {
1012       optional<std::vector<int64>> lhs_contracting_dims;
1013       attrs["lhs_contracting_dims"] = {
1014           /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims};
1015       optional<std::vector<int64>> rhs_contracting_dims;
1016       attrs["rhs_contracting_dims"] = {
1017           /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims};
1018       optional<std::vector<int64>> lhs_batch_dims;
1019       attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1020                                  &lhs_batch_dims};
1021       optional<std::vector<int64>> rhs_batch_dims;
1022       attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List,
1023                                  &rhs_batch_dims};
1024 
1025       if (!ParseOperands(&operands, /*expected_size=*/2) ||
1026           !ParseAttributes(attrs)) {
1027         return false;
1028       }
1029 
1030       DotDimensionNumbers dnum;
1031       if (lhs_contracting_dims) {
1032         *dnum.mutable_lhs_contracting_dimensions() = {
1033             lhs_contracting_dims->begin(), lhs_contracting_dims->end()};
1034       }
1035       if (rhs_contracting_dims) {
1036         *dnum.mutable_rhs_contracting_dimensions() = {
1037             rhs_contracting_dims->begin(), rhs_contracting_dims->end()};
1038       }
1039       if (lhs_batch_dims) {
1040         *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(),
1041                                                 lhs_batch_dims->end()};
1042       }
1043       if (rhs_batch_dims) {
1044         *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(),
1045                                                 rhs_batch_dims->end()};
1046       }
1047 
1048       instruction = builder->AddInstruction(
1049           HloInstruction::CreateDot(shape, operands[0], operands[1], dnum));
1050       break;
1051     }
1052     case HloOpcode::kGather:
1053       // TODO(b/72710576): HLO parsing is not implemented for Gather.
1054       return TokenError("HLO parsing is not implemented for Gather");
1055     case HloOpcode::kTrace:
1056       return TokenError(StrCat("parsing not yet implemented for op: ",
1057                                HloOpcodeString(opcode)));
1058   }
1059 
1060   instruction->set_name(name);
1061 
1062   // Add common attrs (sharding, control predecessors) to the instruction, if
1063   // they were seen.
1064   if (sharding) {
1065     instruction->set_sharding(
1066         HloSharding::FromProto(sharding.value()).ValueOrDie());
1067   }
1068   if (predecessors) {
1069     for (auto* pre : *predecessors) {
1070       Status status = pre->AddControlDependencyTo(instruction);
1071       if (!status.ok()) {
1072         return Error(name_loc, StrCat("error adding control dependency for: ",
1073                                       name, " status: ", status.ToString()));
1074       }
1075     }
1076   }
1077   if (metadata) {
1078     instruction->set_metadata(*metadata);
1079   }
1080   return AddInstruction(name, instruction, name_loc);
1081 }  // NOLINT(readability/fn_size)
1082 
1083 // ::= '{' (single_sharding | tuple_sharding) '}'
1084 //
1085 // tuple_sharding ::= single_sharding* (',' single_sharding)*
ParseSharding(OpSharding * sharding)1086 bool HloParser::ParseSharding(OpSharding* sharding) {
1087   // A single sharding starts with '{' and is not followed by '{'.
1088   // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
1089   // an empty tuple.
1090   if (!ParseToken(TokKind::kLbrace,
1091                   "expected '{' to start sharding attribute")) {
1092     return false;
1093   }
1094 
1095   if (lexer_.GetKind() != TokKind::kLbrace &&
1096       lexer_.GetKind() != TokKind::kRbrace) {
1097     return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
1098   }
1099 
1100   // Tuple sharding.
1101   // Allow empty tuple shardings.
1102   if (lexer_.GetKind() != TokKind::kRbrace) {
1103     do {
1104       if (!ParseSingleSharding(sharding->add_tuple_shardings(),
1105                                /*lbrace_pre_lexed=*/false)) {
1106         return false;
1107       }
1108     } while (EatIfPresent(TokKind::kComma));
1109   }
1110   sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
1111 
1112   return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
1113 }
1114 
1115 //  ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
1116 //          ('devices=' ('[' dims ']')* device_list)? '}'
1117 // dims ::= int_list device_list ::= int_list
ParseSingleSharding(OpSharding * sharding,bool lbrace_pre_lexed)1118 bool HloParser::ParseSingleSharding(OpSharding* sharding,
1119                                     bool lbrace_pre_lexed) {
1120   if (!lbrace_pre_lexed &&
1121       !ParseToken(TokKind::kLbrace,
1122                   "expected '{' to start sharding attribute")) {
1123     return false;
1124   }
1125 
1126   LocTy loc = lexer_.GetLoc();
1127   bool maximal = false;
1128   bool replicated = false;
1129   std::vector<int64> devices;
1130   std::vector<int64> tile_assignment_dimensions;
1131   Shape tile_shape;
1132   while (lexer_.GetKind() != TokKind::kRbrace) {
1133     switch (lexer_.GetKind()) {
1134       case TokKind::kw_maximal:
1135         maximal = true;
1136         lexer_.Lex();
1137         break;
1138       case TokKind::kw_replicated:
1139         replicated = true;
1140         lexer_.Lex();
1141         break;
1142       case TokKind::kAttributeName: {
1143         if (lexer_.GetStrVal() == "device") {
1144           if (lexer_.Lex() != TokKind::kInt) {
1145             return TokenError("device= attribute must be an integer");
1146           }
1147           devices = {lexer_.GetInt64Val()};
1148           lexer_.Lex();
1149         } else if (lexer_.GetStrVal() == "devices") {
1150           lexer_.Lex();
1151           if (!ParseToken(TokKind::kLsquare,
1152                           "expected '[' to start sharding devices shape")) {
1153             return false;
1154           }
1155 
1156           do {
1157             int64 dim;
1158             if (!ParseInt64(&dim)) {
1159               return false;
1160             }
1161             tile_assignment_dimensions.push_back(dim);
1162           } while (EatIfPresent(TokKind::kComma));
1163 
1164           if (!ParseToken(TokKind::kRsquare,
1165                           "expected ']' to start sharding devices shape")) {
1166             return false;
1167           }
1168           do {
1169             int64 device;
1170             if (!ParseInt64(&device)) {
1171               return false;
1172             }
1173             devices.push_back(device);
1174           } while (EatIfPresent(TokKind::kComma));
1175         } else {
1176           return TokenError(
1177               "unknown attribute in sharding: expected device= or devices=");
1178         }
1179         break;
1180       }
1181       case TokKind::kShape:
1182         tile_shape = lexer_.GetShapeVal();
1183         lexer_.Lex();
1184         break;
1185       case TokKind::kRbrace:
1186         break;
1187       default:
1188         return TokenError("unexpected token");
1189     }
1190   }
1191 
1192   if (replicated) {
1193     if (!devices.empty()) {
1194       return Error(loc,
1195                    "replicated shardings should not have any devices assigned");
1196     }
1197     if (!ShapeUtil::Equal(tile_shape, Shape())) {
1198       return Error(loc,
1199                    "replicated shardings should not have any tile shape set");
1200     }
1201     sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
1202   } else if (maximal) {
1203     if (devices.size() != 1) {
1204       return Error(loc,
1205                    "maximal shardings should have exactly one device assigned");
1206     }
1207     if (!ShapeUtil::Equal(tile_shape, Shape())) {
1208       return Error(loc, "maximal shardings should not have any tile shape set");
1209     }
1210     sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
1211     sharding->add_tile_assignment_devices(devices[0]);
1212   } else {
1213     if (devices.size() <= 1) {
1214       return Error(
1215           loc, "non-maximal shardings must have more than one device assigned");
1216     }
1217     if (ShapeUtil::Equal(tile_shape, Shape())) {
1218       return Error(loc, "non-maximal shardings should have a tile shape set");
1219     }
1220     if (tile_assignment_dimensions.empty()) {
1221       return Error(
1222           loc,
1223           "non-maximal shardings must have a tile assignment list including "
1224           "dimensions");
1225     }
1226     sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
1227     *sharding->mutable_tile_shape() = tile_shape;
1228     for (int64 dim : tile_assignment_dimensions) {
1229       sharding->add_tile_assignment_dimensions(dim);
1230     }
1231     for (int64 device : devices) {
1232       sharding->add_tile_assignment_devices(device);
1233     }
1234   }
1235 
1236   lexer_.Lex();
1237   return true;
1238 }
1239 
1240 // '{' name+ '}'
ParseInstructionNames(std::vector<HloInstruction * > * instructions)1241 bool HloParser::ParseInstructionNames(
1242     std::vector<HloInstruction*>* instructions) {
1243   if (!ParseToken(TokKind::kLbrace,
1244                   "expects '{' at the beginning of instruction name list")) {
1245     return false;
1246   }
1247   LocTy loc = lexer_.GetLoc();
1248   do {
1249     string name;
1250     if (!ParseName(&name)) {
1251       return Error(loc, "expects a instruction name");
1252     }
1253     std::pair<HloInstruction*, LocTy>* instr =
1254         tensorflow::gtl::FindOrNull(instruction_pool_, name);
1255     if (!instr) {
1256       return TokenError(
1257           Printf("instruction '%s' is not defined", name.c_str()));
1258     }
1259     instructions->push_back(instr->first);
1260   } while (EatIfPresent(TokKind::kComma));
1261 
1262   return ParseToken(TokKind::kRbrace,
1263                     "expects '}' at the end of instruction name list");
1264 }
1265 
SetValueInLiteral(int64 value,int64 linear_index,Literal * literal)1266 bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
1267                                   Literal* literal) {
1268   const Shape& shape = literal->shape();
1269   switch (shape.element_type()) {
1270     case S8:
1271       return SetValueInLiteralHelper<int8>(value, linear_index, literal);
1272     case S16:
1273       return SetValueInLiteralHelper<int16>(value, linear_index, literal);
1274     case S32:
1275       return SetValueInLiteralHelper<int32>(value, linear_index, literal);
1276     case S64:
1277       return SetValueInLiteralHelper<int64>(value, linear_index, literal);
1278     case U8:
1279       return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
1280     case U16:
1281       return SetValueInLiteralHelper<uint8>(value, linear_index, literal);
1282     case U32:
1283       return SetValueInLiteralHelper<uint32>(value, linear_index, literal);
1284     case U64:
1285       return SetValueInLiteralHelper<uint64>(value, linear_index, literal);
1286     default:
1287       LOG(FATAL) << "unknown integral primitive type "
1288                  << PrimitiveType_Name(shape.element_type());
1289   }
1290 }
1291 
SetValueInLiteral(double value,int64 linear_index,Literal * literal)1292 bool HloParser::SetValueInLiteral(double value, int64 linear_index,
1293                                   Literal* literal) {
1294   const Shape& shape = literal->shape();
1295   switch (shape.element_type()) {
1296     case F16:
1297       return SetValueInLiteralHelper<half>(value, linear_index, literal);
1298     case BF16:
1299       return SetValueInLiteralHelper<bfloat16>(value, linear_index, literal);
1300     case F32:
1301       return SetValueInLiteralHelper<float>(value, linear_index, literal);
1302     case F64:
1303       return SetValueInLiteralHelper<double>(value, linear_index, literal);
1304     default:
1305       LOG(FATAL) << "unknown floating point primitive type "
1306                  << PrimitiveType_Name(shape.element_type());
1307   }
1308 }
1309 
SetValueInLiteral(bool value,int64 linear_index,Literal * literal)1310 bool HloParser::SetValueInLiteral(bool value, int64 linear_index,
1311                                   Literal* literal) {
1312   const Shape& shape = literal->shape();
1313   switch (shape.element_type()) {
1314     case PRED:
1315       return SetValueInLiteralHelper<bool>(value, linear_index, literal);
1316     default:
1317       LOG(FATAL) << PrimitiveType_Name(shape.element_type())
1318                  << " is not PRED type";
1319   }
1320 }
1321 
1322 template <typename LiteralNativeT, typename ParsedElemT>
SetValueInLiteralHelper(ParsedElemT value,int64 linear_index,Literal * literal)1323 bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index,
1324                                         Literal* literal) {
1325   // Check that linear_index is in range.
1326   if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) {
1327     return TokenError(
1328         StrCat("trys to set value ", value, " to a literal in shape ",
1329                ShapeUtil::HumanString(literal->shape()), " at linear index ",
1330                linear_index, ", but the index is out of range"));
1331   }
1332 
1333   if (std::isnan(value) ||
1334       (std::numeric_limits<ParsedElemT>::has_infinity &&
1335        (std::numeric_limits<ParsedElemT>::infinity() == value ||
1336         -std::numeric_limits<ParsedElemT>::infinity() == value))) {
1337     // Skip range checking for non-finite value.
1338   } else if (literal->shape().element_type() == F16 ||
1339              literal->shape().element_type() == BF16) {
1340     if (value > kF16max || value < -kF16max) {
1341       return TokenError(StrCat(
1342           "value ", value, " is out of range for literal's primitive type ",
1343           PrimitiveType_Name(literal->shape().element_type())));
1344     }
1345   } else if (value > static_cast<ParsedElemT>(
1346                          std::numeric_limits<LiteralNativeT>::max()) ||
1347              value < static_cast<ParsedElemT>(
1348                          std::numeric_limits<LiteralNativeT>::lowest())) {
1349     // Value is out of range for LiteralNativeT.
1350     return TokenError(StrCat(
1351         "value ", value, " is out of range for literal's primitive type ",
1352         PrimitiveType_Name(literal->shape().element_type())));
1353   }
1354 
1355   literal->data<LiteralNativeT>().at(linear_index) =
1356       static_cast<LiteralNativeT>(value);
1357   return true;
1358 }
1359 
EatShapeAndCheckCompatible(const Shape & shape)1360 bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) {
1361   Shape new_shape;
1362   if (!ParseShape(&new_shape)) {
1363     return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape)));
1364   }
1365   if (!ShapeUtil::Compatible(shape, new_shape)) {
1366     return TokenError(StrCat(
1367         "expects shape ", ShapeUtil::HumanString(shape),
1368         ", but sees a different shape: ", ShapeUtil::HumanString(new_shape)));
1369   }
1370   return true;
1371 }
1372 
1373 // literal
1374 //  ::= tuple
1375 //  ::= non_tuple
ParseLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1376 bool HloParser::ParseLiteral(std::unique_ptr<Literal>* literal,
1377                              const Shape& shape) {
1378   return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape)
1379                                    : ParseNonTupleLiteral(literal, shape);
1380 }
1381 
1382 // tuple
1383 //  ::= shape '(' literal_list ')'
1384 // literal_list
1385 //  ::= /*empty*/
1386 //  ::= literal (',' literal)*
ParseTupleLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1387 bool HloParser::ParseTupleLiteral(std::unique_ptr<Literal>* literal,
1388                                   const Shape& shape) {
1389   if (!EatShapeAndCheckCompatible(shape)) {
1390     return TokenError(StrCat("expects tuple constant in shape ",
1391                              ShapeUtil::HumanString(shape)));
1392   }
1393   if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) {
1394     return false;
1395   }
1396   std::vector<std::unique_ptr<Literal>> elements(
1397       ShapeUtil::TupleElementCount(shape));
1398 
1399   if (lexer_.GetKind() == TokKind::kRparen) {
1400     // empty
1401   } else {
1402     // literal, (',' literal)*
1403     for (int i = 0; i < elements.size(); i++) {
1404       if (i > 0) {
1405         ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements");
1406       }
1407       if (!ParseLiteral(&elements[i],
1408                         ShapeUtil::GetTupleElementShape(shape, i))) {
1409         return TokenError(StrCat("expects the ", i, "th element"));
1410       }
1411     }
1412   }
1413   *literal = Literal::MakeTupleOwned(std::move(elements));
1414   return ParseToken(TokKind::kRparen,
1415                     StrCat("expects ')' at the end of the tuple with ",
1416                            ShapeUtil::TupleElementCount(shape), "elements"));
1417 }
1418 
1419 // non_tuple
1420 //   ::= rank01
1421 //   ::= rank2345
1422 // rank2345 ::= shape sparse_or_nested_array
ParseNonTupleLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1423 bool HloParser::ParseNonTupleLiteral(std::unique_ptr<Literal>* literal,
1424                                      const Shape& shape) {
1425   if (LayoutUtil::IsSparseArray(shape)) {
1426     return ParseSparseLiteral(literal, shape);
1427   }
1428 
1429   CHECK(LayoutUtil::IsDenseArray(shape));
1430   return ParseDenseLiteral(literal, shape);
1431 }
1432 
ParseDenseLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1433 bool HloParser::ParseDenseLiteral(std::unique_ptr<Literal>* literal,
1434                                   const Shape& shape) {
1435   const int64 rank = ShapeUtil::Rank(shape);
1436   if (rank > 1 && !EatShapeAndCheckCompatible(shape)) {
1437     return false;
1438   }
1439 
1440   // Create a literal with the given shape in default layout.
1441   *literal = Literal::CreateFromDimensions(shape.element_type(),
1442                                            AsInt64Slice(shape.dimensions()));
1443   int64 nest_level = 0;
1444   int64 linear_index = 0;
1445   // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for
1446   // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}},
1447   // when we are parsing the 2nd '{' (right before '1'), we are seeing a
1448   // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at
1449   // the first '}' (right after '3'), it means the sub-array ends, and the
1450   // sub-array is supposed to contain exactly 3 elements, so check if
1451   // elems_seen_per_dim[1] is 3.
1452   std::vector<int64> elems_seen_per_dim(rank);
1453   auto get_index_str = [&elems_seen_per_dim](int dim) -> string {
1454     std::vector<int64> elems_seen_until_dim(elems_seen_per_dim.begin(),
1455                                             elems_seen_per_dim.begin() + dim);
1456     return StrCat("[",
1457                   tensorflow::str_util::Join(
1458                       elems_seen_until_dim, ",",
1459                       [](string* out, const int64& num_elems) {
1460                         tensorflow::strings::StrAppend(out, num_elems - 1);
1461                       }),
1462                   "]");
1463   };
1464   do {
1465     switch (lexer_.GetKind()) {
1466       default:
1467         return TokenError("unexpected token type in a literal");
1468       case TokKind::kLbrace: {
1469         nest_level++;
1470         if (nest_level > rank) {
1471           return TokenError(Printf(
1472               "expects nested array in rank %lld, but sees larger", rank));
1473         }
1474         if (nest_level > 1) {
1475           elems_seen_per_dim[nest_level - 2]++;
1476           if (elems_seen_per_dim[nest_level - 2] >
1477               shape.dimensions(nest_level - 2)) {
1478             return TokenError(Printf(
1479                 "expects %lld elements in the %sth element, but sees more",
1480                 shape.dimensions(nest_level - 2),
1481                 get_index_str(nest_level - 2).c_str()));
1482           }
1483         }
1484         lexer_.Lex();
1485         break;
1486       }
1487       case TokKind::kRbrace: {
1488         nest_level--;
1489         if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) {
1490           return TokenError(Printf(
1491               "expects %lld elements in the %sth element, but sees %lld",
1492               shape.dimensions(nest_level), get_index_str(nest_level).c_str(),
1493               elems_seen_per_dim[nest_level]));
1494         }
1495         elems_seen_per_dim[nest_level] = 0;
1496         lexer_.Lex();
1497         break;
1498       }
1499       case TokKind::kComma:
1500       case TokKind::kComment:
1501         // Skip.
1502         lexer_.Lex();
1503         break;
1504       case TokKind::kw_true:
1505       case TokKind::kw_false:
1506       case TokKind::kInt:
1507       case TokKind::kDecimal:
1508       case TokKind::kw_nan:
1509       case TokKind::kw_inf:
1510       case TokKind::kNegInf: {
1511         if (rank > 0) {
1512           if (nest_level != rank) {
1513             return TokenError(
1514                 Printf("expects nested array in rank %lld, but sees %lld", rank,
1515                        nest_level));
1516           }
1517           elems_seen_per_dim[rank - 1]++;
1518           if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) {
1519             return TokenError(
1520                 Printf("expects %lld elements on the minor-most dimension, but "
1521                        "sees more",
1522                        shape.dimensions(rank - 1)));
1523           }
1524         }
1525         if (lexer_.GetKind() == TokKind::kw_true ||
1526             lexer_.GetKind() == TokKind::kw_false) {
1527           // TODO(congliu): bool type literals with rank >= 1 are actually
1528           // printed in a compact form instead of "true" or "false". Fix that.
1529           if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true,
1530                                  linear_index++, literal->get())) {
1531             return false;
1532           }
1533           lexer_.Lex();
1534         } else if (primitive_util::IsIntegralType(shape.element_type())) {
1535           LocTy loc = lexer_.GetLoc();
1536           int64 value;
1537           if (!ParseInt64(&value)) {
1538             return Error(loc, StrCat("expects integer for primitive type: ",
1539                                      PrimitiveType_Name(shape.element_type())));
1540           }
1541           if (!SetValueInLiteral(value, linear_index++, literal->get())) {
1542             return false;
1543           }
1544         } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
1545           LocTy loc = lexer_.GetLoc();
1546           double value;
1547           if (!ParseDouble(&value)) {
1548             return Error(
1549                 loc, StrCat("expect floating point value for primitive type: ",
1550                             PrimitiveType_Name(shape.element_type())));
1551           }
1552           if (!SetValueInLiteral(value, linear_index++, literal->get())) {
1553             return false;
1554           }
1555         } else {
1556           return TokenError(StrCat("unsupported primitive type ",
1557                                    PrimitiveType_Name(shape.element_type())));
1558         }
1559         break;
1560       }
1561     }  // end of switch
1562   } while (nest_level > 0);
1563 
1564   *literal = (*literal)->Relayout(shape.layout());
1565   return true;
1566 }
1567 
ParseSparseLiteral(std::unique_ptr<Literal> * literal,const Shape & shape)1568 bool HloParser::ParseSparseLiteral(std::unique_ptr<Literal>* literal,
1569                                    const Shape& shape) {
1570   if (!EatShapeAndCheckCompatible(shape)) {
1571     return false;
1572   }
1573 
1574   switch (shape.element_type()) {
1575     case PRED:
1576       return ParseSparseLiteralHelper<uint8>(literal, shape);
1577     case S8:
1578       return ParseSparseLiteralHelper<int8>(literal, shape);
1579     case S16:
1580       return ParseSparseLiteralHelper<int16>(literal, shape);
1581     case S32:
1582       return ParseSparseLiteralHelper<int32>(literal, shape);
1583     case S64:
1584       return ParseSparseLiteralHelper<int64>(literal, shape);
1585     case U8:
1586       return ParseSparseLiteralHelper<uint8>(literal, shape);
1587     case U16:
1588       return ParseSparseLiteralHelper<uint16>(literal, shape);
1589     case U32:
1590       return ParseSparseLiteralHelper<uint32>(literal, shape);
1591     case U64:
1592       return ParseSparseLiteralHelper<uint64>(literal, shape);
1593     case F16:
1594       return ParseSparseLiteralHelper<half>(literal, shape);
1595     case F32:
1596       return ParseSparseLiteralHelper<float>(literal, shape);
1597     case BF16:
1598       return ParseSparseLiteralHelper<bfloat16>(literal, shape);
1599     case F64:
1600       return ParseSparseLiteralHelper<double>(literal, shape);
1601     default:
1602       return Error(lexer_.GetLoc(),
1603                    StrCat("invalid primitive type for sparse literal: ",
1604                           PrimitiveType_Name(shape.element_type())));
1605   }
1606 }
1607 
1608 template <typename LiteralNativeT>
ParseSparseLiteralHelper(std::unique_ptr<Literal> * literal,const Shape & shape)1609 bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
1610                                          const Shape& shape) {
1611   std::vector<int64> index;
1612 
1613   int64 rank = ShapeUtil::Rank(shape);
1614 
1615   *literal = MakeUnique<Literal>(shape);
1616 
1617   if (!ParseToken(TokKind::kLbrace,
1618                   "expects '{' at the beginning of a sparse literal")) {
1619     return false;
1620   }
1621 
1622   for (;;) {
1623     if (lexer_.GetKind() == TokKind::kRbrace) {
1624       lexer_.Lex();
1625       break;
1626     }
1627 
1628     LocTy index_loc = lexer_.GetLoc();
1629     index.clear();
1630     if (lexer_.GetKind() == TokKind::kInt) {
1631       int64 single_index = lexer_.GetInt64Val();
1632       lexer_.Lex();
1633       if (rank != 1) {
1634         return Error(
1635             index_loc,
1636             StrCat("invalid single-dimensional index for shape with rank ",
1637                    rank, ": ", single_index));
1638       }
1639       index.push_back(single_index);
1640     } else {
1641       if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma,
1642                           &index)) {
1643         return false;
1644       }
1645       if (index.size() != rank) {
1646         return Error(
1647             index_loc,
1648             StrCat("invalid multi-dimension index for shape with rank ", rank,
1649                    ": [", tensorflow::str_util::Join(index, ", "), "]"));
1650       }
1651     }
1652     if (!ParseToken(TokKind::kColon,
1653                     "expects ':' after after the sparse array index and before "
1654                     "the sparse array value")) {
1655       return false;
1656     }
1657     LocTy value_loc = lexer_.GetLoc();
1658     LiteralNativeT value;
1659     if (lexer_.GetKind() == TokKind::kw_true ||
1660         lexer_.GetKind() == TokKind::kw_false) {
1661       value = static_cast<LiteralNativeT>(lexer_.GetKind() == TokKind::kw_true);
1662       lexer_.Lex();
1663     } else if (primitive_util::IsIntegralType(shape.element_type())) {
1664       int64 value_s64;
1665       if (!ParseInt64(&value_s64)) {
1666         return Error(value_loc,
1667                      StrCat("expects integer for primitive type: ",
1668                             PrimitiveType_Name(shape.element_type())));
1669       }
1670       value = static_cast<LiteralNativeT>(value_s64);
1671     } else if (primitive_util::IsFloatingPointType(shape.element_type())) {
1672       double value_f64;
1673       if (!ParseDouble(&value_f64)) {
1674         return Error(value_loc,
1675                      StrCat("expects floating point value for primitive type: ",
1676                             PrimitiveType_Name(shape.element_type())));
1677       }
1678       value = static_cast<LiteralNativeT>(value_f64);
1679     } else {
1680       LOG(FATAL) << "Unexpected element type: "
1681                  << PrimitiveType_Name(shape.element_type());
1682     }
1683     if (lexer_.GetKind() != TokKind::kRbrace &&
1684         !ParseToken(TokKind::kComma,
1685                     "expects ',' separator between sparse array elements")) {
1686       return false;
1687     }
1688 
1689     if ((*literal)->sparse_element_count() + 1 ==
1690         LayoutUtil::MaxSparseElements(shape.layout())) {
1691       return Error(
1692           lexer_.GetLoc(),
1693           StrCat("number of sparse elements exceeds maximum for layout: ",
1694                  ShapeUtil::HumanStringWithLayout(shape)));
1695     }
1696 
1697     (*literal)->AppendSparseElement(index, value);
1698   }
1699 
1700   (*literal)->SortSparseElements();
1701   return true;
1702 }
1703 
1704 // operands ::= '(' operands1 ')'
1705 // operands1
1706 //   ::= /*empty*/
1707 //   ::= operand (, operand)*
1708 // operand ::= (shape)? name
ParseOperands(std::vector<HloInstruction * > * operands)1709 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
1710   if (!ParseToken(TokKind::kLparen,
1711                   "expects '(' at the beginning of operands")) {
1712     return false;
1713   }
1714   if (lexer_.GetKind() == TokKind::kRparen) {
1715     // empty
1716   } else {
1717     do {
1718       LocTy loc = lexer_.GetLoc();
1719       string name;
1720       if (CanBeShape()) {
1721         Shape shape;
1722         if (!ParseShape(&shape)) {
1723           return false;
1724         }
1725       }
1726       if (!ParseName(&name)) {
1727         return false;
1728       }
1729       std::pair<HloInstruction*, LocTy>* instruction =
1730           tensorflow::gtl::FindOrNull(instruction_pool_, name);
1731       if (!instruction) {
1732         return Error(loc, StrCat("instruction does not exist: ", name));
1733       }
1734       operands->push_back(instruction->first);
1735     } while (EatIfPresent(TokKind::kComma));
1736   }
1737   return ParseToken(TokKind::kRparen, "expects ')' at the end of operands");
1738 }
1739 
ParseOperands(std::vector<HloInstruction * > * operands,const int expected_size)1740 bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
1741                               const int expected_size) {
1742   LocTy loc = lexer_.GetLoc();
1743   if (!ParseOperands(operands)) {
1744     return false;
1745   }
1746   if (expected_size != operands->size()) {
1747     return Error(loc, StrCat("expects ", expected_size, " operands, but has ",
1748                              operands->size(), " operands"));
1749   }
1750   return true;
1751 }
1752 
1753 // sub_attributes ::= '{' (','? attribute)* '}'
ParseSubAttributes(const std::unordered_map<string,AttrConfig> & attrs)1754 bool HloParser::ParseSubAttributes(
1755     const std::unordered_map<string, AttrConfig>& attrs) {
1756   LocTy loc = lexer_.GetLoc();
1757   if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) {
1758     return false;
1759   }
1760   std::unordered_set<string> seen_attrs;
1761   if (lexer_.GetKind() == TokKind::kRbrace) {
1762     // empty
1763   } else {
1764     do {
1765       EatIfPresent(TokKind::kComma);
1766       if (!ParseAttributeHelper(attrs, &seen_attrs)) {
1767         return false;
1768       }
1769     } while (lexer_.GetKind() != TokKind::kRbrace);
1770   }
1771   // Check that all required attrs were seen.
1772   for (const auto& attr_it : attrs) {
1773     if (attr_it.second.required &&
1774         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
1775       return Error(loc, Printf("sub-attribute %s is expected but not seen",
1776                                attr_it.first.c_str()));
1777     }
1778   }
1779   return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes");
1780 }
1781 
1782 // attributes ::= (',' attribute)*
ParseAttributes(const std::unordered_map<string,AttrConfig> & attrs)1783 bool HloParser::ParseAttributes(
1784     const std::unordered_map<string, AttrConfig>& attrs) {
1785   LocTy loc = lexer_.GetLoc();
1786   std::unordered_set<string> seen_attrs;
1787   while (EatIfPresent(TokKind::kComma)) {
1788     if (!ParseAttributeHelper(attrs, &seen_attrs)) {
1789       return false;
1790     }
1791   }
1792   // Check that all required attrs were seen.
1793   for (const auto& attr_it : attrs) {
1794     if (attr_it.second.required &&
1795         seen_attrs.find(attr_it.first) == seen_attrs.end()) {
1796       return Error(loc, Printf("attribute %s is expected but not seen",
1797                                attr_it.first.c_str()));
1798     }
1799   }
1800   return true;
1801 }
1802 
ParseAttributeHelper(const std::unordered_map<string,AttrConfig> & attrs,std::unordered_set<string> * seen_attrs)1803 bool HloParser::ParseAttributeHelper(
1804     const std::unordered_map<string, AttrConfig>& attrs,
1805     std::unordered_set<string>* seen_attrs) {
1806   LocTy loc = lexer_.GetLoc();
1807   string name;
1808   if (!ParseAttributeName(&name)) {
1809     return Error(loc, "error parsing attributes");
1810   }
1811   VLOG(1) << "Parsing attribute " << name;
1812   if (!seen_attrs->insert(name).second) {
1813     return Error(loc, Printf("attribute %s already exists", name.c_str()));
1814   }
1815   auto attr_it = attrs.find(name);
1816   if (attr_it == attrs.end()) {
1817     return Error(loc, Printf("unexpected attribute %s", name.c_str()));
1818   }
1819   AttrTy attr_type = attr_it->second.attr_type;
1820   void* attr_out_ptr = attr_it->second.result;
1821   bool success = [&] {
1822     LocTy attr_loc = lexer_.GetLoc();
1823     switch (attr_type) {
1824       case AttrTy::kInt64: {
1825         int64 result;
1826         if (!ParseInt64(&result)) {
1827           return false;
1828         }
1829         static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
1830         return true;
1831       }
1832       case AttrTy::kInt32: {
1833         int64 result;
1834         if (!ParseInt64(&result)) {
1835           return false;
1836         }
1837         if (result != static_cast<int32>(result)) {
1838           return Error(attr_loc, "value out of range for int32");
1839         }
1840         static_cast<optional<int32>*>(attr_out_ptr)
1841             ->emplace(static_cast<int32>(result));
1842         return true;
1843       }
1844       case AttrTy::kFloat: {
1845         double result;
1846         if (!ParseDouble(&result)) {
1847           return false;
1848         }
1849         if (result > std::numeric_limits<float>::max() ||
1850             result < std::numeric_limits<float>::lowest()) {
1851           return Error(attr_loc, "value out of range for float");
1852         }
1853         static_cast<optional<float>*>(attr_out_ptr)
1854             ->emplace(static_cast<float>(result));
1855         return true;
1856       }
1857       case AttrTy::kHloComputation: {
1858         HloComputation* result;
1859         if (!ParseComputationName(&result)) {
1860           return false;
1861         }
1862         static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
1863         return true;
1864       }
1865       case AttrTy::kFftType: {
1866         FftType result;
1867         if (!ParseFftType(&result)) {
1868           return false;
1869         }
1870         static_cast<optional<FftType>*>(attr_out_ptr)->emplace(result);
1871         return true;
1872       }
1873       case AttrTy::kWindow: {
1874         Window result;
1875         if (!ParseWindow(&result)) {
1876           return false;
1877         }
1878         static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
1879         return true;
1880       }
1881       case AttrTy::kConvolutionDimensionNumbers: {
1882         ConvolutionDimensionNumbers result;
1883         if (!ParseConvolutionDimensionNumbers(&result)) {
1884           return false;
1885         }
1886         static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
1887             ->emplace(result);
1888         return true;
1889       }
1890       case AttrTy::kSharding: {
1891         OpSharding sharding;
1892         if (!ParseSharding(&sharding)) {
1893           return false;
1894         }
1895         static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
1896         return true;
1897       }
1898       case AttrTy::kInstructionList: {
1899         std::vector<HloInstruction*> result;
1900         if (!ParseInstructionNames(&result)) {
1901           return false;
1902         }
1903         static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
1904             ->emplace(result);
1905         return true;
1906       }
1907       case AttrTy::kFusionKind: {
1908         HloInstruction::FusionKind result;
1909         if (!ParseFusionKind(&result)) {
1910           return false;
1911         }
1912         static_cast<optional<HloInstruction::FusionKind>*>(attr_out_ptr)
1913             ->emplace(result);
1914         return true;
1915       }
1916       case AttrTy::kBracedInt64List: {
1917         std::vector<int64> result;
1918         if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
1919                             &result)) {
1920           return false;
1921         }
1922         static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
1923             ->emplace(result);
1924         return true;
1925       }
1926       case AttrTy::kSliceRanges: {
1927         SliceRanges result;
1928         if (!ParseSliceRanges(&result)) {
1929           return false;
1930         }
1931         static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
1932         return true;
1933       }
1934       case AttrTy::kPaddingConfig: {
1935         PaddingConfig result;
1936         if (!ParsePaddingConfig(&result)) {
1937           return false;
1938         }
1939         static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
1940         return true;
1941       }
1942       case AttrTy::kString: {
1943         string result;
1944         if (!ParseString(&result)) {
1945           return false;
1946         }
1947         static_cast<optional<string>*>(attr_out_ptr)->emplace(result);
1948         return true;
1949       }
1950       case AttrTy::kMetadata: {
1951         OpMetadata result;
1952         if (!ParseMetadata(&result)) {
1953           return false;
1954         }
1955         static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result);
1956         return true;
1957       }
1958       case AttrTy::kDistribution: {
1959         RandomDistribution result;
1960         if (!ParseRandomDistribution(&result)) {
1961           return false;
1962         }
1963         static_cast<optional<RandomDistribution>*>(attr_out_ptr)
1964             ->emplace(result);
1965         return true;
1966       }
1967     }
1968   }();
1969   if (!success) {
1970     return Error(loc, Printf("error parsing attribute %s", name.c_str()));
1971   }
1972   return true;
1973 }
1974 
ParseComputationName(HloComputation ** value)1975 bool HloParser::ParseComputationName(HloComputation** value) {
1976   string name;
1977   LocTy loc = lexer_.GetLoc();
1978   if (!ParseName(&name)) {
1979     return Error(loc, "expects computation name");
1980   }
1981   std::pair<HloComputation*, LocTy>* computation =
1982       tensorflow::gtl::FindOrNull(computation_pool_, name);
1983   if (computation == nullptr) {
1984     return Error(loc, StrCat("computation does not exist: ", name));
1985   }
1986   *value = computation->first;
1987   return true;
1988 }
1989 
1990 // ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
1991 // The subattributes can appear in any order. 'size=' is required, others are
1992 // optional.
ParseWindow(Window * window)1993 bool HloParser::ParseWindow(Window* window) {
1994   LocTy loc = lexer_.GetLoc();
1995   if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
1996     return false;
1997   }
1998 
1999   std::vector<int64> size;
2000   std::vector<int64> stride;
2001   std::vector<std::vector<int64>> pad;
2002   std::vector<int64> lhs_dilate;
2003   std::vector<int64> rhs_dilate;
2004   std::vector<int64> rhs_reversal;
2005   while (lexer_.GetKind() != TokKind::kRbrace) {
2006     LocTy attr_loc = lexer_.GetLoc();
2007     string field_name;
2008     if (!ParseAttributeName(&field_name)) {
2009       return Error(attr_loc, "expects sub-attributes in window");
2010     }
2011     bool ok = [&] {
2012       if (field_name == "size") {
2013         return ParseDxD("size", &size);
2014       }
2015       if (field_name == "stride") {
2016         return ParseDxD("stride", &stride);
2017       }
2018       if (field_name == "lhs_dilate") {
2019         return ParseDxD("lhs_dilate", &lhs_dilate);
2020       }
2021       if (field_name == "rhs_dilate") {
2022         return ParseDxD("rls_dilate", &rhs_dilate);
2023       }
2024       if (field_name == "pad") {
2025         return ParseWindowPad(&pad);
2026       }
2027       if (field_name == "rhs_reversal") {
2028         return ParseDxD("rhs_reversal", &rhs_reversal);
2029       }
2030       return Error(attr_loc, StrCat("unexpected attribute name: ", field_name));
2031     }();
2032     if (!ok) {
2033       return false;
2034     }
2035   }
2036 
2037   if (size.empty()) {
2038     return Error(loc,
2039                  "sub-attribute 'size=' is required in the window attribute");
2040   }
2041   if (!stride.empty() && stride.size() != size.size()) {
2042     return Error(loc, "expects 'stride=' has the same size as 'size='");
2043   }
2044   if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
2045     return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='");
2046   }
2047   if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
2048     return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='");
2049   }
2050   if (!pad.empty() && pad.size() != size.size()) {
2051     return Error(loc, "expects 'pad=' has the same size as 'size='");
2052   }
2053 
2054   for (int i = 0; i < size.size(); i++) {
2055     window->add_dimensions()->set_size(size[i]);
2056     if (!pad.empty()) {
2057       window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
2058       window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
2059     }
2060     // If some field is not present, it has the default value.
2061     window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
2062     window->mutable_dimensions(i)->set_base_dilation(
2063         lhs_dilate.empty() ? 1 : lhs_dilate[i]);
2064     window->mutable_dimensions(i)->set_window_dilation(
2065         rhs_dilate.empty() ? 1 : rhs_dilate[i]);
2066     window->mutable_dimensions(i)->set_window_reversal(
2067         rhs_reversal.empty() ? false : (rhs_reversal[i] == 1));
2068   }
2069   return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
2070 }
2071 
2072 // This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
2073 // The string looks like "dim_labels=0bf_0io->0bf".
ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers * dnums)2074 bool HloParser::ParseConvolutionDimensionNumbers(
2075     ConvolutionDimensionNumbers* dnums) {
2076   if (lexer_.GetKind() != TokKind::kDimLabels) {
2077     return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
2078   }
2079   string str = lexer_.GetStrVal();
2080 
2081   // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
2082   // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
2083   // So we replace the "->" with "_" and then split on "_".
2084   str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
2085                                             /*newsub=*/"_",
2086                                             /*replace_all=*/false);
2087   std::vector<string> lhs_rhs_out = Split(str, "_");
2088   if (lhs_rhs_out.size() != 3) {
2089     LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
2090                << str;
2091   }
2092 
2093   const int64 rank = lhs_rhs_out[0].length();
2094   if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
2095     return TokenError(
2096         "convolution lhs, rhs, and output must have the same rank");
2097   }
2098   if (rank < 2) {
2099     return TokenError("convolution rank must >=2");
2100   }
2101 
2102   auto is_unique = [](string str) -> bool {
2103     std::sort(str.begin(), str.end());
2104     return std::unique(str.begin(), str.end()) == str.end();
2105   };
2106 
2107   // lhs
2108   {
2109     const string& lhs = lhs_rhs_out[0];
2110     if (!is_unique(lhs)) {
2111       return TokenError(
2112           StrCat("expects unique lhs dimension numbers, but sees ", lhs));
2113     }
2114     for (int i = 0; i < rank - 2; i++) {
2115       dnums->add_input_spatial_dimensions(-1);
2116     }
2117     for (int i = 0; i < rank; i++) {
2118       char c = lhs[i];
2119       if (c == 'b') {
2120         dnums->set_input_batch_dimension(i);
2121       } else if (c == 'f') {
2122         dnums->set_input_feature_dimension(i);
2123       } else if (c < '0' + rank && c >= '0') {
2124         dnums->set_input_spatial_dimensions(c - '0', i);
2125       } else {
2126         return TokenError(
2127             Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
2128       }
2129     }
2130   }
2131   // rhs
2132   {
2133     const string& rhs = lhs_rhs_out[1];
2134     if (!is_unique(rhs)) {
2135       return TokenError(
2136           StrCat("expects unique rhs dimension numbers, but sees ", rhs));
2137     }
2138     for (int i = 0; i < rank - 2; i++) {
2139       dnums->add_kernel_spatial_dimensions(-1);
2140     }
2141     for (int i = 0; i < rank; i++) {
2142       char c = rhs[i];
2143       if (c == 'i') {
2144         dnums->set_kernel_input_feature_dimension(i);
2145       } else if (c == 'o') {
2146         dnums->set_kernel_output_feature_dimension(i);
2147       } else if (c < '0' + rank && c >= '0') {
2148         dnums->set_kernel_spatial_dimensions(c - '0', i);
2149       } else {
2150         return TokenError(
2151             Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
2152       }
2153     }
2154   }
2155   // output
2156   {
2157     const string& out = lhs_rhs_out[2];
2158     if (!is_unique(out)) {
2159       return TokenError(
2160           StrCat("expects unique output dimension numbers, but sees ", out));
2161     }
2162     for (int i = 0; i < rank - 2; i++) {
2163       dnums->add_output_spatial_dimensions(-1);
2164     }
2165     for (int i = 0; i < rank; i++) {
2166       char c = out[i];
2167       if (c == 'b') {
2168         dnums->set_output_batch_dimension(i);
2169       } else if (c == 'f') {
2170         dnums->set_output_feature_dimension(i);
2171       } else if (c < '0' + rank && c >= '0') {
2172         dnums->set_output_spatial_dimensions(c - '0', i);
2173       } else {
2174         return TokenError(
2175             Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
2176       }
2177     }
2178   }
2179 
2180   lexer_.Lex();
2181   return true;
2182 }
2183 
2184 // ::= '{' ranges '}'
2185 //   ::= /*empty*/
2186 //   ::= range (',' range)*
2187 // range ::= '[' start ':' limit (':' stride)? ']'
2188 //
2189 // The slice ranges are printed as:
2190 //
2191 //  {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
2192 //
2193 // This function extracts the starts, limits, and strides as 3 vectors to the
2194 // result. If stride is not present, stride is 1. For example, if the slice
2195 // ranges is printed as:
2196 //
2197 //  {[2:3:4], [5:6:7], [8:9]}
2198 //
2199 // The parsed result will be:
2200 //
2201 //  {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
2202 //
ParseSliceRanges(SliceRanges * result)2203 bool HloParser::ParseSliceRanges(SliceRanges* result) {
2204   if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
2205     return false;
2206   }
2207   std::vector<std::vector<int64>> ranges;
2208   if (lexer_.GetKind() == TokKind::kRbrace) {
2209     // empty
2210     return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
2211   }
2212   do {
2213     LocTy loc = lexer_.GetLoc();
2214     ranges.emplace_back();
2215     if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
2216                         &ranges.back())) {
2217       return false;
2218     }
2219     const auto& range = ranges.back();
2220     if (range.size() != 2 && range.size() != 3) {
2221       return Error(loc, Printf("expects [start:limit:step] or [start:limit], "
2222                                "but sees %ld elements.",
2223                                range.size()));
2224     }
2225   } while (EatIfPresent(TokKind::kComma));
2226 
2227   for (const auto& range : ranges) {
2228     result->starts.push_back(range[0]);
2229     result->limits.push_back(range[1]);
2230     result->strides.push_back(range.size() == 3 ? range[2] : 1);
2231   }
2232   return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
2233 }
2234 
2235 // int64list ::= start int64_elements end
2236 // int64_elements
2237 //   ::= /*empty*/
2238 //   ::= int64_val (delim int64_val)*
ParseInt64List(const TokKind start,const TokKind end,const TokKind delim,std::vector<int64> * result)2239 bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
2240                                const TokKind delim,
2241                                std::vector<int64>* result) {
2242   if (!ParseToken(start, StrCat("expects an int64 list starting with ",
2243                                 TokKindToString(start)))) {
2244     return false;
2245   }
2246   if (lexer_.GetKind() == end) {
2247     // empty
2248   } else {
2249     do {
2250       int64 i;
2251       if (!ParseInt64(&i)) {
2252         return false;
2253       }
2254       result->push_back(i);
2255     } while (EatIfPresent(delim));
2256   }
2257   return ParseToken(
2258       end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
2259 }
2260 
2261 // param_list_to_shape ::= param_list '->' shape
ParseParamListToShape(Shape * shape,LocTy * shape_loc)2262 bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
2263   if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {
2264     return false;
2265   }
2266   *shape_loc = lexer_.GetLoc();
2267   return ParseShape(shape);
2268 }
2269 
CanBeParamListToShape()2270 bool HloParser::CanBeParamListToShape() {
2271   return lexer_.GetKind() == TokKind::kLparen;
2272 }
2273 
2274 // param_list ::= '(' param_list1 ')'
2275 // param_list1
2276 //   ::= /*empty*/
2277 //   ::= param (',' param)*
2278 // param ::= name shape
ParseParamList()2279 bool HloParser::ParseParamList() {
2280   if (!ParseToken(TokKind::kLparen,
2281                   "expects '(' at the beginning of param list")) {
2282     return false;
2283   }
2284 
2285   if (lexer_.GetKind() == TokKind::kRparen) {
2286     // empty
2287   } else {
2288     do {
2289       Shape shape;
2290       string name;
2291       if (!ParseName(&name) || !ParseShape(&shape)) {
2292         return false;
2293       }
2294     } while (EatIfPresent(TokKind::kComma));
2295   }
2296   return ParseToken(TokKind::kRparen, "expects ')' at the end of param list");
2297 }
2298 
2299 // shape ::= shape_val_
2300 // shape ::= '(' tuple_elements ')'
2301 // tuple_elements
2302 //   ::= /*empty*/
2303 //   ::= shape (',' shape)*
ParseShape(Shape * result)2304 bool HloParser::ParseShape(Shape* result) {
2305   if (EatIfPresent(TokKind::kLparen)) {  // Tuple
2306     std::vector<Shape> shapes;
2307     if (lexer_.GetKind() == TokKind::kRparen) {
2308       /*empty*/
2309     } else {
2310       // shape (',' shape)*
2311       do {
2312         shapes.emplace_back();
2313         if (!ParseShape(&shapes.back())) {
2314           return false;
2315         }
2316       } while (EatIfPresent(TokKind::kComma));
2317     }
2318     *result = ShapeUtil::MakeTupleShape(shapes);
2319     return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple.");
2320   }
2321 
2322   if (lexer_.GetKind() != TokKind::kShape) {
2323     return TokenError("expects shape");
2324   }
2325   *result = lexer_.GetShapeVal();
2326   lexer_.Lex();
2327   return true;
2328 }
2329 
CanBeShape()2330 bool HloParser::CanBeShape() {
2331   // A non-tuple shape starts with a kShape token; a tuple shape starts with
2332   // '('.
2333   return lexer_.GetKind() == TokKind::kShape ||
2334          lexer_.GetKind() == TokKind::kLparen;
2335 }
2336 
ParseName(string * result)2337 bool HloParser::ParseName(string* result) {
2338   VLOG(1) << "ParseName";
2339   if (lexer_.GetKind() != TokKind::kIdent &&
2340       lexer_.GetKind() != TokKind::kName) {
2341     return TokenError("expects name");
2342   }
2343   *result = lexer_.GetStrVal();
2344   lexer_.Lex();
2345   return true;
2346 }
2347 
ParseAttributeName(string * result)2348 bool HloParser::ParseAttributeName(string* result) {
2349   if (lexer_.GetKind() != TokKind::kAttributeName) {
2350     return TokenError("expects attribute name");
2351   }
2352   *result = lexer_.GetStrVal();
2353   lexer_.Lex();
2354   return true;
2355 }
2356 
ParseString(string * result)2357 bool HloParser::ParseString(string* result) {
2358   VLOG(1) << "ParseString";
2359   if (lexer_.GetKind() != TokKind::kString) {
2360     return TokenError("expects string");
2361   }
2362   *result = lexer_.GetStrVal();
2363   lexer_.Lex();
2364   return true;
2365 }
2366 
ParseDxD(const string & name,std::vector<int64> * result)2367 bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
2368   LocTy loc = lexer_.GetLoc();
2369   if (!result->empty()) {
2370     return Error(loc,
2371                  Printf("sub-attribute '%s=' already exists", name.c_str()));
2372   }
2373   // 1D
2374   if (lexer_.GetKind() == TokKind::kInt) {
2375     int64 number;
2376     if (!ParseInt64(&number)) {
2377       return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str()));
2378     }
2379     result->push_back(number);
2380     return true;
2381   }
2382   // 2D or higher.
2383   if (lexer_.GetKind() == TokKind::kDxD) {
2384     string str = lexer_.GetStrVal();
2385     if (!SplitAndParseAsInts(str, 'x', result)) {
2386       return Error(loc,
2387                    Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
2388     }
2389     lexer_.Lex();
2390     return true;
2391   }
2392   return TokenError("expects token type kInt or kDxD");
2393 }
2394 
ParseWindowPad(std::vector<std::vector<int64>> * pad)2395 bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
2396   LocTy loc = lexer_.GetLoc();
2397   if (!pad->empty()) {
2398     return Error(loc, "sub-attribute 'pad=' already exists");
2399   }
2400   if (lexer_.GetKind() != TokKind::kPad) {
2401     return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
2402   }
2403   string str = lexer_.GetStrVal();
2404   std::vector<string> padding_str = Split(str, 'x');
2405   for (int i = 0; i < padding_str.size(); i++) {
2406     std::vector<int64> low_high;
2407     if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
2408         low_high.size() != 2) {
2409       return Error(loc,
2410                    "expects padding_low and padding_high separated by '_'");
2411     }
2412     pad->push_back(low_high);
2413   }
2414   lexer_.Lex();
2415   return true;
2416 }
2417 
2418 // This is the inverse xla::ToString(PaddingConfig). The padding config string
2419 // looks like "0_0_0x3_3_1". The string is first separated by 'x', each
2420 // substring represents one PaddingConfigDimension. The substring is 3 (or 2)
2421 // numbers joined by '_'.
ParsePaddingConfig(PaddingConfig * padding)2422 bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
2423   if (lexer_.GetKind() != TokKind::kPad) {
2424     return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
2425   }
2426   LocTy loc = lexer_.GetLoc();
2427   string str = lexer_.GetStrVal();
2428   std::vector<string> padding_str = Split(str, 'x');
2429   for (const auto& padding_dim_str : padding_str) {
2430     std::vector<int64> padding_dim;
2431     if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
2432         (padding_dim.size() != 2 && padding_dim.size() != 3)) {
2433       return Error(loc,
2434                    "expects padding config pattern like 'low_high_interior' or "
2435                    "'low_high'");
2436     }
2437     auto* dim = padding->add_dimensions();
2438     dim->set_edge_padding_low(padding_dim[0]);
2439     dim->set_edge_padding_high(padding_dim[1]);
2440     dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
2441   }
2442   lexer_.Lex();
2443   return true;
2444 }
2445 
2446 // '{' metadata_string '}'
ParseMetadata(OpMetadata * metadata)2447 bool HloParser::ParseMetadata(OpMetadata* metadata) {
2448   std::unordered_map<string, AttrConfig> attrs;
2449   optional<string> op_type;
2450   optional<string> op_name;
2451   optional<string> source_file;
2452   optional<int32> source_line;
2453   attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type};
2454   attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name};
2455   attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file};
2456   attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line};
2457   if (!ParseSubAttributes(attrs)) {
2458     return false;
2459   }
2460   if (op_type) {
2461     metadata->set_op_type(*op_type);
2462   }
2463   if (op_name) {
2464     metadata->set_op_name(*op_name);
2465   }
2466   if (source_file) {
2467     metadata->set_source_file(*source_file);
2468   }
2469   if (source_line) {
2470     metadata->set_source_line(*source_line);
2471   }
2472   return true;
2473 }
2474 
ParseOpcode(HloOpcode * result)2475 bool HloParser::ParseOpcode(HloOpcode* result) {
2476   VLOG(1) << "ParseOpcode";
2477   if (lexer_.GetKind() != TokKind::kIdent) {
2478     return TokenError("expects opcode");
2479   }
2480   string val = lexer_.GetStrVal();
2481   auto status_or_result = StringToHloOpcode(val);
2482   if (!status_or_result.ok()) {
2483     return TokenError(
2484         Printf("expects opcode but sees: %s, error: %s", val.c_str(),
2485                status_or_result.status().error_message().c_str()));
2486   }
2487   *result = status_or_result.ValueOrDie();
2488   lexer_.Lex();
2489   return true;
2490 }
2491 
ParseFftType(FftType * result)2492 bool HloParser::ParseFftType(FftType* result) {
2493   VLOG(1) << "ParseFftType";
2494   if (lexer_.GetKind() != TokKind::kIdent) {
2495     return TokenError("expects fft type");
2496   }
2497   string val = lexer_.GetStrVal();
2498   if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) {
2499     return TokenError(Printf("expects fft type but sees: %s", val.c_str()));
2500   }
2501   lexer_.Lex();
2502   return true;
2503 }
2504 
ParseFusionKind(HloInstruction::FusionKind * result)2505 bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) {
2506   VLOG(1) << "ParseFusionKind";
2507   if (lexer_.GetKind() != TokKind::kIdent) {
2508     return TokenError("expects fusion kind");
2509   }
2510   string val = lexer_.GetStrVal();
2511   auto status_or_result = StringToFusionKind(val);
2512   if (!status_or_result.ok()) {
2513     return TokenError(
2514         Printf("expects fusion kind but sees: %s, error: %s", val.c_str(),
2515                status_or_result.status().error_message().c_str()));
2516   }
2517   *result = status_or_result.ValueOrDie();
2518   lexer_.Lex();
2519   return true;
2520 }
2521 
ParseRandomDistribution(RandomDistribution * result)2522 bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
2523   VLOG(1) << "ParseRandomDistribution";
2524   if (lexer_.GetKind() != TokKind::kIdent) {
2525     return TokenError("expects random distribution");
2526   }
2527   string val = lexer_.GetStrVal();
2528   auto status_or_result = StringToRandomDistribution(val);
2529   if (!status_or_result.ok()) {
2530     return TokenError(
2531         Printf("expects random distribution but sees: %s, error: %s",
2532                val.c_str(), status_or_result.status().error_message().c_str()));
2533   }
2534   *result = status_or_result.ValueOrDie();
2535   lexer_.Lex();
2536   return true;
2537 }
2538 
ParseInt64(int64 * result)2539 bool HloParser::ParseInt64(int64* result) {
2540   VLOG(1) << "ParseInt64";
2541   if (lexer_.GetKind() != TokKind::kInt) {
2542     return TokenError("expects integer");
2543   }
2544   *result = lexer_.GetInt64Val();
2545   lexer_.Lex();
2546   return true;
2547 }
2548 
ParseDouble(double * result)2549 bool HloParser::ParseDouble(double* result) {
2550   switch (lexer_.GetKind()) {
2551     case TokKind::kDecimal:
2552       *result = lexer_.GetDecimalVal();
2553       break;
2554     case TokKind::kInt:
2555       *result = static_cast<double>(lexer_.GetInt64Val());
2556       break;
2557     case TokKind::kw_nan:
2558       *result = std::numeric_limits<double>::quiet_NaN();
2559       break;
2560     case TokKind::kw_inf:
2561       *result = std::numeric_limits<double>::infinity();
2562       break;
2563     case TokKind::kNegInf:
2564       *result = -std::numeric_limits<double>::infinity();
2565       break;
2566     default:
2567       return TokenError("expects decimal or integer");
2568   }
2569   lexer_.Lex();
2570   return true;
2571 }
2572 
ParseBool(bool * result)2573 bool HloParser::ParseBool(bool* result) {
2574   if (lexer_.GetKind() != TokKind::kw_true &&
2575       lexer_.GetKind() != TokKind::kw_false) {
2576     return TokenError("expects true or false");
2577   }
2578   *result = lexer_.GetKind() == TokKind::kw_true;
2579   lexer_.Lex();
2580   return true;
2581 }
2582 
ParseToken(TokKind kind,const string & msg)2583 bool HloParser::ParseToken(TokKind kind, const string& msg) {
2584   VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg;
2585   if (lexer_.GetKind() != kind) {
2586     return TokenError(msg);
2587   }
2588   lexer_.Lex();
2589   return true;
2590 }
2591 
EatIfPresent(TokKind kind)2592 bool HloParser::EatIfPresent(TokKind kind) {
2593   if (lexer_.GetKind() != kind) {
2594     return false;
2595   }
2596   lexer_.Lex();
2597   return true;
2598 }
2599 
AddInstruction(const string & name,HloInstruction * instruction,LocTy name_loc)2600 bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
2601                                LocTy name_loc) {
2602   auto result = instruction_pool_.insert({name, {instruction, name_loc}});
2603   if (!result.second) {
2604     Error(name_loc, StrCat("instruction already exists: ", name));
2605     return Error(/*loc=*/result.first->second.second,
2606                  "instruction previously defined here");
2607   }
2608   return true;
2609 }
2610 
AddComputation(const string & name,HloComputation * computation,LocTy name_loc)2611 bool HloParser::AddComputation(const string& name, HloComputation* computation,
2612                                LocTy name_loc) {
2613   auto result = computation_pool_.insert({name, {computation, name_loc}});
2614   if (!result.second) {
2615     Error(name_loc, StrCat("computation already exists: ", name));
2616     return Error(/*loc=*/result.first->second.second,
2617                  "computation previously defined here");
2618   }
2619   return true;
2620 }
2621 
2622 }  // namespace
2623 
Parse(StringPiece str,const HloModuleConfig & config)2624 StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str,
2625                                            const HloModuleConfig& config) {
2626   HloParser parser(str, config);
2627   if (!parser.Run()) {
2628     return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str());
2629   }
2630   return parser.ConsumeHloModule();
2631 }
2632 
Parse(StringPiece str)2633 StatusOr<std::unique_ptr<HloModule>> Parse(StringPiece str) {
2634   HloModuleConfig config;
2635   return Parse(str, config);
2636 }
2637 
2638 }  // namespace tools
2639 }  // namespace xla
2640