1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Generates Android Java sources from a TFLite model with metadata.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import shutil 23import sys 24from absl import app 25from absl import flags 26from absl import logging 27 28from tensorflow_lite_support.codegen.python import _pywrap_codegen 29 30FLAGS = flags.FLAGS 31 32flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.') 33flags.DEFINE_string('destination', None, 'Path of destination of generation.') 34flags.DEFINE_string('package_name', 'org.tensorflow.lite.support', 35 'Name of generated java package to put the wrapper class.') 36flags.DEFINE_string( 37 'model_class_name', 'MyModel', 38 'Name of generated wrapper class (should not contain package name).') 39flags.DEFINE_string( 40 'model_asset_path', '', 41 '(Optional) Path to the model in generated assets/ dir. If not set, ' 42 'generator will use base name of input model.' 43) 44 45 46def get_model_buffer(path): 47 if not os.path.isfile(path): 48 logging.error('Cannot find model at path %s.', path) 49 with open(path, 'rb') as f: 50 buf = f.read() 51 return buf 52 53 54def prepare_directory_for_file(file_path): 55 target_dir = os.path.dirname(file_path) 56 if not os.path.exists(target_dir): 57 os.makedirs(target_dir) 58 return 59 if not os.path.isdir(target_dir): 60 logging.error('Cannot write to %s', target_dir) 61 62 63def run_main(argv): 64 """Main function of the codegen.""" 65 66 if len(argv) > 1: 67 logging.error('None flag arguments found: [%s]', ', '.join(argv[1:])) 68 69 codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination) 70 model_buffer = get_model_buffer(FLAGS.model) 71 model_asset_path = FLAGS.model_asset_path 72 if not model_asset_path: 73 model_asset_path = os.path.basename(FLAGS.model) 74 result = codegen.generate(model_buffer, FLAGS.package_name, 75 FLAGS.model_class_name, model_asset_path) 76 error_message = codegen.get_error_message().strip() 77 if error_message: 78 logging.error(error_message) 79 if not result.files: 80 logging.error('Generation failed!') 81 return 82 83 for each in result.files: 84 prepare_directory_for_file(each.path) 85 with open(each.path, 'w') as f: 86 f.write(each.content) 87 88 logging.info('Generation succeeded!') 89 model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets', 90 model_asset_path) 91 prepare_directory_for_file(model_asset_path) 92 shutil.copy(FLAGS.model, model_asset_path) 93 logging.info('Model copied into assets!') 94 95 96# Simple wrapper to make the code pip-friendly 97def main(): 98 flags.mark_flag_as_required('model') 99 flags.mark_flag_as_required('destination') 100 app.run(main=run_main, argv=sys.argv) 101 102 103if __name__ == '__main__': 104 app.run(main) 105