1 //===- Sparsification.cpp - Implementation of linalg sparsification -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering annotated linalg dialect to sparse code.
10 //
11 // The concept of letting a compiler generate sparse code automatically was
12 // pioneered for dense linear algebra code in Fortran by [Bik96] in MT1 and
13 // formalized to tensor algebra by [Kjolstad17,20] for the Sparse Tensor
14 // Algebra Compiler (TACO). The implementation in this file closely follows
15 // the "sparse iteration theory" that forms the foundation of TACO. A rewriting
16 // rule is applied to each tensor expression in linalg (MLIR's tensor index
17 // notation) where the sparsity of tensors is indicated with annotation using
18 // a per-dimension specification of sparse/dense storage together with a
19 // specification of the order on the dimensions. Subsequently, a topologically
20 // sorted iteration graph, reflecting the required order on indices with respect
21 // to the dimensions of each tensor, is constructed to ensure that all tensors
22 // are visited in natural index order. Next, iteration lattices are constructed
23 // for the tensor expression for every index in topological order. Each
24 // iteration lattice point consists of a conjunction of tensor indices together
25 // with a tensor (sub)expression that needs to be evaluated for that
26 // conjunction. Within the lattice, iteration points are ordered according to
27 // the way indices are exhausted. As such these iteration lattices drive actual
28 // sparse code generation, which consists of a tedious but relatively
29 // straightforward one-to-one mapping from iteration lattices to combinations
30 // of for-loops, while-loops, and if-statements.
31 //
32 // [Bik96] Aart J.C. Bik. Compiler Support for Sparse Matrix Computations.
33 // PhD thesis, Leiden University, May 1996 (aartbik.com/sparse.php).
34 // [Kjolstad17] Fredrik Berg Kjolstad, Shoaib Ashraf Kamil, Stephen Chou,
35 // David Lugato, and Saman Amarasinghe. The Tensor Algebra Compiler.
36 // Proceedings of the ACM on Programming Languages, October 2017.
37 // [Kjolstad20] Fredrik Berg Kjolstad. Sparse Tensor Algebra Compilation.
38 // PhD thesis, MIT, February, 2020 (tensor-compiler.org).
39 //
40 // Implementation detail: We use llvm::SmallVector for vectors with
41 // variable lengths and std::vector for vectors with fixed lengths.
42 //===----------------------------------------------------------------------===//
43 
44 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
45 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
46 #include "mlir/Dialect/Linalg/Utils/Utils.h"
47 #include "mlir/Dialect/SCF/SCF.h"
48 #include "mlir/Dialect/StandardOps/IR/Ops.h"
49 
50 using namespace mlir;
51 
52 namespace {
53 
54 enum class Kind { kTensor, kInvariant, kMulF, kMulI, kAddF, kAddI };
55 
56 /// Tensor expression. Represents a MLIR expression in tensor index notation.
57 /// For tensors, e0 denotes the tensor index. For invariants, the IR value is
58 /// stored directly. For binary operations, e0 and e1 denote the index of the
59 /// children tensor expressions.
60 struct TensorExp {
TensorExp__anon336b03e10111::TensorExp61   TensorExp(Kind k, unsigned x, unsigned y, Value v)
62       : kind(k), e0(x), e1(y), val(v) {
63     assert((kind == Kind::kTensor && e0 != -1u && e1 == -1u && !val) ||
64            (kind == Kind::kInvariant && e0 == -1u && e1 == -1u && val) ||
65            (kind >= Kind::kMulF && e0 != -1u && e1 != -1u && !val));
66   }
67   Kind kind;
68   /// Indices of children expression(s).
69   unsigned e0;
70   unsigned e1;
71   /// Direct link to IR for an invariant. During code generation,
72   /// field is used to cache "hoisted" loop invariant tensor loads.
73   Value val;
74 };
75 
76 /// Lattice point. Each lattice point consists of a conjunction of tensor
77 /// loop indices (encoded in a bitvector) and the index of the corresponding
78 /// tensor expression.
79 struct LatPoint {
LatPoint__anon336b03e10111::LatPoint80   LatPoint(unsigned n, unsigned e, unsigned b) : bits(n, false), exp(e) {
81     bits.set(b);
82   }
LatPoint__anon336b03e10111::LatPoint83   LatPoint(const llvm::BitVector &b, unsigned e) : bits(b), exp(e) {}
84   /// Conjunction of tensor loop indices as bitvector.
85   llvm::BitVector bits;
86   /// Index of the tensor expresssion.
87   unsigned exp;
88 };
89 
90 /// A class to handle all iteration lattice operations. This class abstracts
91 /// away from some implementation details of storing iteration lattices and
92 /// tensor expressions. This allows for fine-tuning performance characteristics
93 /// independently from the basic algorithm if bottlenecks are identified.
94 class Merger {
95 public:
Merger(unsigned t,unsigned l)96   Merger(unsigned t, unsigned l)
97       : numTensors(t), numLoops(l), isSparse(t, std::vector<bool>(l, false)) {}
98 
99   /// Adds a tensor expression. Returns its index.
addExp(Kind k,unsigned e0,unsigned e1=-1u,Value v=Value ())100   unsigned addExp(Kind k, unsigned e0, unsigned e1 = -1u, Value v = Value()) {
101     unsigned e = tensorExps.size();
102     tensorExps.push_back(TensorExp(k, e0, e1, v));
103     return e;
104   }
addExp(Kind k,Value v)105   unsigned addExp(Kind k, Value v) { return addExp(k, -1u, -1u, v); }
106 
107   /// Adds an iteration lattice point. Returns its index.
addLat(unsigned t,unsigned i,unsigned e)108   unsigned addLat(unsigned t, unsigned i, unsigned e) {
109     assert(t < numTensors && i < numLoops);
110     unsigned p = latPoints.size();
111     latPoints.push_back(LatPoint(numLoops * numTensors, e, numTensors * i + t));
112     return p;
113   }
114 
115   /// Adds a new, initially empty, set. Returns its index.
addSet()116   unsigned addSet() {
117     unsigned s = latSets.size();
118     latSets.emplace_back(SmallVector<unsigned, 16>());
119     return s;
120   }
121 
122   /// Computes a single conjunction of two lattice points by taking the "union"
123   /// of loop indices (effectively constucting a larger "intersection" of those
124   /// indices) with a newly constructed tensor (sub)expression of given kind.
125   /// Returns the index of the new lattice point.
conjLatPoint(Kind kind,unsigned p0,unsigned p1)126   unsigned conjLatPoint(Kind kind, unsigned p0, unsigned p1) {
127     unsigned p = latPoints.size();
128     llvm::BitVector nb = llvm::BitVector(latPoints[p0].bits);
129     nb |= latPoints[p1].bits;
130     unsigned e = addExp(kind, latPoints[p0].exp, latPoints[p1].exp);
131     latPoints.push_back(LatPoint(nb, e));
132     return p;
133   }
134 
135   /// Conjunctive merge of L1 and L2 is conjunction of cartesian product.
136   /// Returns the index of the new set.
takeConj(Kind kind,unsigned s0,unsigned s1)137   unsigned takeConj(Kind kind, unsigned s0, unsigned s1) {
138     unsigned s = addSet();
139     for (unsigned p0 : latSets[s0])
140       for (unsigned p1 : latSets[s1])
141         latSets[s].push_back(conjLatPoint(kind, p0, p1));
142     return s;
143   }
144 
145   /// Disjunctive merge of L0 and L1 is (L0 /\_op L1, L0, L1).
146   /// Returns the index of the new set.
takeDisj(Kind kind,unsigned s0,unsigned s1)147   unsigned takeDisj(Kind kind, unsigned s0, unsigned s1) {
148     unsigned s = takeConj(kind, s0, s1);
149     for (unsigned p : latSets[s0])
150       latSets[s].push_back(p);
151     for (unsigned p : latSets[s1])
152       latSets[s].push_back(p);
153     return s;
154   }
155 
156   /// Optimizes the iteration lattice points in the given set. This
157   /// method should be called right before code generation to avoid
158   /// generating redundant loops and conditions.
optimize(unsigned s0)159   unsigned optimize(unsigned s0) {
160     unsigned s = addSet();
161     assert(latSets[s0].size() != 0);
162     unsigned p0 = latSets[s0][0];
163     for (unsigned p1 : latSets[s0]) {
164       bool add = true;
165       if (p0 != p1) {
166         // Is this a straightforward copy?
167         unsigned e = latPoints[p1].exp;
168         if (exp(e).kind == Kind::kTensor && exp(e).e0 == numTensors - 1)
169           continue;
170         // Is any dense index exhausted?
171         llvm::BitVector tmp = latPoints[p1].bits;
172         tmp ^= latPoints[p0].bits;
173         if (hasAnyOf(tmp, false))
174           continue;
175         // Is this a direct duplication of an earlier conjunction?
176         for (unsigned p2 : latSets[s]) {
177           tmp = latPoints[p1].bits;
178           tmp ^= latPoints[p2].bits;
179           if (tmp.count() == 0) {
180             add = false;
181             break;
182           }
183         }
184         assert(!add || latGT(p0, p1));
185       }
186       if (add)
187         latSets[s].push_back(p1);
188     }
189     return s;
190   }
191 
192   // Returns true if Li > Lj.
latGT(unsigned i,unsigned j) const193   bool latGT(unsigned i, unsigned j) const {
194     const llvm::BitVector &bitsi = latPoints[i].bits;
195     const llvm::BitVector &bitsj = latPoints[j].bits;
196     assert(bitsi.size() == bitsj.size());
197     if (bitsi.count() > bitsj.count()) {
198       for (unsigned b = 0, be = bitsj.size(); b < be; b++)
199         if (bitsj[b] && !bitsi[b])
200           return false;
201       return true;
202     }
203     return false;
204   }
205 
206   // Bit translation.
tensor(unsigned b) const207   unsigned tensor(unsigned b) const { return b % numTensors; }
index(unsigned b) const208   unsigned index(unsigned b) const { return b / numTensors; }
209 
210   // Returns true if bit corresponds to sparse access.
isSparseBit(unsigned b) const211   bool isSparseBit(unsigned b) const {
212     return isSparseAccess(tensor(b), index(b));
213   }
214 
215   // Returns true if tensor access at given index is sparse.
isSparseAccess(unsigned t,unsigned i) const216   bool isSparseAccess(unsigned t, unsigned i) const {
217     assert(t < numTensors && i < numLoops);
218     return isSparse[t][i];
219   }
220 
221   // Returns true if any set bit corresponds to sparse/dense access.
hasAnyOf(const llvm::BitVector & bits,bool sparse) const222   bool hasAnyOf(const llvm::BitVector &bits, bool sparse) const {
223     for (unsigned b = 0, be = bits.size(); b < be; b++)
224       if (bits[b] && isSparseBit(b) == sparse)
225         return true;
226     return false;
227   }
228 
229   // Getters.
sparse()230   std::vector<std::vector<bool>> &sparse() { return isSparse; }
exp(unsigned e)231   TensorExp &exp(unsigned e) { return tensorExps[e]; }
lat(unsigned l)232   LatPoint &lat(unsigned l) { return latPoints[l]; }
set(unsigned s)233   SmallVector<unsigned, 16> &set(unsigned s) { return latSets[s]; }
234 
235 private:
236   const unsigned numTensors;
237   const unsigned numLoops;
238 
239   std::vector<std::vector<bool>> isSparse;
240   llvm::SmallVector<TensorExp, 32> tensorExps;
241   llvm::SmallVector<LatPoint, 16> latPoints;
242   llvm::SmallVector<SmallVector<unsigned, 16>, 8> latSets;
243 };
244 
245 // Code generation.
246 struct CodeGen {
CodeGen__anon336b03e10111::CodeGen247   CodeGen(linalg::SparsificationOptions o, unsigned numTensors,
248           unsigned numLoops)
249       : options(o), loops(numLoops), sizes(numLoops), buffers(numTensors),
250         pointers(numTensors, std::vector<Value>(numLoops)),
251         indices(numTensors, std::vector<Value>(numLoops)),
252         highs(numTensors, std::vector<Value>(numLoops)),
253         pidxs(numTensors, std::vector<Value>(numLoops)),
254         idxs(numTensors, std::vector<Value>(numLoops)) {}
255   // Sparsification options.
256   linalg::SparsificationOptions options;
257   // Universal dense indices and upper bounds (by index). The loops array
258   // is updated with the value of the universal dense index in the current
259   // loop. The sizes array is set once with the inferred dimension sizes.
260   std::vector<Value> loops;
261   std::vector<Value> sizes;
262   // Buffers for storing dense and sparse numerical values (by tensor).
263   // This array is set once during bufferization of all tensors.
264   std::vector<Value> buffers;
265   // Sparse storage schemes (1-D): pointers and indices (by tensor and index).
266   // This array is set once during bufferization of all sparse tensors.
267   std::vector<std::vector<Value>> pointers;
268   std::vector<std::vector<Value>> indices;
269   // Sparse iteration information (by tensor and index). These arrays
270   // are updated to remain current within the current loop.
271   std::vector<std::vector<Value>> highs;
272   std::vector<std::vector<Value>> pidxs;
273   std::vector<std::vector<Value>> idxs;
274 };
275 
276 } // namespace
277 
278 /// Helper method to inspect sparse annotations in the linalg operation.
279 /// Fills the per-dimension sparsity information for all tensors.
findSparseAnnotations(linalg::GenericOp op,std::vector<std::vector<bool>> & isSparse)280 static void findSparseAnnotations(linalg::GenericOp op,
281                                   std::vector<std::vector<bool>> &isSparse) {
282   unsigned numTensors = op.getNumInputsAndOutputs();
283   ArrayAttr sparseAttr = op.sparseAttr();
284   for (unsigned t = 0; t < numTensors; t++) {
285     auto map = op.getIndexingMap(t);
286     auto dimAttr = sparseAttr[t].cast<ArrayAttr>();
287     // For each tensor, we accept a per-dimension Sparse or Dense annotation.
288     // This is translated to the loop index that indexes that dimension.
289     unsigned rank = op.getShapedType(t).getRank();
290     for (unsigned d = 0; d < rank; d++)
291       if (isSparseDim(dimAttr[d])) {
292         unsigned idx = map.getDimPosition(d);
293         isSparse[t][idx] = true;
294       } else {
295         assert(isDenseDim(dimAttr[d]));
296       }
297   }
298 }
299 
300 /// A DFS helper to compute a topological sort. Note that recursion is
301 /// bounded by the number of implicit loops, which is always small.
302 /// Returns false when a cycle is detected.
topSortDFS(unsigned i,std::vector<unsigned> & visit,std::vector<unsigned> & topSort,std::vector<std::vector<bool>> & adjM)303 static bool topSortDFS(unsigned i, std::vector<unsigned> &visit,
304                        std::vector<unsigned> &topSort,
305                        std::vector<std::vector<bool>> &adjM) {
306   if (visit[i] != 0)
307     return visit[i] != 1; // 1 denotes cycle!
308   visit[i] = 1;
309   for (unsigned j = 0, e = visit.size(); j < e; j++)
310     if (adjM[i][j])
311       if (!topSortDFS(j, visit, topSort, adjM))
312         return false;
313   visit[i] = 2;
314   topSort.push_back(i);
315   return true;
316 }
317 
318 /// Computes a topologically sorted iteration graph for the linalg operation.
319 /// Ensures all tensors are visited in natural index order. This is essential
320 /// for sparse storage formats since these only support access along fixed
321 /// dimensions. Even for dense storage formats, however, the natural index
322 /// order yields innermost unit-stride access with better spatial locality.
computeIterationGraph(linalg::GenericOp op,std::vector<unsigned> & topSort)323 static bool computeIterationGraph(linalg::GenericOp op,
324                                   std::vector<unsigned> &topSort) {
325   // Set up an n x n from/to adjacency matrix of the iteration graph
326   // for the implicit loop indices i_0 .. i_n-1.
327   unsigned n = op.getNumLoops();
328   std::vector<std::vector<bool>> adjM(n, std::vector<bool>(n, false));
329 
330   // Iterate over the indexing maps of every tensor in the tensor expression.
331   for (auto imap : llvm::enumerate(op.indexing_maps())) {
332     auto map = imap.value().template cast<AffineMapAttr>().getValue();
333     assert(map.getNumDims() == n);
334     // At the moment, we take the index variables in the tensor access
335     // expression in the order in which they appear (conceptually a
336     // "row-major" layout of every tensor). So, a tensor access A_ijk
337     // forces the ordering i < j < k on the loop indices.
338     // TODO: support affine map to define alternative dimension orders.
339     for (unsigned d = 1, e = map.getNumResults(); d < e; d++) {
340       unsigned f = map.getDimPosition(d - 1);
341       unsigned t = map.getDimPosition(d);
342       adjM[f][t] = true;
343     }
344   }
345 
346   // Topologically sort the iteration graph to determine loop order.
347   // Report failure for a cyclic iteration graph.
348   topSort.reserve(n);
349   std::vector<unsigned> visit(n, 0);
350   for (unsigned i = 0; i < n; i++)
351     if (visit[i] == 0)
352       if (!topSortDFS(i, visit, topSort, adjM))
353         return false; // cycle!
354   std::reverse(std::begin(topSort), std::end(topSort));
355   return true;
356 }
357 
358 /// Traverses the SSA tree (possibly a DAG) to build a tensor expression.
359 /// This simplifies constructing (sub)expressions during iteration lattice
360 /// building (compared to using the SSA representation everywhere).
buildTensorExp(Merger & merger,linalg::GenericOp op,Value val)361 static Optional<unsigned> buildTensorExp(Merger &merger, linalg::GenericOp op,
362                                          Value val) {
363   if (auto arg = val.dyn_cast<BlockArgument>()) {
364     unsigned argN = arg.getArgNumber();
365     if (arg.getOwner()->getParentOp() == op) {
366       // Any parameter of the generic op is considered a tensor,
367       // indexed by the implicit loop bounds.
368       auto map = op.getIndexingMap(argN);
369       if (map.isProjectedPermutation())
370         return merger.addExp(Kind::kTensor, argN);
371       // Cannot handle (yet).
372       return None;
373     }
374     // Any parameter of a higher op is invariant.
375     return merger.addExp(Kind::kInvariant, val);
376   }
377   Operation *def = val.getDefiningOp();
378   if (def->getBlock() != &op.region().front()) {
379     // Something defined outside is invariant.
380     return merger.addExp(Kind::kInvariant, val);
381   } else if (def->getNumOperands() == 2) {
382     // Construct binary operations if subexpressions could be built.
383     auto x = buildTensorExp(merger, op, def->getOperand(0));
384     auto y = buildTensorExp(merger, op, def->getOperand(1));
385     if (x.hasValue() && y.hasValue()) {
386       unsigned e0 = x.getValue();
387       unsigned e1 = y.getValue();
388       if (isa<MulFOp>(def))
389         return merger.addExp(Kind::kMulF, e0, e1);
390       if (isa<MulIOp>(def))
391         return merger.addExp(Kind::kMulI, e0, e1);
392       if (isa<AddFOp>(def))
393         return merger.addExp(Kind::kAddF, e0, e1);
394       if (isa<AddIOp>(def))
395         return merger.addExp(Kind::kAddI, e0, e1);
396     }
397   }
398   // Cannot build (yet).
399   return None;
400 }
401 
402 /// Builds the iteration lattices in a bottom-up traversal given the remaining
403 /// tensor (sub)expression and the next loop index in the iteration graph.
buildLattices(Merger & merger,linalg::GenericOp op,unsigned exp,unsigned idx)404 static unsigned buildLattices(Merger &merger, linalg::GenericOp op,
405                               unsigned exp, unsigned idx) {
406   Kind kind = merger.exp(exp).kind;
407   if (kind == Kind::kTensor || kind == Kind::kInvariant) {
408     // Either the index is really used in the tensor expression, or it is
409     // set to the "non-existing dense index" in that dimension. Invariant
410     // expressions borrow the output tensor indices.
411     unsigned s = merger.addSet();
412     unsigned t = kind == Kind::kTensor ? merger.exp(exp).e0
413                                        : op.getNumInputsAndOutputs() - 1;
414     merger.set(s).push_back(merger.addLat(t, idx, exp));
415     return s;
416   }
417   unsigned s0 = buildLattices(merger, op, merger.exp(exp).e0, idx);
418   unsigned s1 = buildLattices(merger, op, merger.exp(exp).e1, idx);
419   switch (kind) {
420   case Kind::kTensor:
421   case Kind::kInvariant:
422     llvm_unreachable("handled above");
423   case Kind::kMulF:
424   case Kind::kMulI:
425     return merger.takeConj(kind, s0, s1);
426   case Kind::kAddF:
427   case Kind::kAddI:
428     return merger.takeDisj(kind, s0, s1);
429   }
430 }
431 
432 /// Maps sparse integer option to actual integral storage type.
genIntType(PatternRewriter & rewriter,linalg::SparseIntType tp)433 static Type genIntType(PatternRewriter &rewriter, linalg::SparseIntType tp) {
434   switch (tp) {
435   case linalg::SparseIntType::kNative:
436     return rewriter.getIndexType();
437   case linalg::SparseIntType::kI64:
438     return rewriter.getIntegerType(64);
439   case linalg::SparseIntType::kI32:
440     return rewriter.getIntegerType(32);
441   }
442 }
443 
444 /// Local bufferization of all dense and sparse data structures.
445 /// This code enables testing the first prototype sparse compiler.
446 // TODO: replace this with a proliferated bufferization strategy
genBuffers(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op)447 static void genBuffers(Merger &merger, CodeGen &codegen,
448                        PatternRewriter &rewriter, linalg::GenericOp op) {
449   Location loc = op.getLoc();
450   unsigned numTensors = op.getNumInputsAndOutputs();
451   unsigned numInputs = op.getNumInputs();
452   assert(numTensors == numInputs + 1);
453 
454   // For now, set all unknown dimensions to 999.
455   // TODO: compute these values (using sparsity or by reading tensor)
456   Value unknown = rewriter.create<ConstantIndexOp>(loc, 999);
457 
458   // For every tensor, find lower and upper bound on dimensions, set the
459   // same bounds on loop indices, and allocate dense or sparse buffer(s).
460   SmallVector<Value, 4> args;
461   for (unsigned t = 0; t < numTensors; t++) {
462     auto tensorType = op.getShapedType(t);
463     auto shape = tensorType.getShape();
464     auto map = op.getIndexingMap(t);
465     // Scan all dimensions of current tensor.
466     bool allDense = true;
467     args.clear();
468     for (unsigned d = 0, rank = shape.size(); d < rank; d++) {
469       unsigned i = map.getDimPosition(d);
470       // Handle sparse storage schemes.
471       if (merger.isSparseAccess(t, i)) {
472         allDense = false;
473         auto dynShape = {ShapedType::kDynamicSize};
474         auto ptrTp = MemRefType::get(
475             dynShape, genIntType(rewriter, codegen.options.ptrType));
476         auto indTp = MemRefType::get(
477             dynShape, genIntType(rewriter, codegen.options.indType));
478         codegen.pointers[t][i] = rewriter.create<AllocaOp>(loc, ptrTp, unknown);
479         codegen.indices[t][i] = rewriter.create<AllocaOp>(loc, indTp, unknown);
480       }
481       // Find lower and upper bound in current dimension.
482       Value up;
483       if (shape[d] == TensorType::kDynamicSize) {
484         // For the output tensor, we may need to infer the upper bound.
485         // For all others, we look at the incoming argument.
486         if (t == numInputs && !op.getNumInitTensors()) {
487           up = codegen.sizes[i];
488           assert(up); // TODO: what else?
489         } else {
490           Value arg = t < numInputs ? op.getInput(t) : op.getInitTensor(0);
491           up = rewriter.create<DimOp>(loc, arg, d);
492         }
493         args.push_back(up);
494       } else {
495         up = rewriter.create<ConstantIndexOp>(loc, shape[d]);
496       }
497       codegen.sizes[i] = codegen.highs[t][i] = up;
498     }
499     // Allocate dense or sparse buffer for numerical values.
500     if (allDense) {
501       auto denseTp = MemRefType::get(shape, tensorType.getElementType());
502       codegen.buffers[t] = rewriter.create<AllocaOp>(loc, denseTp, args);
503     } else {
504       auto sparseTp = MemRefType::get({ShapedType::kDynamicSize},
505                                       tensorType.getElementType());
506       codegen.buffers[t] = rewriter.create<AllocaOp>(loc, sparseTp, unknown);
507     }
508   }
509 }
510 
511 /// Generates a load on a dense or sparse tensor.
genTensorLoad(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned exp)512 static Value genTensorLoad(Merger &merger, CodeGen &codegen,
513                            PatternRewriter &rewriter, linalg::GenericOp op,
514                            unsigned exp) {
515   // Test if the load was hoisted to a higher loop nest.
516   Value val = merger.exp(exp).val;
517   if (val) {
518     merger.exp(exp).val = Value(); // reset
519     return val;
520   }
521   // Actual load.
522   SmallVector<Value, 4> args;
523   unsigned tensor = merger.exp(exp).e0;
524   auto map = op.getIndexingMap(tensor);
525   bool sparse = false;
526   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
527     unsigned idx = map.getDimPosition(i);
528     args.push_back(codegen.loops[idx]); // universal dense index
529     if (sparse || merger.isSparseAccess(tensor, idx)) {
530       sparse = true;
531       args.clear();
532       args.push_back(codegen.pidxs[tensor][idx]); // position index
533     }
534   }
535   Location loc = op.getLoc();
536   Value ptr = codegen.buffers[tensor];
537   return rewriter.create<LoadOp>(loc, ptr, args);
538 }
539 
540 /// Generates a store on a dense tensor.
genTensorStore(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned tensor,Value rhs)541 static void genTensorStore(Merger &merger, CodeGen &codegen,
542                            PatternRewriter &rewriter, linalg::GenericOp op,
543                            unsigned tensor, Value rhs) {
544   SmallVector<Value, 4> args;
545   auto map = op.getIndexingMap(tensor);
546   for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
547     unsigned idx = map.getDimPosition(i);
548     args.push_back(codegen.loops[idx]); // universal dense index
549   }
550   Location loc = op.getLoc();
551   Value ptr = codegen.buffers[tensor];
552   rewriter.create<StoreOp>(loc, rhs, ptr, args);
553 }
554 
555 /// Generates a pointer/index load from the sparse storage scheme.
genLoad(PatternRewriter & rewriter,Location loc,Value ptr,Value s)556 static Value genLoad(PatternRewriter &rewriter, Location loc, Value ptr,
557                      Value s) {
558   Value load = rewriter.create<LoadOp>(loc, ptr, s);
559   return load.getType().isa<IndexType>()
560              ? load
561              : rewriter.create<IndexCastOp>(loc, load, rewriter.getIndexType());
562 }
563 
564 /// Generates an invariant value.
genInvariantValue(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,unsigned exp)565 static Value genInvariantValue(Merger &merger, CodeGen &codegen,
566                                PatternRewriter &rewriter, unsigned exp) {
567   return merger.exp(exp).val;
568 }
569 
570 /// Recursively generates tensor expression.
genExp(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned exp)571 static Value genExp(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
572                     linalg::GenericOp op, unsigned exp) {
573   if (merger.exp(exp).kind == Kind::kTensor)
574     return genTensorLoad(merger, codegen, rewriter, op, exp);
575   else if (merger.exp(exp).kind == Kind::kInvariant)
576     return genInvariantValue(merger, codegen, rewriter, exp);
577   Value v0 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e0);
578   Value v1 = genExp(merger, codegen, rewriter, op, merger.exp(exp).e1);
579   switch (merger.exp(exp).kind) {
580   case Kind::kTensor:
581   case Kind::kInvariant:
582     llvm_unreachable("handled above");
583   case Kind::kMulF:
584     return rewriter.create<MulFOp>(op.getLoc(), v0, v1);
585   case Kind::kMulI:
586     return rewriter.create<MulIOp>(op.getLoc(), v0, v1);
587   case Kind::kAddF:
588     return rewriter.create<AddFOp>(op.getLoc(), v0, v1);
589   case Kind::kAddI:
590     return rewriter.create<AddIOp>(op.getLoc(), v0, v1);
591   }
592 }
593 
594 /// Hoists loop invariant tensor loads for which indices have been exhausted.
genInvariants(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned exp)595 static void genInvariants(Merger &merger, CodeGen &codegen,
596                           PatternRewriter &rewriter, linalg::GenericOp op,
597                           unsigned exp) {
598   if (merger.exp(exp).kind == Kind::kTensor) {
599     unsigned lhs = op.getNumInputsAndOutputs() - 1;
600     unsigned tensor = merger.exp(exp).e0;
601     if (tensor == lhs)
602       return; // TODO: scalarize reduction as well (using scf.yield)
603     auto map = op.getIndexingMap(tensor);
604     for (unsigned i = 0, m = map.getNumResults(); i < m; ++i) {
605       unsigned idx = map.getDimPosition(i);
606       if (!codegen.loops[idx])
607         return; // still in play
608     }
609     // All exhausted at this level.
610     merger.exp(exp).val = genTensorLoad(merger, codegen, rewriter, op, exp);
611 
612   } else if (merger.exp(exp).kind != Kind::kInvariant) {
613     // Traverse into the binary operations. Note that we only hoist
614     // tensor loads, since subsequent MLIR/LLVM passes know how to
615     // deal with all other kinds of derived loop invariants.
616     genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e0);
617     genInvariants(merger, codegen, rewriter, op, merger.exp(exp).e1);
618   }
619 }
620 
621 /// Generates initialization code for the subsequent loop sequence at
622 /// current index level. Returns true if the loop sequence needs to
623 /// maintain the universal index.
genInit(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned at,llvm::BitVector & inits)624 static bool genInit(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
625                     linalg::GenericOp op, std::vector<unsigned> &topSort,
626                     unsigned at, llvm::BitVector &inits) {
627   bool needsUniv = false;
628   Location loc = op.getLoc();
629   unsigned idx = topSort[at];
630 
631   // Initialize sparse positions.
632   for (unsigned b = 0, be = inits.size(); b < be; b++) {
633     if (inits[b]) {
634       unsigned tensor = merger.tensor(b);
635       assert(idx == merger.index(b));
636       if (merger.isSparseBit(b)) {
637         // Initialize sparse index.
638         unsigned pat = at;
639         for (; pat != 0; pat--) {
640           if (codegen.pidxs[tensor][topSort[pat - 1]])
641             break;
642         }
643         Value ptr = codegen.pointers[tensor][idx];
644         Value one = rewriter.create<ConstantIndexOp>(loc, 1);
645         Value p0 = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
646                               : codegen.pidxs[tensor][topSort[pat - 1]];
647         codegen.pidxs[tensor][idx] = genLoad(rewriter, loc, ptr, p0);
648         Value p1 = rewriter.create<AddIOp>(loc, p0, one);
649         codegen.highs[tensor][idx] = genLoad(rewriter, loc, ptr, p1);
650       } else {
651         // Dense index still in play.
652         needsUniv = true;
653       }
654     }
655   }
656 
657   // Initialize the universal dense index.
658   codegen.loops[idx] = rewriter.create<ConstantIndexOp>(loc, 0);
659   return needsUniv;
660 }
661 
662 /// Generates a for-loop on a single index.
genFor(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,bool isOuter,bool isInner,unsigned idx,llvm::BitVector & indices)663 static Operation *genFor(Merger &merger, CodeGen &codegen,
664                          PatternRewriter &rewriter, linalg::GenericOp op,
665                          bool isOuter, bool isInner, unsigned idx,
666                          llvm::BitVector &indices) {
667   unsigned fb = indices.find_first();
668   unsigned tensor = merger.tensor(fb);
669   assert(idx == merger.index(fb));
670 
671   // Parallelization strategy. Any implicit loop in the Linalg operation that
672   // is marked "parallel" is a candidate. Whether it is actually converted to
673   // a parallel operation depends on the requested strategy.
674   auto iteratorTypes = op.iterator_types().getValue();
675   bool isSparse = merger.isSparseBit(fb);
676   bool isParallel = linalg::isParallelIteratorType(iteratorTypes[idx]);
677   switch (codegen.options.parallelizationStrategy) {
678   case linalg::SparseParallelizationStrategy::kNone:
679     isParallel = false;
680     break;
681   case linalg::SparseParallelizationStrategy::kDenseOuterLoop:
682     isParallel &= isOuter && !isSparse;
683     break;
684   case linalg::SparseParallelizationStrategy::kAnyStorageOuterLoop:
685     isParallel &= isOuter;
686     break;
687   case linalg::SparseParallelizationStrategy::kDenseAnyLoop:
688     isParallel &= !isSparse;
689     break;
690   case linalg::SparseParallelizationStrategy::kAnyStorageAnyLoop:
691     break;
692   }
693 
694   // Loop bounds and increment.
695   Location loc = op.getLoc();
696   Value lo;
697   Value hi;
698   Value step = rewriter.create<ConstantIndexOp>(loc, 1);
699   Value index;
700   if (isSparse) {
701     lo = codegen.pidxs[tensor][idx];
702     hi = codegen.highs[tensor][idx];
703   } else {
704     lo = codegen.loops[idx];
705     hi = codegen.sizes[idx];
706   }
707 
708   // Emit a parallel loop.
709   if (isParallel) {
710     scf::ParallelOp parOp = rewriter.create<scf::ParallelOp>(loc, lo, hi, step);
711     if (isSparse)
712       codegen.pidxs[tensor][idx] = parOp.getInductionVars()[0];
713     else
714       codegen.loops[idx] = parOp.getInductionVars()[0];
715     rewriter.setInsertionPointToStart(parOp.getBody());
716     return parOp;
717   }
718 
719   // Emit a sequential loop.
720   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step);
721   if (isSparse)
722     codegen.pidxs[tensor][idx] = forOp.getInductionVar();
723   else
724     codegen.loops[idx] = forOp.getInductionVar();
725   rewriter.setInsertionPointToStart(forOp.getBody());
726   return forOp;
727 }
728 
729 /// Emit a while-loop for co-iteration over multiple indices.
genWhile(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned idx,bool needsUniv,llvm::BitVector & indices)730 static Operation *genWhile(Merger &merger, CodeGen &codegen,
731                            PatternRewriter &rewriter, linalg::GenericOp op,
732                            unsigned idx, bool needsUniv,
733                            llvm::BitVector &indices) {
734   SmallVector<Type, 4> types;
735   SmallVector<Value, 4> operands;
736   // Construct the while-loop with a parameter for each index.
737   Type indexType = rewriter.getIndexType();
738   for (unsigned b = 0, be = indices.size(); b < be; b++) {
739     if (indices[b] && merger.isSparseBit(b)) {
740       unsigned tensor = merger.tensor(b);
741       assert(idx == merger.index(b));
742       types.push_back(indexType);
743       operands.push_back(codegen.pidxs[tensor][idx]);
744     }
745   }
746   if (needsUniv) {
747     types.push_back(indexType);
748     operands.push_back(codegen.loops[idx]);
749   }
750   Location loc = op.getLoc();
751   scf::WhileOp whileOp = rewriter.create<scf::WhileOp>(loc, types, operands);
752   Block *before = rewriter.createBlock(&whileOp.before(), {}, types);
753   Block *after = rewriter.createBlock(&whileOp.after(), {}, types);
754 
755   // Build the "before" region, which effectively consists
756   // of a conjunction of "i < upper" tests on all induction.
757   rewriter.setInsertionPointToStart(&whileOp.before().front());
758   Value cond;
759   unsigned o = 0;
760   for (unsigned b = 0, be = indices.size(); b < be; b++) {
761     if (indices[b] && merger.isSparseBit(b)) {
762       unsigned tensor = merger.tensor(b);
763       assert(idx == merger.index(b));
764       Value op1 = before->getArgument(o);
765       Value op2 = codegen.highs[tensor][idx];
766       Value opc = rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, op1, op2);
767       cond = cond ? rewriter.create<AndOp>(loc, cond, opc) : opc;
768       codegen.pidxs[tensor][idx] = after->getArgument(o++);
769     }
770   }
771   if (needsUniv)
772     codegen.loops[idx] = after->getArgument(o++);
773   assert(o == operands.size());
774   rewriter.create<scf::ConditionOp>(loc, cond, before->getArguments());
775   rewriter.setInsertionPointToStart(&whileOp.after().front());
776   return whileOp;
777 }
778 
779 /// Generates a for-loop or a while-loop, depending on whether it implements
780 /// singleton iteration or co-iteration over the given conjunction.
genLoop(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned at,bool needsUniv,llvm::BitVector & indices)781 static Operation *genLoop(Merger &merger, CodeGen &codegen,
782                           PatternRewriter &rewriter, linalg::GenericOp op,
783                           std::vector<unsigned> &topSort, unsigned at,
784                           bool needsUniv, llvm::BitVector &indices) {
785   unsigned idx = topSort[at];
786   if (indices.count() == 1) {
787     bool isOuter = at == 0;
788     bool isInner = at == topSort.size() - 1;
789     return genFor(merger, codegen, rewriter, op, isOuter, isInner, idx,
790                   indices);
791   }
792   return genWhile(merger, codegen, rewriter, op, idx, needsUniv, indices);
793 }
794 
795 /// Generates the local variables for this loop, consisting of the sparse
796 /// indices, restored universal dense index, and dense positions.
genLocals(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned at,bool needsUniv,llvm::BitVector & locals)797 static void genLocals(Merger &merger, CodeGen &codegen,
798                       PatternRewriter &rewriter, linalg::GenericOp op,
799                       std::vector<unsigned> &topSort, unsigned at,
800                       bool needsUniv, llvm::BitVector &locals) {
801   Location loc = op.getLoc();
802   unsigned idx = topSort[at];
803 
804   // Initialize sparse indices.
805   Value min;
806   for (unsigned b = 0, be = locals.size(); b < be; b++) {
807     if (locals[b] && merger.isSparseBit(b)) {
808       unsigned tensor = merger.tensor(b);
809       assert(idx == merger.index(b));
810       Value ptr = codegen.indices[tensor][idx];
811       Value s = codegen.pidxs[tensor][idx];
812       Value load = genLoad(rewriter, loc, ptr, s);
813       codegen.idxs[tensor][idx] = load;
814       if (!needsUniv) {
815         if (min) {
816           Value cmp =
817               rewriter.create<CmpIOp>(loc, CmpIPredicate::ult, load, min);
818           min = rewriter.create<SelectOp>(loc, cmp, load, min);
819         } else {
820           min = load;
821         }
822       }
823     }
824   }
825 
826   // Merge dense universal index over minimum.
827   if (min) {
828     assert(!needsUniv);
829     codegen.loops[idx] = min;
830   }
831 
832   // Initialize dense positions.
833   for (unsigned b = 0, be = locals.size(); b < be; b++) {
834     if (locals[b] && !merger.isSparseBit(b)) {
835       unsigned tensor = merger.tensor(b);
836       assert(idx == merger.index(b));
837       if (!codegen.highs[tensor][idx])
838         continue; // unused dimension
839       unsigned pat = at;
840       for (; pat != 0; pat--)
841         if (codegen.pidxs[tensor][topSort[pat - 1]])
842           break;
843       Value p = (pat == 0) ? rewriter.create<ConstantIndexOp>(loc, 0)
844                            : codegen.pidxs[tensor][topSort[pat - 1]];
845       Value m = rewriter.create<MulIOp>(loc, codegen.sizes[idx], p);
846       codegen.pidxs[tensor][idx] =
847           rewriter.create<AddIOp>(loc, m, codegen.loops[idx]);
848     }
849   }
850 }
851 
852 /// Generates the induction structure for a while-loop.
genWhileInduction(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned idx,bool needsUniv,llvm::BitVector & induction,ResultRange results)853 static void genWhileInduction(Merger &merger, CodeGen &codegen,
854                               PatternRewriter &rewriter, linalg::GenericOp op,
855                               unsigned idx, bool needsUniv,
856                               llvm::BitVector &induction, ResultRange results) {
857   Location loc = op.getLoc();
858   unsigned o = 0;
859   SmallVector<Value, 4> operands;
860   Value one = rewriter.create<ConstantIndexOp>(loc, 1);
861   for (unsigned b = 0, be = induction.size(); b < be; b++)
862     if (induction[b] && merger.isSparseBit(b)) {
863       unsigned tensor = merger.tensor(b);
864       assert(idx == merger.index(b));
865       Value op1 = codegen.idxs[tensor][idx];
866       Value op2 = codegen.loops[idx];
867       Value op3 = codegen.pidxs[tensor][idx];
868       Value cmp = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
869       Value add = rewriter.create<AddIOp>(loc, op3, one);
870       operands.push_back(rewriter.create<SelectOp>(loc, cmp, add, op3));
871       codegen.pidxs[tensor][idx] = results[o++];
872     }
873   if (needsUniv) {
874     operands.push_back(rewriter.create<AddIOp>(loc, codegen.loops[idx], one));
875     codegen.loops[idx] = results[o++];
876   }
877   assert(o == operands.size());
878   rewriter.create<scf::YieldOp>(loc, operands);
879 }
880 
881 /// Generates a single if-statement within a while-loop.
genIf(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,unsigned idx,llvm::BitVector & conditions,scf::IfOp & ifOp)882 static void genIf(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
883                   linalg::GenericOp op, unsigned idx,
884                   llvm::BitVector &conditions, scf::IfOp &ifOp) {
885   Location loc = op.getLoc();
886   if (ifOp)
887     rewriter.setInsertionPointToStart(&ifOp.elseRegion().front());
888   Value cond;
889   for (unsigned b = 0, be = conditions.size(); b < be; b++) {
890     if (conditions[b]) {
891       unsigned tensor = merger.tensor(b);
892       assert(idx == merger.index(b));
893       Value clause;
894       if (merger.isSparseBit(b)) {
895         Value op1 = codegen.idxs[tensor][idx];
896         Value op2 = codegen.loops[idx];
897         clause = rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, op1, op2);
898       } else {
899         clause = rewriter.create<ConstantIntOp>(loc, 1, 1); // true
900       }
901       cond = cond ? rewriter.create<AndOp>(loc, cond, clause) : clause;
902     }
903   }
904   ifOp = rewriter.create<scf::IfOp>(loc, cond, /*else*/ true);
905   rewriter.setInsertionPointToStart(&ifOp.thenRegion().front());
906 }
907 
908 /// Optimize the loop indices of Li with two rules rules:
909 /// (1) convert multiple dense to single dense, and
910 /// (2) convert singleton sparse/dense to sparse/random access.
optimizeIndices(Merger merger,unsigned lsize,llvm::BitVector & indices)911 static void optimizeIndices(Merger merger, unsigned lsize,
912                             llvm::BitVector &indices) {
913   if (merger.hasAnyOf(indices, false)) {
914     bool reset = lsize == 1 && merger.hasAnyOf(indices, true);
915     for (unsigned b = 0, be = indices.size(); b < be; b++) {
916       if (indices[b] && !merger.isSparseBit(b)) {
917         if (reset)
918           indices.reset(b);
919         reset = true;
920       }
921     }
922   }
923 }
924 
925 /// Recursively generates code while computing iteration lattices in order
926 /// to manage the complexity of implementing co-iteration over unions
927 /// and intersections of sparse iterations spaces.
genStmt(Merger & merger,CodeGen & codegen,PatternRewriter & rewriter,linalg::GenericOp op,std::vector<unsigned> & topSort,unsigned exp,unsigned at)928 static void genStmt(Merger &merger, CodeGen &codegen, PatternRewriter &rewriter,
929                     linalg::GenericOp op, std::vector<unsigned> &topSort,
930                     unsigned exp, unsigned at) {
931   // At each leaf, assign remaining tensor (sub)expression to output tensor.
932   if (at == topSort.size()) {
933     unsigned lhs = op.getNumInputsAndOutputs() - 1;
934     Value rhs = genExp(merger, codegen, rewriter, op, exp);
935     genTensorStore(merger, codegen, rewriter, op, lhs, rhs);
936     return;
937   }
938 
939   // Construct iteration lattices for current loop index, with L0 at top.
940   // Then emit initialization code for the loop sequence at this level.
941   // We maintain the universal dense index if dense indices are still
942   // in play for a non-singleton loop sequence.
943   unsigned idx = topSort[at];
944   unsigned lts = merger.optimize(buildLattices(merger, op, exp, idx));
945   unsigned lsize = merger.set(lts).size();
946   assert(lsize != 0);
947   unsigned l0 = merger.set(lts)[0];
948   LatPoint lat0 = merger.lat(l0);
949   genInvariants(merger, codegen, rewriter, op, exp);
950   bool needsUniv =
951       genInit(merger, codegen, rewriter, op, topSort, at, lat0.bits) &&
952       lsize > 1;
953 
954   // Emit a loop for every lattice point L0 >= Li.
955   for (unsigned li : merger.set(lts)) {
956     LatPoint lati = merger.lat(li);
957 
958     // Emit loop.
959     llvm::BitVector indices = lati.bits;
960     optimizeIndices(merger, lsize, indices);
961     Operation *loop =
962         genLoop(merger, codegen, rewriter, op, topSort, at, needsUniv, indices);
963     genLocals(merger, codegen, rewriter, op, topSort, at, needsUniv, lati.bits);
964 
965     // Visit all lattices points with Li >= Lj to generate the
966     // loop-body, possibly with if statements for coiteration.
967     bool isWhile = dyn_cast<scf::WhileOp>(loop) != nullptr;
968     scf::IfOp ifOp;
969     for (unsigned lj : merger.set(lts)) {
970       if (li == lj || merger.latGT(li, lj)) {
971         LatPoint latj = merger.lat(lj);
972         llvm::BitVector tmp = latj.bits;
973         tmp ^= lati.bits;
974         if (merger.hasAnyOf(tmp, false))
975           continue; // dense exhausted within if/else
976         // Recurse into body of each branch.
977         if (isWhile)
978           genIf(merger, codegen, rewriter, op, idx, latj.bits, ifOp);
979         genStmt(merger, codegen, rewriter, op, topSort, latj.exp, at + 1);
980       }
981     }
982 
983     // Wrap-up induction and restore insertion point.
984     if (isWhile) {
985       scf::WhileOp whileOp = cast<scf::WhileOp>(loop);
986       rewriter.setInsertionPointToEnd(&whileOp.after().front());
987       genWhileInduction(merger, codegen, rewriter, op, idx, needsUniv,
988                         lati.bits, whileOp.results());
989     } else {
990       needsUniv = false;
991     }
992     rewriter.setInsertionPointAfter(loop);
993   }
994   codegen.loops[idx] = Value();
995 }
996 
997 namespace {
998 
999 /// Sparse rewriting rule for generic Lingalg operation.
1000 struct GenericOpSparsifier : public OpRewritePattern<linalg::GenericOp> {
1001 public:
GenericOpSparsifier__anon336b03e10211::GenericOpSparsifier1002   GenericOpSparsifier(MLIRContext *context, linalg::SparsificationOptions o)
1003       : OpRewritePattern<linalg::GenericOp>(context), options(o) {}
1004 
matchAndRewrite__anon336b03e10211::GenericOpSparsifier1005   LogicalResult matchAndRewrite(linalg::GenericOp op,
1006                                 PatternRewriter &rewriter) const override {
1007     // Detects sparse annotations and translate the per-dimension sparsity
1008     // information for all tensors to loop indices in the kernel.
1009     if (!op.hasSparseSemantics())
1010       return failure();
1011     assert(op.getNumOutputs() == 1);
1012     unsigned numTensors = op.getNumInputsAndOutputs();
1013     unsigned numLoops = op.iterator_types().getValue().size();
1014     Merger merger(numTensors, numLoops);
1015     findSparseAnnotations(op, merger.sparse());
1016 
1017     // Computes a topologically sorted iteration graph to ensure
1018     // tensors are visited in natural index order. Fails on cycles.
1019     // This assumes that higher-level passes have already put the
1020     // tensors in each tensor expression in a feasible order.
1021     // TODO: try again without *dense* constraints on failure or
1022     //       even try to insert sparse reorderings to resolve cycles
1023     std::vector<unsigned> topSort;
1024     if (!computeIterationGraph(op, topSort))
1025       return failure();
1026 
1027     // Finds the terminating yield statement and builds the tensor
1028     // expression for the Linalg operation in SSA form.
1029     Operation *yield = op.region().front().getTerminator();
1030     Optional<unsigned> exp = buildTensorExp(merger, op, yield->getOperand(0));
1031     if (!exp.hasValue())
1032       return failure(); // build failure
1033 
1034     // Recursively generates code.
1035     CodeGen codegen(options, numTensors, numLoops);
1036     genBuffers(merger, codegen, rewriter, op);
1037     genStmt(merger, codegen, rewriter, op, topSort, exp.getValue(), 0);
1038     Value result =
1039         rewriter.create<TensorLoadOp>(op.getLoc(), codegen.buffers.back());
1040     rewriter.replaceOp(op, result);
1041     return success();
1042   }
1043 
1044 private:
1045   /// Options to control sparse code generation.
1046   linalg::SparsificationOptions options;
1047 };
1048 
1049 } // namespace
1050 
1051 /// Populates the given patterns list with rewriting rules required for
1052 /// the sparsification of linear algebra operations.
populateSparsificationPatterns(MLIRContext * context,OwningRewritePatternList & patterns,const SparsificationOptions & options)1053 void linalg::populateSparsificationPatterns(
1054     MLIRContext *context, OwningRewritePatternList &patterns,
1055     const SparsificationOptions &options) {
1056   patterns.insert<GenericOpSparsifier>(context, options);
1057 }
1058