1 //===- RunnerUtils.h - Utils for debugging MLIR execution -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares basic classes and functions to debug structured MLIR
10 // types at runtime. Entities in this file may not be compatible with targets
11 // without a C++ runtime. These may be progressively migrated to CRunnerUtils.h
12 // over time.
13 //
14 //===----------------------------------------------------------------------===//
15
16 #ifndef EXECUTIONENGINE_RUNNERUTILS_H_
17 #define EXECUTIONENGINE_RUNNERUTILS_H_
18
19 #ifdef _WIN32
20 #ifndef MLIR_RUNNERUTILS_EXPORT
21 #ifdef mlir_runner_utils_EXPORTS
22 // We are building this library
23 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllexport)
24 #else
25 // We are using this library
26 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllimport)
27 #endif // mlir_runner_utils_EXPORTS
28 #endif // MLIR_RUNNERUTILS_EXPORT
29 #else
30 #define MLIR_RUNNERUTILS_EXPORT
31 #endif // _WIN32
32
33 #include <assert.h>
34 #include <iostream>
35
36 #include "mlir/ExecutionEngine/CRunnerUtils.h"
37
38 template <typename T, typename StreamType>
printMemRefMetaData(StreamType & os,const DynamicMemRefType<T> & V)39 void printMemRefMetaData(StreamType &os, const DynamicMemRefType<T> &V) {
40 os << "base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << V.rank
41 << " offset = " << V.offset;
42 auto print = [&](const int64_t *ptr) {
43 if (V.rank == 0)
44 return;
45 os << ptr[0];
46 for (int64_t i = 1; i < V.rank; ++i)
47 os << ", " << ptr[i];
48 };
49 os << " sizes = [";
50 print(V.sizes);
51 os << "] strides = [";
52 print(V.strides);
53 os << "]";
54 }
55
56 template <typename StreamType, typename T, int N>
printMemRefMetaData(StreamType & os,StridedMemRefType<T,N> & V)57 void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
58 static_assert(N >= 0, "Expected N > 0");
59 os << "MemRef ";
60 printMemRefMetaData(os, DynamicMemRefType<T>(V));
61 }
62
63 template <typename StreamType, typename T>
printUnrankedMemRefMetaData(StreamType & os,UnrankedMemRefType<T> & V)64 void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
65 os << "Unranked MemRef ";
66 printMemRefMetaData(os, DynamicMemRefType<T>(V));
67 }
68
69 ////////////////////////////////////////////////////////////////////////////////
70 // Templated instantiation follows.
71 ////////////////////////////////////////////////////////////////////////////////
72 namespace impl {
73 template <typename T, int M, int... Dims>
74 std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v);
75
76 template <int... Dims> struct StaticSizeMult {
77 static constexpr int value = 1;
78 };
79
80 template <int N, int... Dims> struct StaticSizeMult<N, Dims...> {
81 static constexpr int value = N * StaticSizeMult<Dims...>::value;
82 };
83
84 static inline void printSpace(std::ostream &os, int count) {
85 for (int i = 0; i < count; ++i) {
86 os << ' ';
87 }
88 }
89
90 template <typename T, int M, int... Dims> struct VectorDataPrinter {
91 static void print(std::ostream &os, const Vector<T, M, Dims...> &val);
92 };
93
94 template <typename T, int M, int... Dims>
95 void VectorDataPrinter<T, M, Dims...>::print(std::ostream &os,
96 const Vector<T, M, Dims...> &val) {
97 static_assert(M > 0, "0 dimensioned tensor");
98 static_assert(sizeof(val) == M * StaticSizeMult<Dims...>::value * sizeof(T),
99 "Incorrect vector size!");
100 // First
101 os << "(" << val[0];
102 if (M > 1)
103 os << ", ";
104 if (sizeof...(Dims) > 1)
105 os << "\n";
106 // Kernel
107 for (unsigned i = 1; i + 1 < M; ++i) {
108 printSpace(os, 2 * sizeof...(Dims));
109 os << val[i] << ", ";
110 if (sizeof...(Dims) > 1)
111 os << "\n";
112 }
113 // Last
114 if (M > 1) {
115 printSpace(os, sizeof...(Dims));
116 os << val[M - 1];
117 }
118 os << ")";
119 }
120
121 template <typename T, int M, int... Dims>
122 std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
123 VectorDataPrinter<T, M, Dims...>::print(os, v);
124 return os;
125 }
126
127 template <typename T>
128 struct MemRefDataPrinter {
129 static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,
130 int64_t offset, const int64_t *sizes,
131 const int64_t *strides);
132 static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,
133 int64_t offset, const int64_t *sizes,
134 const int64_t *strides);
135 static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,
136 int64_t offset, const int64_t *sizes,
137 const int64_t *strides);
138 };
139
140 template <typename T>
141 void MemRefDataPrinter<T>::printFirst(std::ostream &os, T *base, int64_t dim,
142 int64_t rank, int64_t offset,
143 const int64_t *sizes,
144 const int64_t *strides) {
145 os << "[";
146 print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1);
147 // If single element, close square bracket and return early.
148 if (sizes[0] <= 1) {
149 os << "]";
150 return;
151 }
152 os << ", ";
153 if (dim > 1)
154 os << "\n";
155 }
156
157 template <typename T>
158 void MemRefDataPrinter<T>::print(std::ostream &os, T *base, int64_t dim,
159 int64_t rank, int64_t offset,
160 const int64_t *sizes, const int64_t *strides) {
161 if (dim == 0) {
162 os << base[offset];
163 return;
164 }
165 printFirst(os, base, dim, rank, offset, sizes, strides);
166 for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
167 printSpace(os, rank - dim + 1);
168 print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1,
169 strides + 1);
170 os << ", ";
171 if (dim > 1)
172 os << "\n";
173 }
174 if (sizes[0] <= 1)
175 return;
176 printLast(os, base, dim, rank, offset, sizes, strides);
177 }
178
179 template <typename T>
180 void MemRefDataPrinter<T>::printLast(std::ostream &os, T *base, int64_t dim,
181 int64_t rank, int64_t offset,
182 const int64_t *sizes,
183 const int64_t *strides) {
184 printSpace(os, rank - dim + 1);
185 print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides),
186 sizes + 1, strides + 1);
187 os << "]";
188 }
189
190 template <typename T>
191 void printMemRef(const DynamicMemRefType<T> &M) {
192 printMemRefMetaData(std::cout, M);
193 std::cout << " data = " << std::endl;
194 if (M.rank == 0)
195 std::cout << "[";
196 MemRefDataPrinter<T>::print(std::cout, M.data, M.rank, M.rank, M.offset,
197 M.sizes, M.strides);
198 if (M.rank == 0)
199 std::cout << "]";
200 std::cout << std::endl;
201 }
202
203 template <typename T, int N>
204 void printMemRef(StridedMemRefType<T, N> &M) {
205 std::cout << "Memref ";
206 printMemRef(DynamicMemRefType<T>(M));
207 }
208
209 template <typename T>
210 void printMemRef(UnrankedMemRefType<T> &M) {
211 std::cout << "Unranked Memref ";
212 printMemRef(DynamicMemRefType<T>(M));
213 }
214 } // namespace impl
215
216 ////////////////////////////////////////////////////////////////////////////////
217 // Currently exposed C API.
218 ////////////////////////////////////////////////////////////////////////////////
219 extern "C" MLIR_RUNNERUTILS_EXPORT void
220 _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
221 extern "C" MLIR_RUNNERUTILS_EXPORT void
222 _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
223 extern "C" MLIR_RUNNERUTILS_EXPORT void
224 _mlir_ciface_print_memref_f64(UnrankedMemRefType<double> *M);
225
226 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i32(int64_t rank,
227 void *ptr);
228 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i64(int64_t rank,
229 void *ptr);
230 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f32(int64_t rank,
231 void *ptr);
232 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f64(int64_t rank,
233 void *ptr);
234
235 extern "C" MLIR_RUNNERUTILS_EXPORT void
236 _mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M);
237 extern "C" MLIR_RUNNERUTILS_EXPORT void
238 _mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M);
239 extern "C" MLIR_RUNNERUTILS_EXPORT void
240 _mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M);
241 extern "C" MLIR_RUNNERUTILS_EXPORT void
242 _mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M);
243 extern "C" MLIR_RUNNERUTILS_EXPORT void
244 _mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M);
245
246 extern "C" MLIR_RUNNERUTILS_EXPORT void
247 _mlir_ciface_print_memref_vector_4x4xf32(
248 StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
249
250 #endif // EXECUTIONENGINE_RUNNERUTILS_H_
251