1 //===- AttributeParser.cpp - MLIR Attribute Parser Implementation ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 #include "mlir/IR/Dialect.h"
17 #include "mlir/IR/IntegerSet.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/Support/Endian.h"
20 
21 using namespace mlir;
22 using namespace mlir::detail;
23 
24 /// Parse an arbitrary attribute.
25 ///
26 ///  attribute-value ::= `unit`
27 ///                    | bool-literal
28 ///                    | integer-literal (`:` (index-type | integer-type))?
29 ///                    | float-literal (`:` float-type)?
30 ///                    | string-literal (`:` type)?
31 ///                    | type
32 ///                    | `[` (attribute-value (`,` attribute-value)*)? `]`
33 ///                    | `{` (attribute-entry (`,` attribute-entry)*)? `}`
34 ///                    | symbol-ref-id (`::` symbol-ref-id)*
35 ///                    | `dense` `<` attribute-value `>` `:`
36 ///                      (tensor-type | vector-type)
37 ///                    | `sparse` `<` attribute-value `,` attribute-value `>`
38 ///                      `:` (tensor-type | vector-type)
39 ///                    | `opaque` `<` dialect-namespace  `,` hex-string-literal
40 ///                      `>` `:` (tensor-type | vector-type)
41 ///                    | extended-attribute
42 ///
parseAttribute(Type type)43 Attribute Parser::parseAttribute(Type type) {
44   switch (getToken().getKind()) {
45   // Parse an AffineMap or IntegerSet attribute.
46   case Token::kw_affine_map: {
47     consumeToken(Token::kw_affine_map);
48 
49     AffineMap map;
50     if (parseToken(Token::less, "expected '<' in affine map") ||
51         parseAffineMapReference(map) ||
52         parseToken(Token::greater, "expected '>' in affine map"))
53       return Attribute();
54     return AffineMapAttr::get(map);
55   }
56   case Token::kw_affine_set: {
57     consumeToken(Token::kw_affine_set);
58 
59     IntegerSet set;
60     if (parseToken(Token::less, "expected '<' in integer set") ||
61         parseIntegerSetReference(set) ||
62         parseToken(Token::greater, "expected '>' in integer set"))
63       return Attribute();
64     return IntegerSetAttr::get(set);
65   }
66 
67   // Parse an array attribute.
68   case Token::l_square: {
69     consumeToken(Token::l_square);
70 
71     SmallVector<Attribute, 4> elements;
72     auto parseElt = [&]() -> ParseResult {
73       elements.push_back(parseAttribute());
74       return elements.back() ? success() : failure();
75     };
76 
77     if (parseCommaSeparatedListUntil(Token::r_square, parseElt))
78       return nullptr;
79     return builder.getArrayAttr(elements);
80   }
81 
82   // Parse a boolean attribute.
83   case Token::kw_false:
84     consumeToken(Token::kw_false);
85     return builder.getBoolAttr(false);
86   case Token::kw_true:
87     consumeToken(Token::kw_true);
88     return builder.getBoolAttr(true);
89 
90   // Parse a dense elements attribute.
91   case Token::kw_dense:
92     return parseDenseElementsAttr(type);
93 
94   // Parse a dictionary attribute.
95   case Token::l_brace: {
96     NamedAttrList elements;
97     if (parseAttributeDict(elements))
98       return nullptr;
99     return elements.getDictionary(getContext());
100   }
101 
102   // Parse an extended attribute, i.e. alias or dialect attribute.
103   case Token::hash_identifier:
104     return parseExtendedAttr(type);
105 
106   // Parse floating point and integer attributes.
107   case Token::floatliteral:
108     return parseFloatAttr(type, /*isNegative=*/false);
109   case Token::integer:
110     return parseDecOrHexAttr(type, /*isNegative=*/false);
111   case Token::minus: {
112     consumeToken(Token::minus);
113     if (getToken().is(Token::integer))
114       return parseDecOrHexAttr(type, /*isNegative=*/true);
115     if (getToken().is(Token::floatliteral))
116       return parseFloatAttr(type, /*isNegative=*/true);
117 
118     return (emitError("expected constant integer or floating point value"),
119             nullptr);
120   }
121 
122   // Parse a location attribute.
123   case Token::kw_loc: {
124     consumeToken(Token::kw_loc);
125 
126     LocationAttr locAttr;
127     if (parseToken(Token::l_paren, "expected '(' in inline location") ||
128         parseLocationInstance(locAttr) ||
129         parseToken(Token::r_paren, "expected ')' in inline location"))
130       return Attribute();
131     return locAttr;
132   }
133 
134   // Parse an opaque elements attribute.
135   case Token::kw_opaque:
136     return parseOpaqueElementsAttr(type);
137 
138   // Parse a sparse elements attribute.
139   case Token::kw_sparse:
140     return parseSparseElementsAttr(type);
141 
142   // Parse a string attribute.
143   case Token::string: {
144     auto val = getToken().getStringValue();
145     consumeToken(Token::string);
146     // Parse the optional trailing colon type if one wasn't explicitly provided.
147     if (!type && consumeIf(Token::colon) && !(type = parseType()))
148       return Attribute();
149 
150     return type ? StringAttr::get(val, type)
151                 : StringAttr::get(val, getContext());
152   }
153 
154   // Parse a symbol reference attribute.
155   case Token::at_identifier: {
156     std::string nameStr = getToken().getSymbolReference();
157     consumeToken(Token::at_identifier);
158 
159     // Parse any nested references.
160     std::vector<FlatSymbolRefAttr> nestedRefs;
161     while (getToken().is(Token::colon)) {
162       // Check for the '::' prefix.
163       const char *curPointer = getToken().getLoc().getPointer();
164       consumeToken(Token::colon);
165       if (!consumeIf(Token::colon)) {
166         state.lex.resetPointer(curPointer);
167         consumeToken();
168         break;
169       }
170       // Parse the reference itself.
171       auto curLoc = getToken().getLoc();
172       if (getToken().isNot(Token::at_identifier)) {
173         emitError(curLoc, "expected nested symbol reference identifier");
174         return Attribute();
175       }
176 
177       std::string nameStr = getToken().getSymbolReference();
178       consumeToken(Token::at_identifier);
179       nestedRefs.push_back(SymbolRefAttr::get(nameStr, getContext()));
180     }
181 
182     return builder.getSymbolRefAttr(nameStr, nestedRefs);
183   }
184 
185   // Parse a 'unit' attribute.
186   case Token::kw_unit:
187     consumeToken(Token::kw_unit);
188     return builder.getUnitAttr();
189 
190   default:
191     // Parse a type attribute.
192     if (Type type = parseType())
193       return TypeAttr::get(type);
194     return nullptr;
195   }
196 }
197 
198 /// Parse an optional attribute with the provided type.
parseOptionalAttribute(Attribute & attribute,Type type)199 OptionalParseResult Parser::parseOptionalAttribute(Attribute &attribute,
200                                                    Type type) {
201   switch (getToken().getKind()) {
202   case Token::at_identifier:
203   case Token::floatliteral:
204   case Token::integer:
205   case Token::hash_identifier:
206   case Token::kw_affine_map:
207   case Token::kw_affine_set:
208   case Token::kw_dense:
209   case Token::kw_false:
210   case Token::kw_loc:
211   case Token::kw_opaque:
212   case Token::kw_sparse:
213   case Token::kw_true:
214   case Token::kw_unit:
215   case Token::l_brace:
216   case Token::l_square:
217   case Token::minus:
218   case Token::string:
219     attribute = parseAttribute(type);
220     return success(attribute != nullptr);
221 
222   default:
223     // Parse an optional type attribute.
224     Type type;
225     OptionalParseResult result = parseOptionalType(type);
226     if (result.hasValue() && succeeded(*result))
227       attribute = TypeAttr::get(type);
228     return result;
229   }
230 }
parseOptionalAttribute(ArrayAttr & attribute,Type type)231 OptionalParseResult Parser::parseOptionalAttribute(ArrayAttr &attribute,
232                                                    Type type) {
233   return parseOptionalAttributeWithToken(Token::l_square, attribute, type);
234 }
parseOptionalAttribute(StringAttr & attribute,Type type)235 OptionalParseResult Parser::parseOptionalAttribute(StringAttr &attribute,
236                                                    Type type) {
237   return parseOptionalAttributeWithToken(Token::string, attribute, type);
238 }
239 
240 /// Attribute dictionary.
241 ///
242 ///   attribute-dict ::= `{` `}`
243 ///                    | `{` attribute-entry (`,` attribute-entry)* `}`
244 ///   attribute-entry ::= (bare-id | string-literal) `=` attribute-value
245 ///
parseAttributeDict(NamedAttrList & attributes)246 ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
247   if (parseToken(Token::l_brace, "expected '{' in attribute dictionary"))
248     return failure();
249 
250   llvm::SmallDenseSet<Identifier> seenKeys;
251   auto parseElt = [&]() -> ParseResult {
252     // The name of an attribute can either be a bare identifier, or a string.
253     Optional<Identifier> nameId;
254     if (getToken().is(Token::string))
255       nameId = builder.getIdentifier(getToken().getStringValue());
256     else if (getToken().isAny(Token::bare_identifier, Token::inttype) ||
257              getToken().isKeyword())
258       nameId = builder.getIdentifier(getTokenSpelling());
259     else
260       return emitError("expected attribute name");
261     if (!seenKeys.insert(*nameId).second)
262       return emitError("duplicate key '")
263              << *nameId << "' in dictionary attribute";
264     consumeToken();
265 
266     // Lazy load a dialect in the context if there is a possible namespace.
267     auto splitName = nameId->strref().split('.');
268     if (!splitName.second.empty())
269       getContext()->getOrLoadDialect(splitName.first);
270 
271     // Try to parse the '=' for the attribute value.
272     if (!consumeIf(Token::equal)) {
273       // If there is no '=', we treat this as a unit attribute.
274       attributes.push_back({*nameId, builder.getUnitAttr()});
275       return success();
276     }
277 
278     auto attr = parseAttribute();
279     if (!attr)
280       return failure();
281     attributes.push_back({*nameId, attr});
282     return success();
283   };
284 
285   if (parseCommaSeparatedListUntil(Token::r_brace, parseElt))
286     return failure();
287 
288   return success();
289 }
290 
291 /// Parse a float attribute.
parseFloatAttr(Type type,bool isNegative)292 Attribute Parser::parseFloatAttr(Type type, bool isNegative) {
293   auto val = getToken().getFloatingPointValue();
294   if (!val.hasValue())
295     return (emitError("floating point value too large for attribute"), nullptr);
296   consumeToken(Token::floatliteral);
297   if (!type) {
298     // Default to F64 when no type is specified.
299     if (!consumeIf(Token::colon))
300       type = builder.getF64Type();
301     else if (!(type = parseType()))
302       return nullptr;
303   }
304   if (!type.isa<FloatType>())
305     return (emitError("floating point value not valid for specified type"),
306             nullptr);
307   return FloatAttr::get(type, isNegative ? -val.getValue() : val.getValue());
308 }
309 
310 /// Construct a float attribute bitwise equivalent to the integer literal.
buildHexadecimalFloatLiteral(Parser * p,FloatType type,uint64_t value)311 static Optional<APFloat> buildHexadecimalFloatLiteral(Parser *p, FloatType type,
312                                                       uint64_t value) {
313   if (type.isF64())
314     return APFloat(type.getFloatSemantics(), APInt(/*numBits=*/64, value));
315 
316   APInt apInt(type.getWidth(), value);
317   if (apInt != value) {
318     p->emitError("hexadecimal float constant out of range for type");
319     return llvm::None;
320   }
321   return APFloat(type.getFloatSemantics(), apInt);
322 }
323 
324 /// Construct an APint from a parsed value, a known attribute type and
325 /// sign.
buildAttributeAPInt(Type type,bool isNegative,StringRef spelling)326 static Optional<APInt> buildAttributeAPInt(Type type, bool isNegative,
327                                            StringRef spelling) {
328   // Parse the integer value into an APInt that is big enough to hold the value.
329   APInt result;
330   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
331   if (spelling.getAsInteger(isHex ? 0 : 10, result))
332     return llvm::None;
333 
334   // Extend or truncate the bitwidth to the right size.
335   unsigned width = type.isIndex() ? IndexType::kInternalStorageBitWidth
336                                   : type.getIntOrFloatBitWidth();
337   if (width > result.getBitWidth()) {
338     result = result.zext(width);
339   } else if (width < result.getBitWidth()) {
340     // The parser can return an unnecessarily wide result with leading zeros.
341     // This isn't a problem, but truncating off bits is bad.
342     if (result.countLeadingZeros() < result.getBitWidth() - width)
343       return llvm::None;
344 
345     result = result.trunc(width);
346   }
347 
348   if (isNegative) {
349     // The value is negative, we have an overflow if the sign bit is not set
350     // in the negated apInt.
351     result.negate();
352     if (!result.isSignBitSet())
353       return llvm::None;
354   } else if ((type.isSignedInteger() || type.isIndex()) &&
355              result.isSignBitSet()) {
356     // The value is a positive signed integer or index,
357     // we have an overflow if the sign bit is set.
358     return llvm::None;
359   }
360 
361   return result;
362 }
363 
364 /// Parse a decimal or a hexadecimal literal, which can be either an integer
365 /// or a float attribute.
parseDecOrHexAttr(Type type,bool isNegative)366 Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
367   // Remember if the literal is hexadecimal.
368   StringRef spelling = getToken().getSpelling();
369   auto loc = state.curToken.getLoc();
370   bool isHex = spelling.size() > 1 && spelling[1] == 'x';
371 
372   consumeToken(Token::integer);
373   if (!type) {
374     // Default to i64 if not type is specified.
375     if (!consumeIf(Token::colon))
376       type = builder.getIntegerType(64);
377     else if (!(type = parseType()))
378       return nullptr;
379   }
380 
381   if (auto floatType = type.dyn_cast<FloatType>()) {
382     if (isNegative)
383       return emitError(
384                  loc,
385                  "hexadecimal float literal should not have a leading minus"),
386              nullptr;
387     if (!isHex) {
388       emitError(loc, "unexpected decimal integer literal for a float attribute")
389               .attachNote()
390           << "add a trailing dot to make the literal a float";
391       return nullptr;
392     }
393 
394     auto val = Token::getUInt64IntegerValue(spelling);
395     if (!val.hasValue())
396       return emitError("integer constant out of range for attribute"), nullptr;
397 
398     // Construct a float attribute bitwise equivalent to the integer literal.
399     Optional<APFloat> apVal =
400         buildHexadecimalFloatLiteral(this, floatType, *val);
401     return apVal ? FloatAttr::get(floatType, *apVal) : Attribute();
402   }
403 
404   if (!type.isa<IntegerType, IndexType>())
405     return emitError(loc, "integer literal not valid for specified type"),
406            nullptr;
407 
408   if (isNegative && type.isUnsignedInteger()) {
409     emitError(loc,
410               "negative integer literal not valid for unsigned integer type");
411     return nullptr;
412   }
413 
414   Optional<APInt> apInt = buildAttributeAPInt(type, isNegative, spelling);
415   if (!apInt)
416     return emitError(loc, "integer constant out of range for attribute"),
417            nullptr;
418   return builder.getIntegerAttr(type, *apInt);
419 }
420 
421 //===----------------------------------------------------------------------===//
422 // TensorLiteralParser
423 //===----------------------------------------------------------------------===//
424 
425 /// Parse elements values stored within a hex string. On success, the values are
426 /// stored into 'result'.
parseElementAttrHexValues(Parser & parser,Token tok,std::string & result)427 static ParseResult parseElementAttrHexValues(Parser &parser, Token tok,
428                                              std::string &result) {
429   if (Optional<std::string> value = tok.getHexStringValue()) {
430     result = std::move(*value);
431     return success();
432   }
433   return parser.emitError(
434       tok.getLoc(), "expected string containing hex digits starting with `0x`");
435 }
436 
437 namespace {
438 /// This class implements a parser for TensorLiterals. A tensor literal is
439 /// either a single element (e.g, 5) or a multi-dimensional list of elements
440 /// (e.g., [[5, 5]]).
441 class TensorLiteralParser {
442 public:
TensorLiteralParser(Parser & p)443   TensorLiteralParser(Parser &p) : p(p) {}
444 
445   /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
446   /// may also parse a tensor literal that is store as a hex string.
447   ParseResult parse(bool allowHex);
448 
449   /// Build a dense attribute instance with the parsed elements and the given
450   /// shaped type.
451   DenseElementsAttr getAttr(llvm::SMLoc loc, ShapedType type);
452 
getShape() const453   ArrayRef<int64_t> getShape() const { return shape; }
454 
455 private:
456   /// Get the parsed elements for an integer attribute.
457   ParseResult getIntAttrElements(llvm::SMLoc loc, Type eltTy,
458                                  std::vector<APInt> &intValues);
459 
460   /// Get the parsed elements for a float attribute.
461   ParseResult getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
462                                    std::vector<APFloat> &floatValues);
463 
464   /// Build a Dense String attribute for the given type.
465   DenseElementsAttr getStringAttr(llvm::SMLoc loc, ShapedType type, Type eltTy);
466 
467   /// Build a Dense attribute with hex data for the given type.
468   DenseElementsAttr getHexAttr(llvm::SMLoc loc, ShapedType type);
469 
470   /// Parse a single element, returning failure if it isn't a valid element
471   /// literal. For example:
472   /// parseElement(1) -> Success, 1
473   /// parseElement([1]) -> Failure
474   ParseResult parseElement();
475 
476   /// Parse a list of either lists or elements, returning the dimensions of the
477   /// parsed sub-tensors in dims. For example:
478   ///   parseList([1, 2, 3]) -> Success, [3]
479   ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
480   ///   parseList([[1, 2], 3]) -> Failure
481   ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
482   ParseResult parseList(SmallVectorImpl<int64_t> &dims);
483 
484   /// Parse a literal that was printed as a hex string.
485   ParseResult parseHexElements();
486 
487   Parser &p;
488 
489   /// The shape inferred from the parsed elements.
490   SmallVector<int64_t, 4> shape;
491 
492   /// Storage used when parsing elements, this is a pair of <is_negated, token>.
493   std::vector<std::pair<bool, Token>> storage;
494 
495   /// Storage used when parsing elements that were stored as hex values.
496   Optional<Token> hexStorage;
497 };
498 } // end anonymous namespace
499 
500 /// Parse the elements of a tensor literal. If 'allowHex' is true, the parser
501 /// may also parse a tensor literal that is store as a hex string.
parse(bool allowHex)502 ParseResult TensorLiteralParser::parse(bool allowHex) {
503   // If hex is allowed, check for a string literal.
504   if (allowHex && p.getToken().is(Token::string)) {
505     hexStorage = p.getToken();
506     p.consumeToken(Token::string);
507     return success();
508   }
509   // Otherwise, parse a list or an individual element.
510   if (p.getToken().is(Token::l_square))
511     return parseList(shape);
512   return parseElement();
513 }
514 
515 /// Build a dense attribute instance with the parsed elements and the given
516 /// shaped type.
getAttr(llvm::SMLoc loc,ShapedType type)517 DenseElementsAttr TensorLiteralParser::getAttr(llvm::SMLoc loc,
518                                                ShapedType type) {
519   Type eltType = type.getElementType();
520 
521   // Check to see if we parse the literal from a hex string.
522   if (hexStorage.hasValue() &&
523       (eltType.isIntOrFloat() || eltType.isa<ComplexType>()))
524     return getHexAttr(loc, type);
525 
526   // Check that the parsed storage size has the same number of elements to the
527   // type, or is a known splat.
528   if (!shape.empty() && getShape() != type.getShape()) {
529     p.emitError(loc) << "inferred shape of elements literal ([" << getShape()
530                      << "]) does not match type ([" << type.getShape() << "])";
531     return nullptr;
532   }
533 
534   // Handle complex types in the specific element type cases below.
535   bool isComplex = false;
536   if (ComplexType complexTy = eltType.dyn_cast<ComplexType>()) {
537     eltType = complexTy.getElementType();
538     isComplex = true;
539   }
540 
541   // Handle integer and index types.
542   if (eltType.isIntOrIndex()) {
543     std::vector<APInt> intValues;
544     if (failed(getIntAttrElements(loc, eltType, intValues)))
545       return nullptr;
546     if (isComplex) {
547       // If this is a complex, treat the parsed values as complex values.
548       auto complexData = llvm::makeArrayRef(
549           reinterpret_cast<std::complex<APInt> *>(intValues.data()),
550           intValues.size() / 2);
551       return DenseElementsAttr::get(type, complexData);
552     }
553     return DenseElementsAttr::get(type, intValues);
554   }
555   // Handle floating point types.
556   if (FloatType floatTy = eltType.dyn_cast<FloatType>()) {
557     std::vector<APFloat> floatValues;
558     if (failed(getFloatAttrElements(loc, floatTy, floatValues)))
559       return nullptr;
560     if (isComplex) {
561       // If this is a complex, treat the parsed values as complex values.
562       auto complexData = llvm::makeArrayRef(
563           reinterpret_cast<std::complex<APFloat> *>(floatValues.data()),
564           floatValues.size() / 2);
565       return DenseElementsAttr::get(type, complexData);
566     }
567     return DenseElementsAttr::get(type, floatValues);
568   }
569 
570   // Other types are assumed to be string representations.
571   return getStringAttr(loc, type, type.getElementType());
572 }
573 
574 /// Build a Dense Integer attribute for the given type.
575 ParseResult
getIntAttrElements(llvm::SMLoc loc,Type eltTy,std::vector<APInt> & intValues)576 TensorLiteralParser::getIntAttrElements(llvm::SMLoc loc, Type eltTy,
577                                         std::vector<APInt> &intValues) {
578   intValues.reserve(storage.size());
579   bool isUintType = eltTy.isUnsignedInteger();
580   for (const auto &signAndToken : storage) {
581     bool isNegative = signAndToken.first;
582     const Token &token = signAndToken.second;
583     auto tokenLoc = token.getLoc();
584 
585     if (isNegative && isUintType) {
586       return p.emitError(tokenLoc)
587              << "expected unsigned integer elements, but parsed negative value";
588     }
589 
590     // Check to see if floating point values were parsed.
591     if (token.is(Token::floatliteral)) {
592       return p.emitError(tokenLoc)
593              << "expected integer elements, but parsed floating-point";
594     }
595 
596     assert(token.isAny(Token::integer, Token::kw_true, Token::kw_false) &&
597            "unexpected token type");
598     if (token.isAny(Token::kw_true, Token::kw_false)) {
599       if (!eltTy.isInteger(1)) {
600         return p.emitError(tokenLoc)
601                << "expected i1 type for 'true' or 'false' values";
602       }
603       APInt apInt(1, token.is(Token::kw_true), /*isSigned=*/false);
604       intValues.push_back(apInt);
605       continue;
606     }
607 
608     // Create APInt values for each element with the correct bitwidth.
609     Optional<APInt> apInt =
610         buildAttributeAPInt(eltTy, isNegative, token.getSpelling());
611     if (!apInt)
612       return p.emitError(tokenLoc, "integer constant out of range for type");
613     intValues.push_back(*apInt);
614   }
615   return success();
616 }
617 
618 /// Build a Dense Float attribute for the given type.
619 ParseResult
getFloatAttrElements(llvm::SMLoc loc,FloatType eltTy,std::vector<APFloat> & floatValues)620 TensorLiteralParser::getFloatAttrElements(llvm::SMLoc loc, FloatType eltTy,
621                                           std::vector<APFloat> &floatValues) {
622   floatValues.reserve(storage.size());
623   for (const auto &signAndToken : storage) {
624     bool isNegative = signAndToken.first;
625     const Token &token = signAndToken.second;
626 
627     // Handle hexadecimal float literals.
628     if (token.is(Token::integer) && token.getSpelling().startswith("0x")) {
629       if (isNegative) {
630         return p.emitError(token.getLoc())
631                << "hexadecimal float literal should not have a leading minus";
632       }
633       auto val = token.getUInt64IntegerValue();
634       if (!val.hasValue()) {
635         return p.emitError(
636             "hexadecimal float constant out of range for attribute");
637       }
638       Optional<APFloat> apVal = buildHexadecimalFloatLiteral(&p, eltTy, *val);
639       if (!apVal)
640         return failure();
641       floatValues.push_back(*apVal);
642       continue;
643     }
644 
645     // Check to see if any decimal integers or booleans were parsed.
646     if (!token.is(Token::floatliteral))
647       return p.emitError()
648              << "expected floating-point elements, but parsed integer";
649 
650     // Build the float values from tokens.
651     auto val = token.getFloatingPointValue();
652     if (!val.hasValue())
653       return p.emitError("floating point value too large for attribute");
654 
655     APFloat apVal(isNegative ? -*val : *val);
656     if (!eltTy.isF64()) {
657       bool unused;
658       apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
659                     &unused);
660     }
661     floatValues.push_back(apVal);
662   }
663   return success();
664 }
665 
666 /// Build a Dense String attribute for the given type.
getStringAttr(llvm::SMLoc loc,ShapedType type,Type eltTy)667 DenseElementsAttr TensorLiteralParser::getStringAttr(llvm::SMLoc loc,
668                                                      ShapedType type,
669                                                      Type eltTy) {
670   if (hexStorage.hasValue()) {
671     auto stringValue = hexStorage.getValue().getStringValue();
672     return DenseStringElementsAttr::get(type, {stringValue});
673   }
674 
675   std::vector<std::string> stringValues;
676   std::vector<StringRef> stringRefValues;
677   stringValues.reserve(storage.size());
678   stringRefValues.reserve(storage.size());
679 
680   for (auto val : storage) {
681     stringValues.push_back(val.second.getStringValue());
682     stringRefValues.push_back(stringValues.back());
683   }
684 
685   return DenseStringElementsAttr::get(type, stringRefValues);
686 }
687 
688 /// Build a Dense attribute with hex data for the given type.
getHexAttr(llvm::SMLoc loc,ShapedType type)689 DenseElementsAttr TensorLiteralParser::getHexAttr(llvm::SMLoc loc,
690                                                   ShapedType type) {
691   Type elementType = type.getElementType();
692   if (!elementType.isIntOrIndexOrFloat() && !elementType.isa<ComplexType>()) {
693     p.emitError(loc)
694         << "expected floating-point, integer, or complex element type, got "
695         << elementType;
696     return nullptr;
697   }
698 
699   std::string data;
700   if (parseElementAttrHexValues(p, hexStorage.getValue(), data))
701     return nullptr;
702 
703   ArrayRef<char> rawData(data.data(), data.size());
704   bool detectedSplat = false;
705   if (!DenseElementsAttr::isValidRawBuffer(type, rawData, detectedSplat)) {
706     p.emitError(loc) << "elements hex data size is invalid for provided type: "
707                      << type;
708     return nullptr;
709   }
710 
711   if (llvm::support::endian::system_endianness() ==
712       llvm::support::endianness::big) {
713     // Convert endianess in big-endian(BE) machines. `rawData` is
714     // little-endian(LE) because HEX in raw data of dense element attribute
715     // is always LE format. It is converted into BE here to be used in BE
716     // machines.
717     SmallVector<char, 64> outDataVec(rawData.size());
718     MutableArrayRef<char> convRawData(outDataVec);
719     DenseIntOrFPElementsAttr::convertEndianOfArrayRefForBEmachine(
720         rawData, convRawData, type);
721     return DenseElementsAttr::getFromRawBuffer(type, convRawData,
722                                                detectedSplat);
723   }
724 
725   return DenseElementsAttr::getFromRawBuffer(type, rawData, detectedSplat);
726 }
727 
parseElement()728 ParseResult TensorLiteralParser::parseElement() {
729   switch (p.getToken().getKind()) {
730   // Parse a boolean element.
731   case Token::kw_true:
732   case Token::kw_false:
733   case Token::floatliteral:
734   case Token::integer:
735     storage.emplace_back(/*isNegative=*/false, p.getToken());
736     p.consumeToken();
737     break;
738 
739   // Parse a signed integer or a negative floating-point element.
740   case Token::minus:
741     p.consumeToken(Token::minus);
742     if (!p.getToken().isAny(Token::floatliteral, Token::integer))
743       return p.emitError("expected integer or floating point literal");
744     storage.emplace_back(/*isNegative=*/true, p.getToken());
745     p.consumeToken();
746     break;
747 
748   case Token::string:
749     storage.emplace_back(/*isNegative=*/false, p.getToken());
750     p.consumeToken();
751     break;
752 
753   // Parse a complex element of the form '(' element ',' element ')'.
754   case Token::l_paren:
755     p.consumeToken(Token::l_paren);
756     if (parseElement() ||
757         p.parseToken(Token::comma, "expected ',' between complex elements") ||
758         parseElement() ||
759         p.parseToken(Token::r_paren, "expected ')' after complex elements"))
760       return failure();
761     break;
762 
763   default:
764     return p.emitError("expected element literal of primitive type");
765   }
766 
767   return success();
768 }
769 
770 /// Parse a list of either lists or elements, returning the dimensions of the
771 /// parsed sub-tensors in dims. For example:
772 ///   parseList([1, 2, 3]) -> Success, [3]
773 ///   parseList([[1, 2], [3, 4]]) -> Success, [2, 2]
774 ///   parseList([[1, 2], 3]) -> Failure
775 ///   parseList([[1, [2, 3]], [4, [5]]]) -> Failure
parseList(SmallVectorImpl<int64_t> & dims)776 ParseResult TensorLiteralParser::parseList(SmallVectorImpl<int64_t> &dims) {
777   p.consumeToken(Token::l_square);
778 
779   auto checkDims = [&](const SmallVectorImpl<int64_t> &prevDims,
780                        const SmallVectorImpl<int64_t> &newDims) -> ParseResult {
781     if (prevDims == newDims)
782       return success();
783     return p.emitError("tensor literal is invalid; ranks are not consistent "
784                        "between elements");
785   };
786 
787   bool first = true;
788   SmallVector<int64_t, 4> newDims;
789   unsigned size = 0;
790   auto parseCommaSeparatedList = [&]() -> ParseResult {
791     SmallVector<int64_t, 4> thisDims;
792     if (p.getToken().getKind() == Token::l_square) {
793       if (parseList(thisDims))
794         return failure();
795     } else if (parseElement()) {
796       return failure();
797     }
798     ++size;
799     if (!first)
800       return checkDims(newDims, thisDims);
801     newDims = thisDims;
802     first = false;
803     return success();
804   };
805   if (p.parseCommaSeparatedListUntil(Token::r_square, parseCommaSeparatedList))
806     return failure();
807 
808   // Return the sublists' dimensions with 'size' prepended.
809   dims.clear();
810   dims.push_back(size);
811   dims.append(newDims.begin(), newDims.end());
812   return success();
813 }
814 
815 //===----------------------------------------------------------------------===//
816 // ElementsAttr Parser
817 //===----------------------------------------------------------------------===//
818 
819 /// Parse a dense elements attribute.
parseDenseElementsAttr(Type attrType)820 Attribute Parser::parseDenseElementsAttr(Type attrType) {
821   auto attribLoc = getToken().getLoc();
822   consumeToken(Token::kw_dense);
823   if (parseToken(Token::less, "expected '<' after 'dense'"))
824     return nullptr;
825 
826   // Parse the literal data if necessary.
827   TensorLiteralParser literalParser(*this);
828   if (!consumeIf(Token::greater)) {
829     if (literalParser.parse(/*allowHex=*/true) ||
830         parseToken(Token::greater, "expected '>'"))
831       return nullptr;
832   }
833 
834   // If the type is specified `parseElementsLiteralType` will not parse a type.
835   // Use the attribute location as the location for error reporting in that
836   // case.
837   auto loc = attrType ? attribLoc : getToken().getLoc();
838   auto type = parseElementsLiteralType(attrType);
839   if (!type)
840     return nullptr;
841   return literalParser.getAttr(loc, type);
842 }
843 
844 /// Parse an opaque elements attribute.
parseOpaqueElementsAttr(Type attrType)845 Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
846   consumeToken(Token::kw_opaque);
847   if (parseToken(Token::less, "expected '<' after 'opaque'"))
848     return nullptr;
849 
850   if (getToken().isNot(Token::string))
851     return (emitError("expected dialect namespace"), nullptr);
852 
853   auto name = getToken().getStringValue();
854   // Lazy load a dialect in the context if there is a possible namespace.
855   Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
856 
857   // TODO: Allow for having an unknown dialect on an opaque
858   // attribute. Otherwise, it can't be roundtripped without having the dialect
859   // registered.
860   if (!dialect)
861     return (emitError("no registered dialect with namespace '" + name + "'"),
862             nullptr);
863   consumeToken(Token::string);
864 
865   if (parseToken(Token::comma, "expected ','"))
866     return nullptr;
867 
868   Token hexTok = getToken();
869   if (parseToken(Token::string, "elements hex string should start with '0x'") ||
870       parseToken(Token::greater, "expected '>'"))
871     return nullptr;
872   auto type = parseElementsLiteralType(attrType);
873   if (!type)
874     return nullptr;
875 
876   std::string data;
877   if (parseElementAttrHexValues(*this, hexTok, data))
878     return nullptr;
879   return OpaqueElementsAttr::get(dialect, type, data);
880 }
881 
882 /// Shaped type for elements attribute.
883 ///
884 ///   elements-literal-type ::= vector-type | ranked-tensor-type
885 ///
886 /// This method also checks the type has static shape.
parseElementsLiteralType(Type type)887 ShapedType Parser::parseElementsLiteralType(Type type) {
888   // If the user didn't provide a type, parse the colon type for the literal.
889   if (!type) {
890     if (parseToken(Token::colon, "expected ':'"))
891       return nullptr;
892     if (!(type = parseType()))
893       return nullptr;
894   }
895 
896   if (!type.isa<RankedTensorType, VectorType>()) {
897     emitError("elements literal must be a ranked tensor or vector type");
898     return nullptr;
899   }
900 
901   auto sType = type.cast<ShapedType>();
902   if (!sType.hasStaticShape())
903     return (emitError("elements literal type must have static shape"), nullptr);
904 
905   return sType;
906 }
907 
908 /// Parse a sparse elements attribute.
parseSparseElementsAttr(Type attrType)909 Attribute Parser::parseSparseElementsAttr(Type attrType) {
910   consumeToken(Token::kw_sparse);
911   if (parseToken(Token::less, "Expected '<' after 'sparse'"))
912     return nullptr;
913 
914   // Check for the case where all elements are sparse. The indices are
915   // represented by a 2-dimensional shape where the second dimension is the rank
916   // of the type.
917   Type indiceEltType = builder.getIntegerType(64);
918   if (consumeIf(Token::greater)) {
919     ShapedType type = parseElementsLiteralType(attrType);
920     if (!type)
921       return nullptr;
922 
923     // Construct the sparse elements attr using zero element indice/value
924     // attributes.
925     ShapedType indicesType =
926         RankedTensorType::get({0, type.getRank()}, indiceEltType);
927     ShapedType valuesType = RankedTensorType::get({0}, type.getElementType());
928     return SparseElementsAttr::get(
929         type, DenseElementsAttr::get(indicesType, ArrayRef<Attribute>()),
930         DenseElementsAttr::get(valuesType, ArrayRef<Attribute>()));
931   }
932 
933   /// Parse the indices. We don't allow hex values here as we may need to use
934   /// the inferred shape.
935   auto indicesLoc = getToken().getLoc();
936   TensorLiteralParser indiceParser(*this);
937   if (indiceParser.parse(/*allowHex=*/false))
938     return nullptr;
939 
940   if (parseToken(Token::comma, "expected ','"))
941     return nullptr;
942 
943   /// Parse the values.
944   auto valuesLoc = getToken().getLoc();
945   TensorLiteralParser valuesParser(*this);
946   if (valuesParser.parse(/*allowHex=*/true))
947     return nullptr;
948 
949   if (parseToken(Token::greater, "expected '>'"))
950     return nullptr;
951 
952   auto type = parseElementsLiteralType(attrType);
953   if (!type)
954     return nullptr;
955 
956   // If the indices are a splat, i.e. the literal parser parsed an element and
957   // not a list, we set the shape explicitly. The indices are represented by a
958   // 2-dimensional shape where the second dimension is the rank of the type.
959   // Given that the parsed indices is a splat, we know that we only have one
960   // indice and thus one for the first dimension.
961   ShapedType indicesType;
962   if (indiceParser.getShape().empty()) {
963     indicesType = RankedTensorType::get({1, type.getRank()}, indiceEltType);
964   } else {
965     // Otherwise, set the shape to the one parsed by the literal parser.
966     indicesType = RankedTensorType::get(indiceParser.getShape(), indiceEltType);
967   }
968   auto indices = indiceParser.getAttr(indicesLoc, indicesType);
969 
970   // If the values are a splat, set the shape explicitly based on the number of
971   // indices. The number of indices is encoded in the first dimension of the
972   // indice shape type.
973   auto valuesEltType = type.getElementType();
974   ShapedType valuesType =
975       valuesParser.getShape().empty()
976           ? RankedTensorType::get({indicesType.getDimSize(0)}, valuesEltType)
977           : RankedTensorType::get(valuesParser.getShape(), valuesEltType);
978   auto values = valuesParser.getAttr(valuesLoc, valuesType);
979 
980   /// Sanity check.
981   if (valuesType.getRank() != 1)
982     return (emitError("expected 1-d tensor for values"), nullptr);
983 
984   auto sameShape = (indicesType.getRank() == 1) ||
985                    (type.getRank() == indicesType.getDimSize(1));
986   auto sameElementNum = indicesType.getDimSize(0) == valuesType.getDimSize(0);
987   if (!sameShape || !sameElementNum) {
988     emitError() << "expected shape ([" << type.getShape()
989                 << "]); inferred shape of indices literal (["
990                 << indicesType.getShape()
991                 << "]); inferred shape of values literal (["
992                 << valuesType.getShape() << "])";
993     return nullptr;
994   }
995 
996   // Build the sparse elements attribute by the indices and values.
997   return SparseElementsAttr::get(type, indices, values);
998 }
999