1 //===- TypeParser.cpp - MLIR Type Parser Implementation -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the parser for the MLIR Types.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "Parser.h"
14 #include "mlir/IR/AffineMap.h"
15 #include "mlir/IR/BuiltinTypes.h"
16 
17 using namespace mlir;
18 using namespace mlir::detail;
19 
20 /// Optionally parse a type.
parseOptionalType(Type & type)21 OptionalParseResult Parser::parseOptionalType(Type &type) {
22   // There are many different starting tokens for a type, check them here.
23   switch (getToken().getKind()) {
24   case Token::l_paren:
25   case Token::kw_memref:
26   case Token::kw_tensor:
27   case Token::kw_complex:
28   case Token::kw_tuple:
29   case Token::kw_vector:
30   case Token::inttype:
31   case Token::kw_bf16:
32   case Token::kw_f16:
33   case Token::kw_f32:
34   case Token::kw_f64:
35   case Token::kw_index:
36   case Token::kw_none:
37   case Token::exclamation_identifier:
38     return failure(!(type = parseType()));
39 
40   default:
41     return llvm::None;
42   }
43 }
44 
45 /// Parse an arbitrary type.
46 ///
47 ///   type ::= function-type
48 ///          | non-function-type
49 ///
parseType()50 Type Parser::parseType() {
51   if (getToken().is(Token::l_paren))
52     return parseFunctionType();
53   return parseNonFunctionType();
54 }
55 
56 /// Parse a function result type.
57 ///
58 ///   function-result-type ::= type-list-parens
59 ///                          | non-function-type
60 ///
parseFunctionResultTypes(SmallVectorImpl<Type> & elements)61 ParseResult Parser::parseFunctionResultTypes(SmallVectorImpl<Type> &elements) {
62   if (getToken().is(Token::l_paren))
63     return parseTypeListParens(elements);
64 
65   Type t = parseNonFunctionType();
66   if (!t)
67     return failure();
68   elements.push_back(t);
69   return success();
70 }
71 
72 /// Parse a list of types without an enclosing parenthesis.  The list must have
73 /// at least one member.
74 ///
75 ///   type-list-no-parens ::=  type (`,` type)*
76 ///
parseTypeListNoParens(SmallVectorImpl<Type> & elements)77 ParseResult Parser::parseTypeListNoParens(SmallVectorImpl<Type> &elements) {
78   auto parseElt = [&]() -> ParseResult {
79     auto elt = parseType();
80     elements.push_back(elt);
81     return elt ? success() : failure();
82   };
83 
84   return parseCommaSeparatedList(parseElt);
85 }
86 
87 /// Parse a parenthesized list of types.
88 ///
89 ///   type-list-parens ::= `(` `)`
90 ///                      | `(` type-list-no-parens `)`
91 ///
parseTypeListParens(SmallVectorImpl<Type> & elements)92 ParseResult Parser::parseTypeListParens(SmallVectorImpl<Type> &elements) {
93   if (parseToken(Token::l_paren, "expected '('"))
94     return failure();
95 
96   // Handle empty lists.
97   if (getToken().is(Token::r_paren))
98     return consumeToken(), success();
99 
100   if (parseTypeListNoParens(elements) ||
101       parseToken(Token::r_paren, "expected ')'"))
102     return failure();
103   return success();
104 }
105 
106 /// Parse a complex type.
107 ///
108 ///   complex-type ::= `complex` `<` type `>`
109 ///
parseComplexType()110 Type Parser::parseComplexType() {
111   consumeToken(Token::kw_complex);
112 
113   // Parse the '<'.
114   if (parseToken(Token::less, "expected '<' in complex type"))
115     return nullptr;
116 
117   llvm::SMLoc elementTypeLoc = getToken().getLoc();
118   auto elementType = parseType();
119   if (!elementType ||
120       parseToken(Token::greater, "expected '>' in complex type"))
121     return nullptr;
122   if (!elementType.isa<FloatType>() && !elementType.isa<IntegerType>())
123     return emitError(elementTypeLoc, "invalid element type for complex"),
124            nullptr;
125 
126   return ComplexType::get(elementType);
127 }
128 
129 /// Parse a function type.
130 ///
131 ///   function-type ::= type-list-parens `->` function-result-type
132 ///
parseFunctionType()133 Type Parser::parseFunctionType() {
134   assert(getToken().is(Token::l_paren));
135 
136   SmallVector<Type, 4> arguments, results;
137   if (parseTypeListParens(arguments) ||
138       parseToken(Token::arrow, "expected '->' in function type") ||
139       parseFunctionResultTypes(results))
140     return nullptr;
141 
142   return builder.getFunctionType(arguments, results);
143 }
144 
145 /// Parse the offset and strides from a strided layout specification.
146 ///
147 ///   strided-layout ::= `offset:` dimension `,` `strides: ` stride-list
148 ///
parseStridedLayout(int64_t & offset,SmallVectorImpl<int64_t> & strides)149 ParseResult Parser::parseStridedLayout(int64_t &offset,
150                                        SmallVectorImpl<int64_t> &strides) {
151   // Parse offset.
152   consumeToken(Token::kw_offset);
153   if (!consumeIf(Token::colon))
154     return emitError("expected colon after `offset` keyword");
155   auto maybeOffset = getToken().getUnsignedIntegerValue();
156   bool question = getToken().is(Token::question);
157   if (!maybeOffset && !question)
158     return emitError("invalid offset");
159   offset = maybeOffset ? static_cast<int64_t>(maybeOffset.getValue())
160                        : MemRefType::getDynamicStrideOrOffset();
161   consumeToken();
162 
163   if (!consumeIf(Token::comma))
164     return emitError("expected comma after offset value");
165 
166   // Parse stride list.
167   if (!consumeIf(Token::kw_strides))
168     return emitError("expected `strides` keyword after offset specification");
169   if (!consumeIf(Token::colon))
170     return emitError("expected colon after `strides` keyword");
171   if (failed(parseStrideList(strides)))
172     return emitError("invalid braces-enclosed stride list");
173   if (llvm::any_of(strides, [](int64_t st) { return st == 0; }))
174     return emitError("invalid memref stride");
175 
176   return success();
177 }
178 
179 /// Parse a memref type.
180 ///
181 ///   memref-type ::= ranked-memref-type | unranked-memref-type
182 ///
183 ///   ranked-memref-type ::= `memref` `<` dimension-list-ranked type
184 ///                          (`,` semi-affine-map-composition)? (`,`
185 ///                          memory-space)? `>`
186 ///
187 ///   unranked-memref-type ::= `memref` `<*x` type (`,` memory-space)? `>`
188 ///
189 ///   semi-affine-map-composition ::= (semi-affine-map `,` )* semi-affine-map
190 ///   memory-space ::= integer-literal /* | TODO: address-space-id */
191 ///
parseMemRefType()192 Type Parser::parseMemRefType() {
193   consumeToken(Token::kw_memref);
194 
195   if (parseToken(Token::less, "expected '<' in memref type"))
196     return nullptr;
197 
198   bool isUnranked;
199   SmallVector<int64_t, 4> dimensions;
200 
201   if (consumeIf(Token::star)) {
202     // This is an unranked memref type.
203     isUnranked = true;
204     if (parseXInDimensionList())
205       return nullptr;
206 
207   } else {
208     isUnranked = false;
209     if (parseDimensionListRanked(dimensions))
210       return nullptr;
211   }
212 
213   // Parse the element type.
214   auto typeLoc = getToken().getLoc();
215   auto elementType = parseType();
216   if (!elementType)
217     return nullptr;
218 
219   // Check that memref is formed from allowed types.
220   if (!elementType.isIntOrIndexOrFloat() &&
221       !elementType.isa<VectorType, ComplexType>())
222     return emitError(typeLoc, "invalid memref element type"), nullptr;
223 
224   // Parse semi-affine-map-composition.
225   SmallVector<AffineMap, 2> affineMapComposition;
226   Optional<unsigned> memorySpace;
227   unsigned numDims = dimensions.size();
228 
229   auto parseElt = [&]() -> ParseResult {
230     // Check for the memory space.
231     if (getToken().is(Token::integer)) {
232       if (memorySpace)
233         return emitError("multiple memory spaces specified in memref type");
234       memorySpace = getToken().getUnsignedIntegerValue();
235       if (!memorySpace.hasValue())
236         return emitError("invalid memory space in memref type");
237       consumeToken(Token::integer);
238       return success();
239     }
240     if (isUnranked)
241       return emitError("cannot have affine map for unranked memref type");
242     if (memorySpace)
243       return emitError("expected memory space to be last in memref type");
244 
245     AffineMap map;
246     llvm::SMLoc mapLoc = getToken().getLoc();
247     if (getToken().is(Token::kw_offset)) {
248       int64_t offset;
249       SmallVector<int64_t, 4> strides;
250       if (failed(parseStridedLayout(offset, strides)))
251         return failure();
252       // Construct strided affine map.
253       map = makeStridedLinearLayoutMap(strides, offset, state.context);
254     } else {
255       // Parse an affine map attribute.
256       auto affineMap = parseAttribute();
257       if (!affineMap)
258         return failure();
259       auto affineMapAttr = affineMap.dyn_cast<AffineMapAttr>();
260       if (!affineMapAttr)
261         return emitError("expected affine map in memref type");
262       map = affineMapAttr.getValue();
263     }
264 
265     if (map.getNumDims() != numDims) {
266       size_t i = affineMapComposition.size();
267       return emitError(mapLoc, "memref affine map dimension mismatch between ")
268              << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
269              << " and affine map" << i + 1 << ": " << numDims
270              << " != " << map.getNumDims();
271     }
272     numDims = map.getNumResults();
273     affineMapComposition.push_back(map);
274     return success();
275   };
276 
277   // Parse a list of mappings and address space if present.
278   if (!consumeIf(Token::greater)) {
279     // Parse comma separated list of affine maps, followed by memory space.
280     if (parseToken(Token::comma, "expected ',' or '>' in memref type") ||
281         parseCommaSeparatedListUntil(Token::greater, parseElt,
282                                      /*allowEmptyList=*/false)) {
283       return nullptr;
284     }
285   }
286 
287   if (isUnranked)
288     return UnrankedMemRefType::get(elementType, memorySpace.getValueOr(0));
289 
290   return MemRefType::get(dimensions, elementType, affineMapComposition,
291                          memorySpace.getValueOr(0));
292 }
293 
294 /// Parse any type except the function type.
295 ///
296 ///   non-function-type ::= integer-type
297 ///                       | index-type
298 ///                       | float-type
299 ///                       | extended-type
300 ///                       | vector-type
301 ///                       | tensor-type
302 ///                       | memref-type
303 ///                       | complex-type
304 ///                       | tuple-type
305 ///                       | none-type
306 ///
307 ///   index-type ::= `index`
308 ///   float-type ::= `f16` | `bf16` | `f32` | `f64`
309 ///   none-type ::= `none`
310 ///
parseNonFunctionType()311 Type Parser::parseNonFunctionType() {
312   switch (getToken().getKind()) {
313   default:
314     return (emitError("expected non-function type"), nullptr);
315   case Token::kw_memref:
316     return parseMemRefType();
317   case Token::kw_tensor:
318     return parseTensorType();
319   case Token::kw_complex:
320     return parseComplexType();
321   case Token::kw_tuple:
322     return parseTupleType();
323   case Token::kw_vector:
324     return parseVectorType();
325   // integer-type
326   case Token::inttype: {
327     auto width = getToken().getIntTypeBitwidth();
328     if (!width.hasValue())
329       return (emitError("invalid integer width"), nullptr);
330     if (width.getValue() > IntegerType::kMaxWidth) {
331       emitError(getToken().getLoc(), "integer bitwidth is limited to ")
332           << IntegerType::kMaxWidth << " bits";
333       return nullptr;
334     }
335 
336     IntegerType::SignednessSemantics signSemantics = IntegerType::Signless;
337     if (Optional<bool> signedness = getToken().getIntTypeSignedness())
338       signSemantics = *signedness ? IntegerType::Signed : IntegerType::Unsigned;
339 
340     consumeToken(Token::inttype);
341     return IntegerType::get(width.getValue(), signSemantics, getContext());
342   }
343 
344   // float-type
345   case Token::kw_bf16:
346     consumeToken(Token::kw_bf16);
347     return builder.getBF16Type();
348   case Token::kw_f16:
349     consumeToken(Token::kw_f16);
350     return builder.getF16Type();
351   case Token::kw_f32:
352     consumeToken(Token::kw_f32);
353     return builder.getF32Type();
354   case Token::kw_f64:
355     consumeToken(Token::kw_f64);
356     return builder.getF64Type();
357 
358   // index-type
359   case Token::kw_index:
360     consumeToken(Token::kw_index);
361     return builder.getIndexType();
362 
363   // none-type
364   case Token::kw_none:
365     consumeToken(Token::kw_none);
366     return builder.getNoneType();
367 
368   // extended type
369   case Token::exclamation_identifier:
370     return parseExtendedType();
371   }
372 }
373 
374 /// Parse a tensor type.
375 ///
376 ///   tensor-type ::= `tensor` `<` dimension-list type `>`
377 ///   dimension-list ::= dimension-list-ranked | `*x`
378 ///
parseTensorType()379 Type Parser::parseTensorType() {
380   consumeToken(Token::kw_tensor);
381 
382   if (parseToken(Token::less, "expected '<' in tensor type"))
383     return nullptr;
384 
385   bool isUnranked;
386   SmallVector<int64_t, 4> dimensions;
387 
388   if (consumeIf(Token::star)) {
389     // This is an unranked tensor type.
390     isUnranked = true;
391 
392     if (parseXInDimensionList())
393       return nullptr;
394 
395   } else {
396     isUnranked = false;
397     if (parseDimensionListRanked(dimensions))
398       return nullptr;
399   }
400 
401   // Parse the element type.
402   auto elementTypeLoc = getToken().getLoc();
403   auto elementType = parseType();
404   if (!elementType || parseToken(Token::greater, "expected '>' in tensor type"))
405     return nullptr;
406   if (!TensorType::isValidElementType(elementType))
407     return emitError(elementTypeLoc, "invalid tensor element type"), nullptr;
408 
409   if (isUnranked)
410     return UnrankedTensorType::get(elementType);
411   return RankedTensorType::get(dimensions, elementType);
412 }
413 
414 /// Parse a tuple type.
415 ///
416 ///   tuple-type ::= `tuple` `<` (type (`,` type)*)? `>`
417 ///
parseTupleType()418 Type Parser::parseTupleType() {
419   consumeToken(Token::kw_tuple);
420 
421   // Parse the '<'.
422   if (parseToken(Token::less, "expected '<' in tuple type"))
423     return nullptr;
424 
425   // Check for an empty tuple by directly parsing '>'.
426   if (consumeIf(Token::greater))
427     return TupleType::get(getContext());
428 
429   // Parse the element types and the '>'.
430   SmallVector<Type, 4> types;
431   if (parseTypeListNoParens(types) ||
432       parseToken(Token::greater, "expected '>' in tuple type"))
433     return nullptr;
434 
435   return TupleType::get(types, getContext());
436 }
437 
438 /// Parse a vector type.
439 ///
440 ///   vector-type ::= `vector` `<` non-empty-static-dimension-list type `>`
441 ///   non-empty-static-dimension-list ::= decimal-literal `x`
442 ///                                       static-dimension-list
443 ///   static-dimension-list ::= (decimal-literal `x`)*
444 ///
parseVectorType()445 VectorType Parser::parseVectorType() {
446   consumeToken(Token::kw_vector);
447 
448   if (parseToken(Token::less, "expected '<' in vector type"))
449     return nullptr;
450 
451   SmallVector<int64_t, 4> dimensions;
452   if (parseDimensionListRanked(dimensions, /*allowDynamic=*/false))
453     return nullptr;
454   if (dimensions.empty())
455     return (emitError("expected dimension size in vector type"), nullptr);
456   if (any_of(dimensions, [](int64_t i) { return i <= 0; }))
457     return emitError(getToken().getLoc(),
458                      "vector types must have positive constant sizes"),
459            nullptr;
460 
461   // Parse the element type.
462   auto typeLoc = getToken().getLoc();
463   auto elementType = parseType();
464   if (!elementType || parseToken(Token::greater, "expected '>' in vector type"))
465     return nullptr;
466   if (!VectorType::isValidElementType(elementType))
467     return emitError(typeLoc, "vector elements must be int or float type"),
468            nullptr;
469 
470   return VectorType::get(dimensions, elementType);
471 }
472 
473 /// Parse a dimension list of a tensor or memref type.  This populates the
474 /// dimension list, using -1 for the `?` dimensions if `allowDynamic` is set and
475 /// errors out on `?` otherwise.
476 ///
477 ///   dimension-list-ranked ::= (dimension `x`)*
478 ///   dimension ::= `?` | decimal-literal
479 ///
480 /// When `allowDynamic` is not set, this is used to parse:
481 ///
482 ///   static-dimension-list ::= (decimal-literal `x`)*
483 ParseResult
parseDimensionListRanked(SmallVectorImpl<int64_t> & dimensions,bool allowDynamic)484 Parser::parseDimensionListRanked(SmallVectorImpl<int64_t> &dimensions,
485                                  bool allowDynamic) {
486   while (getToken().isAny(Token::integer, Token::question)) {
487     if (consumeIf(Token::question)) {
488       if (!allowDynamic)
489         return emitError("expected static shape");
490       dimensions.push_back(-1);
491     } else {
492       // Hexadecimal integer literals (starting with `0x`) are not allowed in
493       // aggregate type declarations.  Therefore, `0xf32` should be processed as
494       // a sequence of separate elements `0`, `x`, `f32`.
495       if (getTokenSpelling().size() > 1 && getTokenSpelling()[1] == 'x') {
496         // We can get here only if the token is an integer literal.  Hexadecimal
497         // integer literals can only start with `0x` (`1x` wouldn't lex as a
498         // literal, just `1` would, at which point we don't get into this
499         // branch).
500         assert(getTokenSpelling()[0] == '0' && "invalid integer literal");
501         dimensions.push_back(0);
502         state.lex.resetPointer(getTokenSpelling().data() + 1);
503         consumeToken();
504       } else {
505         // Make sure this integer value is in bound and valid.
506         auto dimension = getToken().getUnsignedIntegerValue();
507         if (!dimension.hasValue())
508           return emitError("invalid dimension");
509         dimensions.push_back((int64_t)dimension.getValue());
510         consumeToken(Token::integer);
511       }
512     }
513 
514     // Make sure we have an 'x' or something like 'xbf32'.
515     if (parseXInDimensionList())
516       return failure();
517   }
518 
519   return success();
520 }
521 
522 /// Parse an 'x' token in a dimension list, handling the case where the x is
523 /// juxtaposed with an element type, as in "xf32", leaving the "f32" as the next
524 /// token.
parseXInDimensionList()525 ParseResult Parser::parseXInDimensionList() {
526   if (getToken().isNot(Token::bare_identifier) || getTokenSpelling()[0] != 'x')
527     return emitError("expected 'x' in dimension list");
528 
529   // If we had a prefix of 'x', lex the next token immediately after the 'x'.
530   if (getTokenSpelling().size() != 1)
531     state.lex.resetPointer(getTokenSpelling().data() + 1);
532 
533   // Consume the 'x'.
534   consumeToken(Token::bare_identifier);
535 
536   return success();
537 }
538 
539 // Parse a comma-separated list of dimensions, possibly empty:
540 //   stride-list ::= `[` (dimension (`,` dimension)*)? `]`
parseStrideList(SmallVectorImpl<int64_t> & dimensions)541 ParseResult Parser::parseStrideList(SmallVectorImpl<int64_t> &dimensions) {
542   if (!consumeIf(Token::l_square))
543     return failure();
544   // Empty list early exit.
545   if (consumeIf(Token::r_square))
546     return success();
547   while (true) {
548     if (consumeIf(Token::question)) {
549       dimensions.push_back(MemRefType::getDynamicStrideOrOffset());
550     } else {
551       // This must be an integer value.
552       int64_t val;
553       if (getToken().getSpelling().getAsInteger(10, val))
554         return emitError("invalid integer value: ") << getToken().getSpelling();
555       // Make sure it is not the one value for `?`.
556       if (ShapedType::isDynamic(val))
557         return emitError("invalid integer value: ")
558                << getToken().getSpelling()
559                << ", use `?` to specify a dynamic dimension";
560       dimensions.push_back(val);
561       consumeToken(Token::integer);
562     }
563     if (!consumeIf(Token::comma))
564       break;
565   }
566   if (!consumeIf(Token::r_square))
567     return failure();
568   return success();
569 }
570