1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""NN model compiler 18 19Contain classes definition and utilify functions for compiling models and 20examples into NDK-based CTS and VTS unit tests. 21 22Used by example_generator.py and spec_visualizer.py 23""" 24 25from __future__ import absolute_import 26from __future__ import division 27from __future__ import print_function 28import copy 29from functools import reduce 30import argparse 31import io 32import itertools 33import os 34import re 35import sys 36import traceback 37import numpy as np 38 39def GetJointStr(l, sep=", ", method=str): 40 return sep.join([method(i) for i in l]) 41 42# Print in C float literal format 43def PrettyPrintAsFloat(x): 44 s = str(float(x)) 45 if s.find(".") >= 0 or s.find("e") >= 0: 46 return s + "f" 47 else: 48 return s + ".0f" 49 50# Transform from original type to float32 51def Dequantize(v, ty): 52 v -= ty.zeroPoint 53 if ty.scale != 0: 54 v *= ty.scale 55 if isinstance(ty.extraParams, SymmPerChannelQuantParams): 56 v *= ty.extraParams.GetScalesBroadcastArray(ty.dimensions) 57 return v 58 59# Transform float32 to target data type 60def Quantize(v, ty): 61 if ty.scale != 0: 62 v /= ty.scale 63 if isinstance(ty.extraParams, SymmPerChannelQuantParams): 64 v = v / ty.extraParams.GetScalesBroadcastArray(ty.dimensions) 65 v += ty.zeroPoint 66 if not ty.IsFloat(): 67 v = np.round(v) 68 v = v.astype(int) 69 70 if ty.type == "TENSOR_QUANT8_ASYMM": 71 v = np.minimum(np.maximum(v, 0), 255) 72 elif ty.type == "TENSOR_QUANT16_ASYMM": 73 v = np.minimum(np.maximum(v, 0), 65535) 74 elif ty.type == "TENSOR_QUANT8_SYMM_PER_CHANNEL": 75 v = np.minimum(np.maximum(v, -127), 127) 76 elif ty.type == "UINT32": 77 v = np.maximum(v, 0) 78 elif ty.type == "TENSOR_QUANT8_ASYMM_SIGNED": 79 v = np.minimum(np.maximum(v, -128), 127) 80 return v 81 82# Tracking objects inside a model with a unique name 83class NamedObject: 84 existingNames = set() 85 86 def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False): 87 name = GetJointStr([i for i in args if i is not None and i != ""], sep=sep) 88 if skipRenaming: 89 self.name = name 90 return 91 # make the name unique by renaming with a suffix number 92 uniqueName = name if showZero is False else name + sep + str(startsFrom) 93 while uniqueName in self.__class__.existingNames: 94 startsFrom += 1 95 uniqueName = name + sep + str(startsFrom) 96 self.__class__.existingNames.add(uniqueName) 97 self.name = uniqueName 98 99 def __str__(self): 100 return self.name 101 __repr__ = __str__ 102 103 # Since names are unique, objects with the same name are considered equal 104 def __eq__(self, other): 105 return isinstance(other, NamedObject) and self.name == other.name 106 107 def __ne__(self, other): 108 return not self.__eq__(other) 109 110 def __hash__(self): 111 return hash(self.name) 112 113 def __lt__(self, other): 114 return self.name < other.name 115 116# Types, operands should all have a unique name since they share the same namespace 117class NamedVariable(NamedObject): 118 existingNames = set() 119 def __init__(self, *args, sep="_", showZero=False, startsFrom=0, skipRenaming=False): 120 NamedObject.__init__(self, *args, sep=sep, showZero=showZero, 121 startsFrom=startsFrom, skipRenaming=skipRenaming) 122 123# Global variables in the spec namespace such as CreateModel, is_ignored, and examples 124class GlobalVariable(NamedVariable): 125 def __init__(self, *args, skipRenaming=False): 126 NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming) 127 128# Each test should have a unique name, but will not conflict with variables 129class NamedTest(NamedObject): 130 existingNames = set() 131 def __init__(self, *args, startsFrom=0, skipRenaming=False): 132 NamedObject.__init__(self, *args, startsFrom=1, skipRenaming=skipRenaming) 133 134class Type(NamedVariable): 135 typesMap = dict() 136 typeLookup = { 137 "INT32": "int32_t", 138 "UINT32": "uint32_t", 139 "FLOAT32": "float", 140 "FLOAT16": "_Float16", 141 "TENSOR_INT32": "int32_t", 142 "TENSOR_FLOAT16": "_Float16", 143 "TENSOR_FLOAT32": "float", 144 "TENSOR_QUANT8_ASYMM": "uint8_t", 145 "TENSOR_QUANT8_SYMM": "int8_t", 146 "BOOL": "bool8", 147 "TENSOR_QUANT16_ASYMM": "uint16_t", 148 "TENSOR_QUANT16_SYMM": "int16_t", 149 "TENSOR_BOOL8": "bool8", 150 "TENSOR_QUANT8_SYMM_PER_CHANNEL": "int8_t", 151 "TENSOR_QUANT8_ASYMM_SIGNED": "int8_t", 152 # "OEM_SCALAR": this is service-defined. 153 "TENSOR_OEM_BYTE": "uint8_t", 154 "SUBGRAPH": "uint32_t", # Index into TestModel::referenced. 155 } 156 157 # types are named as "type0", "type1", ... 158 def __init__(self, vt, dimensions, scale, zeroPoint, name="type", skipRenaming=False, 159 extraParams=None): 160 NamedVariable.__init__(self, name, sep="", showZero=True, skipRenaming=skipRenaming) 161 self.type = vt 162 self.dimensions = dimensions 163 self.scale = float(scale) 164 self.zeroPoint = int(zeroPoint) 165 self.extraParams = extraParams 166 167 # Factory for Type object, only create a new Type if requested type does 168 # not have a match with all existing types 169 @staticmethod 170 def GetType(vt, dimensions, scale=0, zeroPoint=0, extraParams=None): 171 assert isinstance(dimensions, (list, tuple)), \ 172 'dimensions must be a list or tuple, got {}'.format(type(dimensions)) 173 key = ",".join([vt, str(dimensions), str(scale), str(zeroPoint), str(extraParams)]) 174 if key not in Type.typesMap: 175 Type.typesMap[key] = Type(vt, dimensions, scale, zeroPoint, extraParams=extraParams) 176 return Type.typesMap[key] 177 178 @staticmethod 179 def GetAllTypes(): 180 # sort to ensure a stable order when dumping the code 181 return sorted(Type.typesMap.values()) 182 183 # For backward-compatibility 184 @staticmethod 185 def GetTypeFromString(vt, shape, extraParams=None): 186 dimensions, scale, zeroPoint = Type.GetParsedShape(shape) 187 scale = float(scale) 188 zeroPoint = int(zeroPoint) 189 return Type.GetType(vt, dimensions, scale, zeroPoint, extraParams) 190 191 # For backward-compatibility 192 @staticmethod 193 def GetParsedShape(shape): 194 # Parse shape 195 if (shape != "" and shape != "{}"): 196 left, sep, right = shape.partition('{') 197 real_shape, sep, right = right.partition('}') 198 shape = [int(x) for x in real_shape.split(",")] 199 # left now looks like "0.0f, 127.5f, " 200 scale, sep, zero_point = right.rpartition(',') 201 if scale == "": 202 if zero_point == "": 203 return shape, "0", "0" 204 return shape, zero_point, "0" 205 left, sep, scale = scale.partition(',') 206 return shape, scale.replace("f", ""), zero_point 207 else: 208 return [], "0", "0" 209 210 def GetNumberOfElements(self): 211 return reduce(lambda x,y: x*y, self.dimensions, 1) 212 213 def GetCppTypeString(self): 214 return Type.typeLookup[self.type] 215 216 def IsFloat(self): 217 return self.GetCppTypeString() in ["float", "_Float16"] 218 219 def IsBool(self): 220 return self.GetCppTypeString() == "bool8" 221 222 def IsScalar(self): 223 return not self.type.startswith("TENSOR_") 224 225 def GetElementByteSize(self): 226 cppTypeString = self.GetCppTypeString() 227 if cppTypeString in ["uint8_t", "int8_t", "bool8"]: 228 return 1 229 elif cppTypeString in ["int16_t", "uint16_t", "_Float16"]: 230 return 2 231 else: 232 return 4 233 234 def GetByteSize(self): 235 return self.GetElementByteSize() * self.GetNumberOfElements() 236 237 def GetDimensionsString(self): 238 return "{" + GetJointStr(self.dimensions) + "}" 239 240 def GetSignatureTuple(self): 241 return (self.type, self.dimensions, self.scale, self.zeroPoint) 242 243 def ToUnspecifiedDim(self): 244 return Type.GetType(self.type, [0] * len(self.dimensions), self.scale, self.zeroPoint) 245 246# To track implicitly convertible parameter types 247class ImplicitParameter(): 248 @staticmethod 249 def ImplicitConvertion(value): 250 if isinstance(value, Operand): 251 return value 252 for implicitType in ImplicitParameter.__subclasses__(): 253 if implicitType.IsCompatible(value): 254 return implicitType("param", value) 255 assert False, "%s not supported for implicit parameter"%value 256 257 258# ExtraParams with per-channel quantization. 259class SymmPerChannelQuantParams(): 260 def __init__(self, channelDim, scales, hide = False): 261 self.channelDim = channelDim 262 self.scales = scales 263 self.hide = hide 264 265 def GetScalesBroadcastArray(self, dimensions): 266 bshape = [1] * len(dimensions) 267 bshape[self.channelDim] = len(self.scales) 268 return np.array(self.scales).reshape(bshape) 269 270 def GetConstructor(self): 271 return "SymmPerChannelQuantParams({%s},%d)" % ( 272 ", ".join(str(x) + "f" for x in self.scales), self.channelDim) 273 274 def GetVtsSetter(self): 275 return "channelQuant" 276 277 def GetVtsConstructor(self): 278 return "SymmPerChannelQuantParams{.scales={%s}, .channelDim=%d}" % ( 279 ", ".join(str(x) + "f" for x in self.scales), self.channelDim) 280 281 282# An operand that can be fed into operations. Also, an operand is always 283# declared before operations. 284class Operand(NamedVariable): 285 286 def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None): 287 NamedVariable.__init__(self, name, sep="", skipRenaming=skipRenaming) 288 if type(opType) is str: 289 self.type = Type.GetTypeFromString(opType, value, extraParams) 290 value = backward 291 else: 292 self.type = Type.GetType(*opType, extraParams=extraParams) 293 self.SetValue(value) 294 self.lifetime = "TEMPORARY_VARIABLE" 295 self.model_index = None 296 self.ins = [] 297 self.outs = [] 298 self.mayBeInternal = True 299 300 def SetValue(self, value): 301 self.value = value if type(value) is list or type(value) is tuple or value is None \ 302 else [value] 303 return self 304 305 def SetValueFromNumpy(self, value): 306 self.value = value.flatten().tolist() 307 return self 308 309 def GetValueAsNumpy(self): 310 return np.array(self.value).reshape(self.type.dimensions) 311 312 # Print value as cpp-style list initialization 313 def GetListInitialization(self): 314 if self.value is None: 315 return "{}" 316 elif self.type.IsFloat(): 317 return "{%s}"%(GetJointStr(self.value, method=PrettyPrintAsFloat)) 318 elif self.type.IsBool(): 319 return "{%s}"%(GetJointStr(self.value, method=lambda v: "true" if v else "false")) 320 else: 321 return "{%s}"%(GetJointStr(self.value, method=lambda x: str(int(x)))) 322 323 def ToUnspecifiedDim(self): 324 self.dimensions = self.type.dimensions 325 self.type = self.type.ToUnspecifiedDim() 326 327 def ConvertTo(self, DerivedClass, name=None): 328 assert issubclass(DerivedClass, Operand) 329 name = self.name if name is None else name 330 newop = DerivedClass(name, self.type.GetSignatureTuple(), skipRenaming=True, 331 extraParams=self.type.extraParams) 332 if not issubclass(DerivedClass, Internal): 333 newop.SetValue(self.value) 334 if not self.mayBeInternal: 335 assert not issubclass(DerivedClass, Internal) 336 newop.ShouldNeverBeInternal() 337 return newop 338 339 def ShouldNeverBeInternal(self): 340 self.mayBeInternal = False 341 return self 342 343# Base class of user-defined input/output operand 344class InOut(Operand): 345 346 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 347 Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, extraParams=extraParams) 348 self.lifetime = "SUBGRAPH_INPUT" 349 self.index = 0 350 351 def Feed(self, value): 352 self.SetValue(value[self] if type(value) is dict else value) 353 return self 354 355# A user-declared input operand 356class Input(InOut): 357 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 358 InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 359 self.lifetime = "SUBGRAPH_INPUT" 360 361# A user-declared output operand 362class Output(InOut): 363 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 364 InOut.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 365 self.lifetime = "SUBGRAPH_OUTPUT" 366 367# An output that we don't want to compare the results 368class IgnoredOutput(Output): 369 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 370 Output.__init__(self, name, opType, backward, skipRenaming=skipRenaming, extraParams=extraParams) 371 self.lifetime = "SUBGRAPH_OUTPUT" 372 def Feed(self, value): 373 numElements = reduce(lambda x,y: x*y, self.type.dimensions, 1) 374 self.value = [0 for x in range(numElements)] 375 return self 376 377# An explicitly declared parameter 378class Parameter(Operand): 379 def __init__(self, name, opType, value, backward=None, skipRenaming=False, extraParams=None): 380 Operand.__init__(self, name, opType, value, backward, skipRenaming=skipRenaming, 381 extraParams=extraParams) 382 self.initializer = NamedVariable(str(self) + "_init") 383 if value is None: 384 self.lifetime = "NO_VALUE" 385 elif Configuration.useSHM(): 386 self.lifetime = "CONSTANT_REFERENCE" 387 else: 388 self.lifetime = "CONSTANT_COPY" 389 390# A shortcut for parameters of INT32 391class Int32Scalar(Parameter, ImplicitParameter): 392 def __init__(self, name, value): 393 Parameter.__init__(self, name, ("INT32", []), int(value)) 394 @staticmethod 395 def IsCompatible(value): 396 return type(value) is int 397 398# A shortcut for parameters of FLOAT16 399class Float16Scalar(Parameter, ImplicitParameter): 400 def __init__(self, name, value): 401 Parameter.__init__(self, name, ("FLOAT16", []), float(value)) 402 @staticmethod 403 def IsCompatible(value): 404 return False 405 406# A shortcut for parameters of FLOAT32 407class Float32Scalar(Parameter, ImplicitParameter): 408 def __init__(self, name, value): 409 Parameter.__init__(self, name, ("FLOAT32", []), float(value)) 410 @staticmethod 411 def IsCompatible(value): 412 return type(value) is float 413 414# A shortcut for parameters of BOOL 415class BoolScalar(Parameter, ImplicitParameter): 416 def __init__(self, name, value): 417 Parameter.__init__(self, name, ("BOOL", []), bool(value)) 418 @staticmethod 419 def IsCompatible(value): 420 return type(value) is bool 421 422# A shortcut for parameter of 1-D TENSOR_INT32 423class Int32Vector(Parameter, ImplicitParameter): 424 def __init__(self, name, value): 425 Parameter.__init__(self, name, ("TENSOR_INT32", [len(value)]), [int(v) for v in value]) 426 @staticmethod 427 def IsCompatible(value): 428 if type(value) is not list and type(value) is not tuple: 429 return False 430 return all(type(i) is int for i in value) 431 432# A shortcut for parameter of 1-D TENSOR_FLOAT32 433class Float32Vector(Parameter, ImplicitParameter): 434 def __init__(self, name, value): 435 Parameter.__init__(self, name, ("TENSOR_FLOAT32", [len(value)]), [float(v) for v in value]) 436 @staticmethod 437 def IsCompatible(value): 438 if type(value) is not list and type(value) is not tuple: 439 return False 440 return all(type(i) is float for i in value) 441 442# A shortcut for a SUBGRAPH parameter 443class SubgraphReference(Parameter, ImplicitParameter): 444 def __init__(self, name, model): 445 Parameter.__init__(self, name, ("SUBGRAPH", []), model) 446 self.lifetime = "SUBGRAPH" 447 if model.name is None: 448 model.name = name 449 @staticmethod 450 def IsCompatible(value): 451 return type(value) is Model 452 453# An explicitly declared intermediate result 454class Internal(Operand): 455 def __init__(self, name, opType, backward=None, skipRenaming=False, extraParams=None): 456 Operand.__init__(self, name, opType, backward, None, skipRenaming=skipRenaming, 457 extraParams=extraParams) 458 self.lifetime = "TEMPORARY_VARIABLE" 459 460# An operation in a model, does not need a name 461class Operation: 462 463 def __init__(self, optype, ins, outs): 464 self.optype = optype 465 self.SetInputs(ins) 466 self.SetOutputs(outs) 467 468 # for the ease of debugging 469 def __str__(self): 470 insString = GetJointStr(self.ins) 471 outsString = GetJointStr(self.outs) 472 return "Operation %s: [%s] -> [%s]"%(self.optype, insString, outsString) 473 __repr__ = __str__ 474 475 def SetInputs(self, ins): 476 self.ins = [ImplicitParameter.ImplicitConvertion(i) for i in ins] 477 return self 478 479 def SetOutputs(self, outs): 480 self.outs = list(outs) 481 return self 482 483# Main interface 484class Model: 485 models = list() 486 487 def __init__(self, name=None): 488 self.name = name 489 self.operations = [] 490 self.operands = [] 491 self.isRelaxed = False 492 self.compiled = False 493 self.dumped = False 494 self.version = FileNames.version 495 self.referenced_models = None 496 Model.models.append(self) 497 498 def WithSuffix(self, *args): 499 self.createFunctionName = GlobalVariable("CreateModel", self.name, *args) 500 self.createTestFunctionName = GlobalVariable("createTestModel", self.name, *args) 501 self.isIgnoredFunctionName = GlobalVariable("is_ignored", self.name, *args) 502 return self 503 504 def AddOperand(self, operand): 505 if operand not in self.operands: 506 self.operands.append(operand) 507 return self 508 509 # Makes sure the model contains all (and only) the given inputs in the 510 # specified order. 511 def IdentifyInputs(self, *args): 512 for arg in args: 513 self.AddOperand(arg) 514 inputs = tuple(self.GetInputs()) 515 assert inputs == args, '{} vs {}'.format(inputs, args) 516 return self 517 518 # Makes sure the model contains all (and only) the given outputs in the 519 # specified order. 520 def IdentifyOutputs(self, *args): 521 for arg in args: 522 self.AddOperand(arg) 523 outputs = tuple(self.GetOutputs()) 524 assert outputs == args, '{} vs {}'.format(outputs, args) 525 return self 526 527 def AddOperation(self, operation): 528 self.operations.append(operation) 529 for i in operation.ins: 530 self.AddOperand(i) 531 for o in operation.outs: 532 self.AddOperand(o) 533 return self 534 535 def Operation(self, op_name, *args): 536 return self.AddOperation(Operation(op_name, args, [])) 537 538 def To(self, *args): 539 assert len(self.operations) > 0 540 if type(args[0]) is tuple or type(args[0]) is list: 541 outs = args[0] 542 else: 543 outs = args 544 self.operations[-1].SetOutputs(outs) 545 for o in outs: 546 self.AddOperand(o) 547 return self 548 549 def RelaxedExecution(self, isRelaxed): 550 self.isRelaxed = isRelaxed 551 return self 552 553 # Sets the version of the model in compliance tests. Set to None to disable the test. 554 def IntroducedIn(self, ver): 555 self.version = ver 556 return self 557 558 def GetTypes(self): 559 return sorted(list(set(op.type for op in self.operands))) 560 561 def GetInputs(self): 562 return [i for i in self.operands if isinstance(i, Input)] 563 564 def GetOutputs(self): 565 return [o for o in self.operands if isinstance(o, Output)] 566 567 def GetInputsIndex(self): 568 return [i for i,op in enumerate(self.operands) if isinstance(op, Input)] 569 570 def GetOutputsIndex(self): 571 return [o for o,op in enumerate(self.operands) if isinstance(op, Output)] 572 573 def GetIndexOfOperands(self, operands): 574 return [self.operands.index(i) for i in operands] 575 576 def GetIgnoredOutputs(self): 577 return [o for o in self.operands if isinstance(o, IgnoredOutput)] 578 579 def GetParameters(self): 580 return [p for p in self.operands if isinstance(p, Parameter)] 581 582 def GetReferencedModels(self): 583 assert self.compiled 584 return self.referenced_models 585 586 def GetEquivalentOperands(self, targets): 587 return [self.operands[self.operands.index(t)] for t in targets] 588 589 def UpdateEquivalentOperands(self, targets): 590 for t in targets: 591 self.operands[self.operands.index(t)] = t 592 return self 593 594 def SetOperandIndex(self): 595 for ind, i in enumerate(self.GetInputs()): 596 i.index = ind 597 for ind, o in enumerate(self.GetOutputs()): 598 o.index = ind 599 for ind, op in enumerate(self.operands): 600 op.model_index = ind 601 return self 602 603 def SetOperandInsAndOuts(self): 604 for op in self.operands: 605 op.ins = list() 606 op.outs = list() 607 for op in self.operations: 608 op.ins = self.GetEquivalentOperands(op.ins) 609 op.outs = self.GetEquivalentOperands(op.outs) 610 for i in op.ins: 611 i.outs.append(op) 612 for o in op.outs: 613 o.ins.append(op) 614 return self 615 616 def TopologicalSortHelper(self, op, deps, visited): 617 if op in visited: 618 assert op not in deps, "Cycle detected in the graph" 619 else: 620 visited.add(op) 621 for i in deps[op]: 622 self.TopologicalSortHelper(i, deps, visited) 623 self.operations.append(op) 624 deps.pop(op) 625 626 # Topological sort of the operations, and detect if there is a cycle is the graph 627 def TopologicalSort(self): 628 deps = {op: list() for op in self.operations} 629 [deps[o].append(i) for op in self.operands for o in op.outs for i in op.ins] 630 operations = self.operations.copy() 631 self.operations = [] 632 visited = set() 633 for op in operations: 634 self.TopologicalSortHelper(op, deps, visited) 635 636 def CompileReferencedModels(self, referenced_models, referenced_model_to_index): 637 for operand in self.operands: 638 if operand.lifetime != "SUBGRAPH": 639 continue 640 model = operand.value[0] 641 key = id(model) 642 if key not in referenced_model_to_index: 643 referenced_model_to_index[key] = len(referenced_model_to_index) 644 referenced_models.append(model) 645 model.Compile(referenced_models, referenced_model_to_index) 646 operand.value = [referenced_model_to_index[key]] 647 648 def Compile(self, referenced_models=None, referenced_model_to_index=None): 649 if self.compiled: 650 return self 651 if referenced_models is None: 652 # This is the main model. 653 referenced_models = [] 654 referenced_model_to_index = {} 655 self.referenced_models = referenced_models 656 self.SetOperandIndex() 657 self.SetOperandInsAndOuts() 658 self.TopologicalSort() 659 self.CompileReferencedModels(referenced_models, referenced_model_to_index) 660 # Do not check compliance for relaxed mode tests. 661 if self.isRelaxed: 662 self.IntroducedIn(None) 663 self.compiled = True 664 return self 665 666 def Feed(self, feedDict): 667 for i in self.GetInputs(): 668 i.Feed(feedDict[0]) 669 for o in self.GetOutputs(): 670 o.Feed(feedDict[1]) 671 return self 672 673# To track implicitly convertible variation types 674class ImplicitVariation: 675 @staticmethod 676 def ImplicitConvertion(value): 677 if isinstance(value, ModelVariation): 678 return value 679 for implicitType in ImplicitVariation.__subclasses__(): 680 value = value if type(value) is tuple or type(value) is list else [value] 681 if implicitType.IsCompatible(value[0]): 682 var = implicitType(value[0]) 683 if len(value) > 1: 684 var.Identify(*value[1:]) 685 return var 686 assert False, "%s not supported for implicit variation"%value[0] 687 688# An exception indicating that the current variation list should be skipped. 689class SkipVariation(Exception): 690 pass 691 692# The base class for model variations 693class ModelVariation: 694 supportsSubgraphs = False 695 696 def __init__(self, name=None): 697 self.targetOperands = {} 698 self.name = name 699 700 # Apply the model variation. 701 def ApplyTo(self, model): 702 assert not model.compiled 703 assert not model.dumped 704 705 if not self.supportsSubgraphs: 706 containsSubgraphs = any(operand.lifetime == "SUBGRAPH" for operand in model.operands) 707 assert not containsSubgraphs, "Variation {} does not support subgraphs".format( 708 self.__class__.__name__) 709 710 if not self.targetOperands: 711 self.AutoIdentify(model) 712 713 # Transform operands and model. 714 targets = model.GetEquivalentOperands(sorted(self.targetOperands.keys())) 715 model.UpdateEquivalentOperands( 716 [self.TransformOperand(op, self.targetOperands[op]) for op in targets]) 717 model = self.TransformModel(model) 718 return model 719 720 def IdentifyOperands(self, args=None): 721 if args is None: 722 return self 723 self.targetOperands = args if type(args) is dict else {i: None for i in args} 724 return self 725 726 def Identify(self, operandArgs=None, paramArgs=None): 727 self.IdentifyOperands(operandArgs) 728 return self 729 730 # Set variation to its default name 731 def SetToDefaultName(self): 732 self.name = "" 733 return self 734 735 # Automatically select the target operand list 736 def AutoIdentify(self, model): 737 return self 738 739 # Transform operands that are marked by IdentifyOperands() 740 def TransformOperand(self, op, arg=None): 741 return op 742 743 # Transform the model 744 def TransformModel(self, model): 745 return model 746 747# Default variation that does nothing 748class DefaultVariation(ModelVariation): 749 supportsSubgraphs = True 750 751 def __init__(self, name=None): 752 ModelVariation.__init__(self, name=name) 753 754# Convert operand data type 755class DataTypeConverter(ModelVariation, ImplicitVariation): 756 supportsSubgraphs = True 757 758 def __init__(self, targetType=None, name=None, scale=None, zeroPoint=None): 759 ModelVariation.__init__(self, name=name) 760 if targetType is not None: 761 assert DataTypeConverter.IsCompatible(targetType) 762 self.targetType = targetType 763 self.scale = scale 764 self.zeroPoint = zeroPoint 765 766 @staticmethod 767 def IsCompatible(value): 768 return value.lower() in ["float16", "int32", "quant8", "quant8_signed"] 769 770 def SetToDefaultName(self): 771 if self.targetType is not None: 772 self.name = self.targetType.lower() 773 return self 774 targetTypes = list(zip(*(arg for arg in self.targetOperands.values() 775 if type(arg) is not DataTypeConverter)))[0] 776 if "TENSOR_QUANT8_SYMM_PER_CHANNEL" in targetTypes: 777 self.name = "channelQuant8" 778 elif "TENSOR_QUANT8_ASYMM" in targetTypes: 779 self.name = "quant8" 780 elif "TENSOR_QUANT8_ASYMM_SIGNED" in targetTypes: 781 self.name = "quant8_signed" 782 elif "TENSOR_INT32" in targetTypes: 783 self.name = "int32" 784 elif "TENSOR_FLOAT16" in targetTypes: 785 self.name = "float16" 786 else: 787 self.name = "float32" 788 return self 789 790 def AutoIdentify(self, model): 791 if self.targetType is not None: 792 if self.targetType == "quant8" or self.targetType == "quant8_signed": 793 if self.targetType == "quant8": 794 tensorType = "TENSOR_QUANT8_ASYMM" 795 else: 796 tensorType = "TENSOR_QUANT8_ASYMM_SIGNED" 797 assert self.scale is not None 798 assert self.zeroPoint is not None 799 tensorType = [tensorType, self.scale, self.zeroPoint] 800 scalarType = None # Not supported. 801 else: 802 tensorType = ["TENSOR_" + self.targetType.upper()] 803 scalarType = [self.targetType.upper()] 804 # By default, select all the float32 tensors/scalars 805 targets = dict() 806 targets.update({op: DataTypeConverter(self.targetType, self.name, 807 self.scale, self.zeroPoint) 808 for op in model.operands if op.type.type == "SUBGRAPH"}) 809 targets.update({op: tensorType 810 for op in model.operands if op.type.type == "TENSOR_FLOAT32"}) 811 if scalarType is not None: 812 targets.update({op: scalarType 813 for op in model.operands if op.type.type == "FLOAT32"}) 814 self.Identify(targets) 815 return self 816 817 def TransformOperand(self, op, arg=None): 818 if type(arg) is DataTypeConverter: 819 # Handle nested SUBGRAPHs 820 assert len(op.value) == 1 821 assert type(op.value[0]) is Model 822 op.value[0] = arg.ApplyTo(op.value[0]) 823 return op 824 if len(arg) == 1: 825 typeTuple = (arg[0], op.type.dimensions) 826 else: 827 typeTuple = (arg[0], op.type.dimensions, *arg[1:]) 828 # To handle Internal operands 829 if op.value is None or op.type.GetNumberOfElements() == 0: 830 op.type = Type.GetType(*typeTuple) 831 else: 832 v = Dequantize(op.GetValueAsNumpy().astype(np.float32), op.type) 833 op.type = Type.GetType(*typeTuple) 834 v = Quantize(v, op.type) 835 op.SetValueFromNumpy(v) 836 return op 837 838# Convert model to turn on/off relaxed computation 839class RelaxedModeConverter(ModelVariation, ImplicitVariation): 840 supportsSubgraphs = True 841 842 def __init__(self, isRelaxed=True, name=None): 843 ModelVariation.__init__(self, name=name) 844 if isinstance(isRelaxed, bool): 845 self.isRelaxed = isRelaxed 846 else: 847 assert RelaxedModeConverter.IsCompatible(isRelaxed.lower()) 848 self.isRelaxed = True 849 850 @staticmethod 851 def IsCompatible(value): 852 return value.lower() in ["relaxed"] 853 854 def SetToDefaultName(self): 855 self.name = "relaxed" if self.isRelaxed else "float" 856 return self 857 858 def TransformModel(self, model): 859 model.RelaxedExecution(self.isRelaxed) 860 return model 861 862# Convert data layout between "NHWC" amd "NCHW" 863class DataLayoutConverter(ModelVariation, ImplicitVariation): 864 865 def __init__(self, targetLayout="nchw", name=None): 866 ModelVariation.__init__(self, name=name) 867 self.targetLayout = targetLayout.lower() 868 assert DataLayoutConverter.IsCompatible(self.targetLayout) 869 self.perm = (0, 3, 1, 2) if self.targetLayout == "nchw" else (0, 2, 3, 1) 870 self.param = True if self.targetLayout == "nchw" else False 871 872 @staticmethod 873 def IsCompatible(value): 874 return value.lower() in ["nhwc", "nchw"] 875 876 def SetToDefaultName(self): 877 self.name = self.targetLayout 878 return self 879 880 def TransformOperand(self, op, arg=None): 881 if len(op.type.dimensions) == 4: 882 # To handle Internal operands 883 if op.value is not None and op.type.GetNumberOfElements() != 0: 884 op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm)) 885 newDim = [op.type.dimensions[i] for i in self.perm] 886 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 887 elif len(op.type.dimensions) == 1 and len(op.value) == 4: 888 op.SetValueFromNumpy(op.GetValueAsNumpy()[list(self.perm)]) 889 elif op.type.type == "BOOL": 890 op.SetValue(self.param) 891 else: 892 assert False, "%s not supported by DataLayoutConverter"%op 893 return op 894 895# Convert data by tansposing and removing axis 896class AxisConverter(ModelVariation): 897 898 def __init__(self, origin, target, dim, drop=[], name=None): 899 ModelVariation.__init__(self, name=name) 900 self.origin = origin 901 self.target = target 902 assert all(i >= -dim and i < dim for i in [self.origin, self.target]) 903 self.dim = dim 904 self.perm = list(range(dim)) 905 self.perm.insert(target if target >= 0 else target + dim, self.perm.pop(origin)) 906 self.drop = [drop] if type(drop) is int else list(drop) 907 assert all(i >= -dim and i < dim for i in self.drop) 908 self.drop = [i if i >= 0 else i + dim for i in self.drop] 909 assert target not in self.drop and target + dim not in self.drop 910 911 def SetToDefaultName(self): 912 axis = self.target if self.target >= 0 else self.target + self.dim 913 axis -= sum(i < axis for i in self.drop) 914 neg = "" if self.target >= 0 else "_neg" 915 self.name = "dim%d_axis%d%s"%(self.dim - len(self.drop), axis, neg) 916 return self 917 918 def TransposeAxis(self, op): 919 if op.type.type == "INT32": 920 op.SetValue(self.target) 921 elif len(op.type.dimensions) == self.dim: 922 # To handle Internal operands 923 if op.value is not None: 924 op.SetValueFromNumpy(op.GetValueAsNumpy().transpose(self.perm)) 925 newDim = [op.type.dimensions[i] for i in self.perm] 926 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 927 else: 928 assert False, "%s not supported by AxisConverter"%op 929 return op 930 931 def RemoveAxis(self, op): 932 if op.type.type == "INT32": 933 if op.value[0] >= 0: 934 op.SetValue(op.value[0] - sum(i < op.value[0] for i in self.drop)) 935 else: 936 op.SetValue(op.value[0] + sum(i > (op.value[0] + self.dim) for i in self.drop)) 937 elif len(op.type.dimensions) == self.dim: 938 if op.value is not None: 939 val = op.GetValueAsNumpy() 940 for i in sorted(self.drop, reverse=True): 941 val = np.take(val, 0, axis=i) 942 op.SetValueFromNumpy(val) 943 newDim = [op.type.dimensions[i] for i in range(self.dim) if i not in self.drop] 944 op.type = Type.GetType(op.type.type, newDim, op.type.scale, op.type.zeroPoint) 945 else: 946 assert False, "%s not supported by AxisConverter"%op 947 return op 948 949 def TransformOperand(self, op, arg=None): 950 op = self.TransposeAxis(op) 951 op = self.RemoveAxis(op) 952 return op 953 954# Convert Output based on activation 955class ActivationConverter(ModelVariation, ImplicitVariation): 956 # (Enum, low, high) 957 actMap = { 958 "none": (0, None, None), 959 "relu": (1, 0.0, None), 960 "relu1": (2, -1.0, 1.0), 961 "relu6": (3, 0.0, 6.0), 962 } 963 def __init__(self, act="relu", name=None): 964 ModelVariation.__init__(self, name=name) 965 self.act = act.lower() 966 assert ActivationConverter.IsCompatible(self.act) 967 self.enum = ActivationConverter.actMap[self.act][0] 968 self.low = ActivationConverter.actMap[self.act][1] 969 self.high = ActivationConverter.actMap[self.act][2] 970 971 @staticmethod 972 def IsCompatible(value): 973 return value.lower() in ActivationConverter.actMap.keys() 974 975 def SetToDefaultName(self): 976 self.name = self.act 977 return self 978 979 def TransformOperand(self, op, arg=None): 980 if op.type.type == "INT32": # activation enum 981 return op.SetValue(self.enum) 982 else: 983 assert isinstance(op, Output) 984 v = op.GetValueAsNumpy() 985 if self.low is not None: 986 low = Quantize(self.low, op.type) 987 v = np.maximum(v, low) 988 if self.high is not None: 989 high = Quantize(self.high, op.type) 990 v = np.minimum(v, high) 991 return op.SetValueFromNumpy(v) 992 993# Convert all constant tensors as model inputs. 994class AllTensorsAsInputsConverter(ModelVariation): 995 supportsSubgraphs = True 996 997 def __init__(self, name=None): 998 ModelVariation.__init__(self, name=name) 999 1000 def SetToDefaultName(self): 1001 self.name = "all_tensors_as_inputs" 1002 return self 1003 1004 def TransformModel(self, model): 1005 if len(model.operations) != 1: 1006 raise SkipVariation 1007 1008 # Find all constant tensors. 1009 tensorParams = [ 1010 p for p in model.operands 1011 if type(p) is Parameter and not p.type.IsScalar() and p.value is not None 1012 ] 1013 if not tensorParams: 1014 raise SkipVariation 1015 1016 # Convert to model inputs. 1017 model.UpdateEquivalentOperands([op.ConvertTo(Input) for op in tensorParams]) 1018 return model 1019 1020def CompatibleWithADD(op): 1021 return (len(op.type.dimensions) <= 4 and 1022 len(op.value) > 0 and 1023 op.type.type in ["TENSOR_FLOAT32", "TENSOR_QUANT8_ASYMM", 1024 "TENSOR_FLOAT16", "TENSOR_QUANT8_ASYMM_SIGNED"]) 1025 1026# Add a dummy ADD operation before each model input to make it as an internal operand. 1027class AllInputsAsInternalCoverter(ModelVariation): 1028 supportsSubgraphs = True 1029 1030 def __init__(self, name=None): 1031 ModelVariation.__init__(self, name=name) 1032 1033 def SetToDefaultName(self): 1034 self.name = "all_inputs_as_internal" 1035 return self 1036 1037 def TransformModel(self, model): 1038 if len(model.operations) != 1: 1039 raise SkipVariation 1040 1041 # Find all input tensors that can be an output of the ADD operation. 1042 modelInputs = [i for i in model.GetInputs() if CompatibleWithADD(i) and i.mayBeInternal] 1043 if not modelInputs: 1044 raise SkipVariation 1045 1046 # Make every input an output of a dummy operation: input_new ADD dummy = input. 1047 for op in modelInputs: 1048 newInput = op.ConvertTo(Input, name=op.name + "_new") 1049 dummyParam = Parameter("dummy", (op.type.type, [1], op.type.scale, op.type.zeroPoint), 1050 [op.type.zeroPoint]) 1051 model.Operation("ADD", newInput, dummyParam, 0).To(op) 1052 1053 # Convert to internal operands. 1054 model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelInputs]) 1055 return model 1056 1057# Add a dummy ADD operation after each model output to make it as an internal operand. 1058class AllOutputsAsInternalCoverter(ModelVariation): 1059 supportsSubgraphs = True 1060 1061 def __init__(self, name=None): 1062 ModelVariation.__init__(self, name=name) 1063 1064 def SetToDefaultName(self): 1065 self.name = "all_outputs_as_internal" 1066 return self 1067 1068 def TransformModel(self, model): 1069 if len(model.operations) != 1: 1070 raise SkipVariation 1071 1072 # Find all output tensors that can be an input to an ADD operation. 1073 modelOutputs = [o for o in model.GetOutputs() if CompatibleWithADD(o)] 1074 if not modelOutputs: 1075 raise SkipVariation 1076 1077 # Make every output an input of a dummy operation: output ADD dummy = output_new. 1078 for op in modelOutputs: 1079 newOutput = op.ConvertTo(Output, name=op.name + "_new") 1080 dummyParam = Parameter("dummy", (op.type.type, [1], op.type.scale, op.type.zeroPoint), 1081 [op.type.zeroPoint]) 1082 model.Operation("ADD", op, dummyParam, 0).To(newOutput) 1083 1084 # Convert to internal operands. 1085 model.UpdateEquivalentOperands([op.ConvertTo(Internal) for op in modelOutputs]) 1086 return model 1087 1088# An example is always attached to a model, and could have multiple variations 1089class Example: 1090 examples = [] 1091 versionOverrides = {} 1092 1093 def __init__(self, *args, model=None, name=None): 1094 self.model = Model.models[-1] if model is None else model 1095 self.name = name 1096 self.expectedMultinomialDistributionTolerance = 0 1097 self.expectFailure = False 1098 self.testDynamicOutputShape = True 1099 self.testLifeTimeVariation = True 1100 self.feedDicts = [] 1101 for feedDict in args: 1102 if type(feedDict) is tuple or type(feedDict) is list: 1103 self.feedDicts.append(feedDict) 1104 elif type(feedDict) is dict: 1105 self.feedDicts.append(( 1106 {i: feedDict[i] for i in self.model.GetInputs()}, 1107 {o: feedDict[o] for o in self.model.GetOutputs()} 1108 )) 1109 else: 1110 assert False 1111 self.variations = [] 1112 Example.examples.append(self) 1113 1114 @staticmethod 1115 def SetVersion(ver, *args): 1116 for name in args: 1117 Example.versionOverrides[name] = ver 1118 1119 # Main entrance of test generator 1120 @staticmethod 1121 def DumpAllExamples(DumpModel=None, model_fd=None, 1122 DumpExample=None, example_fd=None, 1123 DumpTest=None, test_fd=None): 1124 Example.CombineAllExamples() 1125 for example in Example.examples: 1126 example.Dump(DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd) 1127 1128 # Combine examples with the same model, same name, and same set of variations 1129 @staticmethod 1130 def CombineAllExamples(): 1131 modelMap = {} 1132 newExamples = [] 1133 for example in Example.examples: 1134 key = (example.model, example.name, tuple(tuple(e) for e in example.variations)) 1135 if key in modelMap: 1136 modelMap[key].Combine(example) 1137 else: 1138 modelMap[key] = example 1139 newExamples.append(example) 1140 Example.examples = newExamples 1141 1142 def AddVariations(self, *args, includeDefault=True, defaultName=None): 1143 self.variations.append([DefaultVariation(defaultName)] if includeDefault else []) 1144 self.variations[-1].extend(ImplicitVariation.ImplicitConvertion(i) for i in args) 1145 return self 1146 1147 def AddNchw(self, *args, includeDefault=True, defaultName="nhwc"): 1148 var = DataLayoutConverter("nchw").Identify(args) 1149 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1150 return self 1151 1152 def AddRelaxed(self, isRelaxed=True, includeDefault=True, defaultName=None): 1153 var = RelaxedModeConverter(isRelaxed) 1154 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1155 return self 1156 1157 def AddRelu(self, *args, includeDefault=True, defaultName=None): 1158 var = ActivationConverter("relu").Identify(args) 1159 self.AddVariations(var, includeDefault=includeDefault, defaultName=defaultName) 1160 return self 1161 1162 def AddAllActivations(self, *args): 1163 var = [ActivationConverter(i).Identify(args) 1164 for i in sorted(ActivationConverter.actMap.keys())] 1165 self.AddVariations(*var, includeDefault=False) 1166 return self 1167 1168 def GuessOriginalAxisAndDim(self, *args): 1169 origin = None 1170 dim = None 1171 for arg in args: 1172 if arg.type.type == "INT32": 1173 origin = arg.value[0] 1174 else: 1175 if dim is None: 1176 dim = len(arg.type.dimensions) 1177 else: 1178 assert dim == len(arg.type.dimensions) 1179 assert dim is not None 1180 origin = dim - 1 if origin is None else origin 1181 origin = origin + dim if origin < 0 else origin 1182 return origin, dim 1183 1184 def AddAxis(self, axis, *args, includeDefault=True, defaultName=None): 1185 origin, dim = self.GuessOriginalAxisAndDim(*args) 1186 axis = [axis] if type(axis) is int else list(axis) 1187 var = [AxisConverter(origin, a, dim).Identify(args) for a in axis] 1188 self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName) 1189 return self 1190 1191 def AddAllPositiveAxis(self, *args): 1192 origin, dim = self.GuessOriginalAxisAndDim(*args) 1193 var = [AxisConverter(origin, a, dim).Identify(args) for a in range(dim)] 1194 self.AddVariations(*var, includeDefault=False) 1195 return self 1196 1197 def AddAllAxis(self, *args): 1198 origin, dim = self.GuessOriginalAxisAndDim(*args) 1199 var = [AxisConverter(origin, a, dim).Identify(args) for a in range(-dim, dim)] 1200 self.AddVariations(*var, includeDefault=False) 1201 return self 1202 1203 def AddDims(self, dims, *args, includeDefault=True, defaultName=None): 1204 origin, dim = self.GuessOriginalAxisAndDim(*args) 1205 dims = [dims] if type(dims) is int else list(dims) 1206 drop = list(range(dim)) 1207 drop.pop(origin) 1208 var = [AxisConverter(origin, origin, dim, drop[0:(dim-i)]).Identify(args) for i in dims] 1209 self.AddVariations(*var, includeDefault=includeDefault, defaultName=defaultName) 1210 return self 1211 1212 def AddAllDims(self, *args): 1213 origin, dim = self.GuessOriginalAxisAndDim(*args) 1214 drop = list(range(dim)) 1215 drop.pop(origin) 1216 var = [AxisConverter(origin, origin, dim, drop[0:i]).Identify(args) for i in range(dim)] 1217 self.AddVariations(*var, includeDefault=False) 1218 return self 1219 1220 def AddAllDimsAndPositiveAxis(self, *args): 1221 origin, dim = self.GuessOriginalAxisAndDim(*args) 1222 var = [AxisConverter(origin, j, dim, range(i)).Identify(args) \ 1223 for i in range(dim) for j in range(i, dim)] 1224 self.AddVariations(*var, includeDefault=False) 1225 return self 1226 1227 def AddAllDimsAndAxis(self, *args): 1228 origin, dim = self.GuessOriginalAxisAndDim(*args) 1229 var = [AxisConverter(origin, k, dim, range(i)).Identify(args) \ 1230 for i in range(dim) for j in range(i, dim) for k in [j, j - dim]] 1231 self.AddVariations(*var, includeDefault=False) 1232 return self 1233 1234 def Combine(self, other): 1235 assert self.model is other.model, "Only examples targetting the same model can be combined" 1236 assert tuple(self.variations) == tuple(other.variations), \ 1237 "Only examples with the same set of variations can be combined" 1238 assert self.name == other.name, "Only examples with the same name can be combined" 1239 self.feedDicts.extend(other.feedDicts) 1240 return self 1241 1242 def Dump(self, DumpModel, model_fd, DumpExample, example_fd, DumpTest, test_fd): 1243 if self.testLifeTimeVariation and len(self.model.operations) == 1 and \ 1244 self.expectedMultinomialDistributionTolerance == 0: 1245 self.AddVariations(AllTensorsAsInputsConverter()) 1246 self.AddVariations(AllInputsAsInternalCoverter()) 1247 [v.SetToDefaultName() for vs in self.variations for v in vs if v.name is None] 1248 1249 for feedDict in self.feedDicts: 1250 self.model.Feed(feedDict) 1251 for variationList in itertools.product(*self.variations): 1252 modelOrigin = self.model 1253 self.model = copy.deepcopy(self.model) 1254 1255 # Apply variations 1256 try: 1257 for variation in variationList: 1258 self.model = variation.ApplyTo(self.model) 1259 except SkipVariation: 1260 self.model = modelOrigin 1261 continue 1262 1263 # Concat names for test and examples 1264 varNames = [v.name for v in variationList] 1265 self.testName = NamedTest(FileNames.specName, self.model.name, self.name, *varNames) 1266 self.examplesName = GlobalVariable("test_model", self.model.name, self.name, 1267 *varNames) 1268 if str(self.testName) in Example.versionOverrides: 1269 self.model.IntroducedIn(Example.versionOverrides[str(self.testName)]) 1270 self.model.WithSuffix(*varNames).Compile() 1271 1272 # Dump files 1273 if DumpExample is not None and example_fd is not None: 1274 DumpExample(self, example_fd) 1275 if DumpTest is not None and test_fd is not None: 1276 DumpTest(self, test_fd) 1277 1278 # Restore model before variation 1279 self.model = modelOrigin 1280 return self 1281 1282 # Specifies the RANDOM_MULTINOMIAL distribution tolerance. 1283 # If set to greater than zero, the input is compared as log-probabilities 1284 # to the output and must be within this tolerance to pass. 1285 def WithMultinomialDistributionTolerance(self, expectedTolerance): 1286 assert self.expectFailure is False 1287 self.expectedMultinomialDistributionTolerance = expectedTolerance 1288 return self 1289 1290 # Specifies that this example is expected to fail during compilation or execution. 1291 def ExpectFailure(self): 1292 assert self.expectedMultinomialDistributionTolerance == 0 1293 self.expectFailure = True 1294 return self 1295 1296 def DisableDynamicOutputShapeVariation(self): 1297 self.testDynamicOutputShape = False 1298 return self 1299 1300 def DisableLifeTimeVariation(self): 1301 self.testLifeTimeVariation = False 1302 return self 1303 1304class FileNames: 1305 specFiles = [] 1306 specNames = [] 1307 exampleFiles = [] 1308 specFile = "" 1309 specName = "" 1310 exampleFile = "" 1311 version = "" 1312 fileIndex = 0 1313 1314 @staticmethod 1315 def InitializeFileLists(spec, example): 1316 # get all spec files and target files 1317 if os.path.isfile(spec): 1318 FileNames.specFiles = [os.path.abspath(spec)] 1319 elif os.path.isdir(spec): 1320 FileNames.specFiles = sorted([os.path.abspath(os.path.join(spec, f)) 1321 for f in os.listdir(spec) if f.endswith(".mod.py")]) 1322 else: 1323 assert False, "%s is neither a file or a directory"%spec 1324 FileNames.specNames = [re.sub(r"\..*", "", os.path.basename(f)) 1325 for f in FileNames.specFiles] 1326 FileNames.exampleFiles = FileNames.ParseTargetFiles(example, ".example.cpp") 1327 1328 @staticmethod 1329 def ParseTargetFiles(arg, ext): 1330 numFiles = len(FileNames.specFiles) 1331 if arg is None: 1332 return [None] * numFiles 1333 absPath = os.path.abspath(arg) 1334 if os.path.isdir(arg): 1335 target = [os.path.join(absPath, f + ext) for f in FileNames.specNames] 1336 elif arg == "-": 1337 target = ["-"] * numFiles 1338 else: 1339 target = [absPath] * numFiles 1340 return target 1341 1342 @staticmethod 1343 def NextFile(): 1344 if FileNames.fileIndex >= len(FileNames.specFiles): 1345 return False 1346 FileNames.specFile = FileNames.specFiles[FileNames.fileIndex] 1347 FileNames.specName = FileNames.specNames[FileNames.fileIndex] 1348 FileNames.exampleFile = FileNames.exampleFiles[FileNames.fileIndex] 1349 FileNames.fileIndex += 1 1350 NamedObject.existingNames = set() 1351 NamedVariable.existingNames = set() 1352 NamedTest.existingNames = set() 1353 Type.typesMap = dict() 1354 Model.models = list() 1355 Example.examples = list() 1356 Configuration.use_shm_for_weights = False 1357 1358 # Extract version from absolute file path. 1359 versionMatch = re.findall(r"/V\d_\d/", FileNames.specFile) 1360 if len(versionMatch) == 1: 1361 FileNames.version = versionMatch[0].strip('/') 1362 else: 1363 FileNames.version = None 1364 return True 1365 1366class Configuration: 1367 use_shm_for_weights = False 1368 hook_mode = False 1369 1370 @staticmethod 1371 def useSHM(): 1372 return Configuration.use_shm_for_weights 1373 1374def GetTestGeneratorMTime(): 1375 tgFiles = ['test_generator.py', 'example_generator.py'] 1376 tgDir = os.path.dirname(__file__) 1377 return max(os.path.getmtime(os.path.join(tgDir, filename)) 1378 for filename in tgFiles) 1379 1380def MightNeedRegeneration(): 1381 specTime = os.path.getmtime(FileNames.specFile) 1382 tgTime = GetTestGeneratorMTime() 1383 return not os.path.exists(FileNames.exampleFile) or \ 1384 os.path.getmtime(FileNames.exampleFile) <= max(specTime, tgTime) 1385 1386def Read(filename): 1387 with open(filename) as reader: 1388 return reader.read() 1389 1390def AtomicWrite(filename, data): 1391 # os.replace(src, dest) may fail if src and dest are on diffrent 1392 # filesystems. 1393 tempFile = filename + '.tmp' 1394 try: 1395 with open(tempFile, 'w') as writer: 1396 writer.write(data) 1397 os.replace(tempFile, filename) 1398 tempFile = None 1399 finally: 1400 if tempFile is not None and os.path.exists(tempFile): 1401 os.remove(tempFile) 1402 1403def GetExecScope(): 1404 return dict( 1405 ActivationConverter=ActivationConverter, 1406 AllInputsAsInternalCoverter=AllInputsAsInternalCoverter, 1407 AllOutputsAsInternalCoverter=AllOutputsAsInternalCoverter, 1408 AllTensorsAsInputsConverter=AllTensorsAsInputsConverter, 1409 BoolScalar=BoolScalar, 1410 Configuration=Configuration, 1411 DataLayoutConverter=DataLayoutConverter, 1412 DataTypeConverter=DataTypeConverter, 1413 Example=Example, 1414 Float16Scalar=Float16Scalar, 1415 Float32Scalar=Float32Scalar, 1416 Float32Vector=Float32Vector, 1417 IgnoredOutput=IgnoredOutput, 1418 Input=Input, 1419 Int32Scalar=Int32Scalar, 1420 Int32Vector=Int32Vector, 1421 Internal=Internal, 1422 Model=Model, 1423 Operand=Operand, 1424 Output=Output, 1425 Parameter=Parameter, 1426 RelaxedModeConverter=RelaxedModeConverter, 1427 SubgraphReference=SubgraphReference, 1428 SymmPerChannelQuantParams=SymmPerChannelQuantParams) 1429 1430def ArgumentParser(): 1431 parser = argparse.ArgumentParser() 1432 parser.add_argument("spec", help="the spec file or directory") 1433 parser.add_argument("--hook", help="hook mode", action='store_true') 1434 return parser 1435 1436def ParseArgs(parser): 1437 args = parser.parse_args() 1438 Configuration.hook_mode = args.hook 1439 return args 1440 1441def Run(InitializeFiles=None, DumpExample=None): 1442 exec_scope = GetExecScope() 1443 while FileNames.NextFile(): 1444 try: 1445 if not MightNeedRegeneration(): 1446 continue 1447 exec(Read(FileNames.specFile), exec_scope) 1448 example_buf = io.StringIO() if FileNames.exampleFile else None 1449 InitializeFiles(example_fd=example_buf) 1450 Example.DumpAllExamples(DumpExample=DumpExample, example_fd=example_buf) 1451 if FileNames.exampleFile is None: 1452 continue 1453 if Configuration.hook_mode and (not os.path.exists(FileNames.exampleFile) or 1454 Read(FileNames.exampleFile) != example_buf.getvalue()): 1455 print(('\n{filename} is out of date. ' 1456 'Please run {generate_all_tests_sh} before uploading.\n').format( 1457 filename=FileNames.exampleFile, 1458 generate_all_tests_sh=os.path.abspath(os.path.join( 1459 os.path.dirname(__file__), '..', '..', 'runtime', 'test', 1460 'specs', 'generate_all_tests.sh')))) 1461 sys.exit(1) 1462 AtomicWrite(FileNames.exampleFile, example_buf.getvalue()) 1463 except Exception: 1464 traceback.print_exc() 1465 sys.exit("Exception raised when processing {}".format(FileNames.specFile)) 1466