1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python console command to invoke TOCO from serialized protos."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import argparse
21import sys
22
23# We need to import pywrap_tensorflow prior to the toco wrapper.
24# pylint: disable=invalid-import-order,g-bad-import-order
25from tensorflow.python import pywrap_tensorflow  # pylint: disable=unused-import
26from tensorflow.python import _pywrap_toco_api
27from tensorflow.python.platform import app
28
29FLAGS = None
30
31
32def execute(unused_args):
33  """Runs the converter."""
34  with open(FLAGS.model_proto_file, "rb") as model_file:
35    model_str = model_file.read()
36
37  with open(FLAGS.toco_proto_file, "rb") as toco_file:
38    toco_str = toco_file.read()
39
40  with open(FLAGS.model_input_file, "rb") as input_file:
41    input_str = input_file.read()
42
43  debug_info_str = None
44  if FLAGS.debug_proto_file:
45    with open(FLAGS.debug_proto_file, "rb") as debug_info_file:
46      debug_info_str = debug_info_file.read()
47
48  enable_mlir_converter = FLAGS.enable_mlir_converter
49
50  output_str = _pywrap_toco_api.TocoConvert(
51      model_str,
52      toco_str,
53      input_str,
54      False,  # extended_return
55      debug_info_str,
56      enable_mlir_converter)
57  open(FLAGS.model_output_file, "wb").write(output_str)
58  sys.exit(0)
59
60
61def main():
62  global FLAGS
63  parser = argparse.ArgumentParser(
64      description="Invoke toco using protos as input.")
65  parser.add_argument(
66      "model_proto_file",
67      type=str,
68      help="File containing serialized proto that describes the model.")
69  parser.add_argument(
70      "toco_proto_file",
71      type=str,
72      help="File containing serialized proto describing how TOCO should run.")
73  parser.add_argument(
74      "model_input_file", type=str, help="Input model is read from this file.")
75  parser.add_argument(
76      "model_output_file",
77      type=str,
78      help="Result of applying TOCO conversion is written here.")
79  parser.add_argument(
80      "--debug_proto_file",
81      type=str,
82      default="",
83      help=("File containing serialized `GraphDebugInfo` proto that describes "
84            "logging information."))
85  parser.add_argument(
86      "--enable_mlir_converter",
87      action="store_true",
88      help=("Boolean indicating whether to enable MLIR-based conversion "
89            "instead of TOCO conversion. (default False)"))
90
91  FLAGS, unparsed = parser.parse_known_args()
92
93  app.run(main=execute, argv=[sys.argv[0]] + unparsed)
94
95
96if __name__ == "__main__":
97  main()
98