1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_lexer.h"
17 
18 #include <unordered_map>
19 
20 #include "absl/strings/ascii.h"
21 #include "absl/strings/escaping.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/xla/shape_util.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/util.h"
28 #include "tensorflow/core/lib/strings/numbers.h"
29 #include "tensorflow/core/platform/regexp.h"
30 
31 namespace xla {
32 namespace {
33 
34 using absl::string_view;
35 
36 constexpr int kEOF = -1;
37 constexpr int kError = -2;
38 
39 // [a-zA-Z0-9_.-]
IsIdentifierChar(char c)40 bool IsIdentifierChar(char c) {
41   return absl::ascii_isalnum(static_cast<unsigned char>(c)) || c == '-' ||
42          c == '.' || c == '_';
43 }
44 
45 }  // namespace
46 
GetNextChar()47 int HloLexer::GetNextChar() {
48   int current_char = PeekCurrentChar();
49   if (current_char != kEOF && current_char != kError) {
50     current_ptr_++;
51   }
52   return current_char;
53 }
54 
PeekCurrentChar() const55 int HloLexer::PeekCurrentChar() const {
56   if (current_ptr_ == buf_.end()) {
57     return kEOF;
58   }
59   char current_char = *current_ptr_;
60   if (current_char == 0) {
61     // '\0' should not appear in the middle of the string.
62     return kError;
63   }
64   return static_cast<unsigned char>(current_char);
65 }
66 
CanDereference(const char * ptr) const67 bool HloLexer::CanDereference(const char* ptr) const {
68   return ptr < buf_.end() && ptr >= buf_.begin();
69 }
70 
StringPieceFromPointers(const char * begin,const char * end) const71 absl::string_view HloLexer::StringPieceFromPointers(const char* begin,
72                                                     const char* end) const {
73   CHECK(begin <= end);
74   CHECK(begin == buf_.end() || CanDereference(begin));
75   CHECK(end == buf_.end() || CanDereference(end));
76   return absl::string_view(begin, end - begin);
77 }
78 
RegexpStringPieceFromPointers(const char * begin,const char * end) const79 tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers(
80     const char* begin, const char* end) const {
81   CHECK(begin <= end);
82   CHECK(begin == buf_.end() || CanDereference(begin));
83   CHECK(end == buf_.end() || CanDereference(end));
84   return tensorflow::RegexpStringPiece(begin, end - begin);
85 }
86 
LookAhead()87 TokKind HloLexer::LookAhead() {
88   if (GetKind() == TokKind::kEof || GetKind() == TokKind::kError) {
89     return GetKind();
90   }
91 
92   const char* old_current_ptr = current_ptr_;
93   TokenState old_token_state = token_state_;
94   Lex();
95   TokKind kind = GetKind();
96   token_state_ = old_token_state;
97   current_ptr_ = old_current_ptr;
98   return kind;
99 }
100 
LexToken()101 TokKind HloLexer::LexToken() {
102   while (true) {
103     token_state_.token_start = current_ptr_;
104 
105     int current_char = GetNextChar();
106     switch (current_char) {
107       default:
108         // [a-zA-Z_]
109         if (absl::ascii_isalpha(static_cast<unsigned char>(current_char)) ||
110             current_char == '_') {
111           return LexIdentifier();
112         }
113         return TokKind::kError;
114       case kEOF:
115         // Hit the end of the input buffer.
116         return TokKind::kEof;
117       case kError:
118         // Hit an invalid character in the input buffer.
119         return TokKind::kError;
120       case ' ':
121       case '\t':
122       case '\n':
123       case '\r':
124         // Ignore whitespace.
125         continue;
126       case '0':
127       case '1':
128       case '2':
129       case '3':
130       case '4':
131       case '5':
132       case '6':
133       case '7':
134       case '8':
135       case '9':
136       case '-':
137         if (current_char == '-' && PeekCurrentChar() == '>') {
138           current_ptr_++;
139           return TokKind::kArrow;
140         }
141         return LexNumberOrPattern();
142       case '=':
143         return TokKind::kEqual;
144       case '<':
145         if (current_char == '<' && PeekCurrentChar() == '=') {
146           current_ptr_++;
147           return TokKind::kLeq;
148         }
149         return TokKind::kError;
150       case ',':
151         return TokKind::kComma;
152       case '%':
153         return LexPercent();
154       case ':':
155         return TokKind::kColon;
156       case '*':
157         return TokKind::kAsterisk;
158       case '[':
159         return TokKind::kLsquare;
160       case ']':
161         return TokKind::kRsquare;
162       case '{':
163         return TokKind::kLbrace;
164       case '}':
165         return TokKind::kRbrace;
166       case '(':
167         return TokKind::kLparen;
168       case ')':
169         return TokKind::kRparen;
170       case '/': {
171         if (PeekCurrentChar() == '*') {
172           // This is the start of a /*...*/ delimited comment. Save the current
173           // location in case the comment is unterminated so the error message
174           // will point to the beginning of the comment.
175           const char* comment_start = current_ptr_;
176           current_ptr_++;
177           // Advance until '*/' is found.
178           while (true) {
179             int current = GetNextChar();
180             if (current == '*' && PeekCurrentChar() == '/') {
181               // End of comment.
182               current_ptr_++;
183               break;
184             }
185             if (current == kEOF) {
186               // Unterminated comment.
187               current_ptr_ = comment_start;
188               return TokKind::kError;
189             }
190             if (current == kError) {
191               return TokKind::kError;
192             }
193           }
194           // Return no token for the comment. Keep lexing.
195           continue;
196         } else if (PeekCurrentChar() == '/') {
197           // This is the start of a '//' delimited comment. Throw away
198           // everything until end of line or file. The end-of-line character(s)
199           // are left unlexed in the buffer which is harmless because these are
200           // skipped later by the lexer. This approach enables support for
201           // different end-of-line encodings.
202           while (true) {
203             int current = PeekCurrentChar();
204             if (current == kEOF || current == '\n' || current == '\r') {
205               break;
206             }
207             if (current == kError) {
208               return TokKind::kError;
209             }
210             current_ptr_++;
211           }
212           continue;
213         }
214         // A lone '/' is an error.
215         return TokKind::kError;
216       }
217       case '.':
218         if (PeekCurrentChar() == '.') {
219           current_ptr_++;
220           if (PeekCurrentChar() == '.') {
221             current_ptr_++;
222             return TokKind::kDots;
223           }
224         }
225         return TokKind::kError;
226       case '"':
227         return LexString();
228     }
229   }
230 }
231 
232 // Lex a shape, name, keyword, attribute name, the dim labels pattern, and
233 // other identifiers.
234 //
235 // shape    ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})?
236 // name     ::= [a-zA-Z_][a-zA-Z0-9_.-]*:
237 // keyword  ::= HloModule, ENTRY, ...
238 // attribute_name ::= condition, body, dimensions, ...
239 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
240 // identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]*
LexIdentifier()241 TokKind HloLexer::LexIdentifier() {
242   while (IsIdentifierChar(PeekCurrentChar())) {
243     current_ptr_++;
244   }
245 
246   // If followed by ':', it's a name.
247   if (PeekCurrentChar() == ':') {
248     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
249     current_ptr_++;  // skip ':'
250     return TokKind::kName;
251   }
252 
253   // If followed by '=', it's a attribute name.
254   if (PeekCurrentChar() == '=') {
255     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
256     current_ptr_++;  // skip '='
257     return TokKind::kAttributeName;
258   }
259 
260   absl::string_view identifier =
261       StringPieceFromPointers(token_state_.token_start, current_ptr_);
262 
263   // Primitive type strings are reserved words. The exception is 'tuple' whose
264   // type is represented using nested parentheses without the string 'tuple'.
265   if (primitive_util::IsPrimitiveTypeName(identifier)) {
266     PrimitiveType primitive_type =
267         primitive_util::StringToPrimitiveType(identifier).ValueOrDie();
268     if (primitive_type != TUPLE) {
269       token_state_.primitive_type_val = primitive_type;
270       return TokKind::kPrimitiveType;
271     }
272   }
273 
274   // See if this is a keyword.
275 #define KEYWORD(STR)            \
276   do {                          \
277     if (identifier == #STR) {   \
278       return TokKind::kw_##STR; \
279     }                           \
280   } while (false)
281 
282   KEYWORD(true);
283   KEYWORD(false);
284   KEYWORD(inf);
285   KEYWORD(nan);
286   KEYWORD(HloModule);
287   KEYWORD(ENTRY);
288   KEYWORD(ROOT);
289   KEYWORD(maximal);
290   KEYWORD(replicated);
291   KEYWORD(sparse);
292 
293 #undef KEYWORD
294 
295   {
296     auto consumable =
297         RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
298     static LazyRE2 dim_labels_pattern = {
299         R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
300     if (RE2::Consume(&consumable, *dim_labels_pattern)) {
301       current_ptr_ = consumable.begin();
302       token_state_.str_val.assign(token_state_.token_start, current_ptr_);
303       return TokKind::kDimLabels;
304     }
305   }
306 
307   token_state_.str_val = string(identifier);
308   return TokKind::kIdent;
309 }
310 
311 // Lex names after a % character.
312 // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*
LexPercent()313 TokKind HloLexer::LexPercent() {
314   const char* name_start = current_ptr_;
315   if (absl::ascii_isalpha(static_cast<unsigned char>(PeekCurrentChar())) ||
316       PeekCurrentChar() == '_') {
317     current_ptr_++;
318     while (IsIdentifierChar(PeekCurrentChar())) {
319       current_ptr_++;
320     }
321     token_state_.str_val.assign(name_start, current_ptr_);
322     return TokKind::kName;
323   }
324   return TokKind::kError;
325 }
326 
327 // Lex integer and floating-point values, -inf, and patterns for dim labels,
328 // dxd (e.g. 1x2x3), and pad.
329 //
330 // fp with exp ::= [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+)
331 // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+)
332 // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,}
333 // dxd_pattern ::= [0-9]+(x[0-9]+)+
334 // pad_pattern ::=
335 //   [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*
336 // int ::=  [-]?[0-9]+
337 // negative inf ::= '-inf'
LexNumberOrPattern()338 TokKind HloLexer::LexNumberOrPattern() {
339   auto consumable =
340       RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
341   static LazyRE2 float_pattern = {
342       R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|[-]?(\d+[.]\d*|\d*[.]\d+))"};
343   if (RE2::Consume(&consumable, *float_pattern)) {
344     current_ptr_ = consumable.begin();
345     CHECK(absl::SimpleAtod(string(token_state_.token_start, current_ptr_),
346                            &token_state_.decimal_val));
347     return TokKind::kDecimal;
348   }
349 
350   static LazyRE2 dim_labels_pattern = {
351       R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"};
352   static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"};
353   static LazyRE2 pad_pattern = {
354       R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"};
355 
356   if (RE2::Consume(&consumable, *dim_labels_pattern)) {
357     current_ptr_ = consumable.begin();
358     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
359     return TokKind::kDimLabels;
360   }
361 
362   if (RE2::Consume(&consumable, *dxd_pattern)) {
363     current_ptr_ = consumable.begin();
364     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
365     return TokKind::kDxD;
366   }
367 
368   if (RE2::Consume(&consumable, *pad_pattern)) {
369     current_ptr_ = consumable.begin();
370     token_state_.str_val.assign(token_state_.token_start, current_ptr_);
371     return TokKind::kPad;
372   }
373 
374   static LazyRE2 int_pattern = {R"([-]?\d+)"};
375   if (RE2::Consume(&consumable, *int_pattern)) {
376     current_ptr_ = consumable.begin();
377     auto slice =
378         StringPieceFromPointers(token_state_.token_start, current_ptr_);
379     if (absl::SimpleAtoi(slice, &token_state_.int64_val)) {
380       return TokKind::kInt;
381     }
382     LOG(ERROR) << "Failed to parse int literal: " << slice;
383     return TokKind::kError;
384   }
385 
386   static LazyRE2 neg_inf = {"-inf"};
387   if (RE2::Consume(&consumable, *neg_inf)) {
388     current_ptr_ = consumable.begin();
389     return TokKind::kNegInf;
390   }
391 
392   return TokKind::kError;
393 }
394 
GetLineAndColumn(LocTy location) const395 std::pair<unsigned, unsigned> HloLexer::GetLineAndColumn(LocTy location) const {
396   unsigned line_no = 1;
397   const char* start = buf_.begin();
398   const char* ptr = start;
399   if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) &&
400       line_no_cache_.last_query <= location) {
401     ptr = line_no_cache_.last_query;
402     line_no = line_no_cache_.line_no_of_query;
403   }
404   for (; ptr != location; ptr++) {
405     CHECK_LT(ptr, buf_.end());
406     if (*ptr == '\n') {
407       line_no++;
408     }
409   }
410 
411   // Update the line number cache.
412   line_no_cache_.last_query = ptr;
413   line_no_cache_.line_no_of_query = line_no;
414   size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n');
415   if (line_offset == absl::string_view::npos) {
416     line_offset = 0;
417   }
418   return {line_no, ptr - start - line_offset};
419 }
420 
GetLine(LocTy loc) const421 absl::string_view HloLexer::GetLine(LocTy loc) const {
422   if (!CanDereference(loc)) {
423     return "LINE OUT OF RANGE";
424   }
425   size_t line_start =
426       StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n');
427   const char* start = line_start == absl::string_view::npos
428                           ? buf_.begin()
429                           : buf_.begin() + line_start + 1;
430   size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n');
431   const char* end =
432       line_end == absl::string_view::npos ? buf_.end() : loc + line_end;
433 
434   return StringPieceFromPointers(start, end);
435 }
436 
437 // Lexes quoted string with escaping characters. If matched, the quoted string
438 // will be unescaped and stored to token_state_.str_val.
LexString()439 TokKind HloLexer::LexString() {
440   auto consumable =
441       RegexpStringPieceFromPointers(token_state_.token_start, buf_.end());
442   static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"};
443   if (RE2::Consume(&consumable, *escaping_pattern)) {
444     current_ptr_ = consumable.begin();
445     absl::string_view raw =
446         StringPieceFromPointers(token_state_.token_start + 1, current_ptr_ - 1);
447     string error;
448     if (!absl::CUnescape(raw, &token_state_.str_val, &error)) {
449       LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error;
450       return TokKind::kError;
451     }
452     return TokKind::kString;
453   }
454   return TokKind::kError;
455 }
456 
457 string TokKindToString(TokKind kind) {
458   switch (kind) {
459     case TokKind::kEof:
460       return "kEof";
461     case TokKind::kError:
462       return "kError";
463     case TokKind::kEqual:
464       return "kEqaul";
465     case TokKind::kComma:
466       return "kComma";
467     case TokKind::kColon:
468       return "kColon";
469     case TokKind::kAsterisk:
470       return "kAsterisk";
471     case TokKind::kLsquare:
472       return "kLsquare";
473     case TokKind::kRsquare:
474       return "kRsquare";
475     case TokKind::kLbrace:
476       return "kLbrace";
477     case TokKind::kRbrace:
478       return "kRbrace";
479     case TokKind::kLparen:
480       return "kLparen";
481     case TokKind::kRparen:
482       return "kRparen";
483     case TokKind::kArrow:
484       return "kArrow";
485     case TokKind::kLeq:
486       return "kLeq";
487     case TokKind::kw_HloModule:
488       return "kw_HloModule";
489     case TokKind::kw_ENTRY:
490       return "kw_ENTRY";
491     case TokKind::kw_ROOT:
492       return "kw_ROOT";
493     case TokKind::kw_true:
494       return "kw_true";
495     case TokKind::kw_false:
496       return "kw_false";
497     case TokKind::kw_maximal:
498       return "kw_maximal";
499     case TokKind::kw_replicated:
500       return "kw_replicated";
501     case TokKind::kw_nan:
502       return "kw_nan";
503     case TokKind::kw_inf:
504       return "kw_inf";
505     case TokKind::kNegInf:
506       return "kNegInf";
507     case TokKind::kw_sparse:
508       return "kw_sparse";
509     case TokKind::kPrimitiveType:
510       return "kPrimitiveType";
511     case TokKind::kName:
512       return "kName";
513     case TokKind::kAttributeName:
514       return "kAttributeName";
515     case TokKind::kDimLabels:
516       return "kDimLabels";
517     case TokKind::kDxD:
518       return "kDxD";
519     case TokKind::kPad:
520       return "kPad";
521     case TokKind::kIdent:
522       return "kIdent";
523     case TokKind::kString:
524       return "kString";
525     case TokKind::kInt:
526       return "kInt";
527     case TokKind::kDecimal:
528       return "kDecimal";
529     case TokKind::kDots:
530       return "kDots";
531   }
532 }
533 
534 }  // namespace xla
535