1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Generates Android Java sources from a TFLite model with metadata."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import shutil
23import sys
24from absl import app
25from absl import flags
26from absl import logging
27
28from tensorflow_lite_support.codegen.python import _pywrap_codegen
29
30FLAGS = flags.FLAGS
31
32flags.DEFINE_string('model', None, 'Path to model (.tflite) flatbuffer file.')
33flags.DEFINE_string('destination', None, 'Path of destination of generation.')
34flags.DEFINE_string('package_name', 'org.tensorflow.lite.support',
35                    'Name of generated java package to put the wrapper class.')
36flags.DEFINE_string(
37    'model_class_name', 'MyModel',
38    'Name of generated wrapper class (should not contain package name).')
39flags.DEFINE_string(
40    'model_asset_path', '',
41    '(Optional) Path to the model in generated assets/ dir. If not set, '
42    'generator will use base name of input model.'
43)
44
45
46def get_model_buffer(path):
47  if not os.path.isfile(path):
48    logging.error('Cannot find model at path %s.', path)
49  with open(path, 'rb') as f:
50    buf = f.read()
51    return buf
52
53
54def prepare_directory_for_file(file_path):
55  target_dir = os.path.dirname(file_path)
56  if not os.path.exists(target_dir):
57    os.makedirs(target_dir)
58    return
59  if not os.path.isdir(target_dir):
60    logging.error('Cannot write to %s', target_dir)
61
62
63def run_main(argv):
64  """Main function of the codegen."""
65
66  if len(argv) > 1:
67    logging.error('None flag arguments found: [%s]', ', '.join(argv[1:]))
68
69  codegen = _pywrap_codegen.AndroidJavaGenerator(FLAGS.destination)
70  model_buffer = get_model_buffer(FLAGS.model)
71  model_asset_path = FLAGS.model_asset_path
72  if not model_asset_path:
73    model_asset_path = os.path.basename(FLAGS.model)
74  result = codegen.generate(model_buffer, FLAGS.package_name,
75                            FLAGS.model_class_name, model_asset_path)
76  error_message = codegen.get_error_message().strip()
77  if error_message:
78    logging.error(error_message)
79  if not result.files:
80    logging.error('Generation failed!')
81    return
82
83  for each in result.files:
84    prepare_directory_for_file(each.path)
85    with open(each.path, 'w') as f:
86      f.write(each.content)
87
88  logging.info('Generation succeeded!')
89  model_asset_path = os.path.join(FLAGS.destination, 'src/main/assets',
90                                  model_asset_path)
91  prepare_directory_for_file(model_asset_path)
92  shutil.copy(FLAGS.model, model_asset_path)
93  logging.info('Model copied into assets!')
94
95
96# Simple wrapper to make the code pip-friendly
97def main():
98  flags.mark_flag_as_required('model')
99  flags.mark_flag_as_required('destination')
100  app.run(main=run_main, argv=sys.argv)
101
102
103if __name__ == '__main__':
104  app.run(main)
105