1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""NN model compiler 18 19Compile models and examples into NDK-based CTS unit tests 20""" 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25import argparse 26from functools import reduce 27import math 28import os 29import struct 30import sys 31import contextlib 32import pprint 33 34@contextlib.contextmanager 35def smart_open(filename=None): 36 if filename and filename != '-': 37 fh = open(filename, 'w') 38 else: 39 fh = sys.stdout 40 41 try: 42 yield fh 43 finally: 44 if fh is not sys.stdout: 45 fh.close() 46 47class Phase(object): 48 def __init__(self): 49 self.__objects = [] 50 self.__contents = [] 51 self.__dict_of_objects = {} 52 53 def append(self, obj, x): 54 self.__objects.append(obj) 55 self.__contents.append(x) 56 self.__dict_of_objects[obj.ID()] = obj 57 58 def dump(self, filename): 59 for x in self.__contents: 60 print (" " + x + ";", file=filename) 61 62 def objects(self): 63 return self.__objects 64 65 def search(self, i): 66 return self.__dict_of_objects[i] 67 68# Tracking objects inside a model with a not necessarily unique name and 69# an unique number 70class NamedObject(object): 71 __serial = 0 72 73 def __init__(self, name = "NamedObject"): 74 self.__name = name 75 self.__id = NamedObject.serial() 76 NamedObject.__serial += 1 77 78 def ID(self): 79 return self.__id 80 81 def serial(): 82 return NamedObject.__serial 83 84 def get_name(self): 85 return self.__name 86 87 def __str__(self): 88 return self.get_name() 89 90 def __hash__(self): 91 return self.__id 92 93# Object that can be traversed during topological sorting phase 94class Traversable(object): 95 def traversable(self): 96 return True 97 98class Nontraversable(object): 99 def traversable(self): 100 return False 101 102# Object that can take input from other objects 103class Uses(object): 104 all_uses = set() 105 def __init__(self, ins = []): 106 self.ins = ins.copy() 107 Uses.all_uses.add(self) 108 for i in ins: 109 i.outs.append(self) 110 111# Object that other objects takes its definition from 112class Definitions(object): 113 def __init__(self, outs = []): 114 self.outs = outs.copy() 115 for o in outs: 116 o.ins.append(self) 117 118class TypeLookup: 119 __type_lookup = { 120 "INT32": "int32_t", 121 "UINT32": "uint32_t", 122 "FLOAT32": "float", 123 "TENSOR_INT32": "int32_t", 124 "TENSOR_FLOAT32": "float", 125 "TENSOR_QUANT8_ASYMM": "uint8_t", 126# "OEM_SCALAR": this is service-defined. 127 "TENSOR_OEM_BYTE": "uint8_t", 128 } 129 130 def get_cpptype(nnapi_type): 131 return TypeLookup.__type_lookup[nnapi_type] 132 133 def is_float(nnapi_type): 134 return TypeLookup.get_cpptype(nnapi_type) == "float" 135 136 def get_size(nnapi_type): 137 return 1 if TypeLookup.get_cpptype(nnapi_type) == "uint8_t" else 4 138 139 140class Type(object): 141 __types = {} 142 __type_serial = 0 # types have their own numbering 143 def __init__(self, vt = None, shape = None): 144 self.__vt = vt 145 self.__shape = shape 146 if vt is None or shape is None: 147 self.__name = None 148 return 149 150 key = str(self) 151 if key not in Type.__types: 152 self.__id = Type.__type_serial 153 Type.__types[str(self)] = self 154 Type.__type_serial += 1 155 else: 156 self.__id = Type.__types[key].__id 157 self.__name = "type" + str(self.__id) 158 159 def get_shape(self): 160 return self.__shape 161 162 def get_element_type(self): 163 return self.__vt 164 165 def get_name(self): 166 return self.__name 167 168 def __str__(self): 169 return (", ".join([self.__vt, self.__shape])) 170 171 def __hash__(self): 172 return self.__id 173 174 def dump(filename): 175 for key, value in sorted(Type.__types.items()): 176 print (" OperandType " + str(value.__name) + "(Type::" + str(key) + ");", file=filename) 177 178 def get_raw_shape(self): 179 return self.__shape 180 181 def get_parsed_shape(self): 182 # Parse shape 183 if (self.__shape != "" and self.__shape != "{}"): 184 left, sep, right = self.__shape.partition('{') 185 real_shape, sep, right = right.partition('}') 186 shape = [int(x) for x in real_shape.split(",")] 187 # left now looks like "0.0f, 127.5f, " 188 scale, sep, zero_point = right.rpartition(',') 189 if scale == "": 190 if zero_point == "": 191 return real_shape, "0", "0" 192 return real_shape, zero_point, "0" 193 left, sep, scale = scale.partition(',') 194 return real_shape, scale.replace("f", ""), zero_point 195 else: 196 return "", "0", "0" 197 198 def get_nr_elements(self): 199 # Parse shape 200 nr_elements = 1 201 real_shape, scale, zero_point = self.get_parsed_shape() 202 203 if (real_shape != "" and real_shape != "{}"): 204 shape = [int(x) for x in real_shape.split(",")] 205 nr_elements = reduce((lambda x, y: x*y), shape) 206 return nr_elements 207 208 def get_size(self): 209 element_size = TypeLookup.get_size(self.__vt) 210 return self.get_nr_elements() * element_size 211 212# A value is a typed, named object 213class Value(NamedObject): 214 def __init__(self, name, vt): 215 NamedObject.__init__(self, name) 216 self.type = vt 217 218# An operand that can be fed into operations. Also, an operand is always 219# declared before operations. 220class Operand(Value): 221 # All operand declarations in string 222 operands = Phase() 223 224 def __init__(self, name, vt): 225 Value.__init__(self, name, vt) 226 def_string = ( 227 "auto " + self.get_name() + " = "\ 228 "model->addOperand(&" + vt.get_name() + ")") 229 Operand.operands.append(self, def_string) 230 231 # By default, produce nothing (when asked by the Topological Sort phase) 232 def Definition(self): 233 pass 234 235 def Reference(self): 236 return NamedObject.__str__(self) 237 238 # Print a set of operands in curly braces 239 def print_operands(operands): 240 return [ x.Reference() for x in operands ] 241 242 # Defined with the model or not 243 def is_weight(self): 244 return False 245 246# A user-declared input operand 247class Input(Operand, Definitions, Traversable): 248 # for enumerating inputs 249 __next_number = 0 250 # Holds reference to all Inputs; used by Topoligcal sort as starting nodes. 251 __inputs = set() 252 253 def __init__(self, name, vt, shape, increase_next_number=True): 254 Operand.__init__(self, name, Type(vt, shape)) 255 Definitions.__init__(self) 256 Input.__inputs.add(self) 257 self.number = Input.__next_number 258 if increase_next_number is True: 259 Input.__next_number += 1 260 261 def lifetime(self): 262 return "MODEL_INPUT" 263 264 def is_internal(self): 265 return False 266 267 def get_inputs(exclude_internal = None): 268 if exclude_internal is not None: 269 external = { x for x in Input.__inputs if not x.is_internal() } 270 return external 271 else: 272 return Input.__inputs 273 274# A user-declared output operand 275class Output(Operand, Uses, Nontraversable): 276 # for enumerating outputs 277 __next_number = 0 278 __outputs = [] 279 280 def __init__(self, name, vt, shape): 281 Operand.__init__(self, name, Type(vt, shape)) 282 Uses.__init__(self) 283 Output.__outputs.append(self) 284 self.number = Output.__next_number 285 Output.__next_number += 1 286 287 def lifetime(self): 288 return "MODEL_OUTPUT" 289 290 # return all unique outputs in the original order 291 def get_outputs(): 292 saw = set() 293 unique = [x for x in Output.__outputs if x not in saw and (saw.add(x) or True)] 294 return unique 295 296# An output that we don't want to compare the results 297class IgnoredOutput(Output): 298 __ignored = set() 299 def __init__(self, name, vt, shape): 300 Output.__init__(self, name, vt, shape) 301 IgnoredOutput.__ignored.add(self) 302 def gen_ignored(): 303 ignored_func = """ 304bool is_ignored(int i) { 305 static std::set<int> ignore = {%s}; 306 return ignore.find(i) != ignore.end(); 307}""" % ", ".join([str(x.number) for x in IgnoredOutput.__ignored]) 308 return ignored_func 309 310class ModelArgument: 311 __arguments = [] 312 313 def __init__(self, arg_type, arg_name): 314 self.__arg_type = arg_type 315 self.__arg_name = arg_name 316 ModelArgument.__arguments.append(" ".join([arg_type, arg_name])) 317 318 def get_arg_type(self): 319 return self.__arg_type 320 321 def get_arg_name(self): 322 return self.__arg_name 323 324 def get_arguments(): 325 return ModelArgument.__arguments 326 327 def lifetime(self): 328 return "CONSTANT_COPY" 329 330# Print in C float literal format 331def pretty_print_as_float(x): 332 s = str(float(x)) 333 if s.find(".") >= 0 or s.find("e") >= 0: 334 return s + "f" 335 else: 336 return s + ".0f" 337 338class Parameter(Input): 339 # TODO seems wrong that's an Input. 340 def __init__(self, name, vt, shape, initializer): 341 Input.__init__(self, name, vt, shape, False) 342 self.initializer = initializer 343 self.cpptype = TypeLookup.get_cpptype(vt) 344 def is_internal(self): 345 return True 346 def Definition(self): 347 init_name = self.get_name() + "_init" 348 initializer = [str(x) for x in self.initializer] 349 if self.cpptype == "float": 350 initializer = [ pretty_print_as_float(x) for x in initializer] 351 init = self.cpptype + " " + init_name + "[]" 352 init = "static " + init + " = {" + ", ".join(initializer) + "};" 353 args = [ self.get_name(), init_name, 354 "sizeof(" + self.cpptype + ") * " + str(len(self.initializer)) ] 355 stmt = "\n ".join([init, 356 "model->setOperandValue(" + ", ".join(args)+");"]) 357 return stmt 358 def is_weight(self): 359 return True 360 def lifetime(self): 361 if Configuration.useSHM(): 362 return "CONSTANT_REFERENCE" 363 else: 364 return "CONSTANT_COPY" 365 366class Int32Scalar(Parameter): 367 def __init__(self, name, value): 368 Parameter.__init__(self, name, "INT32", "{}", [value]) 369 370class Float32Scalar(Parameter): 371 def __init__(self, name, value): 372 Parameter.__init__(self, name, "FLOAT32", "{}", [value]) 373 374# A compiler-generated intermediate result from an operation 375class IntermediateResult(Operand, Definitions, Uses, Traversable): 376 def __init__(self, src: Value): 377 tmp_name = "tmp" + str(NamedObject.serial()) 378 Operand.__init__(self, tmp_name, src.type) 379 Definitions.__init__(self) 380 Uses.__init__(self, [src]) 381 382 def lifetime(self): 383 return "TEMPORARY_VARIABLE" 384 385# An explicitly declared intermediate result 386class Internal(Operand, Definitions, Uses, Traversable): 387 def __init__(self, name, vt, shape): 388 Operand.__init__(self, name, Type(vt, shape)) 389 Definitions.__init__(self) 390 Uses.__init__(self) 391 392 def lifetime(self): 393 return "TEMPORARY_VARIABLE" 394 395# An operation in a model 396class Operation(Definitions, Uses, Traversable): 397 def __init__(self, optype, ins, outs): 398 self.type = ins[0].type 399 Definitions.__init__(self, outs) 400 Uses.__init__(self, ins) 401 self.optype = optype 402 403 def __str__(self): 404 inputs = [ str(x) for x in self.ins ] 405 return "Operation:" + self.optype + " " + ", ".join(inputs) 406 407 def Reference(self): 408 return "operation" + str(self.ID()); 409 410 def Definition(self): 411 inputs = Operand.print_operands(self.ins); 412 outputs = Operand.print_operands(self.outs); 413 return "model->addOperation(ANEURALNETWORKS_"+self.optype+", " + \ 414 "{"+", ".join(inputs)+"}, {" + ", ".join(outputs) + "});" 415 416 # Get Python-ish dump for the op 417 def PyDefinition(self): 418 py_op_string = """Operation("{optype}", {inputs}).To({outputs})""" 419 inputs = [str(x) for x in Operand.print_operands(self.ins)] 420 inputs = ", ".join(inputs) 421 assert len(self.outs) <= 1 422 outputs = str(Operand.print_operands(self.outs)[0]) 423 ops = {"optype": self.optype, "inputs": inputs, "outputs": outputs} 424 return py_op_string.format(**ops) 425 426# Main interface 427class Model(object): 428 __isRelaxed = False 429 430 def __init__(self): 431 self.__currentOp = None 432 433 # TODO turn this into generic binary operations 434 def Add(self, i1: Value, i2 = None) -> Operation: 435 ins = [i1] 436 if i2 is not None: 437 ins.append(i2) 438 if self.__currentOp is not None: 439 ir = IntermediateResult(self.__currentOp) 440 self.__currentOp = ir 441 ins.append(self.__currentOp) 442 443 op = Operation("ADD", ins, []) 444 445 self.__currentOp = op 446 return self 447 448 def Operation(self, op_name, *args): 449 ins = [i for i in args] 450 outs = [] 451 op = Operation(op_name, ins, outs) 452 self.__currentOp = op 453 return self 454 455 def RawAdd(self, i1: Value, i2: Value, o = None) -> Operation: 456 ins = [i1, i2] 457 outs = [] 458 if o is not None: 459 outs = [o] 460 op = Operation("ADD", ins, outs) 461 462 self.__currentOp = op 463 return self 464 465 # See CpuExecutor::executeOperation() for the arguments of each op 466 def AveragePool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation): 467 ins = [input, padding, stride_width, 468 stride_height, filter_width, filter_height, activation] 469 outs = [] 470 op = Operation("AVERAGE_POOL_2D", ins, outs) 471 self.__currentOp = op 472 return self 473 474 def Concatenation(self, *args): 475 ins = [i for i in args] 476 outs = [] 477 op = Operation("CONCATENATION", ins, outs) 478 self.__currentOp = op 479 return self 480 481 def Conv(self, filter, bias, input, padding, stride_width, stride_height, activation): 482 ins = [filter, bias, input, padding, stride_width, 483 stride_height, activation] 484 outs = [] 485 op = Operation("CONV_2D", ins, outs) 486 self.__currentOp = op 487 return self 488 489 def DepthWiseConv(self, filter, bias, input, padding, stride_width, stride_height, depth_multiplier, activation): 490 ins = [filter, bias, input, padding, stride_width, 491 stride_height, depth_multiplier, activation] 492 outs = [] 493 op = Operation("DEPTHWISE_CONV_2D", ins, outs) 494 self.__currentOp = op 495 return self 496 497 def FullyConnected(self, input, weights, bias, activation): 498 ins = [input, weights, bias, activation] 499 outs = [] 500 op = Operation("FULLY_CONNECTED", ins, outs) 501 self.__currentOp = op 502 return self 503 504 def Logistic(self, input): 505 ins = [input] 506 outs = [] 507 op = Operation("LOGISTIC", ins, outs) 508 self.__currentOp = op 509 return self 510 511 def L2Pool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation): 512 ins = [input, padding, stride_width, 513 stride_height, filter_width, filter_height, activation] 514 outs = [] 515 op = Operation("L2_POOL_2D", ins, outs) 516 self.__currentOp = op 517 return self 518 519 def MaxPool(self, input, padding, stride_width, stride_height, filter_width, filter_height, activation): 520 ins = [input, padding, stride_width, 521 stride_height, filter_width, filter_height, activation] 522 outs = [] 523 op = Operation("MAX_POOL_2D", ins, outs) 524 self.__currentOp = op 525 return self 526 527 def SoftMax(self, input, beta): 528 ins = [input, beta] 529 outs = [] 530 op = Operation("SOFTMAX", ins, outs) 531 self.__currentOp = op 532 return self 533 534 def Reshape(self, input, shape): 535 ins = [input, shape] 536 outs = [] 537 op = Operation("RESHAPE", ins, outs) 538 self.__currentOp = op 539 return self 540 541 def Out(self, o): 542 if (type(o) is list or type(o) is tuple): 543 for i in o: 544 self.__currentOp.outs.append(i) 545 i.ins.append(self.__currentOp) 546 else: 547 self.__currentOp.outs.append(o) 548 o.ins.append(self.__currentOp) 549 return self 550 551 def To(self, o:Value): 552 ret = Model.Out(self, o) 553 self.__currentOp = None 554 return self 555 556 def RelaxedExecution(self, isRelaxed): 557 Model.__isRelaxed = isRelaxed 558 return self 559 560 def isRelaxed(): 561 return Model.__isRelaxed 562 563 564class FileNames: 565 SpecFile = "" 566 567class Example(): 568 __examples = [] 569 def __init__(self, list_of_examples): 570 Example.__examples.append(list_of_examples) 571 572 def dump_dict(d): 573 ret = [] 574 for k, v in d.items(): 575 key = str(k) 576 suffix = "f" 577 if type(k) is not int: 578 key = str(k.number) 579 if not TypeLookup.is_float(k.type.get_element_type()): 580 suffix = "" 581 init = ", ".join( 582 [str(i) + (suffix if str(i).find(".") != -1 else "") for i in v]) 583 ret.append("{%s, {%s}}" % (key, init)) 584 return ", ".join(ret) 585 586 def dump_mixed_types(d): 587 ret = [] 588 589 float32_dict = {} 590 int32_dict = {} 591 uint8_dict = {} 592 593 for k, v in d.items(): 594 key_id = k.ID() if type(k) is not int else k 595 ty = Operand.operands.search(key_id).type.get_element_type() 596 # find out type of the operand addressed by the key 597 if (ty == "TENSOR_FLOAT32"): 598 float32_dict[k] = v 599 elif (ty == "TENSOR_INT32"): 600 int32_dict[k] = v 601 elif (ty == "TENSOR_OEM_BYTE"): 602 uint8_dict[k] = v 603 elif (ty == "TENSOR_QUANT8_ASYMM"): 604 uint8_dict[k] = v 605 else: 606 print ("Unhandled type %s"%ty, file = sys.stderr) 607 assert 0 and "unsupported example type" 608 609 tuple_init = """\ 610{{ // See tools/test_generator/include/TestHarness.h:MixedTyped 611 // int -> FLOAT32 map 612 {{{float32_dict}}}, 613 // int -> INT32 map 614 {{{int32_dict}}}, 615 // int -> QUANT8_ASYMM map 616 {{{uint8_dict}}} 617}}""" 618 tuple_contents = { 619 'float32_dict': Example.dump_dict(float32_dict), 620 'int32_dict': Example.dump_dict(int32_dict), 621 'uint8_dict': Example.dump_dict(uint8_dict) 622 } 623 return tuple_init.format(**tuple_contents) 624 625 626 def dump(example_file): 627 if len(Example.__examples) > 0: 628 spec_file = " (from: %s)" % (FileNames.SpecFile) 629 print ('// Generated file%s. Do not edit' % (spec_file), 630 file = example_file) 631 for i, o in Example.__examples: 632 print ('// Begin of an example', file = example_file) 633 print ('{', file = example_file) 634 inputs = Example.dump_mixed_types(i) 635 outputs = Example.dump_mixed_types(o) 636 print ('//Input(s)\n%s,' % inputs , file = example_file) 637 print ('//Output(s)\n%s' % outputs, file = example_file) 638 print ('}, // End of an example', file = example_file) 639 640 # Similar to dump_dict, but in python. Used by the slicing tool 641 # if referenced is not None, only print operands that are present there 642 def py_dump_dict(d, referenced): 643 ret = [] 644 for k, v in d.items(): 645 if referenced != None and k not in referenced: 646 continue 647 key = str(k) 648 init = pprint.pformat(v) 649 ret.append("%s: %s" % (key, init)) 650 return ", ".join(ret) 651 652 # similar to dump, but in python. Used by the slicing tool 653 # if referenced is not None, only print operands that are present there 654 def py_dump(example_file, override, referenced): 655 if len(Example.__examples) > 0: 656 example_no = 0 657 example_template = """\ 658input{no} = {{{inputs}}} 659# Only executed during data collection phase 660if collecting_data is True: 661 Example((input{no}, {{{outputs}}})) 662""" 663 for i, o in Example.__examples: 664 print ('# Begin of an example', file = example_file) 665 inputs = Example.py_dump_dict(i, referenced) 666 output_list = [] 667 for k, v in override.items(): 668 output_list.append("%s: [0] * %d" % (k, v)) 669 outputs = ",".join(output_list) 670 671 # TODO: handle >1 outputs 672 for k, v in o.items(): 673 assert k.number == 0 674 example_contents = { 675 'no': example_no, 676 'inputs': inputs, 677 'outputs': outputs 678 } 679 print (example_template.format(**example_contents), file = example_file) 680 681 682def TopologicalSort(format_op): 683 start = Input.get_inputs().copy() 684 deps = { x: set(x.ins) for x in Uses.all_uses } 685 686 while len(start) > 0: 687 cur = start.pop() 688 if format_op(cur) is False: 689 return 690 distinct_outs = set(cur.outs) 691 for o in distinct_outs: 692 deps[o].remove(cur) 693 if len(deps[o]) == 0 and o.traversable(): 694 start.add(o) 695 696class Configuration: 697 use_shm_for_weights = False 698 def useSHM(): 699 return Configuration.use_shm_for_weights 700 701# Take a model from command line 702def import_source(): 703 parser = argparse.ArgumentParser() 704 parser.add_argument("spec", help="the spec file") 705 parser.add_argument( 706 "-m", "--model", help="the output model file", default="-") 707 parser.add_argument( 708 "-e", "--example", help="the output example file", default="-") 709 args = parser.parse_args() 710 711 if os.path.exists(args.spec): 712 FileNames.SpecFile = os.path.basename(args.spec) 713 exec (open(args.spec).read()) 714 715 return (args.model, args.example) 716 717 718def print_cts_op(model_file, op): 719 fmt = op.Definition() 720 if fmt is not None: 721 print (" %s" % fmt, file = model_file) 722 return True 723 724if __name__ == '__main__': 725 (model, example) = import_source() 726 # Boilerplate 727 args = "" 728 if len(ModelArgument.get_arguments()) > 0: 729 args = ", " + ", ".join(ModelArgument.get_arguments()) 730 731 print("Output CTS model: %s" % model, file=sys.stderr) 732 print("Output example:" + example, file=sys.stderr) 733 734 with smart_open(model) as model_file: 735 spec_file = " (from: %s)" % (FileNames.SpecFile) 736 737 print ('// Generated file%s. Do not edit'%(spec_file), file = model_file) 738 print ("void CreateModel(Model *model" + args + ") {", file=model_file) 739 740 # Phase 0: types 741 Type.dump(model_file) 742 # Phase 1: add operands 743 print (" // Phase 1, operands", file=model_file) 744 Operand.operands.dump(model_file) 745 746 # Phase 2: operations 747 print (" // Phase 2, operations", file=model_file) 748 TopologicalSort(lambda x: print_cts_op(model_file, x)) 749 750 # Phase 3: add inputs and outputs 751 print (" // Phase 3, inputs and outputs", file=model_file) 752 inputs = Operand.print_operands(Input.get_inputs(True)); 753 outputs = Operand.print_operands(Output.get_outputs()); 754 print (" model->identifyInputsAndOutputs(\n" + 755 " {"+", ".join(inputs)+"},\n {" + ", ".join(outputs) + "});", 756 file=model_file) 757 758 # Phase 4: set relaxed execution if needed 759 if (Model.isRelaxed()): 760 print (" // Phase 4: set relaxed execution", file=model_file) 761 print (" model->relaxComputationFloat32toFloat16(true);", file=model_file) 762 763 # Boilerplate 764 print (" assert(model->isValid());", file=model_file); 765 print ("}", file=model_file) 766 print (IgnoredOutput.gen_ignored(), file=model_file) 767 768 with smart_open(example) as example_file: 769 Example.dump(example_file) 770