1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Converts a TFLite model to a TFLite Micro model (C++ Source).""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from absl import app 22from absl import flags 23 24from tensorflow.lite.python import util 25 26FLAGS = flags.FLAGS 27 28flags.DEFINE_string("input_tflite_file", None, 29 "Full path name to the input TFLite model file.") 30flags.DEFINE_string( 31 "output_source_file", None, 32 "Full path name to the output TFLite Micro model (C++ Source) file).") 33flags.DEFINE_string("output_header_file", None, 34 "Full filepath of the output C header file.") 35flags.DEFINE_string("array_variable_name", None, 36 "Name to use for the C data array variable.") 37flags.DEFINE_integer("line_width", 80, "Width to use for formatting.") 38flags.DEFINE_string("include_guard", None, 39 "Name to use for the C header include guard.") 40flags.DEFINE_string("include_path", None, 41 "Optional path to include in generated source file.") 42flags.DEFINE_boolean( 43 "use_tensorflow_license", False, 44 "Whether to prefix the generated files with the TF Apache2 license.") 45 46flags.mark_flag_as_required("input_tflite_file") 47flags.mark_flag_as_required("output_source_file") 48flags.mark_flag_as_required("output_header_file") 49flags.mark_flag_as_required("array_variable_name") 50 51 52def main(_): 53 with open(FLAGS.input_tflite_file, "rb") as input_handle: 54 input_data = input_handle.read() 55 56 source, header = util.convert_bytes_to_c_source( 57 data=input_data, 58 array_name=FLAGS.array_variable_name, 59 max_line_width=FLAGS.line_width, 60 include_guard=FLAGS.include_guard, 61 include_path=FLAGS.include_path, 62 use_tensorflow_license=FLAGS.use_tensorflow_license) 63 64 with open(FLAGS.output_source_file, "w") as source_handle: 65 source_handle.write(source) 66 67 with open(FLAGS.output_header_file, "w") as header_handle: 68 header_handle.write(header) 69 70 71if __name__ == "__main__": 72 app.run(main) 73