1 /*
2  * Copyright 2008 Google Inc.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include <setjmp.h>
17 #ifndef _WIN32
18 #include <signal.h>
19 #endif // !_WIN32
20 #include <stdarg.h>
21 #include <stddef.h>
22 #include <stdio.h>
23 #include <stdlib.h>
24 #include <string.h>
25 #ifdef _WIN32
26 #include <windows.h>
27 #endif // _WIN32
28 #include <cmockery.h>
29 
30 #ifdef _WIN32
31 #define vsnprintf _vsnprintf
32 #endif // _WIN32
33 
34 // Size of guard bytes around dynamically allocated blocks.
35 #define MALLOC_GUARD_SIZE 16
36 // Pattern used to initialize guard blocks.
37 #define MALLOC_GUARD_PATTERN 0xEF
38 // Pattern used to initialize memory allocated with test_malloc().
39 #define MALLOC_ALLOC_PATTERN 0xBA
40 #define MALLOC_FREE_PATTERN 0xCD
41 // Alignment of allocated blocks.  NOTE: This must be base2.
42 #define MALLOC_ALIGNMENT sizeof(size_t)
43 
44 // Printf formatting for source code locations.
45 #define SOURCE_LOCATION_FORMAT "%s:%d"
46 
47 // Calculates the number of elements in an array.
48 #define ARRAY_LENGTH(x) (sizeof(x) / sizeof((x)[0]))
49 
50 // Doubly linked list node.
51 typedef struct ListNode {
52 	const void *value;
53 	int refcount;
54 	struct ListNode *next;
55 	struct ListNode *prev;
56 } ListNode;
57 
58 // Debug information for malloc().
59 typedef struct MallocBlockInfo {
60 	void* block;              // Address of the block returned by malloc().
61 	size_t allocated_size;    // Total size of the allocated block.
62 	size_t size;              // Request block size.
63 	SourceLocation location;  // Where the block was allocated.
64 	ListNode node;            // Node within list of all allocated blocks.
65 } MallocBlockInfo;
66 
67 // State of each test.
68 typedef struct TestState {
69 	const ListNode *check_point; // Check point of the test if there's a
70 	                             // setup function.
71 	void *state;                 // State associated with the test.
72 } TestState;
73 
74 // Determines whether two values are the same.
75 typedef int (*EqualityFunction)(const void *left, const void *right);
76 
77 // Value of a symbol and the place it was declared.
78 typedef struct SymbolValue {
79 	SourceLocation location;
80 	const void* value;
81 } SymbolValue;
82 
83 /* Contains a list of values for a symbol.
84  * NOTE: Each structure referenced by symbol_values_list_head must have a
85  * SourceLocation as its' first member.
86  */
87 typedef struct SymbolMapValue {
88 	const char *symbol_name;
89 	ListNode symbol_values_list_head;
90 } SymbolMapValue;
91 
92 // Used by list_free() to deallocate values referenced by list nodes.
93 typedef void (*CleanupListValue)(const void *value, void *cleanup_value_data);
94 
95 // Structure used to check the range of integer types.
96 typedef struct CheckIntegerRange {
97 	CheckParameterEvent event;
98 	int minimum;
99 	int maximum;
100 } CheckIntegerRange;
101 
102 // Structure used to check whether an integer value is in a set.
103 typedef struct CheckIntegerSet {
104 	CheckParameterEvent event;
105 	const void **set;
106 	size_t size_of_set;
107 } CheckIntegerSet;
108 
109 /* Used to check whether a parameter matches the area of memory referenced by
110  * this structure.  */
111 typedef struct CheckMemoryData {
112 	CheckParameterEvent event;
113 	const void *memory;
114 	size_t size;
115 } CheckMemoryData;
116 
117 static ListNode* list_initialize(ListNode * const node);
118 static ListNode* list_add(ListNode * const head, ListNode *new_node);
119 static ListNode* list_add_value(ListNode * const head, const void *value,
120                                      const int count);
121 static ListNode* list_remove(
122     ListNode * const node, const CleanupListValue cleanup_value,
123     void * const cleanup_value_data);
124 static void list_remove_free(
125     ListNode * const node, const CleanupListValue cleanup_value,
126     void * const cleanup_value_data);
127 static int list_empty(const ListNode * const head);
128 static int list_find(
129     ListNode * const head, const void *value,
130     const EqualityFunction equal_func, ListNode **output);
131 static int list_first(ListNode * const head, ListNode **output);
132 static ListNode* list_free(
133     ListNode * const head, const CleanupListValue cleanup_value,
134     void * const cleanup_value_data);
135 
136 static void add_symbol_value(
137     ListNode * const symbol_map_head, const char * const symbol_names[],
138     const size_t number_of_symbol_names, const void* value, const int count);
139 static int get_symbol_value(
140     ListNode * const symbol_map_head, const char * const symbol_names[],
141     const size_t number_of_symbol_names, void **output);
142 static void free_value(const void *value, void *cleanup_value_data);
143 static void free_symbol_map_value(
144     const void *value, void *cleanup_value_data);
145 static void remove_always_return_values(ListNode * const map_head,
146                                         const size_t number_of_symbol_names);
147 static int check_for_leftover_values(
148     const ListNode * const map_head, const char * const error_message,
149     const size_t number_of_symbol_names);
150 // This must be called at the beginning of a test to initialize some data
151 // structures.
152 static void initialize_testing(const char *test_name);
153 // This must be called at the end of a test to free() allocated structures.
154 static void teardown_testing(const char *test_name);
155 
156 
157 // Keeps track of the calling context returned by setenv() so that the fail()
158 // method can jump out of a test.
159 static jmp_buf global_run_test_env;
160 static int global_running_test = 0;
161 
162 // Keeps track of the calling context returned by setenv() so that
163 // mock_assert() can optionally jump back to expect_assert_failure().
164 jmp_buf global_expect_assert_env;
165 int global_expecting_assert = 0;
166 
167 // Keeps a map of the values that functions will have to return to provide
168 // mocked interfaces.
169 static ListNode global_function_result_map_head;
170 // Location of the last mock value returned was declared.
171 static SourceLocation global_last_mock_value_location;
172 
173 /* Keeps a map of the values that functions expect as parameters to their
174  * mocked interfaces. */
175 static ListNode global_function_parameter_map_head;
176 // Location of last parameter value checked was declared.
177 static SourceLocation global_last_parameter_location;
178 
179 // List of all currently allocated blocks.
180 static ListNode global_allocated_blocks;
181 
182 #ifndef _WIN32
183 // Signals caught by exception_handler().
184 static const int exception_signals[] = {
185 	SIGFPE,
186 	SIGILL,
187 	SIGSEGV,
188 	SIGBUS,
189 	SIGSYS,
190 };
191 
192 // Default signal functions that should be restored after a test is complete.
193 typedef void (*SignalFunction)(int signal);
194 static SignalFunction default_signal_functions[
195     ARRAY_LENGTH(exception_signals)];
196 
197 #else // _WIN32
198 
199 // The default exception filter.
200 static LPTOP_LEVEL_EXCEPTION_FILTER previous_exception_filter;
201 
202 // Fatal exceptions.
203 typedef struct ExceptionCodeInfo {
204 	DWORD code;
205 	const char* description;
206 } ExceptionCodeInfo;
207 
208 #define EXCEPTION_CODE_INFO(exception_code) {exception_code, #exception_code}
209 
210 static const ExceptionCodeInfo exception_codes[] = {
211 	EXCEPTION_CODE_INFO(EXCEPTION_ACCESS_VIOLATION),
212 	EXCEPTION_CODE_INFO(EXCEPTION_ARRAY_BOUNDS_EXCEEDED),
213 	EXCEPTION_CODE_INFO(EXCEPTION_DATATYPE_MISALIGNMENT),
214 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_DENORMAL_OPERAND),
215 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_DIVIDE_BY_ZERO),
216 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_INEXACT_RESULT),
217 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_INVALID_OPERATION),
218 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_OVERFLOW),
219 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_STACK_CHECK),
220 	EXCEPTION_CODE_INFO(EXCEPTION_FLT_UNDERFLOW),
221 	EXCEPTION_CODE_INFO(EXCEPTION_GUARD_PAGE),
222 	EXCEPTION_CODE_INFO(EXCEPTION_ILLEGAL_INSTRUCTION),
223 	EXCEPTION_CODE_INFO(EXCEPTION_INT_DIVIDE_BY_ZERO),
224 	EXCEPTION_CODE_INFO(EXCEPTION_INT_OVERFLOW),
225 	EXCEPTION_CODE_INFO(EXCEPTION_INVALID_DISPOSITION),
226 	EXCEPTION_CODE_INFO(EXCEPTION_INVALID_HANDLE),
227 	EXCEPTION_CODE_INFO(EXCEPTION_IN_PAGE_ERROR),
228 	EXCEPTION_CODE_INFO(EXCEPTION_NONCONTINUABLE_EXCEPTION),
229 	EXCEPTION_CODE_INFO(EXCEPTION_PRIV_INSTRUCTION),
230 	EXCEPTION_CODE_INFO(EXCEPTION_STACK_OVERFLOW),
231 };
232 #endif // !_WIN32
233 
234 
235 // Exit the currently executing test.
exit_test(const int quit_application)236 static void exit_test(const int quit_application) {
237 	if (global_running_test) {
238 		longjmp(global_run_test_env, 1);
239 	} else if (quit_application) {
240 		exit(-1);
241 	}
242 }
243 
244 
245 // Initialize a SourceLocation structure.
initialize_source_location(SourceLocation * const location)246 static void initialize_source_location(SourceLocation * const location) {
247 	assert_true(location);
248 	location->file = NULL;
249 	location->line = 0;
250 }
251 
252 
253 // Determine whether a source location is currently set.
source_location_is_set(const SourceLocation * const location)254 static int source_location_is_set(const SourceLocation * const location) {
255 	assert_true(location);
256 	return location->file && location->line;
257 }
258 
259 
260 // Set a source location.
set_source_location(SourceLocation * const location,const char * const file,const int line)261 static void set_source_location(
262     SourceLocation * const location, const char * const file,
263     const int line) {
264 	assert_true(location);
265 	location->file = file;
266 	location->line = line;
267 }
268 
269 
270 // Create function results and expected parameter lists.
initialize_testing(const char * test_name)271 void initialize_testing(const char *test_name) {
272 	list_initialize(&global_function_result_map_head);
273 	initialize_source_location(&global_last_mock_value_location);
274 	list_initialize(&global_function_parameter_map_head);
275 	initialize_source_location(&global_last_parameter_location);
276 }
277 
278 
fail_if_leftover_values(const char * test_name)279 void fail_if_leftover_values(const char *test_name) {
280 	int error_occurred = 0;
281 	remove_always_return_values(&global_function_result_map_head, 1);
282 	if (check_for_leftover_values(
283 	        &global_function_result_map_head,
284 	        "%s() has remaining non-returned values.\n", 1)) {
285 		error_occurred = 1;
286 	}
287 
288 	remove_always_return_values(&global_function_parameter_map_head, 2);
289 	if (check_for_leftover_values(
290 	        &global_function_parameter_map_head,
291 	        "%s parameter still has values that haven't been checked.\n", 2)) {
292 		error_occurred = 1;
293 	}
294 	if (error_occurred) {
295 		exit_test(1);
296 	}
297 }
298 
299 
teardown_testing(const char * test_name)300 void teardown_testing(const char *test_name) {
301 	list_free(&global_function_result_map_head, free_symbol_map_value,
302 	          (void*)0);
303 	initialize_source_location(&global_last_mock_value_location);
304 	list_free(&global_function_parameter_map_head, free_symbol_map_value,
305 	          (void*)1);
306 	initialize_source_location(&global_last_parameter_location);
307 }
308 
309 // Initialize a list node.
list_initialize(ListNode * const node)310 static ListNode* list_initialize(ListNode * const node) {
311 	node->value = NULL;
312 	node->next = node;
313 	node->prev = node;
314 	node->refcount = 1;
315 	return node;
316 }
317 
318 
319 /* Adds a value at the tail of a given list.
320  * The node referencing the value is allocated from the heap. */
list_add_value(ListNode * const head,const void * value,const int refcount)321 static ListNode* list_add_value(ListNode * const head, const void *value,
322                                      const int refcount) {
323 	ListNode * const new_node = (ListNode*)malloc(sizeof(ListNode));
324 	assert_true(head);
325 	assert_true(value);
326 	new_node->value = value;
327 	new_node->refcount = refcount;
328 	return list_add(head, new_node);
329 }
330 
331 
332 // Add new_node to the end of the list.
list_add(ListNode * const head,ListNode * new_node)333 static ListNode* list_add(ListNode * const head, ListNode *new_node) {
334 	assert_true(head);
335 	assert_true(new_node);
336 	new_node->next = head;
337 	new_node->prev = head->prev;
338 	head->prev->next = new_node;
339 	head->prev = new_node;
340 	return new_node;
341 }
342 
343 
344 // Remove a node from a list.
list_remove(ListNode * const node,const CleanupListValue cleanup_value,void * const cleanup_value_data)345 static ListNode* list_remove(
346         ListNode * const node, const CleanupListValue cleanup_value,
347         void * const cleanup_value_data) {
348 	assert_true(node);
349 	node->prev->next = node->next;
350 	node->next->prev = node->prev;
351 	if (cleanup_value) {
352 		cleanup_value(node->value, cleanup_value_data);
353 	}
354 	return node;
355 }
356 
357 
358 /* Remove a list node from a list and free the node. */
list_remove_free(ListNode * const node,const CleanupListValue cleanup_value,void * const cleanup_value_data)359 static void list_remove_free(
360         ListNode * const node, const CleanupListValue cleanup_value,
361         void * const cleanup_value_data) {
362 	assert_true(node);
363 	free(list_remove(node, cleanup_value, cleanup_value_data));
364 }
365 
366 
367 /* Frees memory kept by a linked list
368  * The cleanup_value function is called for every "value" field of nodes in the
369  * list, except for the head.  In addition to each list value,
370  * cleanup_value_data is passed to each call to cleanup_value.  The head
371  * of the list is not deallocated.
372  */
list_free(ListNode * const head,const CleanupListValue cleanup_value,void * const cleanup_value_data)373 static ListNode* list_free(
374         ListNode * const head, const CleanupListValue cleanup_value,
375         void * const cleanup_value_data) {
376 	assert_true(head);
377 	while (!list_empty(head)) {
378 		list_remove_free(head->next, cleanup_value, cleanup_value_data);
379 	}
380 	return head;
381 }
382 
383 
384 // Determine whether a list is empty.
list_empty(const ListNode * const head)385 static int list_empty(const ListNode * const head) {
386 	assert_true(head);
387 	return head->next == head;
388 }
389 
390 
391 /* Find a value in the list using the equal_func to compare each node with the
392  * value.
393  */
list_find(ListNode * const head,const void * value,const EqualityFunction equal_func,ListNode ** output)394 static int list_find(ListNode * const head, const void *value,
395                      const EqualityFunction equal_func, ListNode **output) {
396 	ListNode *current;
397 	assert_true(head);
398 	for (current = head->next; current != head; current = current->next) {
399 		if (equal_func(current->value, value)) {
400 			*output = current;
401 			return 1;
402 		}
403 	}
404 	return 0;
405 }
406 
407 // Returns the first node of a list
list_first(ListNode * const head,ListNode ** output)408 static int list_first(ListNode * const head, ListNode **output) {
409 	ListNode *target_node;
410 	assert_true(head);
411 	if (list_empty(head)) {
412 		return 0;
413 	}
414 	target_node = head->next;
415 	*output = target_node;
416 	return 1;
417 }
418 
419 
420 // Deallocate a value referenced by a list.
free_value(const void * value,void * cleanup_value_data)421 static void free_value(const void *value, void *cleanup_value_data) {
422 	assert_true(value);
423 	free((void*)value);
424 }
425 
426 
427 // Releases memory associated to a symbol_map_value.
free_symbol_map_value(const void * value,void * cleanup_value_data)428 static void free_symbol_map_value(const void *value,
429                                   void *cleanup_value_data) {
430 	SymbolMapValue * const map_value = (SymbolMapValue*)value;
431 	const unsigned int children = (unsigned int)cleanup_value_data;
432 	assert_true(value);
433 	list_free(&map_value->symbol_values_list_head,
434 	          children ? free_symbol_map_value : free_value,
435 	          (void*)(children - 1));
436 	free(map_value);
437 }
438 
439 
440 /* Determine whether a symbol name referenced by a symbol_map_value
441  * matches the specified function name. */
symbol_names_match(const void * map_value,const void * symbol)442 static int symbol_names_match(const void *map_value, const void *symbol) {
443 	return !strcmp(((SymbolMapValue*)map_value)->symbol_name,
444                    (const char*)symbol);
445 }
446 
447 
448 /* Adds a value to the queue of values associated with the given
449  * hierarchy of symbols.  It's assumed value is allocated from the heap.
450  */
add_symbol_value(ListNode * const symbol_map_head,const char * const symbol_names[],const size_t number_of_symbol_names,const void * value,const int refcount)451 static void add_symbol_value(ListNode * const symbol_map_head,
452                              const char * const symbol_names[],
453                              const size_t number_of_symbol_names,
454                              const void* value, const int refcount) {
455 	const char* symbol_name;
456 	ListNode *target_node;
457 	SymbolMapValue *target_map_value;
458 	assert_true(symbol_map_head);
459 	assert_true(symbol_names);
460 	assert_true(number_of_symbol_names);
461 	symbol_name = symbol_names[0];
462 
463 	if (!list_find(symbol_map_head, symbol_name, symbol_names_match,
464 	               &target_node)) {
465 		SymbolMapValue * const new_symbol_map_value =
466 		    malloc(sizeof(*new_symbol_map_value));
467 		new_symbol_map_value->symbol_name = symbol_name;
468 		list_initialize(&new_symbol_map_value->symbol_values_list_head);
469 		target_node = list_add_value(symbol_map_head, new_symbol_map_value,
470 		                                  1);
471 	}
472 
473 	target_map_value = (SymbolMapValue*)target_node->value;
474 	if (number_of_symbol_names == 1) {
475 			list_add_value(&target_map_value->symbol_values_list_head,
476 			                    value, refcount);
477 	} else {
478 		add_symbol_value(&target_map_value->symbol_values_list_head,
479 		                 &symbol_names[1], number_of_symbol_names - 1, value,
480 		                 refcount);
481 	}
482 }
483 
484 
485 /* Gets the next value associated with the given hierarchy of symbols.
486  * The value is returned as an output parameter with the function returning the
487  * node's old refcount value if a value is found, 0 otherwise.
488  * This means that a return value of 1 indicates the node was just removed from
489  * the list.
490  */
get_symbol_value(ListNode * const head,const char * const symbol_names[],const size_t number_of_symbol_names,void ** output)491 static int get_symbol_value(
492         ListNode * const head, const char * const symbol_names[],
493         const size_t number_of_symbol_names, void **output) {
494 	const char* symbol_name;
495 	ListNode *target_node;
496 	assert_true(head);
497 	assert_true(symbol_names);
498 	assert_true(number_of_symbol_names);
499 	assert_true(output);
500 	symbol_name = symbol_names[0];
501 
502 	if (list_find(head, symbol_name, symbol_names_match, &target_node)) {
503 		SymbolMapValue *map_value;
504 		ListNode *child_list;
505 		int return_value = 0;
506 		assert_true(target_node);
507 		assert_true(target_node->value);
508 
509 		map_value = (SymbolMapValue*)target_node->value;
510 		child_list = &map_value->symbol_values_list_head;
511 
512 		if (number_of_symbol_names == 1) {
513 			ListNode *value_node = NULL;
514 			return_value = list_first(child_list, &value_node);
515 			assert_true(return_value);
516 			*output = (void*) value_node->value;
517 			return_value = value_node->refcount;
518 			if (--value_node->refcount == 0) {
519 				list_remove_free(value_node, NULL, NULL);
520 			}
521 		} else {
522 			return_value = get_symbol_value(
523 			    child_list, &symbol_names[1], number_of_symbol_names - 1,
524 			    output);
525 		}
526 		if (list_empty(child_list)) {
527 			list_remove_free(target_node, free_symbol_map_value, (void*)0);
528 		}
529 		return return_value;
530 	} else {
531 		print_error("No entries for symbol %s.\n", symbol_name);
532 	}
533 	return 0;
534 }
535 
536 
537 /* Traverse down a tree of symbol values and remove the first symbol value
538  * in each branch that has a refcount < -1 (i.e should always be returned
539  * and has been returned at least once).
540  */
remove_always_return_values(ListNode * const map_head,const size_t number_of_symbol_names)541 static void remove_always_return_values(ListNode * const map_head,
542                                         const size_t number_of_symbol_names) {
543 	ListNode *current;
544 	assert_true(map_head);
545 	assert_true(number_of_symbol_names);
546 	current = map_head->next;
547 	while (current != map_head) {
548 		SymbolMapValue * const value = (SymbolMapValue*)current->value;
549 		ListNode * const next = current->next;
550 		ListNode *child_list;
551 		assert_true(value);
552 		child_list = &value->symbol_values_list_head;
553 
554 		if (!list_empty(child_list)) {
555 			if (number_of_symbol_names == 1) {
556 				ListNode * const child_node = child_list->next;
557 				// If this item has been returned more than once, free it.
558 				if (child_node->refcount < -1) {
559 					list_remove_free(child_node, free_value, NULL);
560 				}
561 			} else {
562 				remove_always_return_values(child_list,
563 				                            number_of_symbol_names - 1);
564 			}
565 		}
566 
567 		if (list_empty(child_list)) {
568 			list_remove_free(current, free_value, NULL);
569 		}
570 		current = next;
571 	}
572 }
573 
574 /* Checks if there are any leftover values set up by the test that were never
575  * retrieved through execution, and fail the test if that is the case.
576  */
check_for_leftover_values(const ListNode * const map_head,const char * const error_message,const size_t number_of_symbol_names)577 static int check_for_leftover_values(
578         const ListNode * const map_head, const char * const error_message,
579         const size_t number_of_symbol_names) {
580 	const ListNode *current;
581 	int symbols_with_leftover_values = 0;
582 	assert_true(map_head);
583 	assert_true(number_of_symbol_names);
584 
585 	for (current = map_head->next; current != map_head;
586 	     current = current->next) {
587 		const SymbolMapValue * const value =
588 		    (SymbolMapValue*)current->value;
589 		const ListNode *child_list;
590 		assert_true(value);
591 		child_list = &value->symbol_values_list_head;
592 
593 		if (!list_empty(child_list)) {
594 			if (number_of_symbol_names == 1) {
595 				const ListNode *child_node;
596 				print_error(error_message, value->symbol_name);
597 				print_error("  Remaining item(s) declared at...\n");
598 
599 				for (child_node = child_list->next; child_node != child_list;
600 				     child_node = child_node->next) {
601 					const SourceLocation * const location = child_node->value;
602 					print_error("    " SOURCE_LOCATION_FORMAT "\n",
603 					            location->file, location->line);
604 				}
605 			} else {
606 				print_error("%s.", value->symbol_name);
607 				check_for_leftover_values(child_list, error_message,
608 				                          number_of_symbol_names - 1);
609 			}
610 			symbols_with_leftover_values ++;
611 		}
612 	}
613 	return symbols_with_leftover_values;
614 }
615 
616 
617 // Get the next return value for the specified mock function.
_mock(const char * const function,const char * const file,const int line)618 void* _mock(const char * const function, const char* const file,
619             const int line) {
620 	void *result;
621 	const int rc = get_symbol_value(&global_function_result_map_head,
622 	                                &function, 1, &result);
623 	if (rc) {
624 		SymbolValue * const symbol = result;
625 		void * const value = (void*)symbol->value;
626 		global_last_mock_value_location = symbol->location;
627 		if (rc == 1) {
628 			free(symbol);
629 		}
630 		return value;
631 	} else {
632 		print_error("ERROR: " SOURCE_LOCATION_FORMAT " - Could not get value "
633 		            "to mock function %s\n", file, line, function);
634 		if (source_location_is_set(&global_last_mock_value_location)) {
635 			print_error("Previously returned mock value was declared at "
636 			            SOURCE_LOCATION_FORMAT "\n",
637 			            global_last_mock_value_location.file,
638 			            global_last_mock_value_location.line);
639 		} else {
640 			print_error("There were no previously returned mock values for "
641 			            "this test.\n");
642 		}
643 		exit_test(1);
644 	}
645 	return NULL;
646 }
647 
648 
649 // Add a return value for the specified mock function name.
_will_return(const char * const function_name,const char * const file,const int line,const void * const value,const int count)650 void _will_return(const char * const function_name, const char * const file,
651                   const int line, const void* const value, const int count) {
652 	SymbolValue * const return_value = malloc(sizeof(*return_value));
653 	assert_true(count > 0);
654 	return_value->value = value;
655 	set_source_location(&return_value->location, file, line);
656 	add_symbol_value(&global_function_result_map_head, &function_name, 1,
657 	                 return_value, count);
658 }
659 
660 
661 /* Add a custom parameter checking function.  If the event parameter is NULL
662  * the event structure is allocated internally by this function.  If event
663  * parameter is provided it must be allocated on the heap and doesn't need to
664  * be deallocated by the caller.
665  */
_expect_check(const char * const function,const char * const parameter,const char * const file,const int line,const CheckParameterValue check_function,void * const check_data,CheckParameterEvent * const event,const int count)666 void _expect_check(
667         const char* const function, const char* const parameter,
668         const char* const file, const int line,
669         const CheckParameterValue check_function, void * const check_data,
670         CheckParameterEvent * const event, const int count) {
671 	CheckParameterEvent * const check =
672 	    event ? event : malloc(sizeof(*check));
673 	const char* symbols[] = {function, parameter};
674 	check->parameter_name = parameter;
675 	check->check_value = check_function;
676 	check->check_value_data = check_data;
677 	set_source_location(&check->location, file, line);
678 	add_symbol_value(&global_function_parameter_map_head, symbols, 2, check,
679 	                 count);
680 }
681 
682 
683 /* Returns 1 if the specified values are equal.  If the values are not equal
684  * an error is displayed and 0 is returned. */
values_equal_display_error(const void * const left,const void * const right)685 static int values_equal_display_error(const void* const left,
686                                       const void* const right) {
687 	const int equal = left == right;
688 	if (!equal) {
689 		print_error("0x%x != 0x%x\n", left, right);
690 	}
691 	return equal;
692 }
693 
694 /* Returns 1 if the specified values are not equal.  If the values are equal
695  * an error is displayed and 0 is returned. */
values_not_equal_display_error(const void * const left,const void * const right)696 static int values_not_equal_display_error(const void* const left,
697                                           const void* const right) {
698 	const int not_equal = left != right;
699 	if (!not_equal) {
700 		print_error("0x%x == 0x%x\n", left, right);
701 	}
702 	return not_equal;
703 }
704 
705 
706 /* Determine whether value is contained within check_integer_set.
707  * If invert is 0 and the value is in the set 1 is returned, otherwise 0 is
708  * returned and an error is displayed.  If invert is 1 and the value is not
709  * in the set 1 is returned, otherwise 0 is returned and an error is
710  * displayed. */
value_in_set_display_error(const void * value,const CheckIntegerSet * const check_integer_set,const int invert)711 static int value_in_set_display_error(
712         const void *value, const CheckIntegerSet * const check_integer_set,
713         const int invert) {
714 	int succeeded = invert;
715 	assert_true(check_integer_set);
716 	{
717 		const void ** const set = check_integer_set->set;
718 		const size_t size_of_set = check_integer_set->size_of_set;
719 		size_t i;
720 		for (i = 0; i < size_of_set; i++) {
721 			if (set[i] == value) {
722 				if (invert) {
723 					succeeded = 0;
724 				}
725 				break;
726 			}
727 		}
728 		if (succeeded) {
729 			return 1;
730 		}
731 		print_error("%d is %sin the set (", value, invert ? "" : "not ");
732 		for (i = 0; i < size_of_set; i++) {
733 			print_error("%d, ", set[i]);
734 		}
735 		print_error(")\n");
736 	}
737 	return 0;
738 }
739 
740 
741 /* Determine whether a value is within the specified range.  If the value is
742  * within the specified range 1 is returned.  If the value isn't within the
743  * specified range an error is displayed and 0 is returned. */
integer_in_range_display_error(const int value,const int range_min,const int range_max)744 static int integer_in_range_display_error(
745         const int value, const int range_min, const int range_max) {
746 	if (value >= range_min && value <= range_max) {
747 		return 1;
748 	}
749 	print_error("%d is not within the range %d-%d\n", value, range_min,
750 	            range_max);
751 	return 0;
752 }
753 
754 
755 /* Determine whether a value is within the specified range.  If the value
756  * is not within the range 1 is returned.  If the value is within the
757  * specified range an error is displayed and zero is returned. */
integer_not_in_range_display_error(const int value,const int range_min,const int range_max)758 static int integer_not_in_range_display_error(
759         const int value, const int range_min, const int range_max) {
760 	if (value < range_min || value > range_max) {
761 		return 1;
762 	}
763 	print_error("%d is within the range %d-%d\n", value, range_min,
764 	            range_max);
765 	return 0;
766 }
767 
768 
769 /* Determine whether the specified strings are equal.  If the strings are equal
770  * 1 is returned.  If they're not equal an error is displayed and 0 is
771  * returned. */
string_equal_display_error(const char * const left,const char * const right)772 static int string_equal_display_error(
773         const char * const left, const char * const right) {
774 	if (strcmp(left, right) == 0) {
775 		return 1;
776 	}
777 	print_error("\"%s\" != \"%s\"\n", left, right);
778 	return 0;
779 }
780 
781 
782 /* Determine whether the specified strings are equal.  If the strings are not
783  * equal 1 is returned.  If they're not equal an error is displayed and 0 is
784  * returned */
string_not_equal_display_error(const char * const left,const char * const right)785 static int string_not_equal_display_error(
786         const char * const left, const char * const right) {
787 	if (strcmp(left, right) != 0) {
788 		return 1;
789 	}
790 	print_error("\"%s\" == \"%s\"\n", left, right);
791 	return 0;
792 }
793 
794 
795 /* Determine whether the specified areas of memory are equal.  If they're equal
796  * 1 is returned otherwise an error is displayed and 0 is returned. */
memory_equal_display_error(const char * a,const char * b,const size_t size)797 static int memory_equal_display_error(const char* a, const char* b,
798                                       const size_t size) {
799 	int differences = 0;
800 	size_t i;
801 	for (i = 0; i < size; i++) {
802 		const char l = a[i];
803 		const char r = b[i];
804 		if (l != r) {
805 			print_error("difference at offset %d 0x%02x 0x%02x\n", i, l, r);
806 			differences ++;
807 		}
808 	}
809 	if (differences) {
810 		print_error("%d bytes of 0x%08x and 0x%08x differ\n", differences,
811 		            a, b);
812 		return 0;
813 	}
814 	return 1;
815 }
816 
817 
818 /* Determine whether the specified areas of memory are not equal.  If they're
819  * not equal 1 is returned otherwise an error is displayed and 0 is
820  * returned. */
memory_not_equal_display_error(const char * a,const char * b,const size_t size)821 static int memory_not_equal_display_error(const char* a, const char* b,
822                                           const size_t size) {
823 	int same = 0;
824 	size_t i;
825 	for (i = 0; i < size; i++) {
826 		const char l = a[i];
827 		const char r = b[i];
828 		if (l == r) {
829 			print_error("equal at offset %d 0x%02x 0x%02x\n", i, l, r);
830 			same ++;
831 		}
832 	}
833 	if (same) {
834 		print_error("%d bytes of 0x%08x and 0x%08x the same\n", same,
835 		            a, b);
836 		return 0;
837 	}
838 	return 1;
839 }
840 
841 
842 // CheckParameterValue callback to check whether a value is within a set.
check_in_set(const void * value,void * check_value_data)843 static int check_in_set(const void *value, void *check_value_data) {
844 	return value_in_set_display_error(value,
845 	    (CheckIntegerSet*)check_value_data, 0);
846 }
847 
848 
849 // CheckParameterValue callback to check whether a value isn't within a set.
check_not_in_set(const void * value,void * check_value_data)850 static int check_not_in_set(const void *value, void *check_value_data) {
851 	return value_in_set_display_error(value,
852 	    (CheckIntegerSet*)check_value_data, 1);
853 }
854 
855 
856 /* Create the callback data for check_in_set() or check_not_in_set() and
857  * register a check event. */
expect_set(const char * const function,const char * const parameter,const char * const file,const int line,const void * values[],const size_t number_of_values,const CheckParameterValue check_function,const int count)858 static void expect_set(
859         const char* const function, const char* const parameter,
860         const char* const file, const int line, const void *values[],
861         const size_t number_of_values,
862         const CheckParameterValue check_function, const int count) {
863 	CheckIntegerSet * const check_integer_set =
864 	    malloc(sizeof(*check_integer_set) +
865 	           (sizeof(values[0]) * number_of_values));
866 	void ** const set = (void**)(check_integer_set + 1);
867 	assert_true(values);
868 	assert_true(number_of_values);
869 	memcpy(set, values, number_of_values * sizeof(values[0]));
870 	check_integer_set->set = (const void**)set;
871 	_expect_check(function, parameter, file, line, check_function,
872 	              check_integer_set, &check_integer_set->event, count);
873 }
874 
875 
876 // Add an event to check whether a value is in a set.
_expect_in_set(const char * const function,const char * const parameter,const char * const file,const int line,const void * values[],const size_t number_of_values,const int count)877 void _expect_in_set(
878         const char* const function, const char* const parameter,
879         const char* const file, const int line, const void *values[],
880         const size_t number_of_values, const int count) {
881 	expect_set(function, parameter, file, line, values, number_of_values,
882 	           check_in_set, count);
883 }
884 
885 
886 // Add an event to check whether a value isn't in a set.
_expect_not_in_set(const char * const function,const char * const parameter,const char * const file,const int line,const void * values[],const size_t number_of_values,const int count)887 void _expect_not_in_set(
888         const char* const function, const char* const parameter,
889         const char* const file, const int line, const void *values[],
890         const size_t number_of_values, const int count) {
891 	expect_set(function, parameter, file, line, values, number_of_values,
892 	           check_not_in_set, count);
893 }
894 
895 
896 // CheckParameterValue callback to check whether a value is within a range.
check_in_range(const void * value,void * check_value_data)897 static int check_in_range(const void *value, void *check_value_data) {
898 	CheckIntegerRange * const check_integer_range = check_value_data;
899 	assert_true(check_value_data);
900 	return integer_in_range_display_error(
901 	    (int)value, check_integer_range->minimum,
902 	    check_integer_range->maximum);
903 }
904 
905 
906 // CheckParameterValue callback to check whether a value is not within a range.
check_not_in_range(const void * value,void * check_value_data)907 static int check_not_in_range(const void *value, void *check_value_data) {
908 	CheckIntegerRange * const check_integer_range = check_value_data;
909 	assert_true(check_value_data);
910 	return integer_not_in_range_display_error(
911 	    (int)value, check_integer_range->minimum,
912 	    check_integer_range->maximum);
913 }
914 
915 
916 /* Create the callback data for check_in_range() or check_not_in_range() and
917  * register a check event. */
expect_range(const char * const function,const char * const parameter,const char * const file,const int line,const int minimum,const int maximum,const CheckParameterValue check_function,const int count)918 static void expect_range(
919         const char* const function, const char* const parameter,
920         const char* const file, const int line,
921         const int minimum, const int maximum,
922         const CheckParameterValue check_function, const int count) {
923 	CheckIntegerRange * const check_integer_range =
924 	    malloc(sizeof(*check_integer_range));
925 	check_integer_range->minimum = minimum;
926 	check_integer_range->maximum = maximum;
927 	_expect_check(function, parameter, file, line, check_function,
928 	              check_integer_range, &check_integer_range->event, count);
929 }
930 
931 
932 // Add an event to determine whether a parameter is within a range.
_expect_in_range(const char * const function,const char * const parameter,const char * const file,const int line,const int minimum,const int maximum,const int count)933 void _expect_in_range(
934         const char* const function, const char* const parameter,
935         const char* const file, const int line,
936         const int minimum, const int maximum, const int count) {
937 	expect_range(function, parameter, file, line, minimum, maximum,
938 	             check_in_range, count);
939 }
940 
941 
942 // Add an event to determine whether a parameter is not within a range.
_expect_not_in_range(const char * const function,const char * const parameter,const char * const file,const int line,const int minimum,const int maximum,const int count)943 void _expect_not_in_range(
944         const char* const function, const char* const parameter,
945         const char* const file, const int line,
946         const int minimum, const int maximum, const int count) {
947 	expect_range(function, parameter, file, line, minimum, maximum,
948 	             check_not_in_range, count);
949 }
950 
951 
952 /* CheckParameterValue callback to check whether a value is equal to an
953  * expected value. */
check_value(const void * value,void * check_value_data)954 static int check_value(const void *value, void *check_value_data) {
955 	return values_equal_display_error(value, check_value_data);
956 }
957 
958 
959 // Add an event to check a parameter equals an expected value.
_expect_value(const char * const function,const char * const parameter,const char * const file,const int line,const void * const value,const int count)960 void _expect_value(
961         const char* const function, const char* const parameter,
962         const char* const file, const int line, const void* const value,
963         const int count) {
964 	_expect_check(function, parameter, file, line, check_value,
965 	              (void*)value, NULL, count);
966 }
967 
968 
969 /* CheckParameterValue callback to check whether a value is not equal to an
970  * expected value. */
check_not_value(const void * value,void * check_value_data)971 static int check_not_value(const void *value, void *check_value_data) {
972 	return values_not_equal_display_error(value, check_value_data);
973 }
974 
975 
976 // Add an event to check a parameter is not equal to an expected value.
_expect_not_value(const char * const function,const char * const parameter,const char * const file,const int line,const void * const value,const int count)977 void _expect_not_value(
978         const char* const function, const char* const parameter,
979         const char* const file, const int line, const void* const value,
980         const int count) {
981 	_expect_check(function, parameter, file, line, check_not_value,
982 	              (void*)value, NULL, count);
983 }
984 
985 
986 // CheckParameterValue callback to check whether a parameter equals a string.
check_string(const void * value,void * check_value_data)987 static int check_string(const void * value, void *check_value_data) {
988 	return string_equal_display_error(value, check_value_data);
989 }
990 
991 
992 // Add an event to check whether a parameter is equal to a string.
_expect_string(const char * const function,const char * const parameter,const char * const file,const int line,const char * string,const int count)993 void _expect_string(
994         const char* const function, const char* const parameter,
995         const char* const file, const int line, const char* string,
996         const int count) {
997 	_expect_check(function, parameter, file, line, check_string, (void*)string,
998 	              NULL, count);
999 }
1000 
1001 
1002 /* CheckParameterValue callback to check whether a parameter is not equals to
1003  * a string. */
check_not_string(const void * value,void * check_value_data)1004 static int check_not_string(const void * value, void *check_value_data) {
1005 	return string_not_equal_display_error(value, check_value_data);
1006 }
1007 
1008 
1009 // Add an event to check whether a parameter is not equal to a string.
_expect_not_string(const char * const function,const char * const parameter,const char * const file,const int line,const char * string,const int count)1010 void _expect_not_string(
1011         const char* const function, const char* const parameter,
1012         const char* const file, const int line, const char* string,
1013         const int count) {
1014 	_expect_check(function, parameter, file, line, check_not_string,
1015 	              (void*)string, NULL, count);
1016 }
1017 
1018 /* CheckParameterValue callback to check whether a parameter equals an area of
1019  * memory. */
check_memory(const void * value,void * check_value_data)1020 static int check_memory(const void* value, void *check_value_data) {
1021 	CheckMemoryData * const check = (CheckMemoryData*)check_value_data;
1022 	assert_true(check);
1023 	return memory_equal_display_error(value, check->memory, check->size);
1024 }
1025 
1026 
1027 /* Create the callback data for check_memory() or check_not_memory() and
1028  * register a check event. */
expect_memory_setup(const char * const function,const char * const parameter,const char * const file,const int line,const void * const memory,const size_t size,const CheckParameterValue check_function,const int count)1029 static void expect_memory_setup(
1030         const char* const function, const char* const parameter,
1031         const char* const file, const int line,
1032         const void * const memory, const size_t size,
1033         const CheckParameterValue check_function, const int count) {
1034 	CheckMemoryData * const check_data = malloc(sizeof(*check_data) + size);
1035 	void * const mem = (void*)(check_data + 1);
1036 	assert_true(memory);
1037 	assert_true(size);
1038 	memcpy(mem, memory, size);
1039 	check_data->memory = mem;
1040 	check_data->size = size;
1041 	_expect_check(function, parameter, file, line, check_function,
1042 	              check_data, &check_data->event, count);
1043 }
1044 
1045 
1046 // Add an event to check whether a parameter matches an area of memory.
_expect_memory(const char * const function,const char * const parameter,const char * const file,const int line,const void * const memory,const size_t size,const int count)1047 void _expect_memory(
1048         const char* const function, const char* const parameter,
1049         const char* const file, const int line, const void* const memory,
1050         const size_t size, const int count) {
1051 	expect_memory_setup(function, parameter, file, line, memory, size,
1052 	                    check_memory, count);
1053 }
1054 
1055 
1056 /* CheckParameterValue callback to check whether a parameter is not equal to
1057  * an area of memory. */
check_not_memory(const void * value,void * check_value_data)1058 static int check_not_memory(const void* value, void *check_value_data) {
1059 	CheckMemoryData * const check = (CheckMemoryData*)check_value_data;
1060 	assert_true(check);
1061 	return memory_not_equal_display_error(value, check->memory, check->size);
1062 }
1063 
1064 
1065 // Add an event to check whether a parameter doesn't match an area of memory.
_expect_not_memory(const char * const function,const char * const parameter,const char * const file,const int line,const void * const memory,const size_t size,const int count)1066 void _expect_not_memory(
1067         const char* const function, const char* const parameter,
1068         const char* const file, const int line, const void* const memory,
1069         const size_t size, const int count) {
1070 	expect_memory_setup(function, parameter, file, line, memory, size,
1071 	                    check_not_memory, count);
1072 }
1073 
1074 
1075 // CheckParameterValue callback that always returns 1.
check_any(const void * value,void * check_value_data)1076 static int check_any(const void *value, void *check_value_data) {
1077 	return 1;
1078 }
1079 
1080 
1081 // Add an event to allow any value for a parameter.
_expect_any(const char * const function,const char * const parameter,const char * const file,const int line,const int count)1082 void _expect_any(
1083         const char* const function, const char* const parameter,
1084         const char* const file, const int line, const int count) {
1085 	_expect_check(function, parameter, file, line, check_any, NULL, NULL,
1086 	              count);
1087 }
1088 
1089 
_check_expected(const char * const function_name,const char * const parameter_name,const char * file,const int line,const void * value)1090 void _check_expected(
1091         const char * const function_name, const char * const parameter_name,
1092         const char* file, const int line, const void* value) {
1093 	void *result;
1094 	const char* symbols[] = {function_name, parameter_name};
1095 	const int rc = get_symbol_value(&global_function_parameter_map_head,
1096 	                                symbols, 2, &result);
1097 	if (rc) {
1098 		CheckParameterEvent * const check = (CheckParameterEvent*)result;
1099 		int check_succeeded;
1100 		global_last_parameter_location = check->location;
1101 		check_succeeded = check->check_value(value, check->check_value_data);
1102 		if (rc == 1) {
1103 			free(check);
1104 		}
1105 		if (!check_succeeded) {
1106 			print_error("ERROR: Check of parameter %s, function %s failed\n"
1107 			            "Expected parameter declared at "
1108 			            SOURCE_LOCATION_FORMAT "\n",
1109 			            parameter_name, function_name,
1110 			            global_last_parameter_location.file,
1111 			            global_last_parameter_location.line);
1112 			_fail(file, line);
1113 		}
1114 	} else {
1115 		print_error("ERROR: " SOURCE_LOCATION_FORMAT " - Could not get value "
1116 		            "to check parameter %s of function %s\n", file, line,
1117 		            parameter_name, function_name);
1118 		if (source_location_is_set(&global_last_parameter_location)) {
1119 			print_error("Previously declared parameter value was declared at "
1120 			            SOURCE_LOCATION_FORMAT "\n",
1121 			            global_last_parameter_location.file,
1122 			            global_last_parameter_location.line);
1123 		} else {
1124 			print_error("There were no previously declared parameter values "
1125 			            "for this test.\n");
1126 		}
1127 		exit_test(1);
1128 	}
1129 }
1130 
1131 
1132 // Replacement for assert.
mock_assert(const int result,const char * const expression,const char * const file,const int line)1133 void mock_assert(const int result, const char* const expression,
1134                  const char* const file, const int line) {
1135 	if (!result) {
1136 		if (global_expecting_assert) {
1137 			longjmp(global_expect_assert_env, (int)expression);
1138 		} else {
1139 			print_error("ASSERT: %s\n", expression);
1140 			_fail(file, line);
1141 		}
1142 	}
1143 }
1144 
1145 
_assert_true(const int result,const char * const expression,const char * const file,const int line)1146 void _assert_true(const int result, const char * const expression,
1147                   const char * const file, const int line) {
1148 	if (!result) {
1149 		print_error("%s\n", expression);
1150 		_fail(file, line);
1151 	}
1152 }
1153 
_assert_int_equal(const int a,const int b,const char * const file,const int line)1154 void _assert_int_equal(const int a, const int b, const char * const file,
1155                        const int line) {
1156 	if (!values_equal_display_error((void*)a, (void*)b)) {
1157 		_fail(file, line);
1158 	}
1159 }
1160 
1161 
_assert_int_not_equal(const int a,const int b,const char * const file,const int line)1162 void _assert_int_not_equal(const int a, const int b, const char * const file,
1163                            const int line) {
1164 	if (!values_not_equal_display_error((void*)a, (void*)b)) {
1165 		_fail(file, line);
1166 	}
1167 }
1168 
1169 
_assert_string_equal(const char * const a,const char * const b,const char * const file,const int line)1170 void _assert_string_equal(const char * const a, const char * const b,
1171                           const char * const file, const int line) {
1172 	if (!string_equal_display_error(a, b)) {
1173 		_fail(file, line);
1174 	}
1175 }
1176 
1177 
_assert_string_not_equal(const char * const a,const char * const b,const char * file,const int line)1178 void _assert_string_not_equal(const char * const a, const char * const b,
1179                               const char *file, const int line) {
1180 	if (!string_not_equal_display_error(a, b)) {
1181 		_fail(file, line);
1182 	}
1183 }
1184 
1185 
_assert_memory_equal(const void * const a,const void * const b,const size_t size,const char * const file,const int line)1186 void _assert_memory_equal(const void * const a, const void * const b,
1187                           const size_t size, const char* const file,
1188                           const int line) {
1189 	if (!memory_equal_display_error((const char*)a, (const char*)b, size)) {
1190 		_fail(file, line);
1191 	}
1192 }
1193 
1194 
_assert_memory_not_equal(const void * const a,const void * const b,const size_t size,const char * const file,const int line)1195 void _assert_memory_not_equal(const void * const a, const void * const b,
1196                               const size_t size, const char* const file,
1197                               const int line) {
1198 	if (!memory_not_equal_display_error((const char*)a, (const char*)b,
1199 	                                    size)) {
1200 		_fail(file, line);
1201 	}
1202 }
1203 
1204 
_assert_in_range(const int value,const int minimum,const int maximum,const char * const file,const int line)1205 void _assert_in_range(const int value, const int minimum, const int maximum,
1206                       const char* const file, const int line) {
1207 	if (!integer_in_range_display_error(value, minimum, maximum)) {
1208 		_fail(file, line);
1209 	}
1210 }
1211 
_assert_not_in_range(const int value,const int minimum,const int maximum,const char * const file,const int line)1212 void _assert_not_in_range(const int value, const int minimum,
1213                           const int maximum, const char* const file,
1214                           const int line) {
1215 	if (!integer_not_in_range_display_error(value, minimum, maximum)) {
1216 		_fail(file, line);
1217 	}
1218 }
1219 
_assert_in_set(const void * const value,const void * values[],const size_t number_of_values,const char * const file,const int line)1220 void _assert_in_set(const void* const value, const void *values[],
1221                     const size_t number_of_values, const char* const file,
1222                     const int line) {
1223 	CheckIntegerSet check_integer_set;
1224 	check_integer_set.set = values;
1225 	check_integer_set.size_of_set = number_of_values;
1226 	if (!value_in_set_display_error(value, &check_integer_set, 0)) {
1227 		_fail(file, line);
1228 	}
1229 }
1230 
_assert_not_in_set(const void * const value,const void * values[],const size_t number_of_values,const char * const file,const int line)1231 void _assert_not_in_set(const void* const value, const void *values[],
1232                         const size_t number_of_values, const char* const file,
1233                         const int line) {
1234 	CheckIntegerSet check_integer_set;
1235 	check_integer_set.set = values;
1236 	check_integer_set.size_of_set = number_of_values;
1237 	if (!value_in_set_display_error(value, &check_integer_set, 1)) {
1238 		_fail(file, line);
1239 	}
1240 }
1241 
1242 
1243 // Get the list of allocated blocks.
get_allocated_blocks_list()1244 static ListNode* get_allocated_blocks_list() {
1245 	// If it initialized, initialize the list of allocated blocks.
1246 	if (!global_allocated_blocks.value) {
1247 		list_initialize(&global_allocated_blocks);
1248 		global_allocated_blocks.value = (void*)1;
1249 	}
1250 	return &global_allocated_blocks;
1251 }
1252 
1253 // Use the real malloc in this function.
1254 #undef malloc
_test_malloc(const size_t size,const char * file,const int line)1255 void* _test_malloc(const size_t size, const char* file, const int line) {
1256 	char* ptr;
1257 	MallocBlockInfo *block_info;
1258 	ListNode * const block_list = get_allocated_blocks_list();
1259 	const size_t allocate_size = size + (MALLOC_GUARD_SIZE * 2) +
1260 	    sizeof(*block_info) + MALLOC_ALIGNMENT;
1261 	char* const block = (char*)malloc(allocate_size);
1262 	assert_true(block);
1263 
1264 	// Calculate the returned address.
1265 	ptr = (char*)(((size_t)block + MALLOC_GUARD_SIZE + sizeof(*block_info) +
1266 	              MALLOC_ALIGNMENT) & ~(MALLOC_ALIGNMENT - 1));
1267 
1268 	// Initialize the guard blocks.
1269 	memset(ptr - MALLOC_GUARD_SIZE, MALLOC_GUARD_PATTERN, MALLOC_GUARD_SIZE);
1270 	memset(ptr + size, MALLOC_GUARD_PATTERN, MALLOC_GUARD_SIZE);
1271 	memset(ptr, MALLOC_ALLOC_PATTERN, size);
1272 
1273 	block_info = (MallocBlockInfo*)(ptr - (MALLOC_GUARD_SIZE +
1274 	                                         sizeof(*block_info)));
1275 	set_source_location(&block_info->location, file, line);
1276 	block_info->allocated_size = allocate_size;
1277 	block_info->size = size;
1278 	block_info->block = block;
1279 	block_info->node.value = block_info;
1280 	list_add(block_list, &block_info->node);
1281 	return ptr;
1282 }
1283 #define malloc test_malloc
1284 
1285 
_test_calloc(const size_t number_of_elements,const size_t size,const char * file,const int line)1286 void* _test_calloc(const size_t number_of_elements, const size_t size,
1287                    const char* file, const int line) {
1288 	void* const ptr = _test_malloc(number_of_elements * size, file, line);
1289 	if (ptr) {
1290 		memset(ptr, 0, number_of_elements * size);
1291 	}
1292 	return ptr;
1293 }
1294 
1295 
1296 // Use the real free in this function.
1297 #undef free
_test_free(void * const ptr,const char * file,const int line)1298 void _test_free(void* const ptr, const char* file, const int line) {
1299 	unsigned int i;
1300 	char *block = (char*)ptr;
1301 	MallocBlockInfo *block_info;
1302 	_assert_true((int)ptr, "ptr", file, line);
1303 	block_info = (MallocBlockInfo*)(block - (MALLOC_GUARD_SIZE +
1304 	                                           sizeof(*block_info)));
1305 	// Check the guard blocks.
1306 	{
1307 		char *guards[2] = {block - MALLOC_GUARD_SIZE,
1308 		                   block + block_info->size};
1309 		for (i = 0; i < ARRAY_LENGTH(guards); i++) {
1310 			unsigned int j;
1311 			char * const guard = guards[i];
1312 			for (j = 0; j < MALLOC_GUARD_SIZE; j++) {
1313 				const char diff = guard[j] - MALLOC_GUARD_PATTERN;
1314 				if (diff) {
1315 					print_error(
1316 					    "Guard block of 0x%08x size=%d allocated by "
1317 					    SOURCE_LOCATION_FORMAT " at 0x%08x is corrupt\n",
1318 					    (size_t)ptr, block_info->size,
1319 					    block_info->location.file, block_info->location.line,
1320 					    (size_t)&guard[j]);
1321 					_fail(file, line);
1322 				}
1323 			}
1324 		}
1325 	}
1326 	list_remove(&block_info->node, NULL, NULL);
1327 
1328 	block = block_info->block;
1329 	memset(block, MALLOC_FREE_PATTERN, block_info->allocated_size);
1330 	free(block);
1331 }
1332 #define free test_free
1333 
1334 
1335 // Crudely checkpoint the current heap state.
check_point_allocated_blocks()1336 static const ListNode* check_point_allocated_blocks() {
1337 	return get_allocated_blocks_list()->prev;
1338 }
1339 
1340 
1341 /* Display the blocks allocated after the specified check point.  This
1342  * function returns the number of blocks displayed. */
display_allocated_blocks(const ListNode * const check_point)1343 static int display_allocated_blocks(const ListNode * const check_point) {
1344 	const ListNode * const head = get_allocated_blocks_list();
1345 	const ListNode *node;
1346 	int allocated_blocks = 0;
1347 	assert_true(check_point);
1348 	assert_true(check_point->next);
1349 
1350 	for (node = check_point->next; node != head; node = node->next) {
1351 		const MallocBlockInfo * const block_info = node->value;
1352 		assert_true(block_info);
1353 
1354 		if (!allocated_blocks) {
1355 			print_error("Blocks allocated...\n");
1356 		}
1357 		print_error("  0x%08x : " SOURCE_LOCATION_FORMAT "\n",
1358 		            block_info->block, block_info->location.file,
1359 		            block_info->location.line);
1360 		allocated_blocks ++;
1361 	}
1362 	return allocated_blocks;
1363 }
1364 
1365 
1366 // Free all blocks allocated after the specified check point.
free_allocated_blocks(const ListNode * const check_point)1367 static void free_allocated_blocks(const ListNode * const check_point) {
1368 	const ListNode * const head = get_allocated_blocks_list();
1369 	const ListNode *node;
1370 	assert_true(check_point);
1371 
1372 	node = check_point->next;
1373 	assert_true(node);
1374 
1375 	while (node != head) {
1376 		MallocBlockInfo * const block_info = (MallocBlockInfo*)node->value;
1377 		node = node->next;
1378 		free((char*)block_info + sizeof(*block_info) + MALLOC_GUARD_SIZE);
1379 	}
1380 }
1381 
1382 
1383 // Fail if any any blocks are allocated after the specified check point.
fail_if_blocks_allocated(const ListNode * const check_point,const char * const test_name)1384 static void fail_if_blocks_allocated(const ListNode * const check_point,
1385                                      const char * const test_name) {
1386 	const int allocated_blocks = display_allocated_blocks(check_point);
1387 	if (allocated_blocks) {
1388 		free_allocated_blocks(check_point);
1389 		print_error("ERROR: %s leaked %d block(s)\n", test_name,
1390 		            allocated_blocks);
1391 		exit_test(1);
1392 	}
1393 }
1394 
1395 
_fail(const char * const file,const int line)1396 void _fail(const char * const file, const int line) {
1397 	print_error("ERROR: " SOURCE_LOCATION_FORMAT " Failure!\n", file, line);
1398 	exit_test(1);
1399 }
1400 
1401 
1402 #ifndef _WIN32
exception_handler(int sig)1403 static void exception_handler(int sig) {
1404 	print_error("%s\n", strsignal(sig));
1405 	exit_test(1);
1406 }
1407 
1408 #else // _WIN32
1409 
exception_filter(EXCEPTION_POINTERS * exception_pointers)1410 static LONG WINAPI exception_filter(EXCEPTION_POINTERS *exception_pointers) {
1411 	EXCEPTION_RECORD * const exception_record =
1412 	    exception_pointers->ExceptionRecord;
1413 	const DWORD code = exception_record->ExceptionCode;
1414 	unsigned int i;
1415 	for (i = 0; i < ARRAY_LENGTH(exception_codes); i++) {
1416 		const ExceptionCodeInfo * const code_info = &exception_codes[i];
1417 		if (code == code_info->code) {
1418 			static int shown_debug_message = 0;
1419 			fflush(stdout);
1420 			print_error("%s occurred at 0x%08x.\n", code_info->description,
1421 			            exception_record->ExceptionAddress);
1422 			if (!shown_debug_message) {
1423 				print_error(
1424 				    "\n"
1425 				    "To debug in Visual Studio...\n"
1426 				    "1. Select menu item File->Open Project\n"
1427 				    "2. Change 'Files of type' to 'Executable Files'\n"
1428 				    "3. Open this executable.\n"
1429 				    "4. Select menu item Debug->Start\n"
1430 				    "\n"
1431 				    "Alternatively, set the environment variable \n"
1432 				    "UNIT_TESTING_DEBUG to 1 and rebuild this executable, \n"
1433 				    "then click 'Debug' in the popup dialog box.\n"
1434 				    "\n");
1435 				shown_debug_message = 1;
1436 			}
1437 			exit_test(0);
1438 			return EXCEPTION_EXECUTE_HANDLER;
1439 		}
1440 	}
1441 	return EXCEPTION_CONTINUE_SEARCH;
1442 }
1443 #endif // !_WIN32
1444 
1445 
1446 // Standard output and error print methods.
vprint_message(const char * const format,va_list args)1447 void vprint_message(const char* const format, va_list args) {
1448 	char buffer[1024];
1449 	vsnprintf(buffer, sizeof(buffer), format, args);
1450 	printf(buffer);
1451 #ifdef _WIN32
1452 	OutputDebugString(buffer);
1453 #endif // _WIN32
1454 }
1455 
1456 
vprint_error(const char * const format,va_list args)1457 void vprint_error(const char* const format, va_list args) {
1458 	char buffer[1024];
1459 	vsnprintf(buffer, sizeof(buffer), format, args);
1460 	fprintf(stderr, buffer);
1461 #ifdef _WIN32
1462 	OutputDebugString(buffer);
1463 #endif // _WIN32
1464 }
1465 
1466 
print_message(const char * const format,...)1467 void print_message(const char* const format, ...) {
1468 	va_list args;
1469 	va_start(args, format);
1470 	vprint_message(format, args);
1471 	va_end(args);
1472 }
1473 
1474 
print_error(const char * const format,...)1475 void print_error(const char* const format, ...) {
1476 	va_list args;
1477 	va_start(args, format);
1478 	vprint_error(format, args);
1479 	va_end(args);
1480 }
1481 
1482 
_run_test(const char * const function_name,const UnitTestFunction Function,void ** const state,const UnitTestFunctionType function_type,const void * const heap_check_point)1483 int _run_test(
1484         const char * const function_name,  const UnitTestFunction Function,
1485         void ** const state, const UnitTestFunctionType function_type,
1486         const void* const heap_check_point) {
1487 	const ListNode * const check_point = heap_check_point ?
1488 	    heap_check_point : check_point_allocated_blocks();
1489 	void *current_state = NULL;
1490 	int rc = 1;
1491 	int handle_exceptions = 1;
1492 #ifdef _WIN32
1493 	handle_exceptions = !IsDebuggerPresent();
1494 #endif // _WIN32
1495 #if UNIT_TESTING_DEBUG
1496 	handle_exceptions = 0;
1497 #endif // UNIT_TESTING_DEBUG
1498 
1499 	if (handle_exceptions) {
1500 #ifndef _WIN32
1501 		unsigned int i;
1502 		for (i = 0; i < ARRAY_LENGTH(exception_signals); i++) {
1503 			default_signal_functions[i] = signal(
1504 			    exception_signals[i], exception_handler);
1505 		}
1506 #else // _WIN32
1507 		previous_exception_filter = SetUnhandledExceptionFilter(
1508 		    exception_filter);
1509 #endif // !_WIN32
1510 	}
1511 
1512 	if (function_type == UNIT_TEST_FUNCTION_TYPE_TEST) {
1513 		print_message("%s: Starting test\n", function_name);
1514 	}
1515 	initialize_testing(function_name);
1516 	global_running_test = 1;
1517 	if (setjmp(global_run_test_env) == 0) {
1518 		Function(state ? state : &current_state);
1519 		fail_if_leftover_values(function_name);
1520 
1521 		/* If this is a setup function then ignore any allocated blocks
1522 		 * only ensure they're deallocated on tear down. */
1523 		if (function_type != UNIT_TEST_FUNCTION_TYPE_SETUP) {
1524 			fail_if_blocks_allocated(check_point, function_name);
1525 		}
1526 
1527 		global_running_test = 0;
1528 
1529 		if (function_type == UNIT_TEST_FUNCTION_TYPE_TEST) {
1530 			print_message("%s: Test completed successfully.\n", function_name);
1531 		}
1532 		rc = 0;
1533 	} else {
1534 		global_running_test = 0;
1535 		print_message("%s: Test failed.\n", function_name);
1536 	}
1537 	teardown_testing(function_name);
1538 
1539 	if (handle_exceptions) {
1540 #ifndef _WIN32
1541 		unsigned int i;
1542 		for (i = 0; i < ARRAY_LENGTH(exception_signals); i++) {
1543 			signal(exception_signals[i], default_signal_functions[i]);
1544 		}
1545 #else // _WIN32
1546 		if (previous_exception_filter) {
1547 			SetUnhandledExceptionFilter(previous_exception_filter);
1548 			previous_exception_filter = NULL;
1549 		}
1550 #endif // !_WIN32
1551 	}
1552 
1553 	return rc;
1554 }
1555 
1556 
_run_tests(const UnitTest * const tests,const size_t number_of_tests)1557 int _run_tests(const UnitTest * const tests, const size_t number_of_tests) {
1558 	// Whether to execute the next test.
1559 	int run_next_test = 1;
1560 	// Whether the previous test failed.
1561 	int previous_test_failed = 0;
1562 	// Check point of the heap state.
1563 	const ListNode * const check_point = check_point_allocated_blocks();
1564 	// Current test being executed.
1565 	size_t current_test = 0;
1566 	// Number of tests executed.
1567 	size_t tests_executed = 0;
1568 	// Number of failed tests.
1569 	size_t total_failed = 0;
1570 	// Number of setup functions.
1571 	size_t setups = 0;
1572 	// Number of teardown functions.
1573 	size_t teardowns = 0;
1574 	/* A stack of test states.  A state is pushed on the stack
1575 	 * when a test setup occurs and popped on tear down. */
1576 	TestState* test_states = malloc(number_of_tests * sizeof(*test_states));
1577 	size_t number_of_test_states = 0;
1578 	// Names of the tests that failed.
1579 	const char** failed_names = malloc(number_of_tests *
1580 	                                   sizeof(*failed_names));
1581 	void **current_state = NULL;
1582 
1583 	while (current_test < number_of_tests) {
1584 		const ListNode *test_check_point = NULL;
1585 		TestState *current_TestState;
1586 		const UnitTest * const test = &tests[current_test++];
1587 		if (!test->function) {
1588 			continue;
1589 		}
1590 
1591 		switch (test->function_type) {
1592 		case UNIT_TEST_FUNCTION_TYPE_TEST:
1593 			run_next_test = 1;
1594 			break;
1595 		case UNIT_TEST_FUNCTION_TYPE_SETUP: {
1596 			// Checkpoint the heap before the setup.
1597 			current_TestState = &test_states[number_of_test_states++];
1598 			current_TestState->check_point = check_point_allocated_blocks();
1599 			test_check_point = current_TestState->check_point;
1600 			current_state = &current_TestState->state;
1601 			*current_state = NULL;
1602 			run_next_test = 1;
1603 			setups ++;
1604 			break;
1605 		}
1606 		case UNIT_TEST_FUNCTION_TYPE_TEARDOWN:
1607 			// Check the heap based on the last setup checkpoint.
1608 			assert_true(number_of_test_states);
1609 			current_TestState = &test_states[--number_of_test_states];
1610 			test_check_point = current_TestState->check_point;
1611 			current_state = &current_TestState->state;
1612 			teardowns ++;
1613 			break;
1614 		default:
1615 			print_error("Invalid unit test function type %d\n",
1616 			            test->function_type);
1617 			exit_test(1);
1618 			break;
1619 		}
1620 
1621 		if (run_next_test) {
1622 			int failed = _run_test(test->name, test->function, current_state,
1623 			                       test->function_type, test_check_point);
1624 			if (failed) {
1625 				failed_names[total_failed] = test->name;
1626 			}
1627 
1628 			switch (test->function_type) {
1629 			case UNIT_TEST_FUNCTION_TYPE_TEST:
1630 				previous_test_failed = failed;
1631 				total_failed += failed;
1632 				tests_executed ++;
1633 				break;
1634 
1635 			case UNIT_TEST_FUNCTION_TYPE_SETUP:
1636 				if (failed) {
1637 					total_failed ++;
1638 					tests_executed ++;
1639 					// Skip forward until the next test or setup function.
1640 					run_next_test = 0;
1641 				}
1642 				previous_test_failed = 0;
1643 				break;
1644 
1645 			case UNIT_TEST_FUNCTION_TYPE_TEARDOWN:
1646 				// If this test failed.
1647 				if (failed && !previous_test_failed) {
1648 					total_failed ++;
1649 				}
1650 				break;
1651 			default:
1652 				assert_false("BUG: shouldn't be here!");
1653 				break;
1654 			}
1655 		}
1656 	}
1657 
1658 	if (total_failed) {
1659 		size_t i;
1660 		print_error("%d out of %d tests failed!\n", total_failed,
1661 		            tests_executed);
1662 		for (i = 0; i < total_failed; i++) {
1663 			print_error("    %s\n", failed_names[i]);
1664 		}
1665 	} else {
1666 		print_message("All %d tests passed\n", tests_executed);
1667 	}
1668 
1669 	if (number_of_test_states) {
1670 		print_error("Mismatched number of setup %d and teardown %d "
1671 		            "functions\n", setups, teardowns);
1672 		total_failed = -1;
1673 	}
1674 
1675 	free(test_states);
1676 	free((void*)failed_names);
1677 
1678 	fail_if_blocks_allocated(check_point, "run_tests");
1679 	return (int)total_failed;
1680 }
1681