1// RUN: mlir-opt -convert-linalg-to-loops -lower-affine -convert-scf-to-std -convert-std-to-llvm %s | mlir-cpu-runner -O3 -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s
2
3func @main() {
4  %A = alloc() : memref<16x16xf32>
5  %B = alloc() : memref<16x16xf32>
6  %C = alloc() : memref<16x16xf32>
7
8  %cf1 = constant 1.00000e+00 : f32
9
10  linalg.fill(%A, %cf1) : memref<16x16xf32>, f32
11  linalg.fill(%B, %cf1) : memref<16x16xf32>, f32
12
13  %reps = constant 1 : index
14
15  %t_start = call @rtclock() : () -> f64
16  affine.for %arg0 = 0 to 5 {
17    linalg.fill(%C, %cf1) : memref<16x16xf32>, f32
18    call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> ()
19  }
20  %t_end = call @rtclock() : () -> f64
21  %t = subf %t_end, %t_start : f64
22
23  %pC = memref_cast %C : memref<16x16xf32> to memref<*xf32>
24  call @print_memref_f32(%pC) : (memref<*xf32>) -> ()
25
26  %c0 = constant 0 : index
27  %c1 = constant 1 : index
28  %c2 = constant 2 : index
29
30  %M = dim %C, %c0 : memref<16x16xf32>
31  %N = dim %C, %c1 : memref<16x16xf32>
32  %K = dim %A, %c1 : memref<16x16xf32>
33
34  %f1 = muli %M, %N : index
35  %f2 = muli %f1, %K : index
36
37  // 2*M*N*K.
38  %f3 = muli %c2, %f2 : index
39  %num_flops = muli %reps, %f3 : index
40  %num_flops_i = index_cast %num_flops : index to i16
41  %num_flops_f = sitofp %num_flops_i : i16 to f64
42  %flops = divf %num_flops_f, %t : f64
43  call @print_flops(%flops) : (f64) -> ()
44
45  return
46}
47// CHECK: 17,   17,   17,
48
49func @sgemm_naive(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>) {
50  %c0 = constant 0 : index
51  affine.for %arg3 = 0 to 16 {
52    affine.for %arg4 = 0 to 16 {
53      %m = alloc() : memref<1xf32>
54      %v = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32>
55      affine.store %v, %m[%c0] : memref<1xf32>
56      affine.for %arg5 = 0 to 16 {
57        %3 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32>
58        %4 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32>
59        %5 = affine.load %m[0] : memref<1xf32>
60        %6 = mulf %3, %4 : f32
61        %7 = addf %6, %5 : f32
62        affine.store %7, %m[0] : memref<1xf32>
63      }
64      %s = affine.load %m[%c0] : memref<1xf32>
65      affine.store %s, %arg2[%arg3, %arg4] : memref<16x16xf32>
66      dealloc %m : memref<1xf32>
67    }
68  }
69  return
70}
71
72func private @print_flops(f64)
73func private @rtclock() -> f64
74func private @print_memref_f32(memref<*xf32>)
75