1 //===- TestAllReduceLowering.cpp - Test gpu.all_reduce lowering -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains test passes for lowering the gpu.all_reduce op.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "mlir/Dialect/GPU/Passes.h"
14 #include "mlir/Dialect/StandardOps/IR/Ops.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 
18 using namespace mlir;
19 
20 namespace {
21 struct TestGpuRewritePass
22     : public PassWrapper<TestGpuRewritePass, OperationPass<ModuleOp>> {
getDependentDialects__anon814ba0d40111::TestGpuRewritePass23   void getDependentDialects(DialectRegistry &registry) const override {
24     registry.insert<StandardOpsDialect>();
25   }
runOnOperation__anon814ba0d40111::TestGpuRewritePass26   void runOnOperation() override {
27     OwningRewritePatternList patterns;
28     populateGpuRewritePatterns(&getContext(), patterns);
29     applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
30   }
31 };
32 } // namespace
33 
34 namespace mlir {
registerTestAllReduceLoweringPass()35 void registerTestAllReduceLoweringPass() {
36   PassRegistration<TestGpuRewritePass> pass(
37       "test-gpu-rewrite",
38       "Applies all rewrite patterns within the GPU dialect.");
39 }
40 } // namespace mlir
41