1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/no_micro_features_data.h"
17 #include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h"
18 #include "tensorflow/lite/experimental/micro/examples/micro_speech/micro_features/yes_micro_features_data.h"
19 #include "tensorflow/lite/experimental/micro/kernels/all_ops_resolver.h"
20 #include "tensorflow/lite/experimental/micro/micro_error_reporter.h"
21 #include "tensorflow/lite/experimental/micro/micro_interpreter.h"
22 #include "tensorflow/lite/experimental/micro/testing/micro_test.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 #include "tensorflow/lite/version.h"
25
26 TF_LITE_MICRO_TESTS_BEGIN
27
TF_LITE_MICRO_TEST(TestInvoke)28 TF_LITE_MICRO_TEST(TestInvoke) {
29 // Set up logging.
30 tflite::MicroErrorReporter micro_error_reporter;
31 tflite::ErrorReporter* error_reporter = µ_error_reporter;
32
33 // Map the model into a usable data structure. This doesn't involve any
34 // copying or parsing, it's a very lightweight operation.
35 const tflite::Model* model =
36 ::tflite::GetModel(g_tiny_conv_micro_features_model_data);
37 if (model->version() != TFLITE_SCHEMA_VERSION) {
38 error_reporter->Report(
39 "Model provided is schema version %d not equal "
40 "to supported version %d.\n",
41 model->version(), TFLITE_SCHEMA_VERSION);
42 }
43
44 // This pulls in all the operation implementations we need.
45 tflite::ops::micro::AllOpsResolver resolver;
46
47 // Create an area of memory to use for input, output, and intermediate arrays.
48 const int tensor_arena_size = 10 * 1024;
49 uint8_t tensor_arena[tensor_arena_size];
50 tflite::SimpleTensorAllocator tensor_allocator(tensor_arena,
51 tensor_arena_size);
52
53 // Build an interpreter to run the model with.
54 tflite::MicroInterpreter interpreter(model, resolver, &tensor_allocator,
55 error_reporter);
56
57 // Get information about the memory area to use for the model's input.
58 TfLiteTensor* input = interpreter.input(0);
59
60 // Make sure the input has the properties we expect.
61 TF_LITE_MICRO_EXPECT_NE(nullptr, input);
62 TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size);
63 TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
64 TF_LITE_MICRO_EXPECT_EQ(49, input->dims->data[1]);
65 TF_LITE_MICRO_EXPECT_EQ(40, input->dims->data[2]);
66 TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type);
67
68 // Copy a spectrogram created from a .wav audio file of someone saying "Yes",
69 // into the memory area used for the input.
70 const uint8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data;
71 for (int i = 0; i < input->bytes; ++i) {
72 input->data.uint8[i] = yes_features_data[i];
73 }
74
75 // Run the model on this input and make sure it succeeds.
76 TfLiteStatus invoke_status = interpreter.Invoke();
77 if (invoke_status != kTfLiteOk) {
78 error_reporter->Report("Invoke failed\n");
79 }
80 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
81
82 // Get the output from the model, and make sure it's the expected size and
83 // type.
84 TfLiteTensor* output = interpreter.output(0);
85 TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
86 TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
87 TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
88 TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
89
90 // There are four possible classes in the output, each with a score.
91 const int kSilenceIndex = 0;
92 const int kUnknownIndex = 1;
93 const int kYesIndex = 2;
94 const int kNoIndex = 3;
95
96 // Make sure that the expected "Yes" score is higher than the other classes.
97 uint8_t silence_score = output->data.uint8[kSilenceIndex];
98 uint8_t unknown_score = output->data.uint8[kUnknownIndex];
99 uint8_t yes_score = output->data.uint8[kYesIndex];
100 uint8_t no_score = output->data.uint8[kNoIndex];
101 TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score);
102 TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score);
103 TF_LITE_MICRO_EXPECT_GT(yes_score, no_score);
104
105 // Now test with a different input, from a recording of "No".
106 const uint8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data;
107 for (int i = 0; i < input->bytes; ++i) {
108 input->data.uint8[i] = no_features_data[i];
109 }
110
111 // Run the model on this "No" input.
112 invoke_status = interpreter.Invoke();
113 if (invoke_status != kTfLiteOk) {
114 error_reporter->Report("Invoke failed\n");
115 }
116 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
117
118 // Get the output from the model, and make sure it's the expected size and
119 // type.
120 output = interpreter.output(0);
121 TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
122 TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
123 TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
124 TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
125
126 // Make sure that the expected "No" score is higher than the other classes.
127 silence_score = output->data.uint8[kSilenceIndex];
128 unknown_score = output->data.uint8[kUnknownIndex];
129 yes_score = output->data.uint8[kYesIndex];
130 no_score = output->data.uint8[kNoIndex];
131 TF_LITE_MICRO_EXPECT_GT(no_score, silence_score);
132 TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score);
133 TF_LITE_MICRO_EXPECT_GT(no_score, yes_score);
134
135 error_reporter->Report("Ran successfully\n");
136 }
137
138 TF_LITE_MICRO_TESTS_END
139