1# TFLite Task library - C++ 2 3A flexible and ready-to-use library for common machine learning model types, 4such as classification and detection. 5 6## Text Task Libraries 7 8### QuestionAnswerer 9 10`QuestionAnswerer` API is able to load 11[Mobile BERT](https://tfhub.dev/tensorflow/mobilebert/1) or 12[AlBert](https://tfhub.dev/tensorflow/albert_lite_base/1) TFLite models and 13answer question based on context. 14 15Use the C++ API to answer questions as follows: 16 17```cc 18using tflite::task::text::qa::BertQuestionAnswerer; 19using tflite::task::text::qa::QaAnswer; 20// Create API handler with Mobile Bert model. 21auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswererFromFile("/path/to/mobileBertModel", "/path/to/vocab"); 22// Or create API handler with Albert model. 23// auto qa_client = BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile("/path/to/alBertModel", "/path/to/sentencePieceModel"); 24 25 26std::string context = 27 "Nikola Tesla (Serbian Cyrillic: Никола Тесла; 10 " 28 "July 1856 – 7 January 1943) was a Serbian American inventor, electrical " 29 "engineer, mechanical engineer, physicist, and futurist best known for his " 30 "contributions to the design of the modern alternating current (AC) " 31 "electricity supply system."; 32std::string question = "When was Nikola Tesla born?"; 33// Run inference with `context` and a given `question` to the context, and get top-k 34// answers ranked by logits. 35const std::vector<QaAnswer> answers = qa_client->Answer(context, question); 36// Access QaAnswer results. 37for (const QaAnswer& item : answers) { 38 std::cout << absl::StrFormat("Text: %s logit=%f start=%d end=%d", item.text, 39 item.pos.logit, item.pos.start, item.pos.end) 40 << std::endl; 41} 42// Output: 43// Text: 10 July 1856 logit=16.8527 start=17 end=19 44// ... (and more) 45// 46// So the top-1 answer is: "10 July 1856". 47``` 48 49In the above code, `item.text` is the text content of an answer. We use a span 50with closed interval `[item.pos.start, item.pos.end]` to denote predicted tokens 51in the answer, and `item.pos.logit` is the sum of span logits to represent the 52confidence score. 53 54### NLClassifier 55 56`NLClassifier` API is able to load any TFLite models for natural language 57classaification task such as language detection or sentiment detection. 58 59The API expects a TFLite model with the following input/output tensor: 60Input tensor0: 61 (kTfLiteString) - input of the model, accepts a string. 62Output tensor0: 63 (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64) 64 - output scores for each class, if type is one of the Int types, 65 dequantize it to double 66Output tensor1: optional 67 (kTfLiteString) 68 - output classname for each class, should be of the same length with 69 scores. If this tensor is not present, the API uses score indices as 70 classnames. 71By default the API tries to find the input/output tensors with default 72configurations in NLClassifierOptions, with tensor name prioritized over 73tensor index. The option is configurable for different TFLite models. 74 75Use the C++ API to perform language ID classification as follows: 76 77```cc 78using tflite::task::text::nlclassifier::NLClassifier; 79using tflite::task::core::Category; 80auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model"); 81// Or create a customized NLClassifierOptions 82// NLClassifierOptions options = 83// { 84// .output_score_tensor_name = myOutputScoreTensorName, 85// .output_label_tensor_name = myOutputLabelTensorName, 86// } 87// auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model", options); 88std::string context = "What language is this?"; 89std::vector<Category> categories = classifier->Classify(context); 90// Access category results. 91for (const Categoryr& category : categories) { 92 std::cout << absl::StrFormat("Language: %s Probability: %f", category.class_name, category_.score) 93 << std::endl; 94} 95// Output: 96// Language: en Probability=0.9 97// ... (and more) 98// 99// So the top-1 answer is 'en'. 100``` 101 102## Vision Task Libraries 103 104### Image Classifier 105 106`ImageClassifier` accepts any TFLite image classification model (with optional, 107but strongly recommended, TFLite Model Metadata) that conforms to the following 108spec: 109 110Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): 111 112 - image input of size `[batch x height x width x channels]`. 113 - batch inference is not supported (`batch` is required to be 1). 114 - only RGB inputs are supported (`channels` is required to be 3). 115 - if type is `kTfLiteFloat32`, `NormalizationOptions` are required to be 116 attached to the metadata for input normalization. 117 118At least one output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`) with: 119 120 - `N` classes and either 2 or 4 dimensions, i.e. `[1 x N]` or 121 `[1 x 1 x 1 x N]` 122 - optional (but recommended) label map(s) as AssociatedFile-s with type 123 TENSOR_AXIS_LABELS, containing one label per line. The first such 124 AssociatedFile (if any) is used to fill the `class_name` field of the 125 results. The `display_name` field is filled from the AssociatedFile (if 126 any) whose locale matches the `display_names_locale` field of the 127 `ImageClassifierOptions` used at creation time ("en" by default, i.e. 128 English). If none of these are available, only the `index` field of the 129 results will be filled. 130 131An example of such model can be found at: 132https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1 133 134Example usage: 135 136```cc 137// More options are available (e.g. max number of results to return). At the 138// very least, the model must be specified: 139ImageClassifierOptions options; 140options.mutable_model_file_with_metadata()->set_file_name( 141 "/path/to/model.tflite"); 142 143// Create an ImageClassifier instance from the options. 144StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or = 145 ImageClassifier::CreateFromOptions(options); 146// Check if an error occurred. 147if (!image_classifier_or.ok()) { 148 std::cerr << "An error occurred during ImageClassifier creation: " 149 << image_classifier_or.status().message(); 150 return; 151} 152std::unique_ptr<ImageClassifier> image_classifier = 153 std::move(image_classifier_or.value()); 154 155// Prepare FrameBuffer input from e.g. image RGBA data, width and height: 156std::unique_ptr<FrameBuffer> frame_buffer = 157 CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height}); 158 159// Run inference: 160StatusOr<ClassificationResult> result_or = 161 image_classifier->Classify(*frame_buffer); 162// Check if an error occurred. 163if (!result_or.ok()) { 164 std::cerr << "An error occurred during classification: " 165 << result_or.status().message(); 166 return; 167} 168ClassificationResult result = result_or.value(); 169 170// Example value for 'result': 171// 172// classifications { 173// classes { index: 934 score: 0.95 class_name: "cat" } 174// classes { index: 948 score: 0.007 class_name: "dog" } 175// classes { index: 927 score: 0.003 class_name: "fox" } 176// head_index: 0 177// } 178``` 179 180A CLI demo tool is also available [here][1] for easily trying out this API. 181 182### Object Detector 183 184`ObjectDetector` accepts any object detection TFLite model (with mandatory 185TFLite Model Metadata) that conforms to the following spec (e.g. Single Shot 186Detectors): 187 188Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): 189 190 - image input of size `[batch x height x width x channels]`. 191 - batch inference is not supported (`batch` is required to be 1). 192 - only RGB inputs are supported (`channels` is required to be 3). 193 - if type is kTfLiteFloat32, `NormalizationOptions` are required to be 194 attached to the metadata for input normalization. 195 196Output tensors must be the 4 outputs (type: `kTfLiteFloat32`) of a 197[`DetectionPostProcess`][2] op, i.e: 198 199* Locations: 200 201 - of size `[num_results x 4]`, the inner array 202 representing bounding boxes in the form [top, left, right, bottom]. 203 - BoundingBoxProperties are required to be attached to the metadata 204 and must specify type=BOUNDARIES and coordinate_type=RATIO. 205 206* Classes: 207 208 - of size `[num_results]`, each value representing the 209 integer index of a class. 210 - optional (but recommended) label map(s) can be attached as 211 AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per 212 line. The first such AssociatedFile (if any) is used to fill the 213 `class_name` field of the results. The `display_name` field is filled 214 from the AssociatedFile (if any) whose locale matches the 215 `display_names_locale` field of the `ObjectDetectorOptions` used at 216 creation time ("en" by default, i.e. English). If none of these are 217 available, only the `index` field of the results will be filled. 218 219* Scores: 220 221 - of size `[num_results]`, each value representing the score 222 of the detected object. 223 224* Number of results: 225 226 - integer `num_results` as a tensor of size `[1]` 227 228An example of such model can be found at: 229https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1 230 231Example usage: 232 233```cc 234// More options are available (e.g. max number of results to return). At the 235// very least, the model must be specified: 236ObjectDetectorOptions options; 237options.mutable_model_file_with_metadata()->set_file_name( 238 "/path/to/model.tflite"); 239 240// Create an ObjectDetector instance from the options. 241StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or = 242 ObjectDetector::CreateFromOptions(options); 243// Check if an error occurred. 244if (!object_detector_or.ok()) { 245 std::cerr << "An error occurred during ObjectDetector creation: " 246 << object_detector_or.status().message(); 247 return; 248} 249std::unique_ptr<ObjectDetector> object_detector = 250 std::move(object_detector_or.value()); 251 252// Prepare FrameBuffer input from e.g. image RGBA data, width and height: 253std::unique_ptr<FrameBuffer> frame_buffer = 254 CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height}); 255 256// Run inference: 257StatusOr<DetectionResult> result_or = object_detector->Detect(*frame_buffer); 258// Check if an error occurred. 259if (!result_or.ok()) { 260 std::cerr << "An error occurred during detection: " 261 << result_or.status().message(); 262 return; 263} 264DetectionResult result = result_or.value(); 265 266// Example value for 'result': 267// 268// detections { 269// bounding_box { 270// origin_x: 54 271// origin_y: 398 272// width: 393 273// height: 196 274// } 275// classes { index: 16 score: 0.65 class_name: "cat" } 276// } 277// detections { 278// bounding_box { 279// origin_x: 602 280// origin_y: 157 281// width: 394 282// height: 447 283// } 284// classes { index: 17 score: 0.45 class_name: "dog" } 285// } 286``` 287 288A CLI demo tool is available [here][3] for easily trying out this API. 289 290### Image Segmenter 291 292`ImageSegmenter` accepts any TFLite model (with optional, but strongly 293recommended, TFLite Model Metadata) that conforms to the following spec: 294 295Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): 296 297 - image input of size `[batch x height x width x channels]`. 298 - batch inference is not supported (`batch` is required to be 1). 299 - only RGB inputs are supported (`channels` is required to be 3). 300 - if type is kTfLiteFloat32, `NormalizationOptions` are required to be 301 attached to the metadata for input normalization. 302 303Output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`): 304 305 - tensor of size `[batch x mask_height x mask_width x num_classes]`, where 306 `batch` is required to be 1, `mask_width` and `mask_height` are the 307 dimensions of the segmentation masks produced by the model, and 308 `num_classes` is the number of classes supported by the model. 309 - optional (but recommended) label map(s) can be attached as 310 AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per 311 line. The first such AssociatedFile (if any) is used to fill the 312 `class_name` field of the results. The `display_name` field is filled 313 from the AssociatedFile (if any) whose locale matches the 314 `display_names_locale` field of the `ImageSegmenterOptions` used at 315 creation time ("en" by default, i.e. English). If none of these are 316 available, only the `index` field of the results will be filled. 317 318An example of such model can be found at: 319https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1 320 321Example usage: 322 323```cc 324// More options are available to select between return a single category mask 325// or multiple confidence masks during post-processing. 326ImageSegmenterOptions options; 327options.mutable_model_file_with_metadata()->set_file_name( 328 "/path/to/model.tflite"); 329 330// Create an ImageSegmenter instance from the options. 331StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or = 332 ImageSegmenter::CreateFromOptions(options); 333// Check if an error occurred. 334if (!image_segmenter_or.ok()) { 335 std::cerr << "An error occurred during ImageSegmenter creation: " 336 << image_segmenter_or.status().message(); 337 return; 338} 339std::unique_ptr<ImageSegmenter> immage_segmenter = 340 std::move(image_segmenter_or.value()); 341 342// Prepare FrameBuffer input from e.g. image RGBA data, width and height: 343std::unique_ptr<FrameBuffer> frame_buffer = 344 CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height}); 345 346// Run inference: 347StatusOr<SegmentationResult> result_or = 348 immage_segmenter->Segment(*frame_buffer); 349// Check if an error occurred. 350if (!result_or.ok()) { 351 std::cerr << "An error occurred during segmentation: " 352 << result_or.status().message(); 353 return; 354} 355SegmentationResult result = result_or.value(); 356 357// Example value for 'result': 358// 359// segmentation { 360// width: 257 361// height: 257 362// category_mask: "\x00\x01..." 363// colored_labels { r: 0 g: 0 b: 0 class_name: "background" } 364// colored_labels { r: 128 g: 0 b: 0 class_name: "aeroplane" } 365// ... 366// colored_labels { r: 128 g: 192 b: 0 class_name: "train" } 367// colored_labels { r: 0 g: 64 b: 128 class_name: "tv" } 368// } 369// 370// Where 'category_mask' is a byte buffer of size 'width' x 'height', with the 371// value of each pixel representing the class this pixel belongs to (e.g. '\x00' 372// means "background", '\x01' means "aeroplane", etc). 373// 'colored_labels' provides the label for each possible value, as well as 374// suggested RGB components to optionally transform the result into a more 375// human-friendly colored image. 376// 377``` 378 379A CLI demo tool is available [here][4] for easily trying out this API. 380 381[1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc 382[2]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc 383[3]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc 384[4]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc 385