1// RUN: mlir-opt -convert-linalg-to-loops -lower-affine -convert-scf-to-std -convert-std-to-llvm %s | mlir-cpu-runner -O3 -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext | FileCheck %s 2 3func @main() { 4 %A = alloc() : memref<16x16xf32> 5 %B = alloc() : memref<16x16xf32> 6 %C = alloc() : memref<16x16xf32> 7 8 %cf1 = constant 1.00000e+00 : f32 9 10 linalg.fill(%A, %cf1) : memref<16x16xf32>, f32 11 linalg.fill(%B, %cf1) : memref<16x16xf32>, f32 12 13 %reps = constant 1 : index 14 15 %t_start = call @rtclock() : () -> f64 16 affine.for %arg0 = 0 to 5 { 17 linalg.fill(%C, %cf1) : memref<16x16xf32>, f32 18 call @sgemm_naive(%A, %B, %C) : (memref<16x16xf32>, memref<16x16xf32>, memref<16x16xf32>) -> () 19 } 20 %t_end = call @rtclock() : () -> f64 21 %t = subf %t_end, %t_start : f64 22 23 %pC = memref_cast %C : memref<16x16xf32> to memref<*xf32> 24 call @print_memref_f32(%pC) : (memref<*xf32>) -> () 25 26 %c0 = constant 0 : index 27 %c1 = constant 1 : index 28 %c2 = constant 2 : index 29 30 %M = dim %C, %c0 : memref<16x16xf32> 31 %N = dim %C, %c1 : memref<16x16xf32> 32 %K = dim %A, %c1 : memref<16x16xf32> 33 34 %f1 = muli %M, %N : index 35 %f2 = muli %f1, %K : index 36 37 // 2*M*N*K. 38 %f3 = muli %c2, %f2 : index 39 %num_flops = muli %reps, %f3 : index 40 %num_flops_i = index_cast %num_flops : index to i16 41 %num_flops_f = sitofp %num_flops_i : i16 to f64 42 %flops = divf %num_flops_f, %t : f64 43 call @print_flops(%flops) : (f64) -> () 44 45 return 46} 47// CHECK: 17, 17, 17, 48 49func @sgemm_naive(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf32>, %arg2: memref<16x16xf32>) { 50 %c0 = constant 0 : index 51 affine.for %arg3 = 0 to 16 { 52 affine.for %arg4 = 0 to 16 { 53 %m = alloc() : memref<1xf32> 54 %v = affine.load %arg2[%arg3, %arg4] : memref<16x16xf32> 55 affine.store %v, %m[%c0] : memref<1xf32> 56 affine.for %arg5 = 0 to 16 { 57 %3 = affine.load %arg0[%arg3, %arg5] : memref<16x16xf32> 58 %4 = affine.load %arg1[%arg5, %arg4] : memref<16x16xf32> 59 %5 = affine.load %m[0] : memref<1xf32> 60 %6 = mulf %3, %4 : f32 61 %7 = addf %6, %5 : f32 62 affine.store %7, %m[0] : memref<1xf32> 63 } 64 %s = affine.load %m[%c0] : memref<1xf32> 65 affine.store %s, %arg2[%arg3, %arg4] : memref<16x16xf32> 66 dealloc %m : memref<1xf32> 67 } 68 } 69 return 70} 71 72func private @print_flops(f64) 73func private @rtclock() -> f64 74func private @print_memref_f32(memref<*xf32>) 75