1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22import tempfile 23import numpy as np 24 25from tensorflow.lite.python import lite 26from tensorflow.lite.python import lite_constants 27from tensorflow.lite.python.interpreter import Interpreter 28from tensorflow.python import keras 29from tensorflow.python.client import session 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import test_util 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import math_ops 35from tensorflow.python.ops import nn_ops 36from tensorflow.python.ops import variable_scope 37from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer 38from tensorflow.python.platform import gfile 39from tensorflow.python.platform import resource_loader 40from tensorflow.python.platform import test 41from tensorflow.python.saved_model import saved_model 42from tensorflow.python.training.training_util import write_graph 43 44 45class FromConstructor(test_util.TensorFlowTestCase): 46 47 # Tests invalid constructors using a dummy value for the GraphDef. 48 def testInvalidConstructor(self): 49 message = ('If input_tensors and output_tensors are None, both ' 50 'input_arrays_with_shape and output_arrays must be defined.') 51 52 # `output_arrays` is not defined. 53 with self.assertRaises(ValueError) as error: 54 lite.TFLiteConverter( 55 None, None, [], input_arrays_with_shape=[('input', [3, 9])]) 56 self.assertEqual(message, str(error.exception)) 57 58 # `input_arrays_with_shape` is not defined. 59 with self.assertRaises(ValueError) as error: 60 lite.TFLiteConverter(None, [], None, output_arrays=['output']) 61 self.assertEqual(message, str(error.exception)) 62 63 # Tests valid constructors using a dummy value for the GraphDef. 64 def testValidConstructor(self): 65 converter = lite.TFLiteConverter( 66 None, 67 None, 68 None, 69 input_arrays_with_shape=[('input', [3, 9])], 70 output_arrays=['output']) 71 self.assertFalse(converter._has_valid_tensors()) 72 self.assertEqual(converter.get_input_arrays(), ['input']) 73 74 with self.assertRaises(ValueError) as error: 75 converter._set_batch_size(1) 76 self.assertEqual( 77 'The batch size cannot be set for this model. Please use ' 78 'input_shapes parameter.', str(error.exception)) 79 80 converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor']) 81 self.assertTrue(converter._has_valid_tensors()) 82 83 84@test_util.run_v1_only('b/120545219') 85class FromSessionTest(test_util.TensorFlowTestCase): 86 87 def testFloat(self): 88 in_tensor = array_ops.placeholder( 89 shape=[1, 16, 16, 3], dtype=dtypes.float32) 90 out_tensor = in_tensor + in_tensor 91 sess = session.Session() 92 93 # Convert model and ensure model is not None. 94 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 95 [out_tensor]) 96 tflite_model = converter.convert() 97 self.assertTrue(tflite_model) 98 99 # Check values from converted model. 100 interpreter = Interpreter(model_content=tflite_model) 101 interpreter.allocate_tensors() 102 103 input_details = interpreter.get_input_details() 104 self.assertEqual(1, len(input_details)) 105 self.assertEqual('Placeholder', input_details[0]['name']) 106 self.assertEqual(np.float32, input_details[0]['dtype']) 107 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 108 self.assertEqual((0., 0.), input_details[0]['quantization']) 109 110 output_details = interpreter.get_output_details() 111 self.assertEqual(1, len(output_details)) 112 self.assertEqual('add', output_details[0]['name']) 113 self.assertEqual(np.float32, output_details[0]['dtype']) 114 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 115 self.assertEqual((0., 0.), output_details[0]['quantization']) 116 117 def testString(self): 118 in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string) 119 out_tensor = array_ops.reshape(in_tensor, shape=[2, 2]) 120 sess = session.Session() 121 122 # Convert model and ensure model is not None. 123 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 124 [out_tensor]) 125 tflite_model = converter.convert() 126 self.assertTrue(tflite_model) 127 128 # Check values from converted model. 129 interpreter = Interpreter(model_content=tflite_model) 130 interpreter.allocate_tensors() 131 132 input_details = interpreter.get_input_details() 133 self.assertEqual(1, len(input_details)) 134 self.assertEqual('Placeholder', input_details[0]['name']) 135 self.assertEqual(np.string_, input_details[0]['dtype']) 136 self.assertTrue(([4] == input_details[0]['shape']).all()) 137 138 output_details = interpreter.get_output_details() 139 self.assertEqual(1, len(output_details)) 140 self.assertEqual('Reshape', output_details[0]['name']) 141 self.assertEqual(np.string_, output_details[0]['dtype']) 142 self.assertTrue(([2, 2] == output_details[0]['shape']).all()) 143 # TODO(b/122659643): Test setting/getting string data via the python 144 # interpreter API after support has been added. 145 146 def testQuantization(self): 147 in_tensor_1 = array_ops.placeholder( 148 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 149 in_tensor_2 = array_ops.placeholder( 150 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 151 out_tensor = array_ops.fake_quant_with_min_max_args( 152 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 153 sess = session.Session() 154 155 # Convert model and ensure model is not None. 156 converter = lite.TFLiteConverter.from_session( 157 sess, [in_tensor_1, in_tensor_2], [out_tensor]) 158 converter.inference_type = lite_constants.QUANTIZED_UINT8 159 converter.quantized_input_stats = { 160 'inputA': (0., 1.), 161 'inputB': (0., 1.) 162 } # mean, std_dev 163 tflite_model = converter.convert() 164 self.assertTrue(tflite_model) 165 166 # Check values from converted model. 167 interpreter = Interpreter(model_content=tflite_model) 168 interpreter.allocate_tensors() 169 170 input_details = interpreter.get_input_details() 171 self.assertEqual(2, len(input_details)) 172 self.assertEqual('inputA', input_details[0]['name']) 173 self.assertEqual(np.uint8, input_details[0]['dtype']) 174 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 175 self.assertEqual((1., 0.), 176 input_details[0]['quantization']) # scale, zero_point 177 178 self.assertEqual('inputB', input_details[1]['name']) 179 self.assertEqual(np.uint8, input_details[1]['dtype']) 180 self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) 181 self.assertEqual((1., 0.), 182 input_details[1]['quantization']) # scale, zero_point 183 184 output_details = interpreter.get_output_details() 185 self.assertEqual(1, len(output_details)) 186 self.assertEqual('output', output_details[0]['name']) 187 self.assertEqual(np.uint8, output_details[0]['dtype']) 188 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 189 self.assertTrue(output_details[0]['quantization'][0] > 0) # scale 190 191 def testQuantizationInvalid(self): 192 in_tensor_1 = array_ops.placeholder( 193 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 194 in_tensor_2 = array_ops.placeholder( 195 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 196 out_tensor = array_ops.fake_quant_with_min_max_args( 197 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 198 sess = session.Session() 199 200 # Convert model and ensure model is not None. 201 converter = lite.TFLiteConverter.from_session( 202 sess, [in_tensor_1, in_tensor_2], [out_tensor]) 203 converter.inference_type = lite_constants.QUANTIZED_UINT8 204 converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev 205 with self.assertRaises(ValueError) as error: 206 converter.convert() 207 self.assertEqual( 208 'Quantization input stats are not available for input tensors ' 209 '\'inputB\'.', str(error.exception)) 210 211 def testIntermediateInputArray(self): 212 """Convert a model from an intermediate input array.""" 213 in_tensor_init = array_ops.placeholder( 214 shape=[1, 16, 16, 3], dtype=dtypes.float32) 215 in_tensor_final = in_tensor_init + in_tensor_init 216 out_tensor = in_tensor_final + in_tensor_final 217 sess = session.Session() 218 219 # Convert model and ensure model is not None. 220 converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final], 221 [out_tensor]) 222 tflite_model = converter.convert() 223 self.assertTrue(tflite_model) 224 225 # Check values from converted model. 226 interpreter = Interpreter(model_content=tflite_model) 227 interpreter.allocate_tensors() 228 229 input_details = interpreter.get_input_details() 230 self.assertEqual(1, len(input_details)) 231 self.assertEqual('add', input_details[0]['name']) 232 self.assertEqual(np.float32, input_details[0]['dtype']) 233 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 234 self.assertEqual((0., 0.), input_details[0]['quantization']) 235 236 output_details = interpreter.get_output_details() 237 self.assertEqual(1, len(output_details)) 238 self.assertEqual('add_1', output_details[0]['name']) 239 self.assertEqual(np.float32, output_details[0]['dtype']) 240 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 241 self.assertEqual((0., 0.), output_details[0]['quantization']) 242 243 def testSizeNoneInvalid(self): 244 in_tensor = array_ops.placeholder(dtype=dtypes.float32) 245 out_tensor = in_tensor + in_tensor 246 sess = session.Session() 247 248 # Test None as shape. 249 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 250 [out_tensor]) 251 with self.assertRaises(ValueError) as error: 252 converter.convert() 253 self.assertEqual('Provide an input shape for input array \'Placeholder\'.', 254 str(error.exception)) 255 256 def testScalarValid(self): 257 # Construct a graph using a scalar (empty shape) input. 258 in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 259 out_tensor = in_tensor + in_tensor 260 sess = session.Session() 261 262 # Test conversion with the scalar input shape. 263 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 264 [out_tensor]) 265 tflite_model = converter.convert() 266 self.assertTrue(tflite_model) 267 268 # Check values from converted model. 269 interpreter = Interpreter(model_content=tflite_model) 270 interpreter.allocate_tensors() 271 272 input_details = interpreter.get_input_details() 273 self.assertEqual(1, len(input_details)) 274 self.assertEqual('Placeholder', input_details[0]['name']) 275 self.assertEqual(np.float32, input_details[0]['dtype']) 276 self.assertTrue(([] == input_details[0]['shape']).all()) 277 278 output_details = interpreter.get_output_details() 279 self.assertEqual(1, len(output_details)) 280 self.assertEqual('add', output_details[0]['name']) 281 self.assertEqual(np.float32, output_details[0]['dtype']) 282 self.assertTrue(([] == input_details[0]['shape']).all()) 283 284 # Validate inference using the scalar inputs/outputs. 285 test_input = np.array(4.0, dtype=np.float32) 286 expected_output = np.array(8.0, dtype=np.float32) 287 interpreter.set_tensor(input_details[0]['index'], test_input) 288 interpreter.invoke() 289 290 output_data = interpreter.get_tensor(output_details[0]['index']) 291 self.assertTrue((expected_output == output_data).all()) 292 293 def testSizeInvalid(self): 294 in_tensor = array_ops.placeholder( 295 shape=[1, None, 16, 3], dtype=dtypes.float32) 296 out_tensor = in_tensor + in_tensor 297 sess = session.Session() 298 299 # Test invalid shape. None after 1st dimension. 300 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 301 [out_tensor]) 302 with self.assertRaises(ValueError) as error: 303 converter.convert() 304 self.assertEqual( 305 'None is only supported in the 1st dimension. Tensor ' 306 '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', 307 str(error.exception)) 308 309 def testBatchSizeValid(self): 310 in_tensor = array_ops.placeholder( 311 shape=[None, 16, 16, 3], dtype=dtypes.float32) 312 out_tensor = in_tensor + in_tensor 313 sess = session.Session() 314 315 # Convert model and ensure model is not None. 316 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 317 [out_tensor]) 318 tflite_model = converter.convert() 319 self.assertTrue(tflite_model) 320 321 # Check values from converted model. 322 interpreter = Interpreter(model_content=tflite_model) 323 interpreter.allocate_tensors() 324 325 input_details = interpreter.get_input_details() 326 self.assertEqual(1, len(input_details)) 327 self.assertEqual('Placeholder', input_details[0]['name']) 328 self.assertEqual(np.float32, input_details[0]['dtype']) 329 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 330 self.assertEqual((0., 0.), input_details[0]['quantization']) 331 332 output_details = interpreter.get_output_details() 333 self.assertEqual(1, len(output_details)) 334 self.assertEqual('add', output_details[0]['name']) 335 self.assertEqual(np.float32, output_details[0]['dtype']) 336 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 337 self.assertEqual((0., 0.), output_details[0]['quantization']) 338 339 def testFreezeGraph(self): 340 in_tensor = array_ops.placeholder( 341 shape=[1, 16, 16, 3], dtype=dtypes.float32) 342 var = variable_scope.get_variable( 343 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 344 out_tensor = in_tensor + var 345 sess = session.Session() 346 sess.run(_global_variables_initializer()) 347 348 # Convert model and ensure model is not None. 349 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 350 [out_tensor]) 351 tflite_model = converter.convert() 352 self.assertTrue(tflite_model) 353 354 # Check values from converted model. 355 interpreter = Interpreter(model_content=tflite_model) 356 interpreter.allocate_tensors() 357 358 input_details = interpreter.get_input_details() 359 self.assertEqual(1, len(input_details)) 360 self.assertEqual('Placeholder', input_details[0]['name']) 361 self.assertEqual(np.float32, input_details[0]['dtype']) 362 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 363 self.assertEqual((0., 0.), input_details[0]['quantization']) 364 365 output_details = interpreter.get_output_details() 366 self.assertEqual(1, len(output_details)) 367 self.assertEqual('add', output_details[0]['name']) 368 self.assertEqual(np.float32, output_details[0]['dtype']) 369 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 370 self.assertEqual((0., 0.), output_details[0]['quantization']) 371 372 # TODO(nupurgarg): Verify value of contents in GraphViz. 373 def testGraphviz(self): 374 in_tensor = array_ops.placeholder( 375 shape=[1, 16, 16, 3], dtype=dtypes.float32) 376 out_tensor = in_tensor + in_tensor 377 sess = session.Session() 378 379 # Convert model and ensure model is not None. 380 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 381 [out_tensor]) 382 converter.output_format = lite_constants.GRAPHVIZ_DOT 383 graphviz_output = converter.convert() 384 self.assertTrue(graphviz_output) 385 386 # TODO(nupurgarg): Verify value of contents in GraphViz. 387 def testDumpGraphviz(self): 388 in_tensor = array_ops.placeholder( 389 shape=[1, 16, 16, 3], dtype=dtypes.float32) 390 out_tensor = in_tensor + in_tensor 391 sess = session.Session() 392 393 # Convert model and ensure model is not None. 394 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 395 [out_tensor]) 396 graphviz_dir = self.get_temp_dir() 397 converter.dump_graphviz_dir = graphviz_dir 398 tflite_model = converter.convert() 399 self.assertTrue(tflite_model) 400 401 # Ensure interpreter is able to allocate and check graphviz data. 402 interpreter = Interpreter(model_content=tflite_model) 403 interpreter.allocate_tensors() 404 405 num_items_graphviz = len(os.listdir(graphviz_dir)) 406 self.assertTrue(num_items_graphviz) 407 408 # Convert model and ensure model is not None. 409 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 410 [out_tensor]) 411 graphviz_dir = self.get_temp_dir() 412 converter.dump_graphviz_dir = graphviz_dir 413 converter.dump_graphviz_video = True 414 tflite_model = converter.convert() 415 self.assertTrue(tflite_model) 416 417 # Ensure graphviz folder has more data after using video flag. 418 num_items_graphviz_video = len(os.listdir(graphviz_dir)) 419 self.assertTrue(num_items_graphviz_video > num_items_graphviz) 420 421 def testInferenceInputType(self): 422 in_tensor = array_ops.placeholder( 423 shape=[1, 16, 16, 3], dtype=dtypes.float32) 424 out_tensor = in_tensor + in_tensor 425 sess = session.Session() 426 427 # Convert model and ensure model is not None. 428 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 429 [out_tensor]) 430 converter.inference_input_type = lite_constants.QUANTIZED_UINT8 431 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 432 tflite_model = converter.convert() 433 self.assertTrue(tflite_model) 434 435 # Check values from converted model. 436 interpreter = Interpreter(model_content=tflite_model) 437 interpreter.allocate_tensors() 438 439 input_details = interpreter.get_input_details() 440 self.assertEqual(1, len(input_details)) 441 self.assertEqual('Placeholder', input_details[0]['name']) 442 self.assertEqual(np.uint8, input_details[0]['dtype']) 443 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 444 self.assertEqual((1., 0.), input_details[0]['quantization']) 445 446 output_details = interpreter.get_output_details() 447 self.assertEqual(1, len(output_details)) 448 self.assertEqual('add', output_details[0]['name']) 449 self.assertEqual(np.float32, output_details[0]['dtype']) 450 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 451 452 def testDefaultRangesStats(self): 453 in_tensor = array_ops.placeholder( 454 shape=[1, 16, 16, 3], dtype=dtypes.float32) 455 out_tensor = in_tensor + in_tensor 456 sess = session.Session() 457 458 # Convert model and ensure model is not None. 459 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 460 [out_tensor]) 461 converter.inference_type = lite_constants.QUANTIZED_UINT8 462 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 463 converter.default_ranges_stats = (0, 6) # min, max 464 tflite_model = converter.convert() 465 self.assertTrue(tflite_model) 466 467 # Check values from converted model. 468 interpreter = Interpreter(model_content=tflite_model) 469 interpreter.allocate_tensors() 470 471 input_details = interpreter.get_input_details() 472 self.assertEqual(1, len(input_details)) 473 self.assertEqual('Placeholder', input_details[0]['name']) 474 self.assertEqual(np.uint8, input_details[0]['dtype']) 475 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 476 self.assertEqual((1., 0.), input_details[0]['quantization']) 477 478 output_details = interpreter.get_output_details() 479 self.assertEqual(1, len(output_details)) 480 self.assertEqual('add', output_details[0]['name']) 481 self.assertEqual(np.uint8, output_details[0]['dtype']) 482 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 483 self.assertTrue(output_details[0]['quantization'][0] > 0) # scale 484 485 def testPostTrainingQuantizeDeprecatedAttribute(self): 486 in_tensor_1 = array_ops.placeholder( 487 shape=[33, 33], dtype=dtypes.float32, name='inputA') 488 in_tensor_2 = constant_op.constant( 489 np.random.uniform(low=-10., high=10., size=(33, 33)), 490 shape=[33, 33], 491 dtype=dtypes.float32, 492 name='inputB') 493 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 494 sess = session.Session() 495 496 quantized_converter = lite.TFLiteConverter.from_session( 497 sess, [in_tensor_1], [out_tensor]) 498 self.assertFalse(quantized_converter.post_training_quantize) 499 500 quantized_converter.post_training_quantize = True 501 self.assertTrue(quantized_converter.post_training_quantize) 502 self.assertEqual(quantized_converter.optimizations, 503 [lite.Optimize.OPTIMIZE_FOR_SIZE]) 504 505 quantized_tflite = quantized_converter.convert() 506 self.assertTrue(quantized_tflite) 507 508 def testPostTrainingQuantize(self): 509 np.random.seed(0) 510 # We need the tensor to have more than 1024 elements for quantize_weights 511 # to kick in. Thus, the [33, 33] shape. 512 in_tensor_1 = array_ops.placeholder( 513 shape=[33, 33], dtype=dtypes.float32, name='inputA') 514 in_tensor_2 = constant_op.constant( 515 np.random.uniform(low=-10., high=10., size=(33, 33)), 516 shape=[33, 33], 517 dtype=dtypes.float32, 518 name='inputB') 519 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 520 sess = session.Session() 521 522 # Convert float model. 523 float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1], 524 [out_tensor]) 525 float_tflite = float_converter.convert() 526 self.assertTrue(float_tflite) 527 528 # Convert quantized weights model. 529 quantized_converter = lite.TFLiteConverter.from_session( 530 sess, [in_tensor_1], [out_tensor]) 531 quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE] 532 quantized_tflite = quantized_converter.convert() 533 self.assertTrue(quantized_tflite) 534 535 # Ensure that the quantized weights tflite model is smaller. 536 self.assertTrue(len(quantized_tflite) < len(float_tflite)) 537 538 def testPostTrainingCalibrateAndQuantize(self): 539 np.random.seed(0) 540 # Create a mobilenet like model. 541 output_channel = 16 542 depth_multiplier = 1 543 inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3)) 544 conv = nn_ops.conv2d( 545 inp, 546 filter=array_ops.zeros([3, 3, 3, output_channel]), 547 strides=[1, 1, 1, 1], 548 padding='SAME') 549 dconv = nn_ops.depthwise_conv2d_native( 550 conv, 551 filter=array_ops.zeros( 552 [16, 16, output_channel, output_channel * depth_multiplier]), 553 strides=[1, 1, 1, 1], 554 padding='SAME') 555 pool = nn_ops.pool( 556 dconv, window_shape=[2, 2], pooling_type='AVG', padding='SAME') 557 max_pool = nn_ops.pool( 558 pool, window_shape=[2, 2], pooling_type='MAX', padding='SAME') 559 output = nn_ops.softmax(max_pool) 560 561 def calibration_gen(): 562 for _ in range(10): 563 yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] 564 565 sess = session.Session() 566 567 # Convert float model. 568 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 569 float_tflite = float_converter.convert() 570 self.assertTrue(float_tflite) 571 572 # Convert quantized weights model. 573 quantized_converter = lite.TFLiteConverter.from_session( 574 sess, [inp], [output]) 575 quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE] 576 quantized_converter.representative_dataset = lite.RepresentativeDataset( 577 calibration_gen) 578 quantized_tflite = quantized_converter.convert() 579 self.assertTrue(quantized_tflite) 580 581 # Ensure that the quantized weights tflite model is smaller. 582 self.assertTrue(len(quantized_tflite) < len(float_tflite)) 583 584 def testFloatTocoConverter(self): 585 """Tests deprecated test TocoConverter.""" 586 in_tensor = array_ops.placeholder( 587 shape=[1, 16, 16, 3], dtype=dtypes.float32) 588 out_tensor = in_tensor + in_tensor 589 sess = session.Session() 590 591 # Convert model and ensure model is not None. 592 converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) 593 tflite_model = converter.convert() 594 self.assertTrue(tflite_model) 595 596 # Ensure the interpreter is able to load. 597 interpreter = Interpreter(model_content=tflite_model) 598 interpreter.allocate_tensors() 599 600 def testMultipleOutputNodeNames(self): 601 """Tests converting a graph with an op that have multiple outputs.""" 602 input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) 603 out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0) 604 sess = session.Session() 605 606 # Convert model and ensure model is not None. 607 converter = lite.TFLiteConverter.from_session(sess, [input_tensor], 608 [out0, out1, out2, out3]) 609 tflite_model = converter.convert() 610 self.assertTrue(tflite_model) 611 612 # Check values from converted model. 613 interpreter = Interpreter(model_content=tflite_model) 614 interpreter.allocate_tensors() 615 616 input_details = interpreter.get_input_details() 617 self.assertEqual(1, len(input_details)) 618 interpreter.set_tensor(input_details[0]['index'], 619 np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) 620 interpreter.invoke() 621 622 output_details = interpreter.get_output_details() 623 self.assertEqual(4, len(output_details)) 624 self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index'])) 625 self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index'])) 626 self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index'])) 627 self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index'])) 628 629 630@test_util.run_v1_only('b/120545219') 631class FromFrozenGraphFile(test_util.TensorFlowTestCase): 632 633 def testFloat(self): 634 in_tensor = array_ops.placeholder( 635 shape=[1, 16, 16, 3], dtype=dtypes.float32) 636 _ = in_tensor + in_tensor 637 sess = session.Session() 638 639 # Write graph to file. 640 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 641 write_graph(sess.graph_def, '', graph_def_file, False) 642 sess.close() 643 644 # Convert model and ensure model is not None. 645 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 646 ['Placeholder'], ['add']) 647 tflite_model = converter.convert() 648 self.assertTrue(tflite_model) 649 650 # Check values from converted model. 651 interpreter = Interpreter(model_content=tflite_model) 652 interpreter.allocate_tensors() 653 654 input_details = interpreter.get_input_details() 655 self.assertEqual(1, len(input_details)) 656 self.assertEqual('Placeholder', input_details[0]['name']) 657 self.assertEqual(np.float32, input_details[0]['dtype']) 658 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 659 self.assertEqual((0., 0.), input_details[0]['quantization']) 660 661 output_details = interpreter.get_output_details() 662 self.assertEqual(1, len(output_details)) 663 self.assertEqual('add', output_details[0]['name']) 664 self.assertEqual(np.float32, output_details[0]['dtype']) 665 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 666 self.assertEqual((0., 0.), output_details[0]['quantization']) 667 668 def testFloatWithShapesArray(self): 669 in_tensor = array_ops.placeholder( 670 shape=[1, 16, 16, 3], dtype=dtypes.float32) 671 _ = in_tensor + in_tensor 672 sess = session.Session() 673 674 # Write graph to file. 675 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 676 write_graph(sess.graph_def, '', graph_def_file, False) 677 sess.close() 678 679 # Convert model and ensure model is not None. 680 converter = lite.TFLiteConverter.from_frozen_graph( 681 graph_def_file, ['Placeholder'], ['add'], 682 input_shapes={'Placeholder': [1, 16, 16, 3]}) 683 tflite_model = converter.convert() 684 self.assertTrue(tflite_model) 685 686 # Check values from converted model. 687 interpreter = Interpreter(model_content=tflite_model) 688 interpreter.allocate_tensors() 689 690 input_details = interpreter.get_input_details() 691 self.assertEqual(1, len(input_details)) 692 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 693 694 def testFreezeGraph(self): 695 in_tensor = array_ops.placeholder( 696 shape=[1, 16, 16, 3], dtype=dtypes.float32) 697 var = variable_scope.get_variable( 698 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 699 _ = in_tensor + var 700 sess = session.Session() 701 702 # Write graph to file. 703 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 704 write_graph(sess.graph_def, '', graph_def_file, False) 705 sess.close() 706 707 # Ensure the graph with variables cannot be converted. 708 with self.assertRaises(ValueError) as error: 709 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 710 ['add']) 711 self.assertEqual('Please freeze the graph using freeze_graph.py.', 712 str(error.exception)) 713 714 def testPbtxt(self): 715 in_tensor = array_ops.placeholder( 716 shape=[1, 16, 16, 3], dtype=dtypes.float32) 717 _ = in_tensor + in_tensor 718 sess = session.Session() 719 720 # Write graph to file. 721 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') 722 write_graph(sess.graph_def, '', graph_def_file, True) 723 sess.close() 724 725 # Convert model and ensure model is not None. 726 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 727 ['Placeholder'], ['add']) 728 tflite_model = converter.convert() 729 self.assertTrue(tflite_model) 730 731 # Check values from converted model. 732 interpreter = Interpreter(model_content=tflite_model) 733 interpreter.allocate_tensors() 734 735 input_details = interpreter.get_input_details() 736 self.assertEqual(1, len(input_details)) 737 self.assertEqual('Placeholder', input_details[0]['name']) 738 self.assertEqual(np.float32, input_details[0]['dtype']) 739 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 740 self.assertEqual((0., 0.), input_details[0]['quantization']) 741 742 output_details = interpreter.get_output_details() 743 self.assertEqual(1, len(output_details)) 744 self.assertEqual('add', output_details[0]['name']) 745 self.assertEqual(np.float32, output_details[0]['dtype']) 746 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 747 self.assertEqual((0., 0.), output_details[0]['quantization']) 748 749 def testInvalidFileNotFound(self): 750 with self.assertRaises(IOError) as error: 751 lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'], 752 ['add']) 753 self.assertEqual('File \'invalid_file\' does not exist.', 754 str(error.exception)) 755 756 def testInvalidFileBadData(self): 757 graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') 758 with gfile.Open(graph_def_file, 'wb') as temp_file: 759 temp_file.write('bad data') 760 temp_file.flush() 761 762 # Attempts to convert the invalid model. 763 with self.assertRaises(IOError) as error: 764 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 765 ['add']) 766 self.assertEqual( 767 'Unable to parse input file \'{}\'.'.format(graph_def_file), 768 str(error.exception)) 769 770 # TODO(nupurgarg): Test model loading in open source. 771 def _initObjectDetectionArgs(self): 772 # Initializes the arguments required for the object detection model. 773 # Looks for the model file which is saved in a different location internally 774 # and externally. 775 filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') 776 if not os.path.exists(filename): 777 filename = os.path.join( 778 resource_loader.get_root_dir_with_all_resources(), 779 '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') 780 if not os.path.exists(filename): 781 raise IOError("File '{0}' does not exist.".format(filename)) 782 783 self._graph_def_file = filename 784 self._input_arrays = ['normalized_input_image_tensor'] 785 self._output_arrays = [ 786 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 787 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3' 788 ] 789 self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]} 790 791 def testTFLiteGraphDef(self): 792 # Tests the object detection model that cannot be loaded in TensorFlow. 793 self._initObjectDetectionArgs() 794 795 converter = lite.TFLiteConverter.from_frozen_graph( 796 self._graph_def_file, self._input_arrays, self._output_arrays, 797 self._input_shapes) 798 converter.allow_custom_ops = True 799 tflite_model = converter.convert() 800 self.assertTrue(tflite_model) 801 802 # Check values from converted model. 803 interpreter = Interpreter(model_content=tflite_model) 804 interpreter.allocate_tensors() 805 806 input_details = interpreter.get_input_details() 807 self.assertEqual(1, len(input_details)) 808 self.assertEqual('normalized_input_image_tensor', input_details[0]['name']) 809 self.assertEqual(np.float32, input_details[0]['dtype']) 810 self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all()) 811 self.assertEqual((0., 0.), input_details[0]['quantization']) 812 813 output_details = interpreter.get_output_details() 814 self.assertEqual(4, len(output_details)) 815 self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name']) 816 self.assertEqual(np.float32, output_details[0]['dtype']) 817 self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all()) 818 self.assertEqual((0., 0.), output_details[0]['quantization']) 819 820 self.assertEqual('TFLite_Detection_PostProcess:1', 821 output_details[1]['name']) 822 self.assertTrue(([1, 10] == output_details[1]['shape']).all()) 823 self.assertEqual('TFLite_Detection_PostProcess:2', 824 output_details[2]['name']) 825 self.assertTrue(([1, 10] == output_details[2]['shape']).all()) 826 self.assertEqual('TFLite_Detection_PostProcess:3', 827 output_details[3]['name']) 828 self.assertTrue(([1] == output_details[3]['shape']).all()) 829 830 def testTFLiteGraphDefMissingShape(self): 831 # Tests invalid cases for the model that cannot be loaded in TensorFlow. 832 self._initObjectDetectionArgs() 833 834 # Missing `input_shapes`. 835 with self.assertRaises(ValueError) as error: 836 lite.TFLiteConverter.from_frozen_graph( 837 self._graph_def_file, self._input_arrays, self._output_arrays) 838 self.assertEqual('input_shapes must be defined for this model.', 839 str(error.exception)) 840 841 def testTFLiteGraphDefInvalidShape(self): 842 # Tests invalid cases for the model that cannot be loaded in TensorFlow. 843 self._initObjectDetectionArgs() 844 845 # `input_shapes` does not contain the names in `input_arrays`. 846 with self.assertRaises(ValueError) as error: 847 lite.TFLiteConverter.from_frozen_graph( 848 self._graph_def_file, 849 self._input_arrays, 850 self._output_arrays, 851 input_shapes={'invalid-value': [1, 19]}) 852 self.assertEqual( 853 'input_shapes must contain a value for each item in input_array.', 854 str(error.exception)) 855 856 def testFloatTocoConverter(self): 857 in_tensor = array_ops.placeholder( 858 shape=[1, 16, 16, 3], dtype=dtypes.float32) 859 _ = in_tensor + in_tensor 860 sess = session.Session() 861 862 # Write graph to file. 863 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 864 write_graph(sess.graph_def, '', graph_def_file, False) 865 sess.close() 866 867 # Convert model and ensure model is not None. 868 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 869 ['Placeholder'], ['add']) 870 tflite_model = converter.convert() 871 self.assertTrue(tflite_model) 872 873 # Ensure the model is able to load. 874 interpreter = Interpreter(model_content=tflite_model) 875 interpreter.allocate_tensors() 876 877 878@test_util.run_v1_only('b/120545219') 879class FromSavedModelTest(test_util.TensorFlowTestCase): 880 881 def _createSavedModel(self, shape): 882 """Create a simple SavedModel.""" 883 saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') 884 with session.Session() as sess: 885 in_tensor_1 = array_ops.placeholder( 886 shape=shape, dtype=dtypes.float32, name='inputB') 887 in_tensor_2 = array_ops.placeholder( 888 shape=shape, dtype=dtypes.float32, name='inputA') 889 out_tensor = in_tensor_1 + in_tensor_2 890 inputs = {'x': in_tensor_1, 'y': in_tensor_2} 891 outputs = {'z': out_tensor} 892 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 893 return saved_model_dir 894 895 def testSimpleModel(self): 896 """Test a SavedModel.""" 897 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 898 899 # Convert model and ensure model is not None. 900 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 901 tflite_model = converter.convert() 902 self.assertTrue(tflite_model) 903 904 interpreter = Interpreter(model_content=tflite_model) 905 interpreter.allocate_tensors() 906 907 input_details = interpreter.get_input_details() 908 self.assertEqual(2, len(input_details)) 909 self.assertEqual('inputA', input_details[0]['name']) 910 self.assertEqual(np.float32, input_details[0]['dtype']) 911 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 912 self.assertEqual((0., 0.), input_details[0]['quantization']) 913 914 self.assertEqual('inputB', input_details[1]['name']) 915 self.assertEqual(np.float32, input_details[1]['dtype']) 916 self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) 917 self.assertEqual((0., 0.), input_details[1]['quantization']) 918 919 output_details = interpreter.get_output_details() 920 self.assertEqual(1, len(output_details)) 921 self.assertEqual('add', output_details[0]['name']) 922 self.assertEqual(np.float32, output_details[0]['dtype']) 923 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 924 self.assertEqual((0., 0.), output_details[0]['quantization']) 925 926 def testNoneBatchSize(self): 927 """Test a SavedModel, with None in input tensor's shape.""" 928 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 929 930 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 931 tflite_model = converter.convert() 932 self.assertTrue(tflite_model) 933 934 # Check values from converted model. 935 interpreter = Interpreter(model_content=tflite_model) 936 interpreter.allocate_tensors() 937 938 input_details = interpreter.get_input_details() 939 self.assertEqual(2, len(input_details)) 940 self.assertEqual('inputA', input_details[0]['name']) 941 self.assertEqual(np.float32, input_details[0]['dtype']) 942 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 943 self.assertEqual((0., 0.), input_details[0]['quantization']) 944 945 self.assertEqual('inputB', input_details[1]['name']) 946 self.assertEqual(np.float32, input_details[1]['dtype']) 947 self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) 948 self.assertEqual((0., 0.), input_details[1]['quantization']) 949 950 output_details = interpreter.get_output_details() 951 self.assertEqual(1, len(output_details)) 952 self.assertEqual('add', output_details[0]['name']) 953 self.assertEqual(np.float32, output_details[0]['dtype']) 954 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 955 self.assertEqual((0., 0.), output_details[0]['quantization']) 956 957 def testOrderInputArrays(self): 958 """Test a SavedModel ordering of input arrays.""" 959 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 960 961 converter = lite.TFLiteConverter.from_saved_model( 962 saved_model_dir, input_arrays=['inputB', 'inputA']) 963 tflite_model = converter.convert() 964 self.assertTrue(tflite_model) 965 966 # Check values from converted model. 967 interpreter = Interpreter(model_content=tflite_model) 968 interpreter.allocate_tensors() 969 970 input_details = interpreter.get_input_details() 971 self.assertEqual(2, len(input_details)) 972 self.assertEqual('inputA', input_details[0]['name']) 973 self.assertEqual(np.float32, input_details[0]['dtype']) 974 self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all()) 975 self.assertEqual((0., 0.), input_details[0]['quantization']) 976 977 self.assertEqual('inputB', input_details[1]['name']) 978 self.assertEqual(np.float32, input_details[1]['dtype']) 979 self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all()) 980 self.assertEqual((0., 0.), input_details[1]['quantization']) 981 982 output_details = interpreter.get_output_details() 983 self.assertEqual(1, len(output_details)) 984 self.assertEqual('add', output_details[0]['name']) 985 self.assertEqual(np.float32, output_details[0]['dtype']) 986 self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all()) 987 self.assertEqual((0., 0.), output_details[0]['quantization']) 988 989 def testSubsetInputArrays(self): 990 """Test a SavedModel with a subset of the input array names of the model.""" 991 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 992 993 # Check case where input shape is given. 994 converter = lite.TFLiteConverter.from_saved_model( 995 saved_model_dir, 996 input_arrays=['inputA'], 997 input_shapes={'inputA': [1, 16, 16, 3]}) 998 999 tflite_model = converter.convert() 1000 self.assertTrue(tflite_model) 1001 1002 # Check case where input shape is None. 1003 converter = lite.TFLiteConverter.from_saved_model( 1004 saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) 1005 1006 tflite_model = converter.convert() 1007 self.assertTrue(tflite_model) 1008 1009 def testSimpleModelTocoConverter(self): 1010 """Test a SavedModel with deprecated TocoConverter.""" 1011 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 1012 1013 # Convert model and ensure model is not None. 1014 converter = lite.TocoConverter.from_saved_model(saved_model_dir) 1015 tflite_model = converter.convert() 1016 self.assertTrue(tflite_model) 1017 1018 # Ensure the model is able to load. 1019 interpreter = Interpreter(model_content=tflite_model) 1020 interpreter.allocate_tensors() 1021 1022 1023@test_util.run_v1_only('b/120545219') 1024class FromKerasFile(test_util.TensorFlowTestCase): 1025 1026 def setUp(self): 1027 keras.backend.clear_session() 1028 1029 def _getSequentialModel(self): 1030 with session.Session().as_default(): 1031 model = keras.models.Sequential() 1032 model.add(keras.layers.Dense(2, input_shape=(3,))) 1033 model.add(keras.layers.RepeatVector(3)) 1034 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 1035 model.compile( 1036 loss=keras.losses.MSE, 1037 optimizer=keras.optimizers.RMSprop(), 1038 metrics=[keras.metrics.categorical_accuracy], 1039 sample_weight_mode='temporal') 1040 x = np.random.random((1, 3)) 1041 y = np.random.random((1, 3, 3)) 1042 model.train_on_batch(x, y) 1043 model.predict(x) 1044 1045 try: 1046 fd, keras_file = tempfile.mkstemp('.h5') 1047 keras.models.save_model(model, keras_file) 1048 finally: 1049 os.close(fd) 1050 return keras_file 1051 1052 def testSequentialModel(self): 1053 """Test a Sequential tf.keras model with default inputs.""" 1054 keras_file = self._getSequentialModel() 1055 1056 converter = lite.TFLiteConverter.from_keras_model_file(keras_file) 1057 tflite_model = converter.convert() 1058 self.assertTrue(tflite_model) 1059 1060 # Check tensor details of converted model. 1061 interpreter = Interpreter(model_content=tflite_model) 1062 interpreter.allocate_tensors() 1063 1064 input_details = interpreter.get_input_details() 1065 self.assertEqual(1, len(input_details)) 1066 self.assertEqual('dense_input', input_details[0]['name']) 1067 self.assertEqual(np.float32, input_details[0]['dtype']) 1068 self.assertTrue(([1, 3] == input_details[0]['shape']).all()) 1069 self.assertEqual((0., 0.), input_details[0]['quantization']) 1070 1071 output_details = interpreter.get_output_details() 1072 self.assertEqual(1, len(output_details)) 1073 self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) 1074 self.assertEqual(np.float32, output_details[0]['dtype']) 1075 self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) 1076 self.assertEqual((0., 0.), output_details[0]['quantization']) 1077 1078 # Check inference of converted model. 1079 input_data = np.array([[1, 2, 3]], dtype=np.float32) 1080 interpreter.set_tensor(input_details[0]['index'], input_data) 1081 interpreter.invoke() 1082 tflite_result = interpreter.get_tensor(output_details[0]['index']) 1083 1084 keras_model = keras.models.load_model(keras_file) 1085 keras_result = keras_model.predict(input_data) 1086 1087 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 1088 os.remove(keras_file) 1089 1090 def testSequentialModelInputArray(self): 1091 """Test a Sequential tf.keras model testing input arrays argument.""" 1092 keras_file = self._getSequentialModel() 1093 1094 # Invalid input array raises error. 1095 with self.assertRaises(ValueError) as error: 1096 lite.TFLiteConverter.from_keras_model_file( 1097 keras_file, input_arrays=['invalid-input']) 1098 self.assertEqual("Invalid tensors 'invalid-input' were found.", 1099 str(error.exception)) 1100 1101 # Valid input array. 1102 converter = lite.TFLiteConverter.from_keras_model_file( 1103 keras_file, input_arrays=['dense_input']) 1104 tflite_model = converter.convert() 1105 os.remove(keras_file) 1106 self.assertTrue(tflite_model) 1107 1108 def testSequentialModelInputShape(self): 1109 """Test a Sequential tf.keras model testing input shapes argument.""" 1110 keras_file = self._getSequentialModel() 1111 1112 # Passing in shape of invalid input array raises error. 1113 with self.assertRaises(ValueError) as error: 1114 converter = lite.TFLiteConverter.from_keras_model_file( 1115 keras_file, input_shapes={'invalid-input': [2, 3]}) 1116 self.assertEqual( 1117 "Invalid tensor 'invalid-input' found in tensor shapes map.", 1118 str(error.exception)) 1119 1120 # Passing in shape of valid input array. 1121 converter = lite.TFLiteConverter.from_keras_model_file( 1122 keras_file, input_shapes={'dense_input': [2, 3]}) 1123 tflite_model = converter.convert() 1124 os.remove(keras_file) 1125 self.assertTrue(tflite_model) 1126 1127 # Check input shape from converted model. 1128 interpreter = Interpreter(model_content=tflite_model) 1129 interpreter.allocate_tensors() 1130 1131 input_details = interpreter.get_input_details() 1132 self.assertEqual(1, len(input_details)) 1133 self.assertEqual('dense_input', input_details[0]['name']) 1134 self.assertTrue(([2, 3] == input_details[0]['shape']).all()) 1135 1136 def testSequentialModelOutputArray(self): 1137 """Test a Sequential tf.keras model testing output arrays argument.""" 1138 keras_file = self._getSequentialModel() 1139 1140 # Invalid output array raises error. 1141 with self.assertRaises(ValueError) as error: 1142 lite.TFLiteConverter.from_keras_model_file( 1143 keras_file, output_arrays=['invalid-output']) 1144 self.assertEqual("Invalid tensors 'invalid-output' were found.", 1145 str(error.exception)) 1146 1147 # Valid output array. 1148 converter = lite.TFLiteConverter.from_keras_model_file( 1149 keras_file, output_arrays=['time_distributed/Reshape_1']) 1150 tflite_model = converter.convert() 1151 os.remove(keras_file) 1152 self.assertTrue(tflite_model) 1153 1154 def testFunctionalModel(self): 1155 """Test a Functional tf.keras model with default inputs.""" 1156 with session.Session().as_default(): 1157 inputs = keras.layers.Input(shape=(3,), name='input') 1158 x = keras.layers.Dense(2)(inputs) 1159 output = keras.layers.Dense(3)(x) 1160 1161 model = keras.models.Model(inputs, output) 1162 model.compile( 1163 loss=keras.losses.MSE, 1164 optimizer=keras.optimizers.RMSprop(), 1165 metrics=[keras.metrics.categorical_accuracy]) 1166 x = np.random.random((1, 3)) 1167 y = np.random.random((1, 3)) 1168 model.train_on_batch(x, y) 1169 1170 model.predict(x) 1171 fd, keras_file = tempfile.mkstemp('.h5') 1172 try: 1173 keras.models.save_model(model, keras_file) 1174 finally: 1175 os.close(fd) 1176 1177 # Convert to TFLite model. 1178 converter = lite.TFLiteConverter.from_keras_model_file(keras_file) 1179 tflite_model = converter.convert() 1180 self.assertTrue(tflite_model) 1181 1182 # Check tensor details of converted model. 1183 interpreter = Interpreter(model_content=tflite_model) 1184 interpreter.allocate_tensors() 1185 1186 input_details = interpreter.get_input_details() 1187 self.assertEqual(1, len(input_details)) 1188 self.assertEqual('input', input_details[0]['name']) 1189 self.assertEqual(np.float32, input_details[0]['dtype']) 1190 self.assertTrue(([1, 3] == input_details[0]['shape']).all()) 1191 self.assertEqual((0., 0.), input_details[0]['quantization']) 1192 1193 output_details = interpreter.get_output_details() 1194 self.assertEqual(1, len(output_details)) 1195 self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) 1196 self.assertEqual(np.float32, output_details[0]['dtype']) 1197 self.assertTrue(([1, 3] == output_details[0]['shape']).all()) 1198 self.assertEqual((0., 0.), output_details[0]['quantization']) 1199 1200 # Check inference of converted model. 1201 input_data = np.array([[1, 2, 3]], dtype=np.float32) 1202 interpreter.set_tensor(input_details[0]['index'], input_data) 1203 interpreter.invoke() 1204 tflite_result = interpreter.get_tensor(output_details[0]['index']) 1205 1206 keras_model = keras.models.load_model(keras_file) 1207 keras_result = keras_model.predict(input_data) 1208 1209 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 1210 os.remove(keras_file) 1211 1212 def testFunctionalModelMultipleInputs(self): 1213 """Test a Functional tf.keras model with multiple inputs and outputs.""" 1214 with session.Session().as_default(): 1215 a = keras.layers.Input(shape=(3,), name='input_a') 1216 b = keras.layers.Input(shape=(3,), name='input_b') 1217 dense = keras.layers.Dense(4, name='dense') 1218 c = dense(a) 1219 d = dense(b) 1220 e = keras.layers.Dropout(0.5, name='dropout')(c) 1221 1222 model = keras.models.Model([a, b], [d, e]) 1223 model.compile( 1224 loss=keras.losses.MSE, 1225 optimizer=keras.optimizers.RMSprop(), 1226 metrics=[keras.metrics.mae], 1227 loss_weights=[1., 0.5]) 1228 1229 input_a_np = np.random.random((10, 3)) 1230 input_b_np = np.random.random((10, 3)) 1231 output_d_np = np.random.random((10, 4)) 1232 output_e_np = np.random.random((10, 4)) 1233 model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) 1234 1235 model.predict([input_a_np, input_b_np], batch_size=5) 1236 fd, keras_file = tempfile.mkstemp('.h5') 1237 try: 1238 keras.models.save_model(model, keras_file) 1239 finally: 1240 os.close(fd) 1241 1242 # Convert to TFLite model. 1243 converter = lite.TFLiteConverter.from_keras_model_file(keras_file) 1244 tflite_model = converter.convert() 1245 self.assertTrue(tflite_model) 1246 1247 os.remove(keras_file) 1248 1249 # Check values from converted model. 1250 interpreter = Interpreter(model_content=tflite_model) 1251 interpreter.allocate_tensors() 1252 1253 input_details = interpreter.get_input_details() 1254 self.assertEqual(2, len(input_details)) 1255 self.assertEqual('input_a', input_details[0]['name']) 1256 self.assertEqual(np.float32, input_details[0]['dtype']) 1257 self.assertTrue(([1, 3] == input_details[0]['shape']).all()) 1258 self.assertEqual((0., 0.), input_details[0]['quantization']) 1259 1260 self.assertEqual('input_b', input_details[1]['name']) 1261 self.assertEqual(np.float32, input_details[1]['dtype']) 1262 self.assertTrue(([1, 3] == input_details[1]['shape']).all()) 1263 self.assertEqual((0., 0.), input_details[1]['quantization']) 1264 1265 output_details = interpreter.get_output_details() 1266 self.assertEqual(2, len(output_details)) 1267 self.assertEqual('dense_1/BiasAdd', output_details[0]['name']) 1268 self.assertEqual(np.float32, output_details[0]['dtype']) 1269 self.assertTrue(([1, 4] == output_details[0]['shape']).all()) 1270 self.assertEqual((0., 0.), output_details[0]['quantization']) 1271 1272 self.assertEqual('dropout/Identity', output_details[1]['name']) 1273 self.assertEqual(np.float32, output_details[1]['dtype']) 1274 self.assertTrue(([1, 4] == output_details[1]['shape']).all()) 1275 self.assertEqual((0., 0.), output_details[1]['quantization']) 1276 1277 def testFunctionalSequentialModel(self): 1278 """Test a Functional tf.keras model containing a Sequential model.""" 1279 with session.Session().as_default(): 1280 model = keras.models.Sequential() 1281 model.add(keras.layers.Dense(2, input_shape=(3,))) 1282 model.add(keras.layers.RepeatVector(3)) 1283 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 1284 model = keras.models.Model(model.input, model.output) 1285 1286 model.compile( 1287 loss=keras.losses.MSE, 1288 optimizer=keras.optimizers.RMSprop(), 1289 metrics=[keras.metrics.categorical_accuracy], 1290 sample_weight_mode='temporal') 1291 x = np.random.random((1, 3)) 1292 y = np.random.random((1, 3, 3)) 1293 model.train_on_batch(x, y) 1294 model.predict(x) 1295 1296 model.predict(x) 1297 fd, keras_file = tempfile.mkstemp('.h5') 1298 try: 1299 keras.models.save_model(model, keras_file) 1300 finally: 1301 os.close(fd) 1302 1303 # Convert to TFLite model. 1304 converter = lite.TFLiteConverter.from_keras_model_file(keras_file) 1305 tflite_model = converter.convert() 1306 self.assertTrue(tflite_model) 1307 1308 # Check tensor details of converted model. 1309 interpreter = Interpreter(model_content=tflite_model) 1310 interpreter.allocate_tensors() 1311 1312 input_details = interpreter.get_input_details() 1313 self.assertEqual(1, len(input_details)) 1314 self.assertEqual('dense_input', input_details[0]['name']) 1315 self.assertEqual(np.float32, input_details[0]['dtype']) 1316 self.assertTrue(([1, 3] == input_details[0]['shape']).all()) 1317 self.assertEqual((0., 0.), input_details[0]['quantization']) 1318 1319 output_details = interpreter.get_output_details() 1320 self.assertEqual(1, len(output_details)) 1321 self.assertEqual('time_distributed/Reshape_1', output_details[0]['name']) 1322 self.assertEqual(np.float32, output_details[0]['dtype']) 1323 self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all()) 1324 self.assertEqual((0., 0.), output_details[0]['quantization']) 1325 1326 # Check inference of converted model. 1327 input_data = np.array([[1, 2, 3]], dtype=np.float32) 1328 interpreter.set_tensor(input_details[0]['index'], input_data) 1329 interpreter.invoke() 1330 tflite_result = interpreter.get_tensor(output_details[0]['index']) 1331 1332 keras_model = keras.models.load_model(keras_file) 1333 keras_result = keras_model.predict(input_data) 1334 1335 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 1336 os.remove(keras_file) 1337 1338 def testSequentialModelTocoConverter(self): 1339 """Test a Sequential tf.keras model with deprecated TocoConverter.""" 1340 keras_file = self._getSequentialModel() 1341 1342 converter = lite.TocoConverter.from_keras_model_file(keras_file) 1343 tflite_model = converter.convert() 1344 self.assertTrue(tflite_model) 1345 1346 # Ensure the model is able to load. 1347 interpreter = Interpreter(model_content=tflite_model) 1348 interpreter.allocate_tensors() 1349 1350 1351if __name__ == '__main__': 1352 test.main() 1353