1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python console command to invoke TOCO from serialized protos."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import sys
22from tensorflow.lite.toco.python import tensorflow_wrap_toco
23from tensorflow.python.platform import app
24
25FLAGS = None
26
27
28def execute(unused_args):
29  model_str = open(FLAGS.model_proto_file, "rb").read()
30  toco_str = open(FLAGS.toco_proto_file, "rb").read()
31  input_str = open(FLAGS.model_input_file, "rb").read()
32
33  output_str = tensorflow_wrap_toco.TocoConvert(model_str, toco_str, input_str)
34  open(FLAGS.model_output_file, "wb").write(output_str)
35  sys.exit(0)
36
37
38def main():
39  global FLAGS
40  parser = argparse.ArgumentParser(
41      description="Invoke toco using protos as input.")
42  parser.add_argument(
43      "model_proto_file",
44      type=str,
45      help="File containing serialized proto that describes the model.")
46  parser.add_argument(
47      "toco_proto_file",
48      type=str,
49      help="File containing serialized proto describing how TOCO should run.")
50  parser.add_argument(
51      "model_input_file", type=str, help="Input model is read from this file.")
52  parser.add_argument(
53      "model_output_file",
54      type=str,
55      help="Result of applying TOCO conversion is written here.")
56
57  FLAGS, unparsed = parser.parse_known_args()
58
59  app.run(main=execute, argv=[sys.argv[0]] + unparsed)
60
61
62if __name__ == "__main__":
63  main()
64