1#!/usr/bin/python3 2 3# Copyright 2017, The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Slicing the input Model file 17 18Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is 19not intended to be invoked directly by the users. See that script for 20details on how to use the slicing tool is used. 21 22This script does the following work: 23 24Perform a topological sort similar to the test generator, except that: 25* It would stop at the N-th operation it encounters, and 26* Rename the output of the N-th operation to a model output, and 27* Name that as the output of the model. 28* Also only inputs and weights used by the submodel would be emitted. 29 30""" 31 32from __future__ import absolute_import 33from __future__ import division 34from __future__ import print_function 35import argparse 36from functools import reduce 37import math 38import os 39import struct 40import sys 41import contextlib 42import test_generator 43import pprint 44# Stuff from test generator 45from test_generator import Configuration 46from test_generator import Example 47from test_generator import Float32Scalar 48from test_generator import Float32Vector 49from test_generator import IgnoredOutput 50from test_generator import Input 51from test_generator import Int32Scalar 52from test_generator import Int32Vector 53from test_generator import Internal 54from test_generator import Model 55from test_generator import Output 56from test_generator import Parameter 57from test_generator import SmartOpen 58 59 60# Take a model from command line 61def import_source(): 62 parser = argparse.ArgumentParser() 63 parser.add_argument("spec", help="the spec file") 64 parser.add_argument( 65 "-n", "--number", 66 help="number of operations in the sliced model. Default = 1", 67 default=1) 68 parser.add_argument( 69 "-m", "--model", help="the output model file", default="-") 70 parser.add_argument( 71 "-e", "--example", help="the output example file", default="-") 72 args = parser.parse_args() 73 74 if os.path.exists(args.spec): 75 test_generator.FileNames.specFile = os.path.basename(args.spec) 76 exec (open(args.spec).read()) 77 else: 78 print("cannot find file %s" % args.spec) 79 sys.exit(1) 80 81 return (args.model, args.example, args.number) 82 83 84# Slice till the Nth op the topological sort finds 85# the output of that op becomes the output of the model 86class slicing: 87 88 def __init__(self, threshold): 89 self.__nr_op_seen = 0 90 self.__threshold = threshold 91 self.__last_outs = [] 92 self.__all_formatted_ops = [] 93 self.__referenced_operands = set() 94 95 def format_as_py_op(self, op): 96 fmt = op.PyDefinition() 97 if fmt is not None: 98 self.__nr_op_seen += 1 99 if self.__nr_op_seen > self.__threshold: 100 return False 101 self.__last_outs = op.outs 102 for o in op.ins: 103 self.__referenced_operands.add(o) 104 for o in op.outs: 105 self.__referenced_operands.add(o) 106 self.__all_formatted_ops.append("model = model.%s" % fmt) 107 return True 108 109 def dump(self, model_file): 110 for x in self.__all_formatted_ops: 111 print(x, file=model_file) 112 113 def dump_example(self, example_file): 114 override = {} 115 # Make alias for the output variable 116 for lo in self.__last_outs: 117 override[str(lo)] = lo.type.GetNumberOfElements() 118 alias_def = """\ 119# Alias for the output variable {operand_name} 120aliased_output{number} = {operand_name} 121""" 122 op = { 123 'operand_name': str(lo), 124 'number': 0 # only support one output as of now 125 } 126 print (alias_def.format(**op), file=example_file) 127 Example.py_dump(example_file, override, self.__referenced_operands) 128 129 def format_operands(self, model): 130 # Dump operand definitions 131 op_definitions = [] 132 for o in model.operands: 133 if o not in self.__referenced_operands: 134 continue 135 ty = o.type 136 op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """ 137 if isinstance(o, test_generator.Parameter): 138 op_def += """, {initializer})""" 139 init = o.value 140 py_operand_name = "Parameter" 141 else: 142 op_def += ")" 143 init = [] 144 py_operand_name = "IgnoredOutput" if o in set( 145 self.__last_outs) else o.__class__.__name__ 146 147 op = { 148 "element_type": ty.type, 149 "shape": ty.GetRawShape(), 150 "op_name": str(o), 151 "operand": py_operand_name, 152 "initializer": init 153 } 154 op_definitions.append(op_def.format(**op)) 155 return "\n".join(op_definitions) 156 157 158if __name__ == "__main__": 159 (model, example, number) = import_source() 160 s = slicing(int(number)) 161 162 with SmartOpen(model) as model_file: 163 spec_file = " (from: %s)" % (test_generator.FileNames.specFile) 164 print("# Generated file%s. Do not edit" % (spec_file), file=model_file) 165 print("model = Model()", file=model_file) 166 # slicing tool only support one single model per spec file 167 model = Model.models[0].Compile() 168 for op in model.operations: 169 s.format_as_py_op(op) 170 print(s.format_operands(model), file=model_file) 171 s.dump(model_file) 172 with SmartOpen(example) as example_file: 173 s.dump_example(example_file) 174