1 //===- ByteCode.cpp - Pattern ByteCode Interpreter ------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements MLIR to byte-code generation and the interpreter.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "ByteCode.h"
14 #include "mlir/Analysis/Liveness.h"
15 #include "mlir/Dialect/PDL/IR/PDLTypes.h"
16 #include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/RegionGraphTraits.h"
19 #include "llvm/ADT/IntervalMap.h"
20 #include "llvm/ADT/PostOrderIterator.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 #define DEBUG_TYPE "pdl-bytecode"
25 
26 using namespace mlir;
27 using namespace mlir::detail;
28 
29 //===----------------------------------------------------------------------===//
30 // PDLByteCodePattern
31 //===----------------------------------------------------------------------===//
32 
create(pdl_interp::RecordMatchOp matchOp,ByteCodeAddr rewriterAddr)33 PDLByteCodePattern PDLByteCodePattern::create(pdl_interp::RecordMatchOp matchOp,
34                                               ByteCodeAddr rewriterAddr) {
35   SmallVector<StringRef, 8> generatedOps;
36   if (ArrayAttr generatedOpsAttr = matchOp.generatedOpsAttr())
37     generatedOps =
38         llvm::to_vector<8>(generatedOpsAttr.getAsValueRange<StringAttr>());
39 
40   PatternBenefit benefit = matchOp.benefit();
41   MLIRContext *ctx = matchOp.getContext();
42 
43   // Check to see if this is pattern matches a specific operation type.
44   if (Optional<StringRef> rootKind = matchOp.rootKind())
45     return PDLByteCodePattern(rewriterAddr, *rootKind, generatedOps, benefit,
46                               ctx);
47   return PDLByteCodePattern(rewriterAddr, generatedOps, benefit, ctx,
48                             MatchAnyOpTypeTag());
49 }
50 
51 //===----------------------------------------------------------------------===//
52 // PDLByteCodeMutableState
53 //===----------------------------------------------------------------------===//
54 
55 /// Set the new benefit for a bytecode pattern. The `patternIndex` corresponds
56 /// to the position of the pattern within the range returned by
57 /// `PDLByteCode::getPatterns`.
updatePatternBenefit(unsigned patternIndex,PatternBenefit benefit)58 void PDLByteCodeMutableState::updatePatternBenefit(unsigned patternIndex,
59                                                    PatternBenefit benefit) {
60   currentPatternBenefits[patternIndex] = benefit;
61 }
62 
63 //===----------------------------------------------------------------------===//
64 // Bytecode OpCodes
65 //===----------------------------------------------------------------------===//
66 
67 namespace {
68 enum OpCode : ByteCodeField {
69   /// Apply an externally registered constraint.
70   ApplyConstraint,
71   /// Apply an externally registered rewrite.
72   ApplyRewrite,
73   /// Check if two generic values are equal.
74   AreEqual,
75   /// Unconditional branch.
76   Branch,
77   /// Compare the operand count of an operation with a constant.
78   CheckOperandCount,
79   /// Compare the name of an operation with a constant.
80   CheckOperationName,
81   /// Compare the result count of an operation with a constant.
82   CheckResultCount,
83   /// Invoke a native creation method.
84   CreateNative,
85   /// Create an operation.
86   CreateOperation,
87   /// Erase an operation.
88   EraseOp,
89   /// Terminate a matcher or rewrite sequence.
90   Finalize,
91   /// Get a specific attribute of an operation.
92   GetAttribute,
93   /// Get the type of an attribute.
94   GetAttributeType,
95   /// Get the defining operation of a value.
96   GetDefiningOp,
97   /// Get a specific operand of an operation.
98   GetOperand0,
99   GetOperand1,
100   GetOperand2,
101   GetOperand3,
102   GetOperandN,
103   /// Get a specific result of an operation.
104   GetResult0,
105   GetResult1,
106   GetResult2,
107   GetResult3,
108   GetResultN,
109   /// Get the type of a value.
110   GetValueType,
111   /// Check if a generic value is not null.
112   IsNotNull,
113   /// Record a successful pattern match.
114   RecordMatch,
115   /// Replace an operation.
116   ReplaceOp,
117   /// Compare an attribute with a set of constants.
118   SwitchAttribute,
119   /// Compare the operand count of an operation with a set of constants.
120   SwitchOperandCount,
121   /// Compare the name of an operation with a set of constants.
122   SwitchOperationName,
123   /// Compare the result count of an operation with a set of constants.
124   SwitchResultCount,
125   /// Compare a type with a set of constants.
126   SwitchType,
127 };
128 
129 enum class PDLValueKind { Attribute, Operation, Type, Value };
130 } // end anonymous namespace
131 
132 //===----------------------------------------------------------------------===//
133 // ByteCode Generation
134 //===----------------------------------------------------------------------===//
135 
136 //===----------------------------------------------------------------------===//
137 // Generator
138 
139 namespace {
140 struct ByteCodeWriter;
141 
142 /// This class represents the main generator for the pattern bytecode.
143 class Generator {
144 public:
Generator(MLIRContext * ctx,std::vector<const void * > & uniquedData,SmallVectorImpl<ByteCodeField> & matcherByteCode,SmallVectorImpl<ByteCodeField> & rewriterByteCode,SmallVectorImpl<PDLByteCodePattern> & patterns,ByteCodeField & maxValueMemoryIndex,llvm::StringMap<PDLConstraintFunction> & constraintFns,llvm::StringMap<PDLCreateFunction> & createFns,llvm::StringMap<PDLRewriteFunction> & rewriteFns)145   Generator(MLIRContext *ctx, std::vector<const void *> &uniquedData,
146             SmallVectorImpl<ByteCodeField> &matcherByteCode,
147             SmallVectorImpl<ByteCodeField> &rewriterByteCode,
148             SmallVectorImpl<PDLByteCodePattern> &patterns,
149             ByteCodeField &maxValueMemoryIndex,
150             llvm::StringMap<PDLConstraintFunction> &constraintFns,
151             llvm::StringMap<PDLCreateFunction> &createFns,
152             llvm::StringMap<PDLRewriteFunction> &rewriteFns)
153       : ctx(ctx), uniquedData(uniquedData), matcherByteCode(matcherByteCode),
154         rewriterByteCode(rewriterByteCode), patterns(patterns),
155         maxValueMemoryIndex(maxValueMemoryIndex) {
156     for (auto it : llvm::enumerate(constraintFns))
157       constraintToMemIndex.try_emplace(it.value().first(), it.index());
158     for (auto it : llvm::enumerate(createFns))
159       nativeCreateToMemIndex.try_emplace(it.value().first(), it.index());
160     for (auto it : llvm::enumerate(rewriteFns))
161       externalRewriterToMemIndex.try_emplace(it.value().first(), it.index());
162   }
163 
164   /// Generate the bytecode for the given PDL interpreter module.
165   void generate(ModuleOp module);
166 
167   /// Return the memory index to use for the given value.
getMemIndex(Value value)168   ByteCodeField &getMemIndex(Value value) {
169     assert(valueToMemIndex.count(value) &&
170            "expected memory index to be assigned");
171     return valueToMemIndex[value];
172   }
173 
174   /// Return an index to use when referring to the given data that is uniqued in
175   /// the MLIR context.
176   template <typename T>
177   std::enable_if_t<!std::is_convertible<T, Value>::value, ByteCodeField &>
getMemIndex(T val)178   getMemIndex(T val) {
179     const void *opaqueVal = val.getAsOpaquePointer();
180 
181     // Get or insert a reference to this value.
182     auto it = uniquedDataToMemIndex.try_emplace(
183         opaqueVal, maxValueMemoryIndex + uniquedData.size());
184     if (it.second)
185       uniquedData.push_back(opaqueVal);
186     return it.first->second;
187   }
188 
189 private:
190   /// Allocate memory indices for the results of operations within the matcher
191   /// and rewriters.
192   void allocateMemoryIndices(FuncOp matcherFunc, ModuleOp rewriterModule);
193 
194   /// Generate the bytecode for the given operation.
195   void generate(Operation *op, ByteCodeWriter &writer);
196   void generate(pdl_interp::ApplyConstraintOp op, ByteCodeWriter &writer);
197   void generate(pdl_interp::ApplyRewriteOp op, ByteCodeWriter &writer);
198   void generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer);
199   void generate(pdl_interp::BranchOp op, ByteCodeWriter &writer);
200   void generate(pdl_interp::CheckAttributeOp op, ByteCodeWriter &writer);
201   void generate(pdl_interp::CheckOperandCountOp op, ByteCodeWriter &writer);
202   void generate(pdl_interp::CheckOperationNameOp op, ByteCodeWriter &writer);
203   void generate(pdl_interp::CheckResultCountOp op, ByteCodeWriter &writer);
204   void generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer);
205   void generate(pdl_interp::CreateAttributeOp op, ByteCodeWriter &writer);
206   void generate(pdl_interp::CreateNativeOp op, ByteCodeWriter &writer);
207   void generate(pdl_interp::CreateOperationOp op, ByteCodeWriter &writer);
208   void generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer);
209   void generate(pdl_interp::EraseOp op, ByteCodeWriter &writer);
210   void generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer);
211   void generate(pdl_interp::GetAttributeOp op, ByteCodeWriter &writer);
212   void generate(pdl_interp::GetAttributeTypeOp op, ByteCodeWriter &writer);
213   void generate(pdl_interp::GetDefiningOpOp op, ByteCodeWriter &writer);
214   void generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer);
215   void generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer);
216   void generate(pdl_interp::GetValueTypeOp op, ByteCodeWriter &writer);
217   void generate(pdl_interp::InferredTypeOp op, ByteCodeWriter &writer);
218   void generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer);
219   void generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer);
220   void generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer);
221   void generate(pdl_interp::SwitchAttributeOp op, ByteCodeWriter &writer);
222   void generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer);
223   void generate(pdl_interp::SwitchOperandCountOp op, ByteCodeWriter &writer);
224   void generate(pdl_interp::SwitchOperationNameOp op, ByteCodeWriter &writer);
225   void generate(pdl_interp::SwitchResultCountOp op, ByteCodeWriter &writer);
226 
227   /// Mapping from value to its corresponding memory index.
228   DenseMap<Value, ByteCodeField> valueToMemIndex;
229 
230   /// Mapping from the name of an externally registered rewrite to its index in
231   /// the bytecode registry.
232   llvm::StringMap<ByteCodeField> externalRewriterToMemIndex;
233 
234   /// Mapping from the name of an externally registered constraint to its index
235   /// in the bytecode registry.
236   llvm::StringMap<ByteCodeField> constraintToMemIndex;
237 
238   /// Mapping from the name of an externally registered creation method to its
239   /// index in the bytecode registry.
240   llvm::StringMap<ByteCodeField> nativeCreateToMemIndex;
241 
242   /// Mapping from rewriter function name to the bytecode address of the
243   /// rewriter function in byte.
244   llvm::StringMap<ByteCodeAddr> rewriterToAddr;
245 
246   /// Mapping from a uniqued storage object to its memory index within
247   /// `uniquedData`.
248   DenseMap<const void *, ByteCodeField> uniquedDataToMemIndex;
249 
250   /// The current MLIR context.
251   MLIRContext *ctx;
252 
253   /// Data of the ByteCode class to be populated.
254   std::vector<const void *> &uniquedData;
255   SmallVectorImpl<ByteCodeField> &matcherByteCode;
256   SmallVectorImpl<ByteCodeField> &rewriterByteCode;
257   SmallVectorImpl<PDLByteCodePattern> &patterns;
258   ByteCodeField &maxValueMemoryIndex;
259 };
260 
261 /// This class provides utilities for writing a bytecode stream.
262 struct ByteCodeWriter {
ByteCodeWriter__anon726545990211::ByteCodeWriter263   ByteCodeWriter(SmallVectorImpl<ByteCodeField> &bytecode, Generator &generator)
264       : bytecode(bytecode), generator(generator) {}
265 
266   /// Append a field to the bytecode.
append__anon726545990211::ByteCodeWriter267   void append(ByteCodeField field) { bytecode.push_back(field); }
append__anon726545990211::ByteCodeWriter268   void append(OpCode opCode) { bytecode.push_back(opCode); }
269 
270   /// Append an address to the bytecode.
append__anon726545990211::ByteCodeWriter271   void append(ByteCodeAddr field) {
272     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
273                   "unexpected ByteCode address size");
274 
275     ByteCodeField fieldParts[2];
276     std::memcpy(fieldParts, &field, sizeof(ByteCodeAddr));
277     bytecode.append({fieldParts[0], fieldParts[1]});
278   }
279 
280   /// Append a successor range to the bytecode, the exact address will need to
281   /// be resolved later.
append__anon726545990211::ByteCodeWriter282   void append(SuccessorRange successors) {
283     // Add back references to the any successors so that the address can be
284     // resolved later.
285     for (Block *successor : successors) {
286       unresolvedSuccessorRefs[successor].push_back(bytecode.size());
287       append(ByteCodeAddr(0));
288     }
289   }
290 
291   /// Append a range of values that will be read as generic PDLValues.
appendPDLValueList__anon726545990211::ByteCodeWriter292   void appendPDLValueList(OperandRange values) {
293     bytecode.push_back(values.size());
294     for (Value value : values) {
295       // Append the type of the value in addition to the value itself.
296       PDLValueKind kind =
297           TypeSwitch<Type, PDLValueKind>(value.getType())
298               .Case<pdl::AttributeType>(
299                   [](Type) { return PDLValueKind::Attribute; })
300               .Case<pdl::OperationType>(
301                   [](Type) { return PDLValueKind::Operation; })
302               .Case<pdl::TypeType>([](Type) { return PDLValueKind::Type; })
303               .Case<pdl::ValueType>([](Type) { return PDLValueKind::Value; });
304       bytecode.push_back(static_cast<ByteCodeField>(kind));
305       append(value);
306     }
307   }
308 
309   /// Check if the given class `T` has an iterator type.
310   template <typename T, typename... Args>
311   using has_pointer_traits = decltype(std::declval<T>().getAsOpaquePointer());
312 
313   /// Append a value that will be stored in a memory slot and not inline within
314   /// the bytecode.
315   template <typename T>
316   std::enable_if_t<llvm::is_detected<has_pointer_traits, T>::value ||
317                    std::is_pointer<T>::value>
append__anon726545990211::ByteCodeWriter318   append(T value) {
319     bytecode.push_back(generator.getMemIndex(value));
320   }
321 
322   /// Append a range of values.
323   template <typename T, typename IteratorT = llvm::detail::IterOfRange<T>>
324   std::enable_if_t<!llvm::is_detected<has_pointer_traits, T>::value>
append__anon726545990211::ByteCodeWriter325   append(T range) {
326     bytecode.push_back(llvm::size(range));
327     for (auto it : range)
328       append(it);
329   }
330 
331   /// Append a variadic number of fields to the bytecode.
332   template <typename FieldTy, typename Field2Ty, typename... FieldTys>
append__anon726545990211::ByteCodeWriter333   void append(FieldTy field, Field2Ty field2, FieldTys... fields) {
334     append(field);
335     append(field2, fields...);
336   }
337 
338   /// Successor references in the bytecode that have yet to be resolved.
339   DenseMap<Block *, SmallVector<unsigned, 4>> unresolvedSuccessorRefs;
340 
341   /// The underlying bytecode buffer.
342   SmallVectorImpl<ByteCodeField> &bytecode;
343 
344   /// The main generator producing PDL.
345   Generator &generator;
346 };
347 } // end anonymous namespace
348 
generate(ModuleOp module)349 void Generator::generate(ModuleOp module) {
350   FuncOp matcherFunc = module.lookupSymbol<FuncOp>(
351       pdl_interp::PDLInterpDialect::getMatcherFunctionName());
352   ModuleOp rewriterModule = module.lookupSymbol<ModuleOp>(
353       pdl_interp::PDLInterpDialect::getRewriterModuleName());
354   assert(matcherFunc && rewriterModule && "invalid PDL Interpreter module");
355 
356   // Allocate memory indices for the results of operations within the matcher
357   // and rewriters.
358   allocateMemoryIndices(matcherFunc, rewriterModule);
359 
360   // Generate code for the rewriter functions.
361   ByteCodeWriter rewriterByteCodeWriter(rewriterByteCode, *this);
362   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
363     rewriterToAddr.try_emplace(rewriterFunc.getName(), rewriterByteCode.size());
364     for (Operation &op : rewriterFunc.getOps())
365       generate(&op, rewriterByteCodeWriter);
366   }
367   assert(rewriterByteCodeWriter.unresolvedSuccessorRefs.empty() &&
368          "unexpected branches in rewriter function");
369 
370   // Generate code for the matcher function.
371   DenseMap<Block *, ByteCodeAddr> blockToAddr;
372   llvm::ReversePostOrderTraversal<Region *> rpot(&matcherFunc.getBody());
373   ByteCodeWriter matcherByteCodeWriter(matcherByteCode, *this);
374   for (Block *block : rpot) {
375     // Keep track of where this block begins within the matcher function.
376     blockToAddr.try_emplace(block, matcherByteCode.size());
377     for (Operation &op : *block)
378       generate(&op, matcherByteCodeWriter);
379   }
380 
381   // Resolve successor references in the matcher.
382   for (auto &it : matcherByteCodeWriter.unresolvedSuccessorRefs) {
383     ByteCodeAddr addr = blockToAddr[it.first];
384     for (unsigned offsetToFix : it.second)
385       std::memcpy(&matcherByteCode[offsetToFix], &addr, sizeof(ByteCodeAddr));
386   }
387 }
388 
allocateMemoryIndices(FuncOp matcherFunc,ModuleOp rewriterModule)389 void Generator::allocateMemoryIndices(FuncOp matcherFunc,
390                                       ModuleOp rewriterModule) {
391   // Rewriters use simplistic allocation scheme that simply assigns an index to
392   // each result.
393   for (FuncOp rewriterFunc : rewriterModule.getOps<FuncOp>()) {
394     ByteCodeField index = 0;
395     for (BlockArgument arg : rewriterFunc.getArguments())
396       valueToMemIndex.try_emplace(arg, index++);
397     rewriterFunc.getBody().walk([&](Operation *op) {
398       for (Value result : op->getResults())
399         valueToMemIndex.try_emplace(result, index++);
400     });
401     if (index > maxValueMemoryIndex)
402       maxValueMemoryIndex = index;
403   }
404 
405   // The matcher function uses a more sophisticated numbering that tries to
406   // minimize the number of memory indices assigned. This is done by determining
407   // a live range of the values within the matcher, then the allocation is just
408   // finding the minimal number of overlapping live ranges. This is essentially
409   // a simplified form of register allocation where we don't necessarily have a
410   // limited number of registers, but we still want to minimize the number used.
411   DenseMap<Operation *, ByteCodeField> opToIndex;
412   matcherFunc.getBody().walk([&](Operation *op) {
413     opToIndex.insert(std::make_pair(op, opToIndex.size()));
414   });
415 
416   // Liveness info for each of the defs within the matcher.
417   using LivenessSet = llvm::IntervalMap<ByteCodeField, char, 16>;
418   LivenessSet::Allocator allocator;
419   DenseMap<Value, LivenessSet> valueDefRanges;
420 
421   // Assign the root operation being matched to slot 0.
422   BlockArgument rootOpArg = matcherFunc.getArgument(0);
423   valueToMemIndex[rootOpArg] = 0;
424 
425   // Walk each of the blocks, computing the def interval that the value is used.
426   Liveness matcherLiveness(matcherFunc);
427   for (Block &block : matcherFunc.getBody()) {
428     const LivenessBlockInfo *info = matcherLiveness.getLiveness(&block);
429     assert(info && "expected liveness info for block");
430     auto processValue = [&](Value value, Operation *firstUseOrDef) {
431       // We don't need to process the root op argument, this value is always
432       // assigned to the first memory slot.
433       if (value == rootOpArg)
434         return;
435 
436       // Set indices for the range of this block that the value is used.
437       auto defRangeIt = valueDefRanges.try_emplace(value, allocator).first;
438       defRangeIt->second.insert(
439           opToIndex[firstUseOrDef],
440           opToIndex[info->getEndOperation(value, firstUseOrDef)],
441           /*dummyValue*/ 0);
442     };
443 
444     // Process the live-ins of this block.
445     for (Value liveIn : info->in())
446       processValue(liveIn, &block.front());
447 
448     // Process any new defs within this block.
449     for (Operation &op : block)
450       for (Value result : op.getResults())
451         processValue(result, &op);
452   }
453 
454   // Greedily allocate memory slots using the computed def live ranges.
455   std::vector<LivenessSet> allocatedIndices;
456   for (auto &defIt : valueDefRanges) {
457     ByteCodeField &memIndex = valueToMemIndex[defIt.first];
458     LivenessSet &defSet = defIt.second;
459 
460     // Try to allocate to an existing index.
461     for (auto existingIndexIt : llvm::enumerate(allocatedIndices)) {
462       LivenessSet &existingIndex = existingIndexIt.value();
463       llvm::IntervalMapOverlaps<LivenessSet, LivenessSet> overlaps(
464           defIt.second, existingIndex);
465       if (overlaps.valid())
466         continue;
467       // Union the range of the def within the existing index.
468       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
469         existingIndex.insert(it.start(), it.stop(), /*dummyValue*/ 0);
470       memIndex = existingIndexIt.index() + 1;
471     }
472 
473     // If no existing index could be used, add a new one.
474     if (memIndex == 0) {
475       allocatedIndices.emplace_back(allocator);
476       for (auto it = defSet.begin(), e = defSet.end(); it != e; ++it)
477         allocatedIndices.back().insert(it.start(), it.stop(), /*dummyValue*/ 0);
478       memIndex = allocatedIndices.size();
479     }
480   }
481 
482   // Update the max number of indices.
483   ByteCodeField numMatcherIndices = allocatedIndices.size() + 1;
484   if (numMatcherIndices > maxValueMemoryIndex)
485     maxValueMemoryIndex = numMatcherIndices;
486 }
487 
generate(Operation * op,ByteCodeWriter & writer)488 void Generator::generate(Operation *op, ByteCodeWriter &writer) {
489   TypeSwitch<Operation *>(op)
490       .Case<pdl_interp::ApplyConstraintOp, pdl_interp::ApplyRewriteOp,
491             pdl_interp::AreEqualOp, pdl_interp::BranchOp,
492             pdl_interp::CheckAttributeOp, pdl_interp::CheckOperandCountOp,
493             pdl_interp::CheckOperationNameOp, pdl_interp::CheckResultCountOp,
494             pdl_interp::CheckTypeOp, pdl_interp::CreateAttributeOp,
495             pdl_interp::CreateNativeOp, pdl_interp::CreateOperationOp,
496             pdl_interp::CreateTypeOp, pdl_interp::EraseOp,
497             pdl_interp::FinalizeOp, pdl_interp::GetAttributeOp,
498             pdl_interp::GetAttributeTypeOp, pdl_interp::GetDefiningOpOp,
499             pdl_interp::GetOperandOp, pdl_interp::GetResultOp,
500             pdl_interp::GetValueTypeOp, pdl_interp::InferredTypeOp,
501             pdl_interp::IsNotNullOp, pdl_interp::RecordMatchOp,
502             pdl_interp::ReplaceOp, pdl_interp::SwitchAttributeOp,
503             pdl_interp::SwitchTypeOp, pdl_interp::SwitchOperandCountOp,
504             pdl_interp::SwitchOperationNameOp, pdl_interp::SwitchResultCountOp>(
505           [&](auto interpOp) { this->generate(interpOp, writer); })
506       .Default([](Operation *) {
507         llvm_unreachable("unknown `pdl_interp` operation");
508       });
509 }
510 
generate(pdl_interp::ApplyConstraintOp op,ByteCodeWriter & writer)511 void Generator::generate(pdl_interp::ApplyConstraintOp op,
512                          ByteCodeWriter &writer) {
513   assert(constraintToMemIndex.count(op.name()) &&
514          "expected index for constraint function");
515   writer.append(OpCode::ApplyConstraint, constraintToMemIndex[op.name()],
516                 op.constParamsAttr());
517   writer.appendPDLValueList(op.args());
518   writer.append(op.getSuccessors());
519 }
generate(pdl_interp::ApplyRewriteOp op,ByteCodeWriter & writer)520 void Generator::generate(pdl_interp::ApplyRewriteOp op,
521                          ByteCodeWriter &writer) {
522   assert(externalRewriterToMemIndex.count(op.name()) &&
523          "expected index for rewrite function");
524   writer.append(OpCode::ApplyRewrite, externalRewriterToMemIndex[op.name()],
525                 op.constParamsAttr(), op.root());
526   writer.appendPDLValueList(op.args());
527 }
generate(pdl_interp::AreEqualOp op,ByteCodeWriter & writer)528 void Generator::generate(pdl_interp::AreEqualOp op, ByteCodeWriter &writer) {
529   writer.append(OpCode::AreEqual, op.lhs(), op.rhs(), op.getSuccessors());
530 }
generate(pdl_interp::BranchOp op,ByteCodeWriter & writer)531 void Generator::generate(pdl_interp::BranchOp op, ByteCodeWriter &writer) {
532   writer.append(OpCode::Branch, SuccessorRange(op.getOperation()));
533 }
generate(pdl_interp::CheckAttributeOp op,ByteCodeWriter & writer)534 void Generator::generate(pdl_interp::CheckAttributeOp op,
535                          ByteCodeWriter &writer) {
536   writer.append(OpCode::AreEqual, op.attribute(), op.constantValue(),
537                 op.getSuccessors());
538 }
generate(pdl_interp::CheckOperandCountOp op,ByteCodeWriter & writer)539 void Generator::generate(pdl_interp::CheckOperandCountOp op,
540                          ByteCodeWriter &writer) {
541   writer.append(OpCode::CheckOperandCount, op.operation(), op.count(),
542                 op.getSuccessors());
543 }
generate(pdl_interp::CheckOperationNameOp op,ByteCodeWriter & writer)544 void Generator::generate(pdl_interp::CheckOperationNameOp op,
545                          ByteCodeWriter &writer) {
546   writer.append(OpCode::CheckOperationName, op.operation(),
547                 OperationName(op.name(), ctx), op.getSuccessors());
548 }
generate(pdl_interp::CheckResultCountOp op,ByteCodeWriter & writer)549 void Generator::generate(pdl_interp::CheckResultCountOp op,
550                          ByteCodeWriter &writer) {
551   writer.append(OpCode::CheckResultCount, op.operation(), op.count(),
552                 op.getSuccessors());
553 }
generate(pdl_interp::CheckTypeOp op,ByteCodeWriter & writer)554 void Generator::generate(pdl_interp::CheckTypeOp op, ByteCodeWriter &writer) {
555   writer.append(OpCode::AreEqual, op.value(), op.type(), op.getSuccessors());
556 }
generate(pdl_interp::CreateAttributeOp op,ByteCodeWriter & writer)557 void Generator::generate(pdl_interp::CreateAttributeOp op,
558                          ByteCodeWriter &writer) {
559   // Simply repoint the memory index of the result to the constant.
560   getMemIndex(op.attribute()) = getMemIndex(op.value());
561 }
generate(pdl_interp::CreateNativeOp op,ByteCodeWriter & writer)562 void Generator::generate(pdl_interp::CreateNativeOp op,
563                          ByteCodeWriter &writer) {
564   assert(nativeCreateToMemIndex.count(op.name()) &&
565          "expected index for creation function");
566   writer.append(OpCode::CreateNative, nativeCreateToMemIndex[op.name()],
567                 op.result(), op.constParamsAttr());
568   writer.appendPDLValueList(op.args());
569 }
generate(pdl_interp::CreateOperationOp op,ByteCodeWriter & writer)570 void Generator::generate(pdl_interp::CreateOperationOp op,
571                          ByteCodeWriter &writer) {
572   writer.append(OpCode::CreateOperation, op.operation(),
573                 OperationName(op.name(), ctx), op.operands());
574 
575   // Add the attributes.
576   OperandRange attributes = op.attributes();
577   writer.append(static_cast<ByteCodeField>(attributes.size()));
578   for (auto it : llvm::zip(op.attributeNames(), op.attributes())) {
579     writer.append(
580         Identifier::get(std::get<0>(it).cast<StringAttr>().getValue(), ctx),
581         std::get<1>(it));
582   }
583   writer.append(op.types());
584 }
generate(pdl_interp::CreateTypeOp op,ByteCodeWriter & writer)585 void Generator::generate(pdl_interp::CreateTypeOp op, ByteCodeWriter &writer) {
586   // Simply repoint the memory index of the result to the constant.
587   getMemIndex(op.result()) = getMemIndex(op.value());
588 }
generate(pdl_interp::EraseOp op,ByteCodeWriter & writer)589 void Generator::generate(pdl_interp::EraseOp op, ByteCodeWriter &writer) {
590   writer.append(OpCode::EraseOp, op.operation());
591 }
generate(pdl_interp::FinalizeOp op,ByteCodeWriter & writer)592 void Generator::generate(pdl_interp::FinalizeOp op, ByteCodeWriter &writer) {
593   writer.append(OpCode::Finalize);
594 }
generate(pdl_interp::GetAttributeOp op,ByteCodeWriter & writer)595 void Generator::generate(pdl_interp::GetAttributeOp op,
596                          ByteCodeWriter &writer) {
597   writer.append(OpCode::GetAttribute, op.attribute(), op.operation(),
598                 Identifier::get(op.name(), ctx));
599 }
generate(pdl_interp::GetAttributeTypeOp op,ByteCodeWriter & writer)600 void Generator::generate(pdl_interp::GetAttributeTypeOp op,
601                          ByteCodeWriter &writer) {
602   writer.append(OpCode::GetAttributeType, op.result(), op.value());
603 }
generate(pdl_interp::GetDefiningOpOp op,ByteCodeWriter & writer)604 void Generator::generate(pdl_interp::GetDefiningOpOp op,
605                          ByteCodeWriter &writer) {
606   writer.append(OpCode::GetDefiningOp, op.operation(), op.value());
607 }
generate(pdl_interp::GetOperandOp op,ByteCodeWriter & writer)608 void Generator::generate(pdl_interp::GetOperandOp op, ByteCodeWriter &writer) {
609   uint32_t index = op.index();
610   if (index < 4)
611     writer.append(static_cast<OpCode>(OpCode::GetOperand0 + index));
612   else
613     writer.append(OpCode::GetOperandN, index);
614   writer.append(op.operation(), op.value());
615 }
generate(pdl_interp::GetResultOp op,ByteCodeWriter & writer)616 void Generator::generate(pdl_interp::GetResultOp op, ByteCodeWriter &writer) {
617   uint32_t index = op.index();
618   if (index < 4)
619     writer.append(static_cast<OpCode>(OpCode::GetResult0 + index));
620   else
621     writer.append(OpCode::GetResultN, index);
622   writer.append(op.operation(), op.value());
623 }
generate(pdl_interp::GetValueTypeOp op,ByteCodeWriter & writer)624 void Generator::generate(pdl_interp::GetValueTypeOp op,
625                          ByteCodeWriter &writer) {
626   writer.append(OpCode::GetValueType, op.result(), op.value());
627 }
generate(pdl_interp::InferredTypeOp op,ByteCodeWriter & writer)628 void Generator::generate(pdl_interp::InferredTypeOp op,
629                          ByteCodeWriter &writer) {
630   // InferType maps to a null type as a marker for inferring a result type.
631   getMemIndex(op.type()) = getMemIndex(Type());
632 }
generate(pdl_interp::IsNotNullOp op,ByteCodeWriter & writer)633 void Generator::generate(pdl_interp::IsNotNullOp op, ByteCodeWriter &writer) {
634   writer.append(OpCode::IsNotNull, op.value(), op.getSuccessors());
635 }
generate(pdl_interp::RecordMatchOp op,ByteCodeWriter & writer)636 void Generator::generate(pdl_interp::RecordMatchOp op, ByteCodeWriter &writer) {
637   ByteCodeField patternIndex = patterns.size();
638   patterns.emplace_back(PDLByteCodePattern::create(
639       op, rewriterToAddr[op.rewriter().getLeafReference()]));
640   writer.append(OpCode::RecordMatch, patternIndex,
641                 SuccessorRange(op.getOperation()), op.matchedOps(),
642                 op.inputs());
643 }
generate(pdl_interp::ReplaceOp op,ByteCodeWriter & writer)644 void Generator::generate(pdl_interp::ReplaceOp op, ByteCodeWriter &writer) {
645   writer.append(OpCode::ReplaceOp, op.operation(), op.replValues());
646 }
generate(pdl_interp::SwitchAttributeOp op,ByteCodeWriter & writer)647 void Generator::generate(pdl_interp::SwitchAttributeOp op,
648                          ByteCodeWriter &writer) {
649   writer.append(OpCode::SwitchAttribute, op.attribute(), op.caseValuesAttr(),
650                 op.getSuccessors());
651 }
generate(pdl_interp::SwitchOperandCountOp op,ByteCodeWriter & writer)652 void Generator::generate(pdl_interp::SwitchOperandCountOp op,
653                          ByteCodeWriter &writer) {
654   writer.append(OpCode::SwitchOperandCount, op.operation(), op.caseValuesAttr(),
655                 op.getSuccessors());
656 }
generate(pdl_interp::SwitchOperationNameOp op,ByteCodeWriter & writer)657 void Generator::generate(pdl_interp::SwitchOperationNameOp op,
658                          ByteCodeWriter &writer) {
659   auto cases = llvm::map_range(op.caseValuesAttr(), [&](Attribute attr) {
660     return OperationName(attr.cast<StringAttr>().getValue(), ctx);
661   });
662   writer.append(OpCode::SwitchOperationName, op.operation(), cases,
663                 op.getSuccessors());
664 }
generate(pdl_interp::SwitchResultCountOp op,ByteCodeWriter & writer)665 void Generator::generate(pdl_interp::SwitchResultCountOp op,
666                          ByteCodeWriter &writer) {
667   writer.append(OpCode::SwitchResultCount, op.operation(), op.caseValuesAttr(),
668                 op.getSuccessors());
669 }
generate(pdl_interp::SwitchTypeOp op,ByteCodeWriter & writer)670 void Generator::generate(pdl_interp::SwitchTypeOp op, ByteCodeWriter &writer) {
671   writer.append(OpCode::SwitchType, op.value(), op.caseValuesAttr(),
672                 op.getSuccessors());
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // PDLByteCode
677 //===----------------------------------------------------------------------===//
678 
PDLByteCode(ModuleOp module,llvm::StringMap<PDLConstraintFunction> constraintFns,llvm::StringMap<PDLCreateFunction> createFns,llvm::StringMap<PDLRewriteFunction> rewriteFns)679 PDLByteCode::PDLByteCode(ModuleOp module,
680                          llvm::StringMap<PDLConstraintFunction> constraintFns,
681                          llvm::StringMap<PDLCreateFunction> createFns,
682                          llvm::StringMap<PDLRewriteFunction> rewriteFns) {
683   Generator generator(module.getContext(), uniquedData, matcherByteCode,
684                       rewriterByteCode, patterns, maxValueMemoryIndex,
685                       constraintFns, createFns, rewriteFns);
686   generator.generate(module);
687 
688   // Initialize the external functions.
689   for (auto &it : constraintFns)
690     constraintFunctions.push_back(std::move(it.second));
691   for (auto &it : createFns)
692     createFunctions.push_back(std::move(it.second));
693   for (auto &it : rewriteFns)
694     rewriteFunctions.push_back(std::move(it.second));
695 }
696 
697 /// Initialize the given state such that it can be used to execute the current
698 /// bytecode.
initializeMutableState(PDLByteCodeMutableState & state) const699 void PDLByteCode::initializeMutableState(PDLByteCodeMutableState &state) const {
700   state.memory.resize(maxValueMemoryIndex, nullptr);
701   state.currentPatternBenefits.reserve(patterns.size());
702   for (const PDLByteCodePattern &pattern : patterns)
703     state.currentPatternBenefits.push_back(pattern.getBenefit());
704 }
705 
706 //===----------------------------------------------------------------------===//
707 // ByteCode Execution
708 
709 namespace {
710 /// This class provides support for executing a bytecode stream.
711 class ByteCodeExecutor {
712 public:
ByteCodeExecutor(const ByteCodeField * curCodeIt,MutableArrayRef<const void * > memory,ArrayRef<const void * > uniquedMemory,ArrayRef<ByteCodeField> code,ArrayRef<PatternBenefit> currentPatternBenefits,ArrayRef<PDLByteCodePattern> patterns,ArrayRef<PDLConstraintFunction> constraintFunctions,ArrayRef<PDLCreateFunction> createFunctions,ArrayRef<PDLRewriteFunction> rewriteFunctions)713   ByteCodeExecutor(const ByteCodeField *curCodeIt,
714                    MutableArrayRef<const void *> memory,
715                    ArrayRef<const void *> uniquedMemory,
716                    ArrayRef<ByteCodeField> code,
717                    ArrayRef<PatternBenefit> currentPatternBenefits,
718                    ArrayRef<PDLByteCodePattern> patterns,
719                    ArrayRef<PDLConstraintFunction> constraintFunctions,
720                    ArrayRef<PDLCreateFunction> createFunctions,
721                    ArrayRef<PDLRewriteFunction> rewriteFunctions)
722       : curCodeIt(curCodeIt), memory(memory), uniquedMemory(uniquedMemory),
723         code(code), currentPatternBenefits(currentPatternBenefits),
724         patterns(patterns), constraintFunctions(constraintFunctions),
725         createFunctions(createFunctions), rewriteFunctions(rewriteFunctions) {}
726 
727   /// Start executing the code at the current bytecode index. `matches` is an
728   /// optional field provided when this function is executed in a matching
729   /// context.
730   void execute(PatternRewriter &rewriter,
731                SmallVectorImpl<PDLByteCode::MatchResult> *matches = nullptr,
732                Optional<Location> mainRewriteLoc = {});
733 
734 private:
735   /// Read a value from the bytecode buffer, optionally skipping a certain
736   /// number of prefix values. These methods always update the buffer to point
737   /// to the next field after the read data.
738   template <typename T = ByteCodeField>
read(size_t skipN=0)739   T read(size_t skipN = 0) {
740     curCodeIt += skipN;
741     return readImpl<T>();
742   }
read(size_t skipN=0)743   ByteCodeField read(size_t skipN = 0) { return read<ByteCodeField>(skipN); }
744 
745   /// Read a list of values from the bytecode buffer.
746   template <typename ValueT, typename T>
readList(SmallVectorImpl<T> & list)747   void readList(SmallVectorImpl<T> &list) {
748     list.clear();
749     for (unsigned i = 0, e = read(); i != e; ++i)
750       list.push_back(read<ValueT>());
751   }
752 
753   /// Jump to a specific successor based on a predicate value.
selectJump(bool isTrue)754   void selectJump(bool isTrue) { selectJump(size_t(isTrue ? 0 : 1)); }
755   /// Jump to a specific successor based on a destination index.
selectJump(size_t destIndex)756   void selectJump(size_t destIndex) {
757     curCodeIt = &code[read<ByteCodeAddr>(destIndex * 2)];
758   }
759 
760   /// Handle a switch operation with the provided value and cases.
761   template <typename T, typename RangeT>
handleSwitch(const T & value,RangeT && cases)762   void handleSwitch(const T &value, RangeT &&cases) {
763     LLVM_DEBUG({
764       llvm::dbgs() << "  * Value: " << value << "\n"
765                    << "  * Cases: ";
766       llvm::interleaveComma(cases, llvm::dbgs());
767       llvm::dbgs() << "\n\n";
768     });
769 
770     // Check to see if the attribute value is within the case list. Jump to
771     // the correct successor index based on the result.
772     for (auto it = cases.begin(), e = cases.end(); it != e; ++it)
773       if (*it == value)
774         return selectJump(size_t((it - cases.begin()) + 1));
775     selectJump(size_t(0));
776   }
777 
778   /// Internal implementation of reading various data types from the bytecode
779   /// stream.
780   template <typename T>
readFromMemory()781   const void *readFromMemory() {
782     size_t index = *curCodeIt++;
783 
784     // If this type is an SSA value, it can only be stored in non-const memory.
785     if (llvm::is_one_of<T, Operation *, Value>::value || index < memory.size())
786       return memory[index];
787 
788     // Otherwise, if this index is not inbounds it is uniqued.
789     return uniquedMemory[index - memory.size()];
790   }
791   template <typename T>
readImpl()792   std::enable_if_t<std::is_pointer<T>::value, T> readImpl() {
793     return reinterpret_cast<T>(const_cast<void *>(readFromMemory<T>()));
794   }
795   template <typename T>
796   std::enable_if_t<std::is_class<T>::value && !std::is_same<PDLValue, T>::value,
797                    T>
readImpl()798   readImpl() {
799     return T(T::getFromOpaquePointer(readFromMemory<T>()));
800   }
801   template <typename T>
readImpl()802   std::enable_if_t<std::is_same<PDLValue, T>::value, T> readImpl() {
803     switch (static_cast<PDLValueKind>(read())) {
804     case PDLValueKind::Attribute:
805       return read<Attribute>();
806     case PDLValueKind::Operation:
807       return read<Operation *>();
808     case PDLValueKind::Type:
809       return read<Type>();
810     case PDLValueKind::Value:
811       return read<Value>();
812     }
813   }
814   template <typename T>
readImpl()815   std::enable_if_t<std::is_same<T, ByteCodeAddr>::value, T> readImpl() {
816     static_assert((sizeof(ByteCodeAddr) / sizeof(ByteCodeField)) == 2,
817                   "unexpected ByteCode address size");
818     ByteCodeAddr result;
819     std::memcpy(&result, curCodeIt, sizeof(ByteCodeAddr));
820     curCodeIt += 2;
821     return result;
822   }
823   template <typename T>
readImpl()824   std::enable_if_t<std::is_same<T, ByteCodeField>::value, T> readImpl() {
825     return *curCodeIt++;
826   }
827 
828   /// The underlying bytecode buffer.
829   const ByteCodeField *curCodeIt;
830 
831   /// The current execution memory.
832   MutableArrayRef<const void *> memory;
833 
834   /// References to ByteCode data necessary for execution.
835   ArrayRef<const void *> uniquedMemory;
836   ArrayRef<ByteCodeField> code;
837   ArrayRef<PatternBenefit> currentPatternBenefits;
838   ArrayRef<PDLByteCodePattern> patterns;
839   ArrayRef<PDLConstraintFunction> constraintFunctions;
840   ArrayRef<PDLCreateFunction> createFunctions;
841   ArrayRef<PDLRewriteFunction> rewriteFunctions;
842 };
843 } // end anonymous namespace
844 
execute(PatternRewriter & rewriter,SmallVectorImpl<PDLByteCode::MatchResult> * matches,Optional<Location> mainRewriteLoc)845 void ByteCodeExecutor::execute(
846     PatternRewriter &rewriter,
847     SmallVectorImpl<PDLByteCode::MatchResult> *matches,
848     Optional<Location> mainRewriteLoc) {
849   while (true) {
850     OpCode opCode = static_cast<OpCode>(read());
851     switch (opCode) {
852     case ApplyConstraint: {
853       LLVM_DEBUG(llvm::dbgs() << "Executing ApplyConstraint:\n");
854       const PDLConstraintFunction &constraintFn = constraintFunctions[read()];
855       ArrayAttr constParams = read<ArrayAttr>();
856       SmallVector<PDLValue, 16> args;
857       readList<PDLValue>(args);
858       LLVM_DEBUG({
859         llvm::dbgs() << "  * Arguments: ";
860         llvm::interleaveComma(args, llvm::dbgs());
861         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
862       });
863 
864       // Invoke the constraint and jump to the proper destination.
865       selectJump(succeeded(constraintFn(args, constParams, rewriter)));
866       break;
867     }
868     case ApplyRewrite: {
869       LLVM_DEBUG(llvm::dbgs() << "Executing ApplyRewrite:\n");
870       const PDLRewriteFunction &rewriteFn = rewriteFunctions[read()];
871       ArrayAttr constParams = read<ArrayAttr>();
872       Operation *root = read<Operation *>();
873       SmallVector<PDLValue, 16> args;
874       readList<PDLValue>(args);
875 
876       LLVM_DEBUG({
877         llvm::dbgs() << "  * Root: " << *root << "\n"
878                      << "  * Arguments: ";
879         llvm::interleaveComma(args, llvm::dbgs());
880         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n\n";
881       });
882       rewriteFn(root, args, constParams, rewriter);
883       break;
884     }
885     case AreEqual: {
886       LLVM_DEBUG(llvm::dbgs() << "Executing AreEqual:\n");
887       const void *lhs = read<const void *>();
888       const void *rhs = read<const void *>();
889 
890       LLVM_DEBUG(llvm::dbgs() << "  * " << lhs << " == " << rhs << "\n\n");
891       selectJump(lhs == rhs);
892       break;
893     }
894     case Branch: {
895       LLVM_DEBUG(llvm::dbgs() << "Executing Branch\n\n");
896       curCodeIt = &code[read<ByteCodeAddr>()];
897       break;
898     }
899     case CheckOperandCount: {
900       LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperandCount:\n");
901       Operation *op = read<Operation *>();
902       uint32_t expectedCount = read<uint32_t>();
903 
904       LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumOperands() << "\n"
905                               << "  * Expected: " << expectedCount << "\n\n");
906       selectJump(op->getNumOperands() == expectedCount);
907       break;
908     }
909     case CheckOperationName: {
910       LLVM_DEBUG(llvm::dbgs() << "Executing CheckOperationName:\n");
911       Operation *op = read<Operation *>();
912       OperationName expectedName = read<OperationName>();
913 
914       LLVM_DEBUG(llvm::dbgs()
915                  << "  * Found: \"" << op->getName() << "\"\n"
916                  << "  * Expected: \"" << expectedName << "\"\n\n");
917       selectJump(op->getName() == expectedName);
918       break;
919     }
920     case CheckResultCount: {
921       LLVM_DEBUG(llvm::dbgs() << "Executing CheckResultCount:\n");
922       Operation *op = read<Operation *>();
923       uint32_t expectedCount = read<uint32_t>();
924 
925       LLVM_DEBUG(llvm::dbgs() << "  * Found: " << op->getNumResults() << "\n"
926                               << "  * Expected: " << expectedCount << "\n\n");
927       selectJump(op->getNumResults() == expectedCount);
928       break;
929     }
930     case CreateNative: {
931       LLVM_DEBUG(llvm::dbgs() << "Executing CreateNative:\n");
932       const PDLCreateFunction &createFn = createFunctions[read()];
933       ByteCodeField resultIndex = read();
934       ArrayAttr constParams = read<ArrayAttr>();
935       SmallVector<PDLValue, 16> args;
936       readList<PDLValue>(args);
937 
938       LLVM_DEBUG({
939         llvm::dbgs() << "  * Arguments: ";
940         llvm::interleaveComma(args, llvm::dbgs());
941         llvm::dbgs() << "\n  * Parameters: " << constParams << "\n";
942       });
943 
944       PDLValue result = createFn(args, constParams, rewriter);
945       memory[resultIndex] = result.getAsOpaquePointer();
946 
947       LLVM_DEBUG(llvm::dbgs() << "  * Result: " << result << "\n\n");
948       break;
949     }
950     case CreateOperation: {
951       LLVM_DEBUG(llvm::dbgs() << "Executing CreateOperation:\n");
952       assert(mainRewriteLoc && "expected rewrite loc to be provided when "
953                                "executing the rewriter bytecode");
954 
955       unsigned memIndex = read();
956       OperationState state(*mainRewriteLoc, read<OperationName>());
957       readList<Value>(state.operands);
958       for (unsigned i = 0, e = read(); i != e; ++i) {
959         Identifier name = read<Identifier>();
960         if (Attribute attr = read<Attribute>())
961           state.addAttribute(name, attr);
962       }
963 
964       bool hasInferredTypes = false;
965       for (unsigned i = 0, e = read(); i != e; ++i) {
966         Type resultType = read<Type>();
967         hasInferredTypes |= !resultType;
968         state.types.push_back(resultType);
969       }
970 
971       // Handle the case where the operation has inferred types.
972       if (hasInferredTypes) {
973         InferTypeOpInterface::Concept *concept =
974             state.name.getAbstractOperation()
975                 ->getInterface<InferTypeOpInterface>();
976 
977         // TODO: Handle failure.
978         SmallVector<Type, 2> inferredTypes;
979         if (failed(concept->inferReturnTypes(
980                 state.getContext(), state.location, state.operands,
981                 state.attributes.getDictionary(state.getContext()),
982                 state.regions, inferredTypes)))
983           return;
984 
985         for (unsigned i = 0, e = state.types.size(); i != e; ++i)
986           if (!state.types[i])
987             state.types[i] = inferredTypes[i];
988       }
989       Operation *resultOp = rewriter.createOperation(state);
990       memory[memIndex] = resultOp;
991 
992       LLVM_DEBUG({
993         llvm::dbgs() << "  * Attributes: "
994                      << state.attributes.getDictionary(state.getContext())
995                      << "\n  * Operands: ";
996         llvm::interleaveComma(state.operands, llvm::dbgs());
997         llvm::dbgs() << "\n  * Result Types: ";
998         llvm::interleaveComma(state.types, llvm::dbgs());
999         llvm::dbgs() << "\n  * Result: " << *resultOp << "\n\n";
1000       });
1001       break;
1002     }
1003     case EraseOp: {
1004       LLVM_DEBUG(llvm::dbgs() << "Executing EraseOp:\n");
1005       Operation *op = read<Operation *>();
1006 
1007       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n\n");
1008       rewriter.eraseOp(op);
1009       break;
1010     }
1011     case Finalize: {
1012       LLVM_DEBUG(llvm::dbgs() << "Executing Finalize\n\n");
1013       return;
1014     }
1015     case GetAttribute: {
1016       LLVM_DEBUG(llvm::dbgs() << "Executing GetAttribute:\n");
1017       unsigned memIndex = read();
1018       Operation *op = read<Operation *>();
1019       Identifier attrName = read<Identifier>();
1020       Attribute attr = op->getAttr(attrName);
1021 
1022       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1023                               << "  * Attribute: " << attrName << "\n"
1024                               << "  * Result: " << attr << "\n\n");
1025       memory[memIndex] = attr.getAsOpaquePointer();
1026       break;
1027     }
1028     case GetAttributeType: {
1029       LLVM_DEBUG(llvm::dbgs() << "Executing GetAttributeType:\n");
1030       unsigned memIndex = read();
1031       Attribute attr = read<Attribute>();
1032 
1033       LLVM_DEBUG(llvm::dbgs() << "  * Attribute: " << attr << "\n"
1034                               << "  * Result: " << attr.getType() << "\n\n");
1035       memory[memIndex] = attr.getType().getAsOpaquePointer();
1036       break;
1037     }
1038     case GetDefiningOp: {
1039       LLVM_DEBUG(llvm::dbgs() << "Executing GetDefiningOp:\n");
1040       unsigned memIndex = read();
1041       Value value = read<Value>();
1042       Operation *op = value ? value.getDefiningOp() : nullptr;
1043 
1044       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1045                               << "  * Result: " << *op << "\n\n");
1046       memory[memIndex] = op;
1047       break;
1048     }
1049     case GetOperand0:
1050     case GetOperand1:
1051     case GetOperand2:
1052     case GetOperand3:
1053     case GetOperandN: {
1054       LLVM_DEBUG({
1055         llvm::dbgs() << "Executing GetOperand"
1056                      << (opCode == GetOperandN ? Twine("N")
1057                                                : Twine(opCode - GetOperand0))
1058                      << ":\n";
1059       });
1060       unsigned index =
1061           opCode == GetOperandN ? read<uint32_t>() : (opCode - GetOperand0);
1062       Operation *op = read<Operation *>();
1063       unsigned memIndex = read();
1064       Value operand =
1065           index < op->getNumOperands() ? op->getOperand(index) : Value();
1066 
1067       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1068                               << "  * Index: " << index << "\n"
1069                               << "  * Result: " << operand << "\n\n");
1070       memory[memIndex] = operand.getAsOpaquePointer();
1071       break;
1072     }
1073     case GetResult0:
1074     case GetResult1:
1075     case GetResult2:
1076     case GetResult3:
1077     case GetResultN: {
1078       LLVM_DEBUG({
1079         llvm::dbgs() << "Executing GetResult"
1080                      << (opCode == GetResultN ? Twine("N")
1081                                               : Twine(opCode - GetResult0))
1082                      << ":\n";
1083       });
1084       unsigned index =
1085           opCode == GetResultN ? read<uint32_t>() : (opCode - GetResult0);
1086       Operation *op = read<Operation *>();
1087       unsigned memIndex = read();
1088       OpResult result =
1089           index < op->getNumResults() ? op->getResult(index) : OpResult();
1090 
1091       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n"
1092                               << "  * Index: " << index << "\n"
1093                               << "  * Result: " << result << "\n\n");
1094       memory[memIndex] = result.getAsOpaquePointer();
1095       break;
1096     }
1097     case GetValueType: {
1098       LLVM_DEBUG(llvm::dbgs() << "Executing GetValueType:\n");
1099       unsigned memIndex = read();
1100       Value value = read<Value>();
1101 
1102       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n"
1103                               << "  * Result: " << value.getType() << "\n\n");
1104       memory[memIndex] = value.getType().getAsOpaquePointer();
1105       break;
1106     }
1107     case IsNotNull: {
1108       LLVM_DEBUG(llvm::dbgs() << "Executing IsNotNull:\n");
1109       const void *value = read<const void *>();
1110 
1111       LLVM_DEBUG(llvm::dbgs() << "  * Value: " << value << "\n\n");
1112       selectJump(value != nullptr);
1113       break;
1114     }
1115     case RecordMatch: {
1116       LLVM_DEBUG(llvm::dbgs() << "Executing RecordMatch:\n");
1117       assert(matches &&
1118              "expected matches to be provided when executing the matcher");
1119       unsigned patternIndex = read();
1120       PatternBenefit benefit = currentPatternBenefits[patternIndex];
1121       const ByteCodeField *dest = &code[read<ByteCodeAddr>()];
1122 
1123       // If the benefit of the pattern is impossible, skip the processing of the
1124       // rest of the pattern.
1125       if (benefit.isImpossibleToMatch()) {
1126         LLVM_DEBUG(llvm::dbgs() << "  * Benefit: Impossible To Match\n\n");
1127         curCodeIt = dest;
1128         break;
1129       }
1130 
1131       // Create a fused location containing the locations of each of the
1132       // operations used in the match. This will be used as the location for
1133       // created operations during the rewrite that don't already have an
1134       // explicit location set.
1135       unsigned numMatchLocs = read();
1136       SmallVector<Location, 4> matchLocs;
1137       matchLocs.reserve(numMatchLocs);
1138       for (unsigned i = 0; i != numMatchLocs; ++i)
1139         matchLocs.push_back(read<Operation *>()->getLoc());
1140       Location matchLoc = rewriter.getFusedLoc(matchLocs);
1141 
1142       LLVM_DEBUG(llvm::dbgs() << "  * Benefit: " << benefit.getBenefit() << "\n"
1143                               << "  * Location: " << matchLoc << "\n\n");
1144       matches->emplace_back(matchLoc, patterns[patternIndex], benefit);
1145       readList<const void *>(matches->back().values);
1146       curCodeIt = dest;
1147       break;
1148     }
1149     case ReplaceOp: {
1150       LLVM_DEBUG(llvm::dbgs() << "Executing ReplaceOp:\n");
1151       Operation *op = read<Operation *>();
1152       SmallVector<Value, 16> args;
1153       readList<Value>(args);
1154 
1155       LLVM_DEBUG({
1156         llvm::dbgs() << "  * Operation: " << *op << "\n"
1157                      << "  * Values: ";
1158         llvm::interleaveComma(args, llvm::dbgs());
1159         llvm::dbgs() << "\n\n";
1160       });
1161       rewriter.replaceOp(op, args);
1162       break;
1163     }
1164     case SwitchAttribute: {
1165       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchAttribute:\n");
1166       Attribute value = read<Attribute>();
1167       ArrayAttr cases = read<ArrayAttr>();
1168       handleSwitch(value, cases);
1169       break;
1170     }
1171     case SwitchOperandCount: {
1172       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperandCount:\n");
1173       Operation *op = read<Operation *>();
1174       auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1175 
1176       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1177       handleSwitch(op->getNumOperands(), cases);
1178       break;
1179     }
1180     case SwitchOperationName: {
1181       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchOperationName:\n");
1182       OperationName value = read<Operation *>()->getName();
1183       size_t caseCount = read();
1184 
1185       // The operation names are stored in-line, so to print them out for
1186       // debugging purposes we need to read the array before executing the
1187       // switch so that we can display all of the possible values.
1188       LLVM_DEBUG({
1189         const ByteCodeField *prevCodeIt = curCodeIt;
1190         llvm::dbgs() << "  * Value: " << value << "\n"
1191                      << "  * Cases: ";
1192         llvm::interleaveComma(
1193             llvm::map_range(llvm::seq<size_t>(0, caseCount),
1194                             [&](size_t i) { return read<OperationName>(); }),
1195             llvm::dbgs());
1196         llvm::dbgs() << "\n\n";
1197         curCodeIt = prevCodeIt;
1198       });
1199 
1200       // Try to find the switch value within any of the cases.
1201       size_t jumpDest = 0;
1202       for (size_t i = 0; i != caseCount; ++i) {
1203         if (read<OperationName>() == value) {
1204           curCodeIt += (caseCount - i - 1);
1205           jumpDest = i + 1;
1206           break;
1207         }
1208       }
1209       selectJump(jumpDest);
1210       break;
1211     }
1212     case SwitchResultCount: {
1213       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchResultCount:\n");
1214       Operation *op = read<Operation *>();
1215       auto cases = read<DenseIntOrFPElementsAttr>().getValues<uint32_t>();
1216 
1217       LLVM_DEBUG(llvm::dbgs() << "  * Operation: " << *op << "\n");
1218       handleSwitch(op->getNumResults(), cases);
1219       break;
1220     }
1221     case SwitchType: {
1222       LLVM_DEBUG(llvm::dbgs() << "Executing SwitchType:\n");
1223       Type value = read<Type>();
1224       auto cases = read<ArrayAttr>().getAsValueRange<TypeAttr>();
1225       handleSwitch(value, cases);
1226       break;
1227     }
1228     }
1229   }
1230 }
1231 
1232 /// Run the pattern matcher on the given root operation, collecting the matched
1233 /// patterns in `matches`.
match(Operation * op,PatternRewriter & rewriter,SmallVectorImpl<MatchResult> & matches,PDLByteCodeMutableState & state) const1234 void PDLByteCode::match(Operation *op, PatternRewriter &rewriter,
1235                         SmallVectorImpl<MatchResult> &matches,
1236                         PDLByteCodeMutableState &state) const {
1237   // The first memory slot is always the root operation.
1238   state.memory[0] = op;
1239 
1240   // The matcher function always starts at code address 0.
1241   ByteCodeExecutor executor(matcherByteCode.data(), state.memory, uniquedData,
1242                             matcherByteCode, state.currentPatternBenefits,
1243                             patterns, constraintFunctions, createFunctions,
1244                             rewriteFunctions);
1245   executor.execute(rewriter, &matches);
1246 
1247   // Order the found matches by benefit.
1248   std::stable_sort(matches.begin(), matches.end(),
1249                    [](const MatchResult &lhs, const MatchResult &rhs) {
1250                      return lhs.benefit > rhs.benefit;
1251                    });
1252 }
1253 
1254 /// Run the rewriter of the given pattern on the root operation `op`.
rewrite(PatternRewriter & rewriter,const MatchResult & match,PDLByteCodeMutableState & state) const1255 void PDLByteCode::rewrite(PatternRewriter &rewriter, const MatchResult &match,
1256                           PDLByteCodeMutableState &state) const {
1257   // The arguments of the rewrite function are stored at the start of the
1258   // memory buffer.
1259   llvm::copy(match.values, state.memory.begin());
1260 
1261   ByteCodeExecutor executor(
1262       &rewriterByteCode[match.pattern->getRewriterAddr()], state.memory,
1263       uniquedData, rewriterByteCode, state.currentPatternBenefits, patterns,
1264       constraintFunctions, createFunctions, rewriteFunctions);
1265   executor.execute(rewriter, /*matches=*/nullptr, match.location);
1266 }
1267