1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Converts a TFLite model to a TFLite Micro model (C++ Source)."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from absl import app
22from absl import flags
23
24from tensorflow.lite.python import util
25
26FLAGS = flags.FLAGS
27
28flags.DEFINE_string("input_tflite_file", None,
29                    "Full path name to the input TFLite model file.")
30flags.DEFINE_string(
31    "output_source_file", None,
32    "Full path name to the output TFLite Micro model (C++ Source) file).")
33flags.DEFINE_string("output_header_file", None,
34                    "Full filepath of the output C header file.")
35flags.DEFINE_string("array_variable_name", None,
36                    "Name to use for the C data array variable.")
37flags.DEFINE_integer("line_width", 80, "Width to use for formatting.")
38flags.DEFINE_string("include_guard", None,
39                    "Name to use for the C header include guard.")
40flags.DEFINE_string("include_path", None,
41                    "Optional path to include in generated source file.")
42flags.DEFINE_boolean(
43    "use_tensorflow_license", False,
44    "Whether to prefix the generated files with the TF Apache2 license.")
45
46flags.mark_flag_as_required("input_tflite_file")
47flags.mark_flag_as_required("output_source_file")
48flags.mark_flag_as_required("output_header_file")
49flags.mark_flag_as_required("array_variable_name")
50
51
52def main(_):
53  with open(FLAGS.input_tflite_file, "rb") as input_handle:
54    input_data = input_handle.read()
55
56  source, header = util.convert_bytes_to_c_source(
57      data=input_data,
58      array_name=FLAGS.array_variable_name,
59      max_line_width=FLAGS.line_width,
60      include_guard=FLAGS.include_guard,
61      include_path=FLAGS.include_path,
62      use_tensorflow_license=FLAGS.use_tensorflow_license)
63
64  with open(FLAGS.output_source_file, "w") as source_handle:
65    source_handle.write(source)
66
67  with open(FLAGS.output_header_file, "w") as header_handle:
68    header_handle.write(header)
69
70
71if __name__ == "__main__":
72  app.run(main)
73