1 /*
2  * Copyright 2012      Ecole Normale Superieure
3  *
4  * Use of this software is governed by the MIT license
5  *
6  * Written by Sven Verdoolaege,
7  * Ecole Normale Superieure, 45 rue d’Ulm, 75230 Paris, France
8  */
9 
10 #include <string.h>
11 
12 #include <isl/aff.h>
13 
14 #include "gpu_print.h"
15 #include "print.h"
16 #include "schedule.h"
17 
18 /* Print declarations to "p" for arrays that are local to "prog"
19  * but that are used on the host and therefore require a declaration.
20  */
gpu_print_local_declarations(__isl_take isl_printer * p,struct gpu_prog * prog)21 __isl_give isl_printer *gpu_print_local_declarations(__isl_take isl_printer *p,
22 	struct gpu_prog *prog)
23 {
24 	int i;
25 
26 	if (!prog)
27 		return isl_printer_free(p);
28 
29 	for (i = 0; i < prog->n_array; ++i) {
30 		struct gpu_array_info *array = &prog->array[i];
31 		isl_ast_expr *size;
32 
33 		if (!array->declare_local)
34 			continue;
35 		size = array->declared_size;
36 		p = ppcg_print_declaration_with_size(p, array->type, size);
37 	}
38 
39 	return p;
40 }
41 
42 /* Print an expression for the size of "array" in bytes.
43  */
gpu_array_info_print_size(__isl_take isl_printer * prn,struct gpu_array_info * array)44 __isl_give isl_printer *gpu_array_info_print_size(__isl_take isl_printer *prn,
45 	struct gpu_array_info *array)
46 {
47 	int i;
48 
49 	for (i = 0; i < array->n_index; ++i) {
50 		isl_ast_expr *bound;
51 
52 		prn = isl_printer_print_str(prn, "(");
53 		bound = isl_ast_expr_get_op_arg(array->bound_expr, 1 + i);
54 		prn = isl_printer_print_ast_expr(prn, bound);
55 		isl_ast_expr_free(bound);
56 		prn = isl_printer_print_str(prn, ") * ");
57 	}
58 	prn = isl_printer_print_str(prn, "sizeof(");
59 	prn = isl_printer_print_str(prn, array->type);
60 	prn = isl_printer_print_str(prn, ")");
61 
62 	return prn;
63 }
64 
65 /* Print the declaration of a non-linearized array argument.
66  */
print_non_linearized_declaration_argument(__isl_take isl_printer * p,struct gpu_array_info * array)67 static __isl_give isl_printer *print_non_linearized_declaration_argument(
68 	__isl_take isl_printer *p, struct gpu_array_info *array)
69 {
70 	p = isl_printer_print_str(p, array->type);
71 	p = isl_printer_print_str(p, " ");
72 
73 	p = isl_printer_print_ast_expr(p, array->bound_expr);
74 
75 	return p;
76 }
77 
78 /* Print the declaration of an array argument.
79  * "memory_space" allows to specify a memory space prefix.
80  */
gpu_array_info_print_declaration_argument(__isl_take isl_printer * p,struct gpu_array_info * array,const char * memory_space)81 __isl_give isl_printer *gpu_array_info_print_declaration_argument(
82 	__isl_take isl_printer *p, struct gpu_array_info *array,
83 	const char *memory_space)
84 {
85 	if (gpu_array_is_read_only_scalar(array)) {
86 		p = isl_printer_print_str(p, array->type);
87 		p = isl_printer_print_str(p, " ");
88 		p = isl_printer_print_str(p, array->name);
89 		return p;
90 	}
91 
92 	if (memory_space) {
93 		p = isl_printer_print_str(p, memory_space);
94 		p = isl_printer_print_str(p, " ");
95 	}
96 
97 	if (array->n_index != 0 && !array->linearize)
98 		return print_non_linearized_declaration_argument(p, array);
99 
100 	p = isl_printer_print_str(p, array->type);
101 	p = isl_printer_print_str(p, " ");
102 	p = isl_printer_print_str(p, "*");
103 	p = isl_printer_print_str(p, array->name);
104 
105 	return p;
106 }
107 
108 /* Print the call of an array argument.
109  */
gpu_array_info_print_call_argument(__isl_take isl_printer * p,struct gpu_array_info * array)110 __isl_give isl_printer *gpu_array_info_print_call_argument(
111 	__isl_take isl_printer *p, struct gpu_array_info *array)
112 {
113 	if (gpu_array_is_read_only_scalar(array))
114 		return isl_printer_print_str(p, array->name);
115 
116 	p = isl_printer_print_str(p, "dev_");
117 	p = isl_printer_print_str(p, array->name);
118 
119 	return p;
120 }
121 
122 /* Print an access to the element in the private/shared memory copy
123  * described by "stmt".  The index of the copy is recorded in
124  * stmt->local_index as an access to the array.
125  */
stmt_print_local_index(__isl_take isl_printer * p,struct ppcg_kernel_stmt * stmt)126 static __isl_give isl_printer *stmt_print_local_index(__isl_take isl_printer *p,
127 	struct ppcg_kernel_stmt *stmt)
128 {
129 	return isl_printer_print_ast_expr(p, stmt->u.c.local_index);
130 }
131 
132 /* Print an access to the element in the global memory copy
133  * described by "stmt".  The index of the copy is recorded in
134  * stmt->index as an access to the array.
135  */
stmt_print_global_index(__isl_take isl_printer * p,struct ppcg_kernel_stmt * stmt)136 static __isl_give isl_printer *stmt_print_global_index(
137 	__isl_take isl_printer *p, struct ppcg_kernel_stmt *stmt)
138 {
139 	struct gpu_array_info *array = stmt->u.c.array;
140 	isl_ast_expr *index;
141 
142 	if (gpu_array_is_scalar(array)) {
143 		if (!gpu_array_is_read_only_scalar(array))
144 			p = isl_printer_print_str(p, "*");
145 		p = isl_printer_print_str(p, array->name);
146 		return p;
147 	}
148 
149 	index = isl_ast_expr_copy(stmt->u.c.index);
150 
151 	p = isl_printer_print_ast_expr(p, index);
152 	isl_ast_expr_free(index);
153 
154 	return p;
155 }
156 
157 /* Print a copy statement.
158  *
159  * A read copy statement is printed as
160  *
161  *	local = global;
162  *
163  * while a write copy statement is printed as
164  *
165  *	global = local;
166  */
ppcg_kernel_print_copy(__isl_take isl_printer * p,struct ppcg_kernel_stmt * stmt)167 __isl_give isl_printer *ppcg_kernel_print_copy(__isl_take isl_printer *p,
168 	struct ppcg_kernel_stmt *stmt)
169 {
170 	p = isl_printer_start_line(p);
171 	if (stmt->u.c.read) {
172 		p = stmt_print_local_index(p, stmt);
173 		p = isl_printer_print_str(p, " = ");
174 		p = stmt_print_global_index(p, stmt);
175 	} else {
176 		p = stmt_print_global_index(p, stmt);
177 		p = isl_printer_print_str(p, " = ");
178 		p = stmt_print_local_index(p, stmt);
179 	}
180 	p = isl_printer_print_str(p, ";");
181 	p = isl_printer_end_line(p);
182 
183 	return p;
184 }
185 
ppcg_kernel_print_domain(__isl_take isl_printer * p,struct ppcg_kernel_stmt * stmt)186 __isl_give isl_printer *ppcg_kernel_print_domain(__isl_take isl_printer *p,
187 	struct ppcg_kernel_stmt *stmt)
188 {
189 	return pet_stmt_print_body(stmt->u.d.stmt->stmt, p, stmt->u.d.ref2expr);
190 }
191 
192 /* This function is called for each node in a GPU AST.
193  * In case of a user node, print the macro definitions required
194  * for printing the AST expressions in the annotation, if any.
195  * For other nodes, return true such that descendants are also
196  * visited.
197  *
198  * In particular, for a kernel launch, print the macro definitions
199  * needed for the grid size.
200  * For a copy statement, print the macro definitions needed
201  * for the two index expressions.
202  * For an original user statement, print the macro definitions
203  * needed for the substitutions.
204  */
at_node(__isl_keep isl_ast_node * node,void * user)205 static isl_bool at_node(__isl_keep isl_ast_node *node, void *user)
206 {
207 	const char *name;
208 	isl_id *id;
209 	int is_kernel;
210 	struct ppcg_kernel *kernel;
211 	struct ppcg_kernel_stmt *stmt;
212 	isl_printer **p = user;
213 
214 	if (isl_ast_node_get_type(node) != isl_ast_node_user)
215 		return isl_bool_true;
216 
217 	id = isl_ast_node_get_annotation(node);
218 	if (!id)
219 		return isl_bool_false;
220 
221 	name = isl_id_get_name(id);
222 	if (!name)
223 		return isl_bool_error;
224 	is_kernel = !strcmp(name, "kernel");
225 	kernel = is_kernel ? isl_id_get_user(id) : NULL;
226 	stmt = is_kernel ? NULL : isl_id_get_user(id);
227 	isl_id_free(id);
228 
229 	if ((is_kernel && !kernel) || (!is_kernel && !stmt))
230 		return isl_bool_error;
231 
232 	if (is_kernel) {
233 		*p = ppcg_ast_expr_print_macros(kernel->grid_size_expr, *p);
234 	} else if (stmt->type == ppcg_kernel_copy) {
235 		*p = ppcg_ast_expr_print_macros(stmt->u.c.index, *p);
236 		*p = ppcg_ast_expr_print_macros(stmt->u.c.local_index, *p);
237 	} else if (stmt->type == ppcg_kernel_domain) {
238 		*p = ppcg_print_body_macros(*p, stmt->u.d.ref2expr);
239 	}
240 	if (!*p)
241 		return isl_bool_error;
242 
243 	return isl_bool_false;
244 }
245 
246 /* Print the required macros for the GPU AST "node" to "p",
247  * including those needed for the user statements inside the AST.
248  */
gpu_print_macros(__isl_take isl_printer * p,__isl_keep isl_ast_node * node)249 __isl_give isl_printer *gpu_print_macros(__isl_take isl_printer *p,
250 	__isl_keep isl_ast_node *node)
251 {
252 	if (isl_ast_node_foreach_descendant_top_down(node, &at_node, &p) < 0)
253 		return isl_printer_free(p);
254 	p = ppcg_print_macros(p, node);
255 	return p;
256 }
257 
258 /* Was the definition of "type" printed before?
259  * That is, does its name appear in the list of printed types "types"?
260  */
already_printed(struct gpu_types * types,struct pet_type * type)261 static int already_printed(struct gpu_types *types,
262 	struct pet_type *type)
263 {
264 	int i;
265 
266 	for (i = 0; i < types->n; ++i)
267 		if (!strcmp(types->name[i], type->name))
268 			return 1;
269 
270 	return 0;
271 }
272 
273 /* Print the definitions of all types prog->scop that have not been
274  * printed before (according to "types") on "p".
275  * Extend the list of printed types "types" with the newly printed types.
276  */
gpu_print_types(__isl_take isl_printer * p,struct gpu_types * types,struct gpu_prog * prog)277 __isl_give isl_printer *gpu_print_types(__isl_take isl_printer *p,
278 	struct gpu_types *types, struct gpu_prog *prog)
279 {
280 	int i, n;
281 	isl_ctx *ctx;
282 	char **name;
283 
284 	n = prog->scop->pet->n_type;
285 
286 	if (n == 0)
287 		return p;
288 
289 	ctx = isl_printer_get_ctx(p);
290 	name = isl_realloc_array(ctx, types->name, char *, types->n + n);
291 	if (!name)
292 		return isl_printer_free(p);
293 	types->name = name;
294 
295 	for (i = 0; i < n; ++i) {
296 		struct pet_type *type = prog->scop->pet->types[i];
297 
298 		if (already_printed(types, type))
299 			continue;
300 
301 		p = isl_printer_start_line(p);
302 		p = isl_printer_print_str(p, type->definition);
303 		p = isl_printer_print_str(p, ";");
304 		p = isl_printer_end_line(p);
305 
306 		types->name[types->n++] = strdup(type->name);
307 	}
308 
309 	return p;
310 }
311