1# coding=utf-8 2# 3# Copyright © 2011, 2018 Intel Corporation 4# 5# Permission is hereby granted, free of charge, to any person obtaining a 6# copy of this software and associated documentation files (the "Software"), 7# to deal in the Software without restriction, including without limitation 8# the rights to use, copy, modify, merge, publish, distribute, sublicense, 9# and/or sell copies of the Software, and to permit persons to whom the 10# Software is furnished to do so, subject to the following conditions: 11# 12# The above copyright notice and this permission notice (including the next 13# paragraph) shall be included in all copies or substantial portions of the 14# Software. 15# 16# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL 19# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 21# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 22# DEALINGS IN THE SOFTWARE. 23 24from sexps import * 25 26def make_test_case(f_name, ret_type, body): 27 """Create a simple optimization test case consisting of a single 28 function with the given name, return type, and body. 29 30 Global declarations are automatically created for any undeclared 31 variables that are referenced by the function. All undeclared 32 variables are assumed to be floats. 33 """ 34 check_sexp(body) 35 declarations = {} 36 def make_declarations(sexp, already_declared = ()): 37 if isinstance(sexp, list): 38 if len(sexp) == 2 and sexp[0] == 'var_ref': 39 if sexp[1] not in already_declared: 40 declarations[sexp[1]] = [ 41 'declare', ['in'], 'float', sexp[1]] 42 elif len(sexp) == 4 and sexp[0] == 'assign': 43 assert sexp[2][0] == 'var_ref' 44 if sexp[2][1] not in already_declared: 45 declarations[sexp[2][1]] = [ 46 'declare', ['out'], 'float', sexp[2][1]] 47 make_declarations(sexp[3], already_declared) 48 else: 49 already_declared = set(already_declared) 50 for s in sexp: 51 if isinstance(s, list) and len(s) >= 4 and \ 52 s[0] == 'declare': 53 already_declared.add(s[3]) 54 else: 55 make_declarations(s, already_declared) 56 make_declarations(body) 57 return list(declarations.values()) + \ 58 [['function', f_name, ['signature', ret_type, ['parameters'], body]]] 59 60 61# The following functions can be used to build expressions. 62 63def const_float(value): 64 """Create an expression representing the given floating point value.""" 65 return ['constant', 'float', ['{0:.6f}'.format(value)]] 66 67def const_bool(value): 68 """Create an expression representing the given boolean value. 69 70 If value is not a boolean, it is converted to a boolean. So, for 71 instance, const_bool(1) is equivalent to const_bool(True). 72 """ 73 return ['constant', 'bool', ['{0}'.format(1 if value else 0)]] 74 75def gt_zero(var_name): 76 """Create Construct the expression var_name > 0""" 77 return ['expression', 'bool', '<', const_float(0), ['var_ref', var_name]] 78 79 80# The following functions can be used to build complex control flow 81# statements. All of these functions return statement lists (even 82# those which only create a single statement), so that statements can 83# be sequenced together using the '+' operator. 84 85def return_(value = None): 86 """Create a return statement.""" 87 if value is not None: 88 return [['return', value]] 89 else: 90 return [['return']] 91 92def break_(): 93 """Create a break statement.""" 94 return ['break'] 95 96def continue_(): 97 """Create a continue statement.""" 98 return ['continue'] 99 100def simple_if(var_name, then_statements, else_statements = None): 101 """Create a statement of the form 102 103 if (var_name > 0.0) { 104 <then_statements> 105 } else { 106 <else_statements> 107 } 108 109 else_statements may be omitted. 110 """ 111 if else_statements is None: 112 else_statements = [] 113 check_sexp(then_statements) 114 check_sexp(else_statements) 115 return [['if', gt_zero(var_name), then_statements, else_statements]] 116 117def loop(statements): 118 """Create a loop containing the given statements as its loop 119 body. 120 """ 121 check_sexp(statements) 122 return [['loop', statements]] 123 124def declare_temp(var_type, var_name): 125 """Create a declaration of the form 126 127 (declare (temporary) <var_type> <var_name) 128 """ 129 return [['declare', ['temporary'], var_type, var_name]] 130 131def assign_x(var_name, value): 132 """Create a statement that assigns <value> to the variable 133 <var_name>. The assignment uses the mask (x). 134 """ 135 check_sexp(value) 136 return [['assign', ['x'], ['var_ref', var_name], value]] 137 138def complex_if(var_prefix, statements): 139 """Create a statement of the form 140 141 if (<var_prefix>a > 0.0) { 142 if (<var_prefix>b > 0.0) { 143 <statements> 144 } 145 } 146 147 This is useful in testing jump lowering, because if <statements> 148 ends in a jump, lower_jumps.cpp won't try to combine this 149 construct with the code that follows it, as it might do for a 150 simple if. 151 152 All variables used in the if statement are prefixed with 153 var_prefix. This can be used to ensure uniqueness. 154 """ 155 check_sexp(statements) 156 return simple_if(var_prefix + 'a', simple_if(var_prefix + 'b', statements)) 157 158def declare_execute_flag(): 159 """Create the statements that lower_jumps.cpp uses to declare and 160 initialize the temporary boolean execute_flag. 161 """ 162 return declare_temp('bool', 'execute_flag') + \ 163 assign_x('execute_flag', const_bool(True)) 164 165def declare_return_flag(): 166 """Create the statements that lower_jumps.cpp uses to declare and 167 initialize the temporary boolean return_flag. 168 """ 169 return declare_temp('bool', 'return_flag') + \ 170 assign_x('return_flag', const_bool(False)) 171 172def declare_return_value(): 173 """Create the statements that lower_jumps.cpp uses to declare and 174 initialize the temporary variable return_value. Assume that 175 return_value is a float. 176 """ 177 return declare_temp('float', 'return_value') 178 179def declare_break_flag(): 180 """Create the statements that lower_jumps.cpp uses to declare and 181 initialize the temporary boolean break_flag. 182 """ 183 return declare_temp('bool', 'break_flag') + \ 184 assign_x('break_flag', const_bool(False)) 185 186def lowered_return_simple(value = None): 187 """Create the statements that lower_jumps.cpp lowers a return 188 statement to, in situations where it does not need to clear the 189 execute flag. 190 """ 191 if value: 192 result = assign_x('return_value', value) 193 else: 194 result = [] 195 return result + assign_x('return_flag', const_bool(True)) 196 197def lowered_return(value = None): 198 """Create the statements that lower_jumps.cpp lowers a return 199 statement to, in situations where it needs to clear the execute 200 flag. 201 """ 202 return lowered_return_simple(value) + \ 203 assign_x('execute_flag', const_bool(False)) 204 205def lowered_continue(): 206 """Create the statement that lower_jumps.cpp lowers a continue 207 statement to. 208 """ 209 return assign_x('execute_flag', const_bool(False)) 210 211def lowered_break_simple(): 212 """Create the statement that lower_jumps.cpp lowers a break 213 statement to, in situations where it does not need to clear the 214 execute flag. 215 """ 216 return assign_x('break_flag', const_bool(True)) 217 218def lowered_break(): 219 """Create the statement that lower_jumps.cpp lowers a break 220 statement to, in situations where it needs to clear the execute 221 flag. 222 """ 223 return lowered_break_simple() + assign_x('execute_flag', const_bool(False)) 224 225def if_execute_flag(statements): 226 """Wrap statements in an if test so that they will only execute if 227 execute_flag is True. 228 """ 229 check_sexp(statements) 230 return [['if', ['var_ref', 'execute_flag'], statements, []]] 231 232def if_return_flag(then_statements, else_statements): 233 """Wrap statements in an if test with return_flag as the condition. 234 """ 235 check_sexp(then_statements) 236 check_sexp(else_statements) 237 return [['if', ['var_ref', 'return_flag'], then_statements, else_statements]] 238 239def if_not_return_flag(statements): 240 """Wrap statements in an if test so that they will only execute if 241 return_flag is False. 242 """ 243 check_sexp(statements) 244 return [['if', ['var_ref', 'return_flag'], [], statements]] 245 246def final_return(): 247 """Create the return statement that lower_jumps.cpp places at the 248 end of a function when lowering returns. 249 """ 250 return [['return', ['var_ref', 'return_value']]] 251 252def final_break(): 253 """Create the conditional break statement that lower_jumps.cpp 254 places at the end of a function when lowering breaks. 255 """ 256 return [['if', ['var_ref', 'break_flag'], break_(), []]] 257 258def bash_quote(*args): 259 """Quote the arguments appropriately so that bash will understand 260 each argument as a single word. 261 """ 262 def quote_word(word): 263 for c in word: 264 if not (c.isalpha() or c.isdigit() or c in '@%_-+=:,./'): 265 break 266 else: 267 if not word: 268 return "''" 269 return word 270 return "'{0}'".format(word.replace("'", "'\"'\"'")) 271 return ' '.join(quote_word(word) for word in args) 272 273def create_test_case(input_sexp, expected_sexp, test_name, 274 pull_out_jumps=False, lower_sub_return=False, 275 lower_main_return=False, lower_continue=False, 276 lower_break=False): 277 """Create a test case that verifies that do_lower_jumps transforms 278 the given code in the expected way. 279 """ 280 check_sexp(input_sexp) 281 check_sexp(expected_sexp) 282 input_str = sexp_to_string(sort_decls(input_sexp)) 283 expected_output = sexp_to_string(sort_decls(expected_sexp)) # XXX: don't stringify this 284 optimization = ( 285 'do_lower_jumps({0:d}, {1:d}, {2:d}, {3:d}, {4:d})'.format( 286 pull_out_jumps, lower_sub_return, lower_main_return, 287 lower_continue, lower_break)) 288 289 return (test_name, optimization, input_str, expected_output) 290 291def test_lower_returns_main(): 292 """Test that do_lower_jumps respects the lower_main_return flag in deciding 293 whether to lower returns in the main function. 294 """ 295 input_sexp = make_test_case('main', 'void', ( 296 complex_if('', return_()) 297 )) 298 expected_sexp = make_test_case('main', 'void', ( 299 declare_execute_flag() + 300 declare_return_flag() + 301 complex_if('', lowered_return()) 302 )) 303 yield create_test_case( 304 input_sexp, expected_sexp, 'lower_returns_main_true', 305 lower_main_return=True) 306 yield create_test_case( 307 input_sexp, input_sexp, 'lower_returns_main_false', 308 lower_main_return=False) 309 310def test_lower_returns_sub(): 311 """Test that do_lower_jumps respects the lower_sub_return flag in deciding 312 whether to lower returns in subroutines. 313 """ 314 input_sexp = make_test_case('sub', 'void', ( 315 complex_if('', return_()) 316 )) 317 expected_sexp = make_test_case('sub', 'void', ( 318 declare_execute_flag() + 319 declare_return_flag() + 320 complex_if('', lowered_return()) 321 )) 322 yield create_test_case( 323 input_sexp, expected_sexp, 'lower_returns_sub_true', 324 lower_sub_return=True) 325 yield create_test_case( 326 input_sexp, input_sexp, 'lower_returns_sub_false', 327 lower_sub_return=False) 328 329def test_lower_returns_1(): 330 """Test that a void return at the end of a function is eliminated.""" 331 input_sexp = make_test_case('main', 'void', ( 332 assign_x('a', const_float(1)) + 333 return_() 334 )) 335 expected_sexp = make_test_case('main', 'void', ( 336 assign_x('a', const_float(1)) 337 )) 338 yield create_test_case( 339 input_sexp, expected_sexp, 'lower_returns_1', lower_main_return=True) 340 341def test_lower_returns_2(): 342 """Test that lowering is not performed on a non-void return at the end of 343 subroutine. 344 """ 345 input_sexp = make_test_case('sub', 'float', ( 346 assign_x('a', const_float(1)) + 347 return_(const_float(1)) 348 )) 349 yield create_test_case( 350 input_sexp, input_sexp, 'lower_returns_2', lower_sub_return=True) 351 352def test_lower_returns_3(): 353 """Test lowering of returns when there is one nested inside a complex 354 structure of ifs, and one at the end of a function. 355 356 In this case, the latter return needs to be lowered because it will not be 357 at the end of the function once the final return is inserted. 358 """ 359 input_sexp = make_test_case('sub', 'float', ( 360 complex_if('', return_(const_float(1))) + 361 return_(const_float(2)) 362 )) 363 expected_sexp = make_test_case('sub', 'float', ( 364 declare_execute_flag() + 365 declare_return_value() + 366 declare_return_flag() + 367 complex_if('', lowered_return(const_float(1))) + 368 if_execute_flag(lowered_return(const_float(2))) + 369 final_return() 370 )) 371 yield create_test_case( 372 input_sexp, expected_sexp, 'lower_returns_3', lower_sub_return=True) 373 374def test_lower_returns_4(): 375 """Test that returns are properly lowered when they occur in both branches 376 of an if-statement. 377 """ 378 input_sexp = make_test_case('sub', 'float', ( 379 simple_if('a', return_(const_float(1)), 380 return_(const_float(2))) 381 )) 382 expected_sexp = make_test_case('sub', 'float', ( 383 declare_execute_flag() + 384 declare_return_value() + 385 declare_return_flag() + 386 simple_if('a', lowered_return(const_float(1)), 387 lowered_return(const_float(2))) + 388 final_return() 389 )) 390 yield create_test_case( 391 input_sexp, expected_sexp, 'lower_returns_4', lower_sub_return=True) 392 393def test_lower_unified_returns(): 394 """If both branches of an if statement end in a return, and pull_out_jumps 395 is True, then those returns should be lifted outside the if and then 396 properly lowered. 397 398 Verify that this lowering occurs during the same pass as the lowering of 399 other returns by checking that extra temporary variables aren't generated. 400 """ 401 input_sexp = make_test_case('main', 'void', ( 402 complex_if('a', return_()) + 403 simple_if('b', simple_if('c', return_(), return_())) 404 )) 405 expected_sexp = make_test_case('main', 'void', ( 406 declare_execute_flag() + 407 declare_return_flag() + 408 complex_if('a', lowered_return()) + 409 if_execute_flag(simple_if('b', (simple_if('c', [], []) + 410 lowered_return()))) 411 )) 412 yield create_test_case( 413 input_sexp, expected_sexp, 'lower_unified_returns', 414 lower_main_return=True, pull_out_jumps=True) 415 416def test_lower_pulled_out_jump(): 417 doc_string = """If one branch of an if ends in a jump, and control cannot 418 fall out the bottom of the other branch, and pull_out_jumps is 419 True, then the jump is lifted outside the if. 420 421 Verify that this lowering occurs during the same pass as the 422 lowering of other jumps by checking that extra temporary 423 variables aren't generated. 424 """ 425 input_sexp = make_test_case('main', 'void', ( 426 complex_if('a', return_()) + 427 loop(simple_if('b', simple_if('c', break_(), continue_()), 428 return_())) + 429 assign_x('d', const_float(1)) 430 )) 431 # Note: optimization produces two other effects: the break 432 # gets lifted out of the if statements, and the code after the 433 # loop gets guarded so that it only executes if the return 434 # flag is clear. 435 expected_sexp = make_test_case('main', 'void', ( 436 declare_execute_flag() + 437 declare_return_flag() + 438 complex_if('a', lowered_return()) + 439 if_execute_flag( 440 loop(simple_if('b', simple_if('c', [], continue_()), 441 lowered_return_simple()) + 442 break_()) + 443 444 if_return_flag(assign_x('return_flag', const_bool(1)) + 445 assign_x('execute_flag', const_bool(0)), 446 assign_x('d', const_float(1)))) 447 )) 448 yield create_test_case( 449 input_sexp, expected_sexp, 'lower_pulled_out_jump', 450 lower_main_return=True, pull_out_jumps=True) 451 452def test_lower_breaks_1(): 453 """If a loop contains an unconditional break at the bottom of it, it should 454 not be lowered. 455 """ 456 input_sexp = make_test_case('main', 'void', ( 457 loop(assign_x('a', const_float(1)) + 458 break_()) 459 )) 460 expected_sexp = input_sexp 461 yield create_test_case( 462 input_sexp, expected_sexp, 'lower_breaks_1', lower_break=True) 463 464def test_lower_breaks_2(): 465 """If a loop contains a conditional break at the bottom of it, it should 466 not be lowered if it is in the then-clause. 467 """ 468 input_sexp = make_test_case('main', 'void', ( 469 loop(assign_x('a', const_float(1)) + 470 simple_if('b', break_())) 471 )) 472 expected_sexp = input_sexp 473 yield create_test_case( 474 input_sexp, expected_sexp, 'lower_breaks_2', lower_break=True) 475 476def test_lower_breaks_3(): 477 """If a loop contains a conditional break at the bottom of it, it should 478 not be lowered if it is in the then-clause, even if there are statements 479 preceding the break. 480 """ 481 input_sexp = make_test_case('main', 'void', ( 482 loop(assign_x('a', const_float(1)) + 483 simple_if('b', (assign_x('c', const_float(1)) + 484 break_()))) 485 )) 486 expected_sexp = input_sexp 487 yield create_test_case( 488 input_sexp, expected_sexp, 'lower_breaks_3', lower_break=True) 489 490def test_lower_breaks_4(): 491 """If a loop contains a conditional break at the bottom of it, it should 492 not be lowered if it is in the else-clause. 493 """ 494 input_sexp = make_test_case('main', 'void', ( 495 loop(assign_x('a', const_float(1)) + 496 simple_if('b', [], break_())) 497 )) 498 expected_sexp = input_sexp 499 yield create_test_case( 500 input_sexp, expected_sexp, 'lower_breaks_4', lower_break=True) 501 502def test_lower_breaks_5(): 503 """If a loop contains a conditional break at the bottom of it, it should 504 not be lowered if it is in the else-clause, even if there are statements 505 preceding the break. 506 """ 507 input_sexp = make_test_case('main', 'void', ( 508 loop(assign_x('a', const_float(1)) + 509 simple_if('b', [], (assign_x('c', const_float(1)) + 510 break_()))) 511 )) 512 expected_sexp = input_sexp 513 yield create_test_case( 514 input_sexp, expected_sexp, 'lower_breaks_5', lower_break=True) 515 516def test_lower_breaks_6(): 517 """If a loop contains conditional breaks and continues, and ends in an 518 unconditional break, then the unconditional break needs to be lowered, 519 because it will no longer be at the end of the loop after the final break 520 is added. 521 """ 522 input_sexp = make_test_case('main', 'void', ( 523 loop(simple_if('a', (complex_if('b', continue_()) + 524 complex_if('c', break_()))) + 525 break_()) 526 )) 527 expected_sexp = make_test_case('main', 'void', ( 528 declare_break_flag() + 529 loop(declare_execute_flag() + 530 simple_if( 531 'a', 532 (complex_if('b', lowered_continue()) + 533 if_execute_flag( 534 complex_if('c', lowered_break())))) + 535 if_execute_flag(lowered_break_simple()) + 536 final_break()) 537 )) 538 yield create_test_case( 539 input_sexp, expected_sexp, 'lower_breaks_6', lower_break=True, 540 lower_continue=True) 541 542def test_lower_guarded_conditional_break(): 543 """Normally a conditional break at the end of a loop isn't lowered, however 544 if the conditional break gets placed inside an if(execute_flag) because of 545 earlier lowering of continues, then the break needs to be lowered. 546 """ 547 input_sexp = make_test_case('main', 'void', ( 548 loop(complex_if('a', continue_()) + 549 simple_if('b', break_())) 550 )) 551 expected_sexp = make_test_case('main', 'void', ( 552 declare_break_flag() + 553 loop(declare_execute_flag() + 554 complex_if('a', lowered_continue()) + 555 if_execute_flag(simple_if('b', lowered_break())) + 556 final_break()) 557 )) 558 yield create_test_case( 559 input_sexp, expected_sexp, 'lower_guarded_conditional_break', 560 lower_break=True, lower_continue=True) 561 562def test_remove_continue_at_end_of_loop(): 563 """Test that a redundant continue-statement at the end of a loop is 564 removed. 565 """ 566 input_sexp = make_test_case('main', 'void', ( 567 loop(assign_x('a', const_float(1)) + 568 continue_()) 569 )) 570 expected_sexp = make_test_case('main', 'void', ( 571 loop(assign_x('a', const_float(1))) 572 )) 573 yield create_test_case(input_sexp, expected_sexp, 'remove_continue_at_end_of_loop') 574 575def test_lower_return_void_at_end_of_loop(): 576 """Test that a return of void at the end of a loop is properly lowered.""" 577 input_sexp = make_test_case('main', 'void', ( 578 loop(assign_x('a', const_float(1)) + 579 return_()) + 580 assign_x('b', const_float(2)) 581 )) 582 expected_sexp = make_test_case('main', 'void', ( 583 declare_execute_flag() + 584 declare_return_flag() + 585 loop(assign_x('a', const_float(1)) + 586 lowered_return_simple() + 587 break_()) + 588 if_return_flag(assign_x('return_flag', const_bool(1)) + 589 assign_x('execute_flag', const_bool(0)), 590 assign_x('b', const_float(2))) 591 )) 592 yield create_test_case( 593 input_sexp, input_sexp, 'return_void_at_end_of_loop_lower_nothing') 594 yield create_test_case( 595 input_sexp, expected_sexp, 'return_void_at_end_of_loop_lower_return', 596 lower_main_return=True) 597 yield create_test_case( 598 input_sexp, expected_sexp, 599 'return_void_at_end_of_loop_lower_return_and_break', 600 lower_main_return=True, lower_break=True) 601 602def test_lower_return_non_void_at_end_of_loop(): 603 """Test that a non-void return at the end of a loop is properly lowered.""" 604 input_sexp = make_test_case('sub', 'float', ( 605 loop(assign_x('a', const_float(1)) + 606 return_(const_float(2))) + 607 assign_x('b', const_float(3)) + 608 return_(const_float(4)) 609 )) 610 expected_sexp = make_test_case('sub', 'float', ( 611 declare_execute_flag() + 612 declare_return_value() + 613 declare_return_flag() + 614 loop(assign_x('a', const_float(1)) + 615 lowered_return_simple(const_float(2)) + 616 break_()) + 617 if_return_flag(assign_x('return_value', '(var_ref return_value)') + 618 assign_x('return_flag', const_bool(1)) + 619 assign_x('execute_flag', const_bool(0)), 620 assign_x('b', const_float(3)) + 621 lowered_return(const_float(4))) + 622 final_return() 623 )) 624 yield create_test_case( 625 input_sexp, input_sexp, 'return_non_void_at_end_of_loop_lower_nothing') 626 yield create_test_case( 627 input_sexp, expected_sexp, 628 'return_non_void_at_end_of_loop_lower_return', lower_sub_return=True) 629 yield create_test_case( 630 input_sexp, expected_sexp, 631 'return_non_void_at_end_of_loop_lower_return_and_break', 632 lower_sub_return=True, lower_break=True) 633 634CASES = [ 635 test_lower_breaks_1, test_lower_breaks_2, test_lower_breaks_3, 636 test_lower_breaks_4, test_lower_breaks_5, test_lower_breaks_6, 637 test_lower_guarded_conditional_break, test_lower_pulled_out_jump, 638 test_lower_return_non_void_at_end_of_loop, 639 test_lower_return_void_at_end_of_loop, 640 test_lower_returns_1, test_lower_returns_2, test_lower_returns_3, 641 test_lower_returns_4, test_lower_returns_main, test_lower_returns_sub, 642 test_lower_unified_returns, test_remove_continue_at_end_of_loop, 643] 644