1# coding=utf-8
2#
3# Copyright © 2011, 2018 Intel Corporation
4#
5# Permission is hereby granted, free of charge, to any person obtaining a
6# copy of this software and associated documentation files (the "Software"),
7# to deal in the Software without restriction, including without limitation
8# the rights to use, copy, modify, merge, publish, distribute, sublicense,
9# and/or sell copies of the Software, and to permit persons to whom the
10# Software is furnished to do so, subject to the following conditions:
11#
12# The above copyright notice and this permission notice (including the next
13# paragraph) shall be included in all copies or substantial portions of the
14# Software.
15#
16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22# DEALINGS IN THE SOFTWARE.
23
24from sexps import *
25
26def make_test_case(f_name, ret_type, body):
27    """Create a simple optimization test case consisting of a single
28    function with the given name, return type, and body.
29
30    Global declarations are automatically created for any undeclared
31    variables that are referenced by the function.  All undeclared
32    variables are assumed to be floats.
33    """
34    check_sexp(body)
35    declarations = {}
36    def make_declarations(sexp, already_declared = ()):
37        if isinstance(sexp, list):
38            if len(sexp) == 2 and sexp[0] == 'var_ref':
39                if sexp[1] not in already_declared:
40                    declarations[sexp[1]] = [
41                        'declare', ['in'], 'float', sexp[1]]
42            elif len(sexp) == 4 and sexp[0] == 'assign':
43                assert sexp[2][0] == 'var_ref'
44                if sexp[2][1] not in already_declared:
45                    declarations[sexp[2][1]] = [
46                        'declare', ['out'], 'float', sexp[2][1]]
47                make_declarations(sexp[3], already_declared)
48            else:
49                already_declared = set(already_declared)
50                for s in sexp:
51                    if isinstance(s, list) and len(s) >= 4 and \
52                            s[0] == 'declare':
53                        already_declared.add(s[3])
54                    else:
55                        make_declarations(s, already_declared)
56    make_declarations(body)
57    return list(declarations.values()) + \
58        [['function', f_name, ['signature', ret_type, ['parameters'], body]]]
59
60
61# The following functions can be used to build expressions.
62
63def const_float(value):
64    """Create an expression representing the given floating point value."""
65    return ['constant', 'float', ['{0:.6f}'.format(value)]]
66
67def const_bool(value):
68    """Create an expression representing the given boolean value.
69
70    If value is not a boolean, it is converted to a boolean.  So, for
71    instance, const_bool(1) is equivalent to const_bool(True).
72    """
73    return ['constant', 'bool', ['{0}'.format(1 if value else 0)]]
74
75def gt_zero(var_name):
76    """Create Construct the expression var_name > 0"""
77    return ['expression', 'bool', '<', const_float(0), ['var_ref', var_name]]
78
79
80# The following functions can be used to build complex control flow
81# statements.  All of these functions return statement lists (even
82# those which only create a single statement), so that statements can
83# be sequenced together using the '+' operator.
84
85def return_(value = None):
86    """Create a return statement."""
87    if value is not None:
88        return [['return', value]]
89    else:
90        return [['return']]
91
92def break_():
93    """Create a break statement."""
94    return ['break']
95
96def continue_():
97    """Create a continue statement."""
98    return ['continue']
99
100def simple_if(var_name, then_statements, else_statements = None):
101    """Create a statement of the form
102
103    if (var_name > 0.0) {
104       <then_statements>
105    } else {
106       <else_statements>
107    }
108
109    else_statements may be omitted.
110    """
111    if else_statements is None:
112        else_statements = []
113    check_sexp(then_statements)
114    check_sexp(else_statements)
115    return [['if', gt_zero(var_name), then_statements, else_statements]]
116
117def loop(statements):
118    """Create a loop containing the given statements as its loop
119    body.
120    """
121    check_sexp(statements)
122    return [['loop', statements]]
123
124def declare_temp(var_type, var_name):
125    """Create a declaration of the form
126
127    (declare (temporary) <var_type> <var_name)
128    """
129    return [['declare', ['temporary'], var_type, var_name]]
130
131def assign_x(var_name, value):
132    """Create a statement that assigns <value> to the variable
133    <var_name>.  The assignment uses the mask (x).
134    """
135    check_sexp(value)
136    return [['assign', ['x'], ['var_ref', var_name], value]]
137
138def complex_if(var_prefix, statements):
139    """Create a statement of the form
140
141    if (<var_prefix>a > 0.0) {
142       if (<var_prefix>b > 0.0) {
143          <statements>
144       }
145    }
146
147    This is useful in testing jump lowering, because if <statements>
148    ends in a jump, lower_jumps.cpp won't try to combine this
149    construct with the code that follows it, as it might do for a
150    simple if.
151
152    All variables used in the if statement are prefixed with
153    var_prefix.  This can be used to ensure uniqueness.
154    """
155    check_sexp(statements)
156    return simple_if(var_prefix + 'a', simple_if(var_prefix + 'b', statements))
157
158def declare_execute_flag():
159    """Create the statements that lower_jumps.cpp uses to declare and
160    initialize the temporary boolean execute_flag.
161    """
162    return declare_temp('bool', 'execute_flag') + \
163        assign_x('execute_flag', const_bool(True))
164
165def declare_return_flag():
166    """Create the statements that lower_jumps.cpp uses to declare and
167    initialize the temporary boolean return_flag.
168    """
169    return declare_temp('bool', 'return_flag') + \
170        assign_x('return_flag', const_bool(False))
171
172def declare_return_value():
173    """Create the statements that lower_jumps.cpp uses to declare and
174    initialize the temporary variable return_value.  Assume that
175    return_value is a float.
176    """
177    return declare_temp('float', 'return_value')
178
179def declare_break_flag():
180    """Create the statements that lower_jumps.cpp uses to declare and
181    initialize the temporary boolean break_flag.
182    """
183    return declare_temp('bool', 'break_flag') + \
184        assign_x('break_flag', const_bool(False))
185
186def lowered_return_simple(value = None):
187    """Create the statements that lower_jumps.cpp lowers a return
188    statement to, in situations where it does not need to clear the
189    execute flag.
190    """
191    if value:
192        result = assign_x('return_value', value)
193    else:
194        result = []
195    return result + assign_x('return_flag', const_bool(True))
196
197def lowered_return(value = None):
198    """Create the statements that lower_jumps.cpp lowers a return
199    statement to, in situations where it needs to clear the execute
200    flag.
201    """
202    return lowered_return_simple(value) + \
203        assign_x('execute_flag', const_bool(False))
204
205def lowered_continue():
206    """Create the statement that lower_jumps.cpp lowers a continue
207    statement to.
208    """
209    return assign_x('execute_flag', const_bool(False))
210
211def lowered_break_simple():
212    """Create the statement that lower_jumps.cpp lowers a break
213    statement to, in situations where it does not need to clear the
214    execute flag.
215    """
216    return assign_x('break_flag', const_bool(True))
217
218def lowered_break():
219    """Create the statement that lower_jumps.cpp lowers a break
220    statement to, in situations where it needs to clear the execute
221    flag.
222    """
223    return lowered_break_simple() + assign_x('execute_flag', const_bool(False))
224
225def if_execute_flag(statements):
226    """Wrap statements in an if test so that they will only execute if
227    execute_flag is True.
228    """
229    check_sexp(statements)
230    return [['if', ['var_ref', 'execute_flag'], statements, []]]
231
232def if_return_flag(then_statements, else_statements):
233    """Wrap statements in an if test with return_flag as the condition.
234    """
235    check_sexp(then_statements)
236    check_sexp(else_statements)
237    return [['if', ['var_ref', 'return_flag'], then_statements, else_statements]]
238
239def if_not_return_flag(statements):
240    """Wrap statements in an if test so that they will only execute if
241    return_flag is False.
242    """
243    check_sexp(statements)
244    return [['if', ['var_ref', 'return_flag'], [], statements]]
245
246def final_return():
247    """Create the return statement that lower_jumps.cpp places at the
248    end of a function when lowering returns.
249    """
250    return [['return', ['var_ref', 'return_value']]]
251
252def final_break():
253    """Create the conditional break statement that lower_jumps.cpp
254    places at the end of a function when lowering breaks.
255    """
256    return [['if', ['var_ref', 'break_flag'], break_(), []]]
257
258def bash_quote(*args):
259    """Quote the arguments appropriately so that bash will understand
260    each argument as a single word.
261    """
262    def quote_word(word):
263        for c in word:
264            if not (c.isalpha() or c.isdigit() or c in '@%_-+=:,./'):
265                break
266        else:
267            if not word:
268                return "''"
269            return word
270        return "'{0}'".format(word.replace("'", "'\"'\"'"))
271    return ' '.join(quote_word(word) for word in args)
272
273def create_test_case(input_sexp, expected_sexp, test_name,
274                     pull_out_jumps=False, lower_sub_return=False,
275                     lower_main_return=False, lower_continue=False,
276                     lower_break=False):
277    """Create a test case that verifies that do_lower_jumps transforms
278    the given code in the expected way.
279    """
280    check_sexp(input_sexp)
281    check_sexp(expected_sexp)
282    input_str = sexp_to_string(sort_decls(input_sexp))
283    expected_output = sexp_to_string(sort_decls(expected_sexp)) # XXX: don't stringify this
284    optimization = (
285        'do_lower_jumps({0:d}, {1:d}, {2:d}, {3:d}, {4:d})'.format(
286            pull_out_jumps, lower_sub_return, lower_main_return,
287            lower_continue, lower_break))
288
289    return (test_name, optimization, input_str, expected_output)
290
291def test_lower_returns_main():
292    """Test that do_lower_jumps respects the lower_main_return flag in deciding
293    whether to lower returns in the main function.
294    """
295    input_sexp = make_test_case('main', 'void', (
296            complex_if('', return_())
297            ))
298    expected_sexp = make_test_case('main', 'void', (
299            declare_execute_flag() +
300            declare_return_flag() +
301            complex_if('', lowered_return())
302            ))
303    yield create_test_case(
304        input_sexp, expected_sexp, 'lower_returns_main_true',
305        lower_main_return=True)
306    yield create_test_case(
307        input_sexp, input_sexp, 'lower_returns_main_false',
308        lower_main_return=False)
309
310def test_lower_returns_sub():
311    """Test that do_lower_jumps respects the lower_sub_return flag in deciding
312    whether to lower returns in subroutines.
313    """
314    input_sexp = make_test_case('sub', 'void', (
315            complex_if('', return_())
316            ))
317    expected_sexp = make_test_case('sub', 'void', (
318            declare_execute_flag() +
319            declare_return_flag() +
320            complex_if('', lowered_return())
321            ))
322    yield create_test_case(
323        input_sexp, expected_sexp, 'lower_returns_sub_true',
324        lower_sub_return=True)
325    yield create_test_case(
326        input_sexp, input_sexp, 'lower_returns_sub_false',
327        lower_sub_return=False)
328
329def test_lower_returns_1():
330    """Test that a void return at the end of a function is eliminated."""
331    input_sexp = make_test_case('main', 'void', (
332            assign_x('a', const_float(1)) +
333            return_()
334            ))
335    expected_sexp = make_test_case('main', 'void', (
336            assign_x('a', const_float(1))
337            ))
338    yield create_test_case(
339        input_sexp, expected_sexp, 'lower_returns_1', lower_main_return=True)
340
341def test_lower_returns_2():
342    """Test that lowering is not performed on a non-void return at the end of
343    subroutine.
344    """
345    input_sexp = make_test_case('sub', 'float', (
346            assign_x('a', const_float(1)) +
347            return_(const_float(1))
348            ))
349    yield create_test_case(
350        input_sexp, input_sexp, 'lower_returns_2', lower_sub_return=True)
351
352def test_lower_returns_3():
353    """Test lowering of returns when there is one nested inside a complex
354    structure of ifs, and one at the end of a function.
355
356    In this case, the latter return needs to be lowered because it will not be
357    at the end of the function once the final return is inserted.
358    """
359    input_sexp = make_test_case('sub', 'float', (
360            complex_if('', return_(const_float(1))) +
361            return_(const_float(2))
362            ))
363    expected_sexp = make_test_case('sub', 'float', (
364            declare_execute_flag() +
365            declare_return_value() +
366            declare_return_flag() +
367            complex_if('', lowered_return(const_float(1))) +
368            if_execute_flag(lowered_return(const_float(2))) +
369            final_return()
370            ))
371    yield create_test_case(
372        input_sexp, expected_sexp, 'lower_returns_3', lower_sub_return=True)
373
374def test_lower_returns_4():
375    """Test that returns are properly lowered when they occur in both branches
376    of an if-statement.
377    """
378    input_sexp = make_test_case('sub', 'float', (
379            simple_if('a', return_(const_float(1)),
380                      return_(const_float(2)))
381            ))
382    expected_sexp = make_test_case('sub', 'float', (
383            declare_execute_flag() +
384            declare_return_value() +
385            declare_return_flag() +
386            simple_if('a', lowered_return(const_float(1)),
387                      lowered_return(const_float(2))) +
388            final_return()
389            ))
390    yield create_test_case(
391        input_sexp, expected_sexp, 'lower_returns_4', lower_sub_return=True)
392
393def test_lower_unified_returns():
394    """If both branches of an if statement end in a return, and pull_out_jumps
395    is True, then those returns should be lifted outside the if and then
396    properly lowered.
397
398    Verify that this lowering occurs during the same pass as the lowering of
399    other returns by checking that extra temporary variables aren't generated.
400    """
401    input_sexp = make_test_case('main', 'void', (
402            complex_if('a', return_()) +
403            simple_if('b', simple_if('c', return_(), return_()))
404            ))
405    expected_sexp = make_test_case('main', 'void', (
406            declare_execute_flag() +
407            declare_return_flag() +
408            complex_if('a', lowered_return()) +
409            if_execute_flag(simple_if('b', (simple_if('c', [], []) +
410                                            lowered_return())))
411            ))
412    yield create_test_case(
413        input_sexp, expected_sexp, 'lower_unified_returns',
414        lower_main_return=True, pull_out_jumps=True)
415
416def test_lower_pulled_out_jump():
417    doc_string = """If one branch of an if ends in a jump, and control cannot
418    fall out the bottom of the other branch, and pull_out_jumps is
419    True, then the jump is lifted outside the if.
420
421    Verify that this lowering occurs during the same pass as the
422    lowering of other jumps by checking that extra temporary
423    variables aren't generated.
424    """
425    input_sexp = make_test_case('main', 'void', (
426            complex_if('a', return_()) +
427            loop(simple_if('b', simple_if('c', break_(), continue_()),
428                           return_())) +
429            assign_x('d', const_float(1))
430            ))
431    # Note: optimization produces two other effects: the break
432    # gets lifted out of the if statements, and the code after the
433    # loop gets guarded so that it only executes if the return
434    # flag is clear.
435    expected_sexp = make_test_case('main', 'void', (
436            declare_execute_flag() +
437            declare_return_flag() +
438            complex_if('a', lowered_return()) +
439            if_execute_flag(
440                loop(simple_if('b', simple_if('c', [], continue_()),
441                               lowered_return_simple()) +
442                     break_()) +
443
444                if_return_flag(assign_x('return_flag', const_bool(1)) +
445                               assign_x('execute_flag', const_bool(0)),
446                               assign_x('d', const_float(1))))
447            ))
448    yield create_test_case(
449        input_sexp, expected_sexp, 'lower_pulled_out_jump',
450        lower_main_return=True, pull_out_jumps=True)
451
452def test_lower_breaks_1():
453    """If a loop contains an unconditional break at the bottom of it, it should
454    not be lowered.
455    """
456    input_sexp = make_test_case('main', 'void', (
457            loop(assign_x('a', const_float(1)) +
458                 break_())
459            ))
460    expected_sexp = input_sexp
461    yield create_test_case(
462        input_sexp, expected_sexp, 'lower_breaks_1', lower_break=True)
463
464def test_lower_breaks_2():
465    """If a loop contains a conditional break at the bottom of it, it should
466    not be lowered if it is in the then-clause.
467    """
468    input_sexp = make_test_case('main', 'void', (
469            loop(assign_x('a', const_float(1)) +
470                 simple_if('b', break_()))
471            ))
472    expected_sexp = input_sexp
473    yield create_test_case(
474        input_sexp, expected_sexp, 'lower_breaks_2', lower_break=True)
475
476def test_lower_breaks_3():
477    """If a loop contains a conditional break at the bottom of it, it should
478    not be lowered if it is in the then-clause, even if there are statements
479    preceding the break.
480    """
481    input_sexp = make_test_case('main', 'void', (
482            loop(assign_x('a', const_float(1)) +
483                 simple_if('b', (assign_x('c', const_float(1)) +
484                                 break_())))
485            ))
486    expected_sexp = input_sexp
487    yield create_test_case(
488        input_sexp, expected_sexp, 'lower_breaks_3', lower_break=True)
489
490def test_lower_breaks_4():
491    """If a loop contains a conditional break at the bottom of it, it should
492    not be lowered if it is in the else-clause.
493    """
494    input_sexp = make_test_case('main', 'void', (
495            loop(assign_x('a', const_float(1)) +
496                 simple_if('b', [], break_()))
497            ))
498    expected_sexp = input_sexp
499    yield create_test_case(
500        input_sexp, expected_sexp, 'lower_breaks_4', lower_break=True)
501
502def test_lower_breaks_5():
503    """If a loop contains a conditional break at the bottom of it, it should
504    not be lowered if it is in the else-clause, even if there are statements
505    preceding the break.
506    """
507    input_sexp = make_test_case('main', 'void', (
508            loop(assign_x('a', const_float(1)) +
509                 simple_if('b', [], (assign_x('c', const_float(1)) +
510                                     break_())))
511            ))
512    expected_sexp = input_sexp
513    yield create_test_case(
514        input_sexp, expected_sexp, 'lower_breaks_5', lower_break=True)
515
516def test_lower_breaks_6():
517    """If a loop contains conditional breaks and continues, and ends in an
518    unconditional break, then the unconditional break needs to be lowered,
519    because it will no longer be at the end of the loop after the final break
520    is added.
521    """
522    input_sexp = make_test_case('main', 'void', (
523            loop(simple_if('a', (complex_if('b', continue_()) +
524                                 complex_if('c', break_()))) +
525                 break_())
526            ))
527    expected_sexp = make_test_case('main', 'void', (
528            declare_break_flag() +
529            loop(declare_execute_flag() +
530                 simple_if(
531                    'a',
532                    (complex_if('b', lowered_continue()) +
533                     if_execute_flag(
534                            complex_if('c', lowered_break())))) +
535                 if_execute_flag(lowered_break_simple()) +
536                 final_break())
537            ))
538    yield create_test_case(
539        input_sexp, expected_sexp, 'lower_breaks_6', lower_break=True,
540        lower_continue=True)
541
542def test_lower_guarded_conditional_break():
543    """Normally a conditional break at the end of a loop isn't lowered, however
544    if the conditional break gets placed inside an if(execute_flag) because of
545    earlier lowering of continues, then the break needs to be lowered.
546    """
547    input_sexp = make_test_case('main', 'void', (
548            loop(complex_if('a', continue_()) +
549                 simple_if('b', break_()))
550            ))
551    expected_sexp = make_test_case('main', 'void', (
552            declare_break_flag() +
553            loop(declare_execute_flag() +
554                 complex_if('a', lowered_continue()) +
555                 if_execute_flag(simple_if('b', lowered_break())) +
556                 final_break())
557            ))
558    yield create_test_case(
559        input_sexp, expected_sexp, 'lower_guarded_conditional_break',
560        lower_break=True, lower_continue=True)
561
562def test_remove_continue_at_end_of_loop():
563    """Test that a redundant continue-statement at the end of a loop is
564    removed.
565    """
566    input_sexp = make_test_case('main', 'void', (
567            loop(assign_x('a', const_float(1)) +
568                 continue_())
569            ))
570    expected_sexp = make_test_case('main', 'void', (
571            loop(assign_x('a', const_float(1)))
572            ))
573    yield create_test_case(input_sexp, expected_sexp, 'remove_continue_at_end_of_loop')
574
575def test_lower_return_void_at_end_of_loop():
576    """Test that a return of void at the end of a loop is properly lowered."""
577    input_sexp = make_test_case('main', 'void', (
578            loop(assign_x('a', const_float(1)) +
579                 return_()) +
580            assign_x('b', const_float(2))
581            ))
582    expected_sexp = make_test_case('main', 'void', (
583            declare_execute_flag() +
584            declare_return_flag() +
585            loop(assign_x('a', const_float(1)) +
586                 lowered_return_simple() +
587                 break_()) +
588            if_return_flag(assign_x('return_flag', const_bool(1)) +
589                           assign_x('execute_flag', const_bool(0)),
590                           assign_x('b', const_float(2)))
591            ))
592    yield create_test_case(
593        input_sexp, input_sexp, 'return_void_at_end_of_loop_lower_nothing')
594    yield create_test_case(
595        input_sexp, expected_sexp, 'return_void_at_end_of_loop_lower_return',
596        lower_main_return=True)
597    yield create_test_case(
598        input_sexp, expected_sexp,
599        'return_void_at_end_of_loop_lower_return_and_break',
600        lower_main_return=True, lower_break=True)
601
602def test_lower_return_non_void_at_end_of_loop():
603    """Test that a non-void return at the end of a loop is properly lowered."""
604    input_sexp = make_test_case('sub', 'float', (
605            loop(assign_x('a', const_float(1)) +
606                 return_(const_float(2))) +
607            assign_x('b', const_float(3)) +
608            return_(const_float(4))
609            ))
610    expected_sexp = make_test_case('sub', 'float', (
611            declare_execute_flag() +
612            declare_return_value() +
613            declare_return_flag() +
614            loop(assign_x('a', const_float(1)) +
615                 lowered_return_simple(const_float(2)) +
616                 break_()) +
617            if_return_flag(assign_x('return_value', '(var_ref return_value)') +
618                           assign_x('return_flag', const_bool(1)) +
619                           assign_x('execute_flag', const_bool(0)),
620                           assign_x('b', const_float(3)) +
621                               lowered_return(const_float(4))) +
622            final_return()
623            ))
624    yield create_test_case(
625        input_sexp, input_sexp, 'return_non_void_at_end_of_loop_lower_nothing')
626    yield create_test_case(
627        input_sexp, expected_sexp,
628        'return_non_void_at_end_of_loop_lower_return', lower_sub_return=True)
629    yield create_test_case(
630        input_sexp, expected_sexp,
631        'return_non_void_at_end_of_loop_lower_return_and_break',
632        lower_sub_return=True, lower_break=True)
633
634CASES = [
635    test_lower_breaks_1, test_lower_breaks_2, test_lower_breaks_3,
636    test_lower_breaks_4, test_lower_breaks_5, test_lower_breaks_6,
637    test_lower_guarded_conditional_break, test_lower_pulled_out_jump,
638    test_lower_return_non_void_at_end_of_loop,
639    test_lower_return_void_at_end_of_loop,
640    test_lower_returns_1, test_lower_returns_2, test_lower_returns_3,
641    test_lower_returns_4, test_lower_returns_main, test_lower_returns_sub,
642    test_lower_unified_returns, test_remove_continue_at_end_of_loop,
643]
644