1#!/usr/bin/python3
2
3# Copyright 2017, The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Slicing the input Model file
17
18Invoked by ml/nn/runtime/test/specs/slicing.sh; this Python code is
19not intended to be invoked directly by the users. See that script for
20details on how to use the slicing tool is used.
21
22This script does the following work:
23
24Perform a topological sort similar to the test generator, except that:
25* It would stop at the N-th operation it encounters, and
26* Rename the output of the N-th operation to a model output, and
27* Name that as the output of the model.
28* Also only inputs and weights used by the submodel would be emitted.
29
30"""
31
32from __future__ import absolute_import
33from __future__ import division
34from __future__ import print_function
35import argparse
36from functools import reduce
37import math
38import os
39import struct
40import sys
41import contextlib
42import test_generator
43import pprint
44# Stuff from test generator
45from test_generator import Configuration
46from test_generator import Example
47from test_generator import Float32Scalar
48from test_generator import Float32Vector
49from test_generator import IgnoredOutput
50from test_generator import Input
51from test_generator import Int32Scalar
52from test_generator import Int32Vector
53from test_generator import Internal
54from test_generator import Model
55from test_generator import Output
56from test_generator import Parameter
57from test_generator import SmartOpen
58
59
60# Take a model from command line
61def import_source():
62  parser = argparse.ArgumentParser()
63  parser.add_argument("spec", help="the spec file")
64  parser.add_argument(
65      "-n", "--number",
66      help="number of operations in the sliced model. Default = 1",
67      default=1)
68  parser.add_argument(
69      "-m", "--model", help="the output model file", default="-")
70  parser.add_argument(
71      "-e", "--example", help="the output example file", default="-")
72  args = parser.parse_args()
73
74  if os.path.exists(args.spec):
75    test_generator.FileNames.specFile = os.path.basename(args.spec)
76    exec (open(args.spec).read())
77  else:
78    print("cannot find file %s" % args.spec)
79    sys.exit(1)
80
81  return (args.model, args.example, args.number)
82
83
84# Slice till the Nth op the topological sort finds
85# the output of that op becomes the output of the model
86class slicing:
87
88  def __init__(self, threshold):
89    self.__nr_op_seen = 0
90    self.__threshold = threshold
91    self.__last_outs = []
92    self.__all_formatted_ops = []
93    self.__referenced_operands = set()
94
95  def format_as_py_op(self, op):
96    fmt = op.PyDefinition()
97    if fmt is not None:
98      self.__nr_op_seen += 1
99      if self.__nr_op_seen > self.__threshold:
100        return False
101      self.__last_outs = op.outs
102      for o in op.ins:
103        self.__referenced_operands.add(o)
104      for o in op.outs:
105        self.__referenced_operands.add(o)
106      self.__all_formatted_ops.append("model = model.%s" % fmt)
107      return True
108
109  def dump(self, model_file):
110    for x in self.__all_formatted_ops:
111      print(x, file=model_file)
112
113  def dump_example(self, example_file):
114    override = {}
115    # Make alias for the output variable
116    for lo in self.__last_outs:
117      override[str(lo)] = lo.type.GetNumberOfElements()
118      alias_def = """\
119# Alias for the output variable {operand_name}
120aliased_output{number} = {operand_name}
121"""
122      op = {
123          'operand_name': str(lo),
124          'number': 0 # only support one output as of now
125      }
126      print (alias_def.format(**op), file=example_file)
127    Example.py_dump(example_file, override, self.__referenced_operands)
128
129  def format_operands(self, model):
130    # Dump operand definitions
131    op_definitions = []
132    for o in model.operands:
133      if o not in self.__referenced_operands:
134        continue
135      ty = o.type
136      op_def = """{op_name} = {operand}("{op_name}", "{element_type}", "{shape}" """
137      if isinstance(o, test_generator.Parameter):
138        op_def += """, {initializer})"""
139        init = o.value
140        py_operand_name = "Parameter"
141      else:
142        op_def += ")"
143        init = []
144        py_operand_name = "IgnoredOutput" if o in set(
145            self.__last_outs) else o.__class__.__name__
146
147      op = {
148          "element_type": ty.type,
149          "shape": ty.GetRawShape(),
150          "op_name": str(o),
151          "operand": py_operand_name,
152          "initializer": init
153      }
154      op_definitions.append(op_def.format(**op))
155    return "\n".join(op_definitions)
156
157
158if __name__ == "__main__":
159  (model, example, number) = import_source()
160  s = slicing(int(number))
161
162  with SmartOpen(model) as model_file:
163    spec_file = " (from: %s)" % (test_generator.FileNames.specFile)
164    print("# Generated file%s. Do not edit" % (spec_file), file=model_file)
165    print("model = Model()", file=model_file)
166    # slicing tool only support one single model per spec file
167    model = Model.models[0].Compile()
168    for op in model.operations:
169        s.format_as_py_op(op)
170    print(s.format_operands(model), file=model_file)
171    s.dump(model_file)
172  with SmartOpen(example) as example_file:
173    s.dump_example(example_file)
174