1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
17 #define TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
18 
19 #include <memory>
20 #include <string>
21 #include <vector>
22 
23 #include "tensorflow_lite_support/codegen/code_generator.h"
24 #include "tensorflow_lite_support/codegen/utils.h"
25 #include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
26 #include "tensorflow/lite/schema/schema_generated.h"
27 
28 namespace tflite {
29 namespace support {
30 namespace codegen {
31 
32 namespace details_android_java {
33 
34 /// The intermediate data structure for generating code from TensorMetadata.
35 /// Should only be used as const reference when created.
36 struct TensorInfo {
37   std::string name;
38   std::string upper_camel_name;
39   std::string content_type;
40   std::string wrapper_type;
41   std::string processor_type;
42   bool is_input;
43   /// Optional. Set to -1 if not applicable.
44   int normalization_unit;
45   /// Optional. Set to -1 if associated_axis_label is empty.
46   int associated_axis_label_index;
47   /// Optional. Set to -1 if associated_value_label is empty.
48   int associated_value_label_index;
49 };
50 
51 /// The intermediate data structure for generating code from ModelMetadata.
52 /// Should only be used as const reference when created.
53 struct ModelInfo {
54   std::string package_name;
55   std::string model_asset_path;
56   std::string model_class_name;
57   std::string model_versioned_name;
58   std::vector<TensorInfo> inputs;
59   std::vector<TensorInfo> outputs;
60   // Extra helper fields. For models with inputs "a", "b" and outputs "x", "y":
61   std::string input_type_param_list;
62   // e.g. "TensorImage a, TensorBuffer b"
63   std::string inputs_list;
64   // e.g. "a, b"
65   std::string postprocessor_type_param_list;
66   // e.g. "ImageProcessor xPostprocessor, TensorProcessor yPostprocessor"
67   std::string postprocessors_list;
68   // e.g. "xPostprocessor, yPostprocessor"
69 };
70 
71 }  // namespace details_android_java
72 
73 constexpr char JAVA_EXT[] = ".java";
74 
75 /// Generates Android supporting codes and modules (in Java) based on TFLite
76 /// metadata.
77 class AndroidJavaGenerator : public CodeGenerator {
78  public:
79   /// Creates an AndroidJavaGenerator.
80   /// Args:
81   /// - module_root: The root of destination Java module.
82   explicit AndroidJavaGenerator(const std::string& module_root);
83 
84   /// Generates files. Returns the file paths and contents.
85   /// Args:
86   /// - model: The TFLite model with Metadata filled.
87   /// - package_name: The name of the Java package which generated classes
88   /// belong to.
89   /// - model_class_name: A readable name of the generated wrapper class, such
90   /// as "ImageClassifier", "MobileNetV2" or "MyModel".
91   /// - model_asset_path: The relevant path to the model file in the asset.
92   // TODO(b/141225157): Automatically generate model_class_name.
93   GenerationResult Generate(const Model* model, const std::string& package_name,
94                             const std::string& model_class_name,
95                             const std::string& model_asset_path);
96 
97   /// Generates files and returns the file paths and contents.
98   /// It's mostly identical with the previous one, but the model here is
99   /// provided as binary flatbuffer content without parsing.
100   GenerationResult Generate(const char* model_storage,
101                             const std::string& package_name,
102                             const std::string& model_class_name,
103                             const std::string& model_asset_path);
104 
105   std::string GetErrorMessage();
106 
107  private:
108   const std::string module_root_;
109   ErrorReporter err_;
110 };
111 
112 }  // namespace codegen
113 }  // namespace support
114 }  // namespace tflite
115 
116 #endif  // TENSORFLOW_LITE_SUPPORT_CODEGEN_ANDROID_JAVA_GENERATOR_H_
117