1 //===- RunnerUtils.h - Utils for debugging MLIR execution -----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file declares basic classes and functions to debug structured MLIR
10 // types at runtime. Entities in this file may not be compatible with targets
11 // without a C++ runtime. These may be progressively migrated to CRunnerUtils.h
12 // over time.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #ifndef EXECUTIONENGINE_RUNNERUTILS_H_
17 #define EXECUTIONENGINE_RUNNERUTILS_H_
18 
19 #ifdef _WIN32
20 #ifndef MLIR_RUNNERUTILS_EXPORT
21 #ifdef mlir_runner_utils_EXPORTS
22 // We are building this library
23 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllexport)
24 #else
25 // We are using this library
26 #define MLIR_RUNNERUTILS_EXPORT __declspec(dllimport)
27 #endif // mlir_runner_utils_EXPORTS
28 #endif // MLIR_RUNNERUTILS_EXPORT
29 #else
30 #define MLIR_RUNNERUTILS_EXPORT
31 #endif // _WIN32
32 
33 #include <assert.h>
34 #include <iostream>
35 
36 #include "mlir/ExecutionEngine/CRunnerUtils.h"
37 
38 template <typename T, typename StreamType>
printMemRefMetaData(StreamType & os,const DynamicMemRefType<T> & V)39 void printMemRefMetaData(StreamType &os, const DynamicMemRefType<T> &V) {
40   os << "base@ = " << reinterpret_cast<void *>(V.data) << " rank = " << V.rank
41      << " offset = " << V.offset;
42   auto print = [&](const int64_t *ptr) {
43     if (V.rank == 0)
44       return;
45     os << ptr[0];
46     for (int64_t i = 1; i < V.rank; ++i)
47       os << ", " << ptr[i];
48   };
49   os << " sizes = [";
50   print(V.sizes);
51   os << "] strides = [";
52   print(V.strides);
53   os << "]";
54 }
55 
56 template <typename StreamType, typename T, int N>
printMemRefMetaData(StreamType & os,StridedMemRefType<T,N> & V)57 void printMemRefMetaData(StreamType &os, StridedMemRefType<T, N> &V) {
58   static_assert(N >= 0, "Expected N > 0");
59   os << "MemRef ";
60   printMemRefMetaData(os, DynamicMemRefType<T>(V));
61 }
62 
63 template <typename StreamType, typename T>
printUnrankedMemRefMetaData(StreamType & os,UnrankedMemRefType<T> & V)64 void printUnrankedMemRefMetaData(StreamType &os, UnrankedMemRefType<T> &V) {
65   os << "Unranked MemRef ";
66   printMemRefMetaData(os, DynamicMemRefType<T>(V));
67 }
68 
69 ////////////////////////////////////////////////////////////////////////////////
70 // Templated instantiation follows.
71 ////////////////////////////////////////////////////////////////////////////////
72 namespace impl {
73 template <typename T, int M, int... Dims>
74 std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v);
75 
76 template <int... Dims> struct StaticSizeMult {
77   static constexpr int value = 1;
78 };
79 
80 template <int N, int... Dims> struct StaticSizeMult<N, Dims...> {
81   static constexpr int value = N * StaticSizeMult<Dims...>::value;
82 };
83 
84 static inline void printSpace(std::ostream &os, int count) {
85   for (int i = 0; i < count; ++i) {
86     os << ' ';
87   }
88 }
89 
90 template <typename T, int M, int... Dims> struct VectorDataPrinter {
91   static void print(std::ostream &os, const Vector<T, M, Dims...> &val);
92 };
93 
94 template <typename T, int M, int... Dims>
95 void VectorDataPrinter<T, M, Dims...>::print(std::ostream &os,
96                                              const Vector<T, M, Dims...> &val) {
97   static_assert(M > 0, "0 dimensioned tensor");
98   static_assert(sizeof(val) == M * StaticSizeMult<Dims...>::value * sizeof(T),
99                 "Incorrect vector size!");
100   // First
101   os << "(" << val[0];
102   if (M > 1)
103     os << ", ";
104   if (sizeof...(Dims) > 1)
105     os << "\n";
106   // Kernel
107   for (unsigned i = 1; i + 1 < M; ++i) {
108     printSpace(os, 2 * sizeof...(Dims));
109     os << val[i] << ", ";
110     if (sizeof...(Dims) > 1)
111       os << "\n";
112   }
113   // Last
114   if (M > 1) {
115     printSpace(os, sizeof...(Dims));
116     os << val[M - 1];
117   }
118   os << ")";
119 }
120 
121 template <typename T, int M, int... Dims>
122 std::ostream &operator<<(std::ostream &os, const Vector<T, M, Dims...> &v) {
123   VectorDataPrinter<T, M, Dims...>::print(os, v);
124   return os;
125 }
126 
127 template <typename T>
128 struct MemRefDataPrinter {
129   static void print(std::ostream &os, T *base, int64_t dim, int64_t rank,
130                     int64_t offset, const int64_t *sizes,
131                     const int64_t *strides);
132   static void printFirst(std::ostream &os, T *base, int64_t dim, int64_t rank,
133                          int64_t offset, const int64_t *sizes,
134                          const int64_t *strides);
135   static void printLast(std::ostream &os, T *base, int64_t dim, int64_t rank,
136                         int64_t offset, const int64_t *sizes,
137                         const int64_t *strides);
138 };
139 
140 template <typename T>
141 void MemRefDataPrinter<T>::printFirst(std::ostream &os, T *base, int64_t dim,
142                                       int64_t rank, int64_t offset,
143                                       const int64_t *sizes,
144                                       const int64_t *strides) {
145   os << "[";
146   print(os, base, dim - 1, rank, offset, sizes + 1, strides + 1);
147   // If single element, close square bracket and return early.
148   if (sizes[0] <= 1) {
149     os << "]";
150     return;
151   }
152   os << ", ";
153   if (dim > 1)
154     os << "\n";
155 }
156 
157 template <typename T>
158 void MemRefDataPrinter<T>::print(std::ostream &os, T *base, int64_t dim,
159                                  int64_t rank, int64_t offset,
160                                  const int64_t *sizes, const int64_t *strides) {
161   if (dim == 0) {
162     os << base[offset];
163     return;
164   }
165   printFirst(os, base, dim, rank, offset, sizes, strides);
166   for (unsigned i = 1; i + 1 < sizes[0]; ++i) {
167     printSpace(os, rank - dim + 1);
168     print(os, base, dim - 1, rank, offset + i * strides[0], sizes + 1,
169           strides + 1);
170     os << ", ";
171     if (dim > 1)
172       os << "\n";
173   }
174   if (sizes[0] <= 1)
175     return;
176   printLast(os, base, dim, rank, offset, sizes, strides);
177 }
178 
179 template <typename T>
180 void MemRefDataPrinter<T>::printLast(std::ostream &os, T *base, int64_t dim,
181                                      int64_t rank, int64_t offset,
182                                      const int64_t *sizes,
183                                      const int64_t *strides) {
184   printSpace(os, rank - dim + 1);
185   print(os, base, dim - 1, rank, offset + (sizes[0] - 1) * (*strides),
186         sizes + 1, strides + 1);
187   os << "]";
188 }
189 
190 template <typename T>
191 void printMemRef(const DynamicMemRefType<T> &M) {
192   printMemRefMetaData(std::cout, M);
193   std::cout << " data = " << std::endl;
194   if (M.rank == 0)
195     std::cout << "[";
196   MemRefDataPrinter<T>::print(std::cout, M.data, M.rank, M.rank, M.offset,
197                               M.sizes, M.strides);
198   if (M.rank == 0)
199     std::cout << "]";
200   std::cout << std::endl;
201 }
202 
203 template <typename T, int N>
204 void printMemRef(StridedMemRefType<T, N> &M) {
205   std::cout << "Memref ";
206   printMemRef(DynamicMemRefType<T>(M));
207 }
208 
209 template <typename T>
210 void printMemRef(UnrankedMemRefType<T> &M) {
211   std::cout << "Unranked Memref ";
212   printMemRef(DynamicMemRefType<T>(M));
213 }
214 } // namespace impl
215 
216 ////////////////////////////////////////////////////////////////////////////////
217 // Currently exposed C API.
218 ////////////////////////////////////////////////////////////////////////////////
219 extern "C" MLIR_RUNNERUTILS_EXPORT void
220 _mlir_ciface_print_memref_i8(UnrankedMemRefType<int8_t> *M);
221 extern "C" MLIR_RUNNERUTILS_EXPORT void
222 _mlir_ciface_print_memref_f32(UnrankedMemRefType<float> *M);
223 extern "C" MLIR_RUNNERUTILS_EXPORT void
224 _mlir_ciface_print_memref_f64(UnrankedMemRefType<double> *M);
225 
226 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i32(int64_t rank,
227                                                          void *ptr);
228 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_i64(int64_t rank,
229                                                          void *ptr);
230 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f32(int64_t rank,
231                                                          void *ptr);
232 extern "C" MLIR_RUNNERUTILS_EXPORT void print_memref_f64(int64_t rank,
233                                                          void *ptr);
234 
235 extern "C" MLIR_RUNNERUTILS_EXPORT void
236 _mlir_ciface_print_memref_0d_f32(StridedMemRefType<float, 0> *M);
237 extern "C" MLIR_RUNNERUTILS_EXPORT void
238 _mlir_ciface_print_memref_1d_f32(StridedMemRefType<float, 1> *M);
239 extern "C" MLIR_RUNNERUTILS_EXPORT void
240 _mlir_ciface_print_memref_2d_f32(StridedMemRefType<float, 2> *M);
241 extern "C" MLIR_RUNNERUTILS_EXPORT void
242 _mlir_ciface_print_memref_3d_f32(StridedMemRefType<float, 3> *M);
243 extern "C" MLIR_RUNNERUTILS_EXPORT void
244 _mlir_ciface_print_memref_4d_f32(StridedMemRefType<float, 4> *M);
245 
246 extern "C" MLIR_RUNNERUTILS_EXPORT void
247 _mlir_ciface_print_memref_vector_4x4xf32(
248     StridedMemRefType<Vector2D<4, 4, float>, 2> *M);
249 
250 #endif // EXECUTIONENGINE_RUNNERUTILS_H_
251