1# Lint as: python2, python3 2# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for lite.py.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import io 23import logging 24import os 25import tempfile 26 27from absl.testing import parameterized 28import numpy as np 29import six 30from six.moves import range 31from tensorflow import keras 32 33from tensorflow.lite.python import lite 34from tensorflow.lite.python import lite_constants 35from tensorflow.lite.python.convert import ConverterError 36from tensorflow.lite.python.convert import mlir_quantize 37from tensorflow.lite.python.interpreter import Interpreter 38from tensorflow.python.client import session 39from tensorflow.python.eager import context 40from tensorflow.python.eager import def_function 41from tensorflow.python.framework import constant_op 42from tensorflow.python.framework import convert_to_constants 43from tensorflow.python.framework import dtypes 44from tensorflow.python.framework import ops 45from tensorflow.python.framework import test_util 46from tensorflow.python.ops import array_ops 47from tensorflow.python.ops import gen_array_ops 48from tensorflow.python.ops import math_ops 49from tensorflow.python.ops import nn_ops 50from tensorflow.python.ops import random_ops 51from tensorflow.python.ops import variable_scope 52from tensorflow.python.ops import variables 53from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer 54from tensorflow.python.platform import gfile 55from tensorflow.python.platform import resource_loader 56from tensorflow.python.platform import test 57from tensorflow.python.saved_model import saved_model 58from tensorflow.python.training.training_util import write_graph 59 60 61class LiteTest(test_util.TensorFlowTestCase): 62 """Base class of all the tests in this module.""" 63 64 65class TestModels(LiteTest): 66 67 def assertValidDebugInfo(self, debug_info): 68 """Verify the DebugInfo is valid.""" 69 file_names = set() 70 for file_path in debug_info.files: 71 file_names.add(os.path.basename(file_path)) 72 # To make the test independent on how the nodes are created, we only assert 73 # the name of this test file. 74 self.assertIn('lite_test.py', file_names) 75 self.assertNotIn('lite_v2_test.py', file_names) 76 77 78class FromConstructor(TestModels): 79 80 # Tests invalid constructors using a dummy value for the GraphDef. 81 def testInvalidConstructor(self): 82 message = ('If input_tensors and output_tensors are None, both ' 83 'input_arrays_with_shape and output_arrays must be defined.') 84 85 # `output_arrays` is not defined. 86 with self.assertRaises(ValueError) as error: 87 lite.TFLiteConverter( 88 None, None, [], input_arrays_with_shape=[('input', [3, 9])]) 89 self.assertEqual(message, str(error.exception)) 90 91 # `input_arrays_with_shape` is not defined. 92 with self.assertRaises(ValueError) as error: 93 lite.TFLiteConverter(None, [], None, output_arrays=['output']) 94 self.assertEqual(message, str(error.exception)) 95 96 # Tests valid constructors using a dummy value for the GraphDef. 97 def testValidConstructor(self): 98 converter = lite.TFLiteConverter( 99 None, 100 None, 101 None, 102 input_arrays_with_shape=[('input', [3, 9])], 103 output_arrays=['output']) 104 self.assertFalse(converter._has_valid_tensors()) 105 self.assertEqual(converter.get_input_arrays(), ['input']) 106 107 with self.assertRaises(ValueError) as error: 108 converter._set_batch_size(1) 109 self.assertEqual( 110 'The batch size cannot be set for this model. Please use ' 111 'input_shapes parameter.', str(error.exception)) 112 113 converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor']) 114 self.assertTrue(converter._has_valid_tensors()) 115 116 def testRedundantArgumentsWarning(self): 117 """Test if the warning message when there are redundant arguments.""" 118 with ops.Graph().as_default(): 119 in_tensor = array_ops.placeholder( 120 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 121 out_tensor = math_ops.add(in_tensor, in_tensor, name='add') 122 sess = session.Session() 123 124 frozen_graph_def = ( 125 convert_to_constants.convert_variables_to_constants_from_session_graph( 126 sess, sess.graph_def, ['add'])) 127 128 # Convert model and ensure model is not None. 129 log = io.BytesIO() if six.PY2 else io.StringIO() 130 handler = logging.StreamHandler(log) 131 logging.root.addHandler(handler) 132 converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor], 133 [out_tensor], 134 [('in_tensor', [2, 16, 16, 3])], ['add']) 135 136 input_warning_message = 'input_arrays_with_shape will be ignored' 137 output_warning_message = 'output_arrays will be ignored' 138 139 # Convert model and ensure model is not None. 140 tflite_model = converter.convert() 141 self.assertIsNotNone(tflite_model) 142 self.assertIn(input_warning_message, log.getvalue()) 143 self.assertIn(output_warning_message, log.getvalue()) 144 logging.root.removeHandler(handler) 145 146 def testShapeOverriding(self): 147 """Test a shape overriding case via the constructor.""" 148 with ops.Graph().as_default(): 149 in_tensor = array_ops.placeholder( 150 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 151 math_ops.add(in_tensor, in_tensor, name='add') 152 sess = session.Session() 153 154 frozen_graph_def = ( 155 convert_to_constants.convert_variables_to_constants_from_session_graph( 156 sess, sess.graph_def, ['add'])) 157 158 # Convert model and ensure model is not None. 159 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 160 [('in_tensor', [2, 16, 16, 3])], ['add']) 161 tflite_model = converter.convert() 162 self.assertIsNotNone(tflite_model) 163 164 # Check values from converted model. 165 interpreter = Interpreter(model_content=tflite_model) 166 interpreter.allocate_tensors() 167 168 input_details = interpreter.get_input_details() 169 self.assertLen(input_details, 1) 170 self.assertEqual('in_tensor', input_details[0]['name']) 171 self.assertEqual(np.float32, input_details[0]['dtype']) 172 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 173 self.assertEqual((0., 0.), input_details[0]['quantization']) 174 175 output_details = interpreter.get_output_details() 176 self.assertLen(output_details, 1) 177 self.assertEqual('add', output_details[0]['name']) 178 self.assertEqual(np.float32, output_details[0]['dtype']) 179 self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape']) 180 self.assertEqual((0., 0.), output_details[0]['quantization']) 181 182 def testPartialShapeOverriding(self): 183 """Test a partial shape overriding case via the constructor.""" 184 with ops.Graph().as_default(): 185 in_tensor_a = array_ops.placeholder( 186 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a') 187 in_tensor_b = array_ops.placeholder( 188 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b') 189 math_ops.add(in_tensor_a, in_tensor_b, name='add') 190 sess = session.Session() 191 192 frozen_graph_def = ( 193 convert_to_constants.convert_variables_to_constants_from_session_graph( 194 sess, sess.graph_def, ['add'])) 195 196 # Convert model and ensure model is not None. 197 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 198 [('in_tensor_a', [2, 16, 16, 3])], ['add']) 199 # There is an unhandled Placeholder op. 200 with self.assertRaises(ConverterError): 201 converter.convert() 202 203 def testInvalidShapeOverriding(self): 204 """Test an invalid shape overriding case via the constructor.""" 205 with ops.Graph().as_default(): 206 in_tensor = array_ops.placeholder( 207 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor') 208 math_ops.add(in_tensor, in_tensor, name='add') 209 sess = session.Session() 210 211 frozen_graph_def = ( 212 convert_to_constants.convert_variables_to_constants_from_session_graph( 213 sess, sess.graph_def, ['add'])) 214 215 # Convert model and ensure model is not None. 216 converter = lite.TFLiteConverter(frozen_graph_def, None, None, 217 [('wrong_tensor', [2, 16, 16, 3])], 218 ['add']) 219 with self.assertRaises(ConverterError): 220 converter.convert() 221 222 223class FromSessionTest(TestModels, parameterized.TestCase): 224 225 def testFloatModel(self): 226 with ops.Graph().as_default(): 227 in_tensor = array_ops.placeholder( 228 shape=[1, 16, 16, 3], dtype=dtypes.float32) 229 out_tensor = in_tensor + in_tensor 230 sess = session.Session() 231 232 # Convert model and ensure model is not None. 233 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 234 [out_tensor]) 235 tflite_model = converter.convert() 236 self.assertIsNotNone(tflite_model) 237 238 # Check values from converted model. 239 interpreter = Interpreter(model_content=tflite_model) 240 interpreter.allocate_tensors() 241 242 input_details = interpreter.get_input_details() 243 self.assertLen(input_details, 1) 244 self.assertEqual('Placeholder', input_details[0]['name']) 245 self.assertEqual(np.float32, input_details[0]['dtype']) 246 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 247 self.assertEqual((0., 0.), input_details[0]['quantization']) 248 249 output_details = interpreter.get_output_details() 250 self.assertLen(output_details, 1) 251 self.assertEqual('add', output_details[0]['name']) 252 self.assertEqual(np.float32, output_details[0]['dtype']) 253 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 254 self.assertEqual((0., 0.), output_details[0]['quantization']) 255 256 def testFloatModelQuantizedInput(self): 257 with ops.Graph().as_default(): 258 in_tensor = array_ops.placeholder( 259 shape=[1, 16, 16, 3], dtype=dtypes.float32) 260 out_tensor = in_tensor + in_tensor 261 sess = session.Session() 262 263 # Convert model and ensure model is not None. 264 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 265 [out_tensor]) 266 converter.inference_input_type = dtypes.uint8 267 converter.inference_type = dtypes.float32 268 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 269 tflite_model = converter.convert() 270 self.assertIsNotNone(tflite_model) 271 272 # Check values from converted model. 273 interpreter = Interpreter(model_content=tflite_model) 274 interpreter.allocate_tensors() 275 276 input_details = interpreter.get_input_details() 277 self.assertLen(input_details, 1) 278 self.assertEqual('Placeholder', input_details[0]['name']) 279 self.assertEqual(np.uint8, input_details[0]['dtype']) 280 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 281 self.assertEqual((1., 0.), input_details[0]['quantization']) 282 283 output_details = interpreter.get_output_details() 284 self.assertLen(output_details, 1) 285 self.assertEqual('add', output_details[0]['name']) 286 self.assertEqual(np.float32, output_details[0]['dtype']) 287 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 288 self.assertEqual((0., 0.), output_details[0]['quantization']) # float 289 290 def testForgottenCallToAllocateTensors(self): 291 with ops.Graph().as_default(): 292 in_tensor = array_ops.placeholder( 293 shape=[1, 16, 16, 3], dtype=dtypes.float32) 294 out_tensor = in_tensor + in_tensor 295 sess = session.Session() 296 # Convert model and ensure model is not None. 297 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 298 [out_tensor]) 299 tflite_model = converter.convert() 300 self.assertIsNotNone(tflite_model) 301 302 # Check values from converted model. 303 interpreter = Interpreter(model_content=tflite_model) 304 input_index = interpreter.get_input_details()[0]['index'] 305 dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32) 306 with self.assertRaises(ValueError): 307 interpreter.set_tensor(input_index, dummy_tensor) 308 309 @parameterized.named_parameters( 310 ('_INT8InputOutput', False, False, dtypes.int8), 311 ('_UINT8InputOutput', False, False, dtypes.uint8), 312 ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16), 313 ('_IntOnly_INT8InputOutput', True, False, dtypes.int8), 314 ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8), 315 ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16), 316 ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True), 317 ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True)) 318 def testIntegerQuantizationWithUnsupportedOps(self, 319 is_int_only, 320 is_int16_quantize, 321 inference_input_output_type, 322 enable_mlir_quantizer=False): 323 with ops.Graph().as_default(): 324 in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 325 in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 326 # ceil kernel does not support int8 nor int16 types neither. 327 left = math_ops.ceil(in_tensor_a) 328 out_tensor_b = math_ops.tanh(in_tensor_b) 329 add = math_ops.add(left, out_tensor_b) 330 # ceil kernel does not support int8 nor int16 types neither. 331 out_tensor_a = math_ops.ceil(add) 332 sess = session.Session() 333 334 def calibration_gen(): 335 for _ in range(5): 336 yield [ 337 np.random.uniform(-1, 1, size=(3)).astype(np.float32), 338 np.random.uniform(-1, 1, size=(3)).astype(np.float32) 339 ] 340 341 quantized_converter = lite.TFLiteConverter.from_session( 342 sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b]) 343 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 344 quantized_converter.representative_dataset = calibration_gen 345 if is_int_only: 346 if is_int16_quantize: 347 quantized_converter.target_spec.supported_ops = [ 348 lite.OpsSet.\ 349 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, 350 lite.OpsSet.TFLITE_BUILTINS 351 ] 352 else: 353 quantized_converter.target_spec.supported_ops = [ 354 lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS 355 ] 356 else: 357 if is_int16_quantize: 358 quantized_converter.target_spec.supported_ops = [ 359 lite.OpsSet.\ 360 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8, 361 lite.OpsSet.TFLITE_BUILTINS 362 ] 363 else: 364 quantized_converter.target_spec.supported_ops = [ 365 lite.OpsSet.TFLITE_BUILTINS 366 ] 367 368 quantized_converter.inference_input_type = inference_input_output_type 369 quantized_converter.inference_output_type = inference_input_output_type 370 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 371 quantized_tflite_model = quantized_converter.convert() 372 self.assertIsNotNone(quantized_tflite_model) 373 374 expected_dtype = inference_input_output_type.as_numpy_dtype 375 # Allow float32 for fallback on non-quantizable op. 376 expected_ceil_dtype = ( 377 expected_dtype if enable_mlir_quantizer else dtypes.float32) 378 379 interpreter = Interpreter(model_content=quantized_tflite_model) 380 interpreter.allocate_tensors() 381 input_details = interpreter.get_input_details() 382 self.assertLen(input_details, 2) 383 self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype) 384 self.assertEqual(input_details[1]['dtype'], expected_dtype) 385 output_details = interpreter.get_output_details() 386 self.assertLen(output_details, 2) 387 self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype) 388 self.assertEqual(output_details[1]['dtype'], expected_dtype) 389 390 @parameterized.named_parameters( 391 ('EnableMlirConverter', True), # enable mlir 392 ('DisableMlirConverter', False)) # disable mlir 393 def testString(self, enable_mlir_converter): 394 with ops.Graph().as_default(): 395 in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string) 396 out_tensor = array_ops.reshape(in_tensor, shape=[2, 2]) 397 sess = session.Session() 398 399 # Convert model and ensure model is not None. 400 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 401 [out_tensor]) 402 converter.experimental_new_converter = enable_mlir_converter 403 tflite_model = converter.convert() 404 self.assertIsNotNone(tflite_model) 405 406 # Check values from converted model. 407 interpreter = Interpreter(model_content=tflite_model) 408 interpreter.allocate_tensors() 409 410 input_details = interpreter.get_input_details() 411 self.assertLen(input_details, 1) 412 self.assertEqual('Placeholder', input_details[0]['name']) 413 self.assertEqual(np.string_, input_details[0]['dtype']) 414 self.assertAllEqual([4], input_details[0]['shape']) 415 416 output_details = interpreter.get_output_details() 417 self.assertLen(output_details, 1) 418 self.assertEqual('Reshape', output_details[0]['name']) 419 self.assertEqual(np.string_, output_details[0]['dtype']) 420 self.assertAllEqual([2, 2], output_details[0]['shape']) 421 # TODO(b/122659643): Test setting/getting string data via the python 422 # interpreter API after support has been added. 423 424 def testIntermediateInputArray(self): 425 """Convert a model from an intermediate input array.""" 426 with ops.Graph().as_default(): 427 in_tensor_init = array_ops.placeholder( 428 shape=[1, 16, 16, 3], dtype=dtypes.float32) 429 in_tensor_final = in_tensor_init + in_tensor_init 430 out_tensor = in_tensor_final + in_tensor_final 431 sess = session.Session() 432 433 # Convert model and ensure model is not None. 434 converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final], 435 [out_tensor]) 436 tflite_model = converter.convert() 437 self.assertIsNotNone(tflite_model) 438 439 # Check values from converted model. 440 interpreter = Interpreter(model_content=tflite_model) 441 interpreter.allocate_tensors() 442 443 input_details = interpreter.get_input_details() 444 self.assertLen(input_details, 1) 445 self.assertEqual('add', input_details[0]['name']) 446 self.assertEqual(np.float32, input_details[0]['dtype']) 447 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 448 self.assertEqual((0., 0.), input_details[0]['quantization']) 449 450 output_details = interpreter.get_output_details() 451 self.assertLen(output_details, 1) 452 self.assertEqual('add_1', output_details[0]['name']) 453 self.assertEqual(np.float32, output_details[0]['dtype']) 454 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 455 self.assertEqual((0., 0.), output_details[0]['quantization']) 456 457 def testSizeNoneInvalid(self): 458 with ops.Graph().as_default(): 459 in_tensor = array_ops.placeholder(dtype=dtypes.float32) 460 out_tensor = in_tensor + in_tensor 461 sess = session.Session() 462 463 # Test None as shape when dynamic shapes are disabled. Run with TOCO in 464 # order to invoke shape checking code. 465 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 466 [out_tensor]) 467 converter.experimental_new_converter = False 468 with self.assertRaises(ValueError) as error: 469 converter.convert() 470 self.assertEqual('Provide an input shape for input array \'Placeholder\'.', 471 str(error.exception)) 472 473 @parameterized.named_parameters( 474 ('EnableMlirConverter', True), # enable mlir 475 ('DisableMlirConverter', False)) # disable mlir 476 def testScalarValid(self, enable_mlir_converter): 477 # Construct a graph using a scalar (empty shape) input. 478 with ops.Graph().as_default(): 479 in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[]) 480 out_tensor = in_tensor + in_tensor 481 sess = session.Session() 482 483 # Test conversion with the scalar input shape. 484 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 485 [out_tensor]) 486 converter.experimental_new_converter = enable_mlir_converter 487 tflite_model = converter.convert() 488 self.assertIsNotNone(tflite_model) 489 490 # Check values from converted model. 491 interpreter = Interpreter(model_content=tflite_model) 492 interpreter.allocate_tensors() 493 494 input_details = interpreter.get_input_details() 495 self.assertLen(input_details, 1) 496 self.assertEqual('Placeholder', input_details[0]['name']) 497 self.assertEqual(np.float32, input_details[0]['dtype']) 498 self.assertEmpty(input_details[0]['shape']) 499 500 output_details = interpreter.get_output_details() 501 self.assertLen(output_details, 1) 502 self.assertEqual('add', output_details[0]['name']) 503 self.assertEqual(np.float32, output_details[0]['dtype']) 504 self.assertEmpty(input_details[0]['shape']) 505 506 # Validate inference using the scalar inputs/outputs. 507 test_input = np.array(4.0, dtype=np.float32) 508 expected_output = np.array(8.0, dtype=np.float32) 509 interpreter.set_tensor(input_details[0]['index'], test_input) 510 interpreter.invoke() 511 512 output_data = interpreter.get_tensor(output_details[0]['index']) 513 self.assertEqual(expected_output, output_data) 514 515 def testSizeInvalid(self): 516 with ops.Graph().as_default(): 517 in_tensor = array_ops.placeholder( 518 shape=[1, None, 16, 3], dtype=dtypes.float32) 519 out_tensor = in_tensor + in_tensor 520 sess = session.Session() 521 522 # Test invalid shape. None after 1st dimension. Run with TOCO in order to 523 # invoke shape checking code. 524 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 525 [out_tensor]) 526 converter.experimental_new_converter = False 527 with self.assertRaises(ValueError) as error: 528 converter.convert() 529 self.assertEqual( 530 'None is only supported in the 1st dimension. Tensor ' 531 '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.', 532 str(error.exception)) 533 534 def testSizeNone(self): 535 with ops.Graph().as_default(): 536 in_tensor = array_ops.placeholder( 537 shape=[1, None, 16, 3], dtype=dtypes.float32) 538 out_tensor = in_tensor + in_tensor 539 sess = session.Session() 540 541 # Test None after 1st dimension. 542 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 543 [out_tensor]) 544 tflite_model = converter.convert() 545 546 # Check values from converted model. 547 interpreter = Interpreter(model_content=tflite_model) 548 input_details = interpreter.get_input_details() 549 self.assertLen(input_details, 1) 550 self.assertEqual('Placeholder', input_details[0]['name']) 551 self.assertEqual(np.float32, input_details[0]['dtype']) 552 self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape']) 553 self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature']) 554 self.assertEqual((0., 0.), input_details[0]['quantization']) 555 556 # Resize tensor with strict checking. 557 with self.assertRaises(RuntimeError) as error: 558 interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) 559 self.assertIn( 560 'ResizeInputTensorStrict only allows mutating unknown dimensions ' 561 'identified by -1.', str(error.exception)) 562 563 # Resize tensor and invoke. 564 interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) 565 interpreter.allocate_tensors() 566 interpreter.invoke() 567 568 input_details = interpreter.get_input_details() 569 self.assertLen(input_details, 1) 570 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 571 self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature']) 572 573 output_details = interpreter.get_output_details() 574 self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature']) 575 576 def testResizeTensorInputStrict(self): 577 # Ensures that resize_tensor_input(strict=True) works as expected. 578 with ops.Graph().as_default(): 579 in_tensor = array_ops.placeholder( 580 shape=[1, 16, 16, 3], dtype=dtypes.float32) 581 out_tensor = in_tensor + in_tensor 582 sess = session.Session() 583 584 # Convert model and ensure model is not None. 585 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 586 [out_tensor]) 587 tflite_model = converter.convert() 588 self.assertIsNotNone(tflite_model) 589 590 # Check values from converted model. 591 interpreter = Interpreter(model_content=tflite_model) 592 593 # Resize incorrect value. 594 with self.assertRaises(RuntimeError) as error: 595 interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True) 596 self.assertIn( 597 'ResizeInputTensorStrict only allows mutating unknown dimensions ' 598 'identified by -1.', str(error.exception)) 599 600 # Resize correct value. 601 interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True) 602 interpreter.allocate_tensors() 603 604 def testBatchSizeValid(self): 605 with ops.Graph().as_default(): 606 in_tensor = array_ops.placeholder( 607 shape=[None, 16, 16, 3], dtype=dtypes.float32) 608 out_tensor = in_tensor + in_tensor 609 sess = session.Session() 610 611 # Convert model and ensure model is not None. 612 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 613 [out_tensor]) 614 tflite_model = converter.convert() 615 self.assertIsNotNone(tflite_model) 616 617 # Check values from converted model. 618 interpreter = Interpreter(model_content=tflite_model) 619 interpreter.allocate_tensors() 620 621 input_details = interpreter.get_input_details() 622 self.assertLen(input_details, 1) 623 self.assertEqual('Placeholder', input_details[0]['name']) 624 self.assertEqual(np.float32, input_details[0]['dtype']) 625 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 626 self.assertEqual((0., 0.), input_details[0]['quantization']) 627 628 output_details = interpreter.get_output_details() 629 self.assertLen(output_details, 1) 630 self.assertEqual('add', output_details[0]['name']) 631 self.assertEqual(np.float32, output_details[0]['dtype']) 632 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 633 self.assertEqual((0., 0.), output_details[0]['quantization']) 634 635 def testBatchSizeNonZero(self): 636 with ops.Graph().as_default(): 637 in_tensor_1 = array_ops.placeholder( 638 shape=[None, 4], dtype=dtypes.float32, name='input1') 639 in_tensor_2 = array_ops.placeholder( 640 shape=[4, 10], dtype=dtypes.float32, name='input2') 641 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2) 642 sess = session.Session() 643 644 # Convert model and ensure model is not None. 645 converter = lite.TFLiteConverter.from_session(sess, 646 [in_tensor_1, in_tensor_2], 647 [out_tensor]) 648 tflite_model = converter.convert() 649 self.assertIsNotNone(tflite_model) 650 651 # Check values from converted model. 652 interpreter = Interpreter(model_content=tflite_model) 653 interpreter.allocate_tensors() 654 655 input_details = interpreter.get_input_details() 656 self.assertLen(input_details, 2) 657 self.assertEqual('input1', input_details[0]['name']) 658 self.assertAllEqual([1, 4], input_details[0]['shape']) 659 self.assertEqual('input2', input_details[1]['name']) 660 self.assertAllEqual([4, 10], input_details[1]['shape']) 661 662 def testFreezeGraph(self): 663 with ops.Graph().as_default(): 664 in_tensor = array_ops.placeholder( 665 shape=[1, 16, 16, 3], dtype=dtypes.float32) 666 var = variable_scope.get_variable( 667 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 668 # Get the second output to ensure freezing properly processes tensor names 669 # like 'X:1'. 670 out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1] 671 sess = session.Session() 672 sess.run(_global_variables_initializer()) 673 674 # Convert model and ensure model is not None. 675 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 676 [out_tensor]) 677 tflite_model = converter.convert() 678 self.assertIsNotNone(tflite_model) 679 680 # Check values from converted model. 681 interpreter = Interpreter(model_content=tflite_model) 682 interpreter.allocate_tensors() 683 684 input_details = interpreter.get_input_details() 685 self.assertLen(input_details, 1) 686 self.assertEqual('Placeholder', input_details[0]['name']) 687 self.assertEqual(np.float32, input_details[0]['dtype']) 688 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 689 self.assertEqual((0., 0.), input_details[0]['quantization']) 690 691 output_details = interpreter.get_output_details() 692 self.assertLen(output_details, 1) 693 self.assertEqual('top_k:1', output_details[0]['name']) 694 self.assertEqual(np.int32, output_details[0]['dtype']) 695 self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape']) 696 self.assertEqual((0., 0.), output_details[0]['quantization']) 697 698 def testGraphviz(self): 699 with ops.Graph().as_default(): 700 in_tensor = array_ops.placeholder( 701 shape=[1, 16, 16, 3], dtype=dtypes.float32) 702 out_tensor = in_tensor + in_tensor 703 sess = session.Session() 704 705 # Convert model and ensure model is not None. 706 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 707 [out_tensor]) 708 converter.output_format = lite_constants.GRAPHVIZ_DOT 709 graphviz_output = converter.convert() 710 self.assertIsNotNone(graphviz_output) 711 712 @parameterized.named_parameters( 713 ('EnableMlirConverter', True), # enable mlir 714 ('DisableMlirConverter', False)) # disable mlir 715 def testDumpGraphviz(self, enable_mlir_converter): 716 with ops.Graph().as_default(): 717 in_tensor = array_ops.placeholder( 718 shape=[1, 16, 16, 3], dtype=dtypes.float32) 719 out_tensor = in_tensor + in_tensor 720 sess = session.Session() 721 722 # Convert model and ensure model is not None. 723 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 724 [out_tensor]) 725 converter.experimental_new_converter = enable_mlir_converter 726 graphviz_dir = self.get_temp_dir() 727 converter.dump_graphviz_dir = graphviz_dir 728 tflite_model = converter.convert() 729 self.assertIsNotNone(tflite_model) 730 731 # Ensure interpreter is able to allocate and check graphviz data. 732 interpreter = Interpreter(model_content=tflite_model) 733 interpreter.allocate_tensors() 734 735 num_items_graphviz = len(os.listdir(graphviz_dir)) 736 self.assertIsNotNone(num_items_graphviz) 737 self.assertIsNotNone( 738 os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot'))) 739 self.assertIsNotNone( 740 os.path.exists( 741 os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot'))) 742 743 # new converter doesn't support `dump_graphviz_video` flag 744 if not enable_mlir_converter: 745 # Convert model and ensure model is not None. 746 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 747 [out_tensor]) 748 converter.experimental_new_converter = enable_mlir_converter 749 graphviz_dir = self.get_temp_dir() 750 converter.dump_graphviz_dir = graphviz_dir 751 converter.dump_graphviz_video = True 752 tflite_model = converter.convert() 753 self.assertIsNotNone(tflite_model) 754 755 # Ensure graphviz folder has more data after using video flag. 756 num_items_graphviz_video = len(os.listdir(graphviz_dir)) 757 self.assertGreater(num_items_graphviz_video, num_items_graphviz) 758 759 def testDumpConversionSummary(self): 760 with ops.Graph().as_default(): 761 in_tensor = array_ops.placeholder( 762 shape=[1, 16, 16, 3], dtype=dtypes.float32) 763 out_tensor = in_tensor + in_tensor 764 sess = session.Session() 765 766 # Convert model and ensure model is not None. 767 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 768 [out_tensor]) 769 log_dir = self.get_temp_dir() 770 converter.conversion_summary_dir = log_dir 771 tflite_model = converter.convert() 772 self.assertIsNotNone(tflite_model) 773 774 self.assertNotEmpty(os.listdir(log_dir)) 775 776 def testDumpConversionSummaryWithOldConverter(self): 777 with ops.Graph().as_default(): 778 in_tensor = array_ops.placeholder( 779 shape=[1, 16, 16, 3], dtype=dtypes.float32) 780 out_tensor = in_tensor + in_tensor 781 sess = session.Session() 782 783 # Convert model and ensure model is not None. 784 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 785 [out_tensor]) 786 converter.experimental_new_converter = False 787 log_dir = self.get_temp_dir() 788 converter.conversion_summary_dir = log_dir 789 tflite_model = converter.convert() 790 self.assertIsNotNone(tflite_model) 791 # Check nothing is generated under the conversion summary path. 792 num_items_conversion_summary = len(os.listdir(log_dir)) 793 self.assertEqual(num_items_conversion_summary, 0) 794 795 @parameterized.named_parameters( 796 ('EnableMlirConverter', True), # enable mlir 797 ('DisableMlirConverter', False)) # disable mlir 798 def testQuantizeDynamicRange(self, enable_mlir_converter): 799 np.random.seed(0) 800 with ops.Graph().as_default(): 801 # We need the tensor to have more than 1024 elements for quantize_weights 802 # to kick in. Thus, the [33, 33] shape. 803 in_tensor_1 = array_ops.placeholder( 804 shape=[33, 33], dtype=dtypes.float32, name='inputA') 805 in_tensor_2 = constant_op.constant( 806 np.random.uniform(low=-10., high=10., size=(33, 33)), 807 shape=[33, 33], 808 dtype=dtypes.float32, 809 name='inputB') 810 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 811 sess = session.Session() 812 813 # Convert float model. 814 float_converter = lite.TFLiteConverter.from_session( 815 sess, [in_tensor_1], [out_tensor]) 816 float_converter.experimental_new_converter = enable_mlir_converter 817 float_tflite_model = float_converter.convert() 818 self.assertIsNotNone(float_tflite_model) 819 820 # Convert quantized weights model. 821 quantized_converter = lite.TFLiteConverter.from_session( 822 sess, [in_tensor_1], [out_tensor]) 823 824 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 825 quantized_converter.experimental_new_converter = enable_mlir_converter 826 quantized_tflite_model = quantized_converter.convert() 827 self.assertIsNotNone(quantized_tflite_model) 828 829 # Ensure that the quantized weights tflite model is smaller. 830 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 831 832 @parameterized.named_parameters( 833 ('EnableMlirConverter', True), # enable mlir 834 ('DisableMlirConverter', False)) # disable mlir 835 def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute( 836 self, enable_mlir_converter): 837 with ops.Graph().as_default(): 838 in_tensor_1 = array_ops.placeholder( 839 shape=[33, 33], dtype=dtypes.float32, name='inputA') 840 in_tensor_2 = constant_op.constant( 841 np.random.uniform(low=-10., high=10., size=(33, 33)), 842 shape=[33, 33], 843 dtype=dtypes.float32, 844 name='inputB') 845 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 846 sess = session.Session() 847 848 quantized_converter = lite.TFLiteConverter.from_session( 849 sess, [in_tensor_1], [out_tensor]) 850 self.assertFalse(quantized_converter.post_training_quantize) 851 quantized_converter.experimental_new_converter = enable_mlir_converter 852 853 quantized_converter.post_training_quantize = True 854 self.assertTrue(quantized_converter.post_training_quantize) 855 self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT]) 856 857 quantized_tflite_model = quantized_converter.convert() 858 self.assertIsNotNone(quantized_tflite_model) 859 860 def _getIntegerQuantizeModel(self): 861 np.random.seed(0) 862 inp = array_ops.placeholder( 863 dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input') 864 conv = nn_ops.conv2d( 865 inp, 866 filter=array_ops.ones([3, 3, 3, 16]), 867 strides=[1, 1, 1, 1], 868 padding='SAME') 869 output = nn_ops.relu(conv, name='output') 870 871 def calibration_gen(): 872 for _ in range(5): 873 yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)] 874 875 return (inp, output, calibration_gen) 876 877 @parameterized.named_parameters( 878 ('EnableMlirConverter', True), # enable mlir 879 ('DisableMlirConverter', False)) # disable mlir 880 def testQuantizeInt8AllowFloat(self, enable_mlir_converter): 881 with ops.Graph().as_default(): 882 inp, output, calibration_gen = self._getIntegerQuantizeModel() 883 sess = session.Session() 884 885 # Convert float model. 886 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 887 float_tflite_model = float_converter.convert() 888 self.assertIsNotNone(float_tflite_model) 889 890 # Convert quantized model. 891 quantized_converter = lite.TFLiteConverter.from_session( 892 sess, [inp], [output]) 893 quantized_converter.experimental_new_converter = enable_mlir_converter 894 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 895 quantized_converter.representative_dataset = calibration_gen 896 quantized_tflite_model = quantized_converter.convert() 897 self.assertIsNotNone(quantized_tflite_model) 898 899 # The default input and output types should be float. 900 interpreter = Interpreter(model_content=quantized_tflite_model) 901 interpreter.allocate_tensors() 902 input_details = interpreter.get_input_details() 903 self.assertLen(input_details, 1) 904 self.assertEqual(np.float32, input_details[0]['dtype']) 905 output_details = interpreter.get_output_details() 906 self.assertLen(output_details, 1) 907 self.assertEqual(np.float32, output_details[0]['dtype']) 908 909 # Ensure that the quantized weights tflite model is smaller. 910 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 911 912 @parameterized.named_parameters( 913 # Quantize model to Int8: with enable mlir 914 ('UseTfliteBuiltinsIntEnableMLIR', 915 [lite.OpsSet.TFLITE_BUILTINS_INT8], True), 916 # Quantize model to Int8: with disable mlir 917 ('UseTfliteBuiltinsIntDisableMLIR', 918 [lite.OpsSet.TFLITE_BUILTINS_INT8], False), 919 # Quantize model to Int16: with disable mlir 920 ('UseTfliteBuiltinsInt16DisableMLIR', 921 [lite.OpsSet.\ 922 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], 923 False), 924 ('UseTfliteBuiltinsInt16EnableMLIR', 925 [lite.OpsSet.\ 926 EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8], 927 True)) 928 def testQuantizeInt8And16x8(self, supported_ops, enable_mlir_converter): 929 with ops.Graph().as_default(): 930 inp, output, calibration_gen = self._getIntegerQuantizeModel() 931 sess = session.Session() 932 933 # Convert float model. 934 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 935 float_converter.experimental_new_converter = enable_mlir_converter 936 float_tflite_model = float_converter.convert() 937 self.assertIsNotNone(float_tflite_model) 938 939 # Convert model by specifying target spec (instead of optimizations), since 940 # when targeting an integer only backend, quantization is mandatory. 941 quantized_converter = lite.TFLiteConverter.from_session( 942 sess, [inp], [output]) 943 quantized_converter.experimental_new_converter = enable_mlir_converter 944 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 945 quantized_converter.target_spec.supported_ops = supported_ops 946 quantized_converter.representative_dataset = calibration_gen 947 quantized_tflite_model = quantized_converter.convert() 948 self.assertIsNotNone(quantized_tflite_model) 949 950 # The default input and output types should be float. 951 interpreter = Interpreter(model_content=quantized_tflite_model) 952 interpreter.allocate_tensors() 953 input_details = interpreter.get_input_details() 954 self.assertLen(input_details, 1) 955 self.assertEqual(np.float32, input_details[0]['dtype']) 956 output_details = interpreter.get_output_details() 957 self.assertLen(output_details, 1) 958 self.assertEqual(np.float32, output_details[0]['dtype']) 959 960 # Ensure that the quantized weights tflite model is smaller. 961 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 962 963 @parameterized.named_parameters( 964 ('EnableMlirConverter', True), # enable mlir 965 ('DisableMlirConverter', False)) # disable mlir 966 def testQuantizeInt8InputOutput(self, enable_mlir_converter): 967 with ops.Graph().as_default(): 968 inp, output, calibration_gen = self._getIntegerQuantizeModel() 969 sess = session.Session() 970 971 # Convert float model. 972 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 973 float_converter.experimental_new_converter = enable_mlir_converter 974 float_tflite_model = float_converter.convert() 975 self.assertIsNotNone(float_tflite_model) 976 977 # Convert quantized weights model. 978 quantized_converter = lite.TFLiteConverter.from_session( 979 sess, [inp], [output]) 980 quantized_converter.experimental_new_converter = enable_mlir_converter 981 quantized_converter.inference_input_type = dtypes.int8 982 quantized_converter.inference_output_type = dtypes.int8 983 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 984 quantized_converter.representative_dataset = calibration_gen 985 quantized_tflite_model = quantized_converter.convert() 986 self.assertIsNotNone(quantized_tflite_model) 987 988 # The input and output types should be int8. 989 interpreter = Interpreter(model_content=quantized_tflite_model) 990 interpreter.allocate_tensors() 991 input_details = interpreter.get_input_details() 992 self.assertLen(input_details, 1) 993 self.assertEqual(np.int8, input_details[0]['dtype']) 994 output_details = interpreter.get_output_details() 995 self.assertLen(output_details, 1) 996 self.assertEqual(np.int8, output_details[0]['dtype']) 997 998 # Ensure that the quantized weights tflite model is smaller. 999 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1000 1001 @parameterized.named_parameters( 1002 ('EnableMlirConverter', True), # enable mlir 1003 ('DisableMlirConverter', False)) # disable mlir 1004 def testInvalidQuantizeInt8(self, enable_mlir_converter): 1005 np.random.seed(0) 1006 with ops.Graph().as_default(): 1007 # We need the tensor to have more than 1024 elements for quantize_weights 1008 # to kick in. Thus, the [33, 33] shape. 1009 in_tensor_1 = array_ops.placeholder( 1010 shape=[33, 33], dtype=dtypes.float32, name='inputA') 1011 in_tensor_2 = constant_op.constant( 1012 np.random.uniform(low=-10., high=10., size=(33, 33)), 1013 shape=[33, 33], 1014 dtype=dtypes.float32, 1015 name='inputB') 1016 out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') 1017 sess = session.Session() 1018 1019 # Attempt to convert to quantized weights model. 1020 quantized_converter = lite.TFLiteConverter.from_session( 1021 sess, [in_tensor_1], [out_tensor]) 1022 quantized_converter.experimental_new_converter = enable_mlir_converter 1023 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1024 # Restricting to int8 type only 1025 quantized_converter.target_spec.supported_types = [dtypes.int8] 1026 # A representative dataset is required for full fixed point quantization. 1027 with self.assertRaises(ValueError) as error: 1028 quantized_converter.convert() 1029 self.assertEqual( 1030 'representative_dataset is required when specifying ' 1031 'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception)) 1032 1033 @parameterized.named_parameters( 1034 ('EnableMlirConverter', True), # enable mlir 1035 ('DisableMlirConverter', False)) # disable mlir 1036 def testQuantizeUInt8(self, enable_mlir_converter): 1037 with ops.Graph().as_default(): 1038 in_tensor_1 = array_ops.placeholder( 1039 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 1040 in_tensor_2 = array_ops.placeholder( 1041 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 1042 out_tensor = array_ops.fake_quant_with_min_max_args( 1043 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 1044 sess = session.Session() 1045 1046 # Convert model and ensure model is not None. 1047 converter = lite.TFLiteConverter.from_session(sess, 1048 [in_tensor_1, in_tensor_2], 1049 [out_tensor]) 1050 converter.inference_type = dtypes.uint8 1051 converter.quantized_input_stats = { 1052 'inputA': (0., 1.), 1053 'inputB': (0., 1.) 1054 } # mean, std_dev 1055 converter.experimental_new_converter = enable_mlir_converter 1056 tflite_model = converter.convert() 1057 self.assertIsNotNone(tflite_model) 1058 1059 # Check values from converted model. 1060 interpreter = Interpreter(model_content=tflite_model) 1061 interpreter.allocate_tensors() 1062 1063 input_details = interpreter.get_input_details() 1064 self.assertLen(input_details, 2) 1065 self.assertEqual('inputA', input_details[0]['name']) 1066 self.assertEqual(np.uint8, input_details[0]['dtype']) 1067 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1068 self.assertEqual((1., 0.), input_details[0]['quantization']) 1069 1070 self.assertEqual('inputB', input_details[1]['name']) 1071 self.assertEqual(np.uint8, input_details[1]['dtype']) 1072 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1073 self.assertEqual((1., 0.), input_details[1]['quantization']) 1074 1075 output_details = interpreter.get_output_details() 1076 self.assertLen(output_details, 1) 1077 self.assertEqual(np.uint8, output_details[0]['dtype']) 1078 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1079 self.assertGreater(output_details[0]['quantization'][0], 0) # scale 1080 1081 def testQuantizeUInt8UsingDefaultRangeStats(self): 1082 with ops.Graph().as_default(): 1083 in_tensor = array_ops.placeholder( 1084 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1085 out_tensor = in_tensor + in_tensor 1086 sess = session.Session() 1087 1088 # Convert model and ensure model is not None. 1089 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1090 [out_tensor]) 1091 converter.inference_type = dtypes.uint8 1092 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 1093 converter.default_ranges_stats = (0, 6) # min, max 1094 tflite_model = converter.convert() 1095 self.assertIsNotNone(tflite_model) 1096 1097 # Check values from converted model. 1098 interpreter = Interpreter(model_content=tflite_model) 1099 interpreter.allocate_tensors() 1100 1101 input_details = interpreter.get_input_details() 1102 self.assertLen(input_details, 1) 1103 self.assertEqual('Placeholder', input_details[0]['name']) 1104 self.assertEqual(np.uint8, input_details[0]['dtype']) 1105 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1106 self.assertEqual((1., 0.), input_details[0]['quantization']) 1107 1108 output_details = interpreter.get_output_details() 1109 self.assertLen(output_details, 1) 1110 self.assertEqual('add', output_details[0]['name']) 1111 self.assertEqual(np.uint8, output_details[0]['dtype']) 1112 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1113 self.assertGreater(output_details[0]['quantization'][0], 0) # scale 1114 1115 @parameterized.named_parameters( 1116 # Quantize to Float16 even if rep data provided. 1117 ('UseRepresentativeData', True, False, True, False, False, False, False), 1118 # Quantize to Float16 if no rep data provided. 1119 ('NoRepresentativeData', False, False, True, False, False, False, False), 1120 # Post training quantization if both rep data and int8 included. 1121 ('UseSampleDataIncludeInt8', True, True, False, False, True, False, False 1122 ), 1123 # Quantize to Float16 even if rep data provided with mlir. 1124 ('UseRepresentativeDataMlir', True, False, True, False, False, True, False 1125 ), 1126 # Quantize to Float16 if no rep data provided with mlir. 1127 ('NoRepresentativeDataMlir', False, False, True, False, False, True, False 1128 ), 1129 # Post training quantization if both rep data and int8 included with mlir. 1130 ('SampleDataIncludeInt8Mlir', True, True, False, False, True, True, False 1131 ), 1132 # Same as above, but using MLIR quantizer 1133 ('SampleDataIncludeInt8MlirQuant', True, True, False, False, True, True, 1134 True)) 1135 def testQuantizeFloat16(self, use_rep_data, include_int8, 1136 is_float16_quantized, is_error, 1137 is_post_training_quantized, enable_mlir_converter, 1138 enable_mlir_quantizer): 1139 with ops.Graph().as_default(): 1140 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1141 sess = session.Session() 1142 1143 bias_idx = 1 if enable_mlir_converter else 0 1144 bias_name = 'Conv2D' if enable_mlir_converter else 'Conv2D_bias' 1145 1146 # Convert float model. 1147 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1148 float_converter.experimental_new_converter = enable_mlir_converter 1149 float_tflite_model = float_converter.convert() 1150 self.assertIsNotNone(float_tflite_model) 1151 interpreter = Interpreter(model_content=float_tflite_model) 1152 interpreter.allocate_tensors() 1153 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'], 1154 bias_name) 1155 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1156 dtypes.float32) 1157 1158 # MLIR quantizer has different bias index. 1159 if enable_mlir_quantizer: 1160 bias_idx = 2 1161 1162 # Convert model to quantized version 1163 quantized_converter = lite.TFLiteConverter.from_session( 1164 sess, [inp], [output]) 1165 quantized_converter.experimental_new_converter = enable_mlir_converter 1166 quantized_converter.experimental_new_quantizer = enable_mlir_quantizer 1167 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1168 quantized_converter.target_spec.supported_types = [dtypes.float16] 1169 if include_int8: 1170 quantized_converter.target_spec.supported_types.append(dtypes.int8) 1171 if use_rep_data: 1172 quantized_converter.representative_dataset = calibration_gen 1173 1174 if is_error: 1175 with self.assertRaises(ValueError) as error: 1176 quantized_converter.convert() 1177 self.assertEqual( 1178 'representative_dataset is required when specifying ' 1179 'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception)) 1180 1181 else: 1182 quantized_tflite_model = quantized_converter.convert() 1183 self.assertIsNotNone(quantized_tflite_model) 1184 interpreter = Interpreter(model_content=quantized_tflite_model) 1185 interpreter.allocate_tensors() 1186 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'], 1187 bias_name) 1188 1189 if is_float16_quantized: 1190 # Verify that bias constant is float16 type. 1191 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1192 dtypes.float16) 1193 elif is_post_training_quantized: 1194 # Verify that bias constants is int32 type. 1195 self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'], 1196 dtypes.int32) 1197 else: 1198 raise ValueError('Invalid test options.') 1199 1200 @parameterized.named_parameters( 1201 ('EnableMlirConverter', True), # enable mlir 1202 ('DisableMlirConverter', False)) # disable mlir 1203 def testInvalidQuantizeFloat16(self, enable_mlir_converter): 1204 with ops.Graph().as_default(): 1205 inp, output, _ = self._getIntegerQuantizeModel() 1206 sess = session.Session() 1207 1208 # Specify float16 quantization 1209 quantized_converter = lite.TFLiteConverter.from_session( 1210 sess, [inp], [output]) 1211 quantized_converter.experimental_new_converter = enable_mlir_converter 1212 quantized_converter.optimizations = [lite.Optimize.DEFAULT] 1213 quantized_converter.target_spec.supported_types = [dtypes.float16] 1214 # Specify only int8 builtin ops 1215 quantized_converter.target_spec.supported_ops = [ 1216 lite.OpsSet.TFLITE_BUILTINS_INT8 1217 ] 1218 with self.assertRaises(ValueError) as error: 1219 quantized_converter.convert() 1220 self.assertEqual( 1221 'TFLITE_BUILTINS_INT8 requires smallest supported type to be INT8.', 1222 str(error.exception)) 1223 1224 @parameterized.named_parameters( 1225 ('InferenceType_INT8', dtypes.int8), 1226 ('InferenceType_UINT8', dtypes.uint8)) 1227 def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type): 1228 with ops.Graph().as_default(): 1229 in_tensor = array_ops.placeholder( 1230 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1231 out_tensor = array_ops.fake_quant_with_min_max_args( 1232 in_tensor + in_tensor, min=0., max=1.) 1233 sess = session.Session() 1234 1235 quantized_converter = lite.TFLiteConverter.from_session( 1236 sess, [in_tensor], [out_tensor]) 1237 1238 with self.assertRaises(ValueError) as error: 1239 quantized_converter.inference_type = quantized_type 1240 quantized_converter.convert() 1241 self.assertEqual( 1242 'The `quantized_input_stats` flag must be defined when either ' 1243 '`inference_type` flag or `inference_input_type` flag is set to ' 1244 'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and ' 1245 '`inference_input_type=None`.'.format(quantized_type.name), 1246 str(error.exception)) 1247 1248 with self.assertRaises(ValueError) as error: 1249 quantized_converter.inference_type = dtypes.float32 1250 quantized_converter.inference_input_type = quantized_type 1251 quantized_converter.convert() 1252 self.assertEqual( 1253 'The `quantized_input_stats` flag must be defined when either ' 1254 '`inference_type` flag or `inference_input_type` flag is set to ' 1255 'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and ' 1256 '`inference_input_type=tf.{}`.'.format(quantized_type.name), 1257 str(error.exception)) 1258 1259 quantized_converter.inference_type = quantized_type 1260 quantized_converter.inference_input_type = quantized_type 1261 1262 input_arrays = quantized_converter.get_input_arrays() 1263 quantized_converter.quantized_input_stats = { 1264 input_arrays[0]: (0., 1.) 1265 } 1266 quantized_converter.convert() 1267 1268 def testInvalidQuantizeQATModelMissingInputStats(self): 1269 with ops.Graph().as_default(): 1270 in_tensor_1 = array_ops.placeholder( 1271 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 1272 in_tensor_2 = array_ops.placeholder( 1273 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 1274 out_tensor = array_ops.fake_quant_with_min_max_args( 1275 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 1276 sess = session.Session() 1277 1278 # Convert model and ensure model is not None. 1279 converter = lite.TFLiteConverter.from_session(sess, 1280 [in_tensor_1, in_tensor_2], 1281 [out_tensor]) 1282 converter.inference_type = dtypes.uint8 1283 converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev 1284 with self.assertRaises(ValueError) as error: 1285 converter.convert() 1286 self.assertEqual( 1287 'Quantization input stats are not available for input tensors ' 1288 '\'inputB\'.', str(error.exception)) 1289 1290 def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self): 1291 with ops.Graph().as_default(): 1292 inp, output, calibration_gen = self._getIntegerQuantizeModel() 1293 sess = session.Session() 1294 1295 # Convert float model. 1296 float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1297 float_tflite_model = float_converter.convert() 1298 self.assertIsNotNone(float_tflite_model) 1299 1300 converter = lite.TFLiteConverter.from_session(sess, [inp], [output]) 1301 1302 # extra flags to trigger training time quantization conversion 1303 converter.inference_type = dtypes.int8 1304 converter.inference_input_type = dtypes.float32 1305 converter.inference_output_type = dtypes.float32 1306 input_arrays = converter.get_input_arrays() 1307 converter.quantized_input_stats = { 1308 input_arrays[0]: (0., 1.) 1309 } 1310 # trigger post-training quantization 1311 converter.optimizations = [lite.Optimize.DEFAULT] 1312 converter.representative_dataset = calibration_gen 1313 converter.experimental_new_quantizer = True 1314 quantized_tflite_model = converter.convert() 1315 self.assertIsNotNone(quantized_tflite_model) 1316 self.assertLess(len(quantized_tflite_model), len(float_tflite_model)) 1317 1318 # calibration only api 1319 converter._experimental_calibrate_only = True 1320 calibrated_tflite = converter.convert() 1321 quantized_tflite_model = mlir_quantize( 1322 calibrated_tflite, fully_quantize=True) 1323 interpreter = Interpreter(model_content=quantized_tflite_model) 1324 interpreter.allocate_tensors() 1325 input_details = interpreter.get_input_details() 1326 self.assertEqual(np.int8, input_details[0]['dtype']) 1327 self.assertEqual((1., 0.), input_details[0]['quantization']) 1328 1329 output_details = interpreter.get_output_details() 1330 self.assertEqual(np.int8, output_details[0]['dtype']) 1331 1332 def testFloatTocoConverter(self): 1333 """Tests deprecated test TocoConverter.""" 1334 with ops.Graph().as_default(): 1335 in_tensor = array_ops.placeholder( 1336 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1337 out_tensor = in_tensor + in_tensor 1338 sess = session.Session() 1339 1340 # Convert model and ensure model is not None. 1341 converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) 1342 tflite_model = converter.convert() 1343 self.assertIsNotNone(tflite_model) 1344 1345 # Ensure the interpreter is able to load. 1346 interpreter = Interpreter(model_content=tflite_model) 1347 interpreter.allocate_tensors() 1348 1349 def testMultipleOutputNodeNames(self): 1350 """Tests converting a graph with an op that have multiple outputs.""" 1351 with ops.Graph().as_default(): 1352 input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32) 1353 out0, out1, out2, out3 = array_ops.split( 1354 input_tensor, [1, 1, 1, 1], axis=0) 1355 sess = session.Session() 1356 1357 # Convert model and ensure model is not None. 1358 converter = lite.TFLiteConverter.from_session(sess, [input_tensor], 1359 [out0, out1, out2, out3]) 1360 tflite_model = converter.convert() 1361 self.assertIsNotNone(tflite_model) 1362 1363 # Check values from converted model. 1364 interpreter = Interpreter(model_content=tflite_model) 1365 interpreter.allocate_tensors() 1366 1367 input_details = interpreter.get_input_details() 1368 self.assertLen(input_details, 1) 1369 interpreter.set_tensor(input_details[0]['index'], 1370 np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) 1371 interpreter.invoke() 1372 1373 output_details = interpreter.get_output_details() 1374 self.assertLen(output_details, 4) 1375 self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index'])) 1376 self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index'])) 1377 self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index'])) 1378 self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index'])) 1379 1380 @parameterized.named_parameters( 1381 ('EnableMlirConverter', True), # enable mlir 1382 ('DisableMlirConverter', False)) # disable mlir 1383 @test_util.run_in_graph_and_eager_modes 1384 def testFunctions(self, enable_mlir_converter): 1385 """Tests tf.function in 1.X.""" 1386 1387 @def_function.function 1388 def plus_placeholder(x, placeholder): 1389 return x + placeholder 1390 1391 with ops.Graph().as_default(): 1392 placeholder = array_ops.placeholder( 1393 dtype=dtypes.float32, shape=[1], name='input') 1394 variable_node = variables.Variable(1.0, name='variable_node') 1395 defun_node = plus_placeholder(variable_node, placeholder) 1396 output_node = math_ops.multiply(defun_node, 2.0, name='output_node') 1397 1398 # Initialize variables in the model. 1399 sess = session.Session() 1400 sess.run(variables.variables_initializer([variable_node])) 1401 1402 # Convert model and ensure model is not None. 1403 converter = lite.TFLiteConverter.from_session(sess, [placeholder], 1404 [output_node]) 1405 converter.experimental_new_converter = enable_mlir_converter 1406 tflite_model = converter.convert() 1407 self.assertIsNotNone(tflite_model) 1408 1409 # Check values from converted model. 1410 interpreter = Interpreter(model_content=tflite_model) 1411 interpreter.allocate_tensors() 1412 1413 input_details = interpreter.get_input_details() 1414 self.assertLen(input_details, 1) 1415 self.assertEqual('input', input_details[0]['name']) 1416 self.assertEqual(np.float32, input_details[0]['dtype']) 1417 self.assertAllEqual([1], input_details[0]['shape']) 1418 self.assertEqual((0., 0.), input_details[0]['quantization']) 1419 1420 output_details = interpreter.get_output_details() 1421 self.assertLen(output_details, 1) 1422 self.assertEqual('output_node', output_details[0]['name']) 1423 self.assertEqual(np.float32, output_details[0]['dtype']) 1424 self.assertAllEqual([1], output_details[0]['shape']) 1425 self.assertEqual((0., 0.), output_details[0]['quantization']) 1426 1427 def testInferenceInputOutputTypeFloatDefault(self): 1428 with ops.Graph().as_default(): 1429 in_tensor = array_ops.placeholder( 1430 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1431 out_tensor = in_tensor + in_tensor 1432 sess = session.Session() 1433 1434 # Convert model and ensure model is not None. 1435 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1436 [out_tensor]) 1437 tflite_model = converter.convert() 1438 self.assertIsNotNone(tflite_model) 1439 1440 # Check values from converted model. 1441 interpreter = Interpreter(model_content=tflite_model) 1442 interpreter.allocate_tensors() 1443 1444 input_details = interpreter.get_input_details() 1445 self.assertLen(input_details, 1) 1446 self.assertEqual('Placeholder', input_details[0]['name']) 1447 self.assertEqual(np.float32, input_details[0]['dtype']) 1448 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1449 1450 output_details = interpreter.get_output_details() 1451 self.assertLen(output_details, 1) 1452 self.assertEqual('add', output_details[0]['name']) 1453 self.assertEqual(np.float32, output_details[0]['dtype']) 1454 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1455 1456 def testInferenceInputOutputTypeQuantizedUint8Default(self): 1457 with ops.Graph().as_default(): 1458 in_tensor = array_ops.placeholder( 1459 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1460 out_tensor = array_ops.fake_quant_with_min_max_args( 1461 in_tensor + in_tensor, min=0., max=1., name='output') 1462 sess = session.Session() 1463 1464 # Convert model and ensure model is not None. 1465 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1466 [out_tensor]) 1467 converter.inference_type = dtypes.uint8 1468 converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev 1469 tflite_model = converter.convert() 1470 self.assertIsNotNone(tflite_model) 1471 1472 # Check values from converted model. 1473 interpreter = Interpreter(model_content=tflite_model) 1474 interpreter.allocate_tensors() 1475 1476 input_details = interpreter.get_input_details() 1477 self.assertLen(input_details, 1) 1478 self.assertEqual('Placeholder', input_details[0]['name']) 1479 self.assertEqual(np.uint8, input_details[0]['dtype']) 1480 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1481 1482 output_details = interpreter.get_output_details() 1483 self.assertLen(output_details, 1) 1484 self.assertEqual('output', output_details[0]['name']) 1485 self.assertEqual(np.uint8, output_details[0]['dtype']) 1486 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1487 1488 def testReusingConverterWithDifferentPostTrainingQuantization(self): 1489 with ops.Graph().as_default(): 1490 in_tensor = array_ops.placeholder( 1491 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1492 out_tensor = array_ops.fake_quant_with_min_max_args( 1493 in_tensor + in_tensor, min=0., max=1., name='output') 1494 sess = session.Session() 1495 1496 # Convert model and ensure model is not None. 1497 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1498 [out_tensor]) 1499 1500 converter.post_training_quantize = True 1501 tflite_model = converter.convert() 1502 self.assertIsNotNone(tflite_model) 1503 1504 converter.post_training_quantize = False 1505 tflite_model = converter.convert() 1506 self.assertIsNotNone(tflite_model) 1507 1508 def testResizeWithShape(self): 1509 with ops.Graph().as_default(): 1510 # Construct a graph with a dynamically shapped input and an internal node 1511 # that relies on the output of that input's shape. 1512 in_tensor = array_ops.placeholder( 1513 shape=[None, None], dtype=dtypes.float32) 1514 in_tensor2 = [[1, 2], [3, 4]] 1515 out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor)) 1516 sess = session.Session() 1517 1518 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 1519 [out_tensor]) 1520 tflite_model = converter.convert() 1521 1522 # Check values from converted model. 1523 interpreter = Interpreter(model_content=tflite_model) 1524 input_details = interpreter.get_input_details() 1525 self.assertLen(input_details, 1) 1526 self.assertAllEqual([1, 1], input_details[0]['shape']) 1527 self.assertAllEqual([-1, -1], input_details[0]['shape_signature']) 1528 1529 # Resize tensor and invoke. 1530 interpreter.resize_tensor_input(0, [4]) 1531 interpreter.allocate_tensors() 1532 interpreter.invoke() 1533 1534 # The output should be reshaped properly according to the resized input. 1535 output_details = interpreter.get_output_details() 1536 self.assertLen(output_details, 1) 1537 self.assertEqual(np.int32, output_details[0]['dtype']) 1538 self.assertAllEqual([4], output_details[0]['shape']) 1539 output_data = interpreter.get_tensor(output_details[0]['index']) 1540 self.assertAllEqual([1, 2, 3, 4], output_data) 1541 1542 def testResizingIntermediateDynamicTensor(self): 1543 # This is a regression test for the case where shape of dynamic output 1544 # tensors changes between invocations. 1545 # See also https://github.com/tensorflow/tensorflow/issues/26549 1546 with ops.Graph().as_default(): 1547 input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32) 1548 input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32) 1549 1550 # The bug is triggered only when dynamic tensor is intermediate. Putting 1551 # some other ops around it. 1552 neg = math_ops.negative(input2_tensor) 1553 padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32) 1554 output_tensor = array_ops.pad(input_tensor, padding) + neg 1555 1556 sess = session.Session() 1557 1558 converter = lite.TFLiteConverter.from_session( 1559 sess, [input_tensor, padding, input2_tensor], [output_tensor]) 1560 tflite_model = converter.convert() 1561 1562 interpreter = Interpreter(model_content=tflite_model) 1563 interpreter.allocate_tensors() 1564 1565 input_details = interpreter.get_input_details() 1566 interpreter.set_tensor(input_details[1]['index'], 1567 np.array([[1, 1], [1, 1]], dtype=np.int32)) 1568 interpreter.invoke() 1569 1570 # Without the fix, invocation will fail when changing the shape of 1571 # intermediate dynamic tensors. 1572 interpreter.set_tensor(input_details[1]['index'], 1573 np.array([[2, 2], [2, 2]], dtype=np.int32)) 1574 interpreter.invoke() 1575 1576 def testGraphDebugInfo(self): 1577 """Test a session has debug info captured.""" 1578 1579 @def_function.function 1580 def plus_placeholder(x, placeholder): 1581 return x + placeholder 1582 1583 with ops.Graph().as_default(): 1584 placeholder = array_ops.placeholder( 1585 dtype=dtypes.float32, shape=[1], name='input') 1586 variable_node = variables.Variable(1.0, name='variable_node') 1587 defun_node = plus_placeholder(variable_node, placeholder) 1588 output_node = math_ops.multiply(defun_node, 2.0, name='output_node') 1589 1590 # Initialize variables in the model. 1591 sess = session.Session() 1592 sess.run(variables.variables_initializer([variable_node])) 1593 1594 converter = lite.TFLiteConverter.from_session(sess, [placeholder], 1595 [output_node]) 1596 converter.convert() 1597 self.assertValidDebugInfo(converter._debug_info) 1598 1599 # Check the add node in the inlined function is included. 1600 func = sess.graph.as_graph_def().library.function[0].signature.name 1601 self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces) 1602 1603 def testOutputOnlyModel(self): 1604 with ops.Graph().as_default(): 1605 out_tensor = random_ops.random_normal(shape=[3]) 1606 sess = session.Session() 1607 1608 # Convert model and ensure model is not None. 1609 converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor]) 1610 converter.target_spec.supported_ops = [ 1611 lite.OpsSet.TFLITE_BUILTINS, 1612 lite.OpsSet.SELECT_TF_OPS, 1613 ] 1614 1615 # Empty input array is a valid input. 1616 self.assertTrue(converter._has_valid_tensors()) 1617 1618 tflite_model = converter.convert() 1619 self.assertIsNotNone(tflite_model) 1620 1621 1622class FromFrozenGraphFile(LiteTest): 1623 1624 def testFloat(self): 1625 with ops.Graph().as_default(): 1626 in_tensor = array_ops.placeholder( 1627 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1628 _ = in_tensor + in_tensor 1629 sess = session.Session() 1630 1631 # Write graph to file. 1632 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1633 write_graph(sess.graph_def, '', graph_def_file, False) 1634 sess.close() 1635 1636 # Convert model and ensure model is not None. 1637 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1638 ['Placeholder'], ['add']) 1639 tflite_model = converter.convert() 1640 self.assertIsNotNone(tflite_model) 1641 1642 # Check values from converted model. 1643 interpreter = Interpreter(model_content=tflite_model) 1644 interpreter.allocate_tensors() 1645 1646 input_details = interpreter.get_input_details() 1647 self.assertLen(input_details, 1) 1648 self.assertEqual('Placeholder', input_details[0]['name']) 1649 self.assertEqual(np.float32, input_details[0]['dtype']) 1650 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1651 self.assertEqual((0., 0.), input_details[0]['quantization']) 1652 1653 output_details = interpreter.get_output_details() 1654 self.assertLen(output_details, 1) 1655 self.assertEqual('add', output_details[0]['name']) 1656 self.assertEqual(np.float32, output_details[0]['dtype']) 1657 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1658 self.assertEqual((0., 0.), output_details[0]['quantization']) 1659 1660 def testFloatWithShapesArray(self): 1661 """Test a shape overriding case.""" 1662 with ops.Graph().as_default(): 1663 in_tensor = array_ops.placeholder( 1664 shape=[None, 16, 16, 3], dtype=dtypes.float32) 1665 _ = in_tensor + in_tensor 1666 sess = session.Session() 1667 1668 # Write graph to file. 1669 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1670 write_graph(sess.graph_def, '', graph_def_file, False) 1671 sess.close() 1672 1673 # Convert model and ensure model is not None. 1674 converter = lite.TFLiteConverter.from_frozen_graph( 1675 graph_def_file, ['Placeholder'], ['add'], 1676 input_shapes={'Placeholder': [2, 16, 16, 3]}) 1677 tflite_model = converter.convert() 1678 self.assertIsNotNone(tflite_model) 1679 1680 # Check values from converted model. 1681 interpreter = Interpreter(model_content=tflite_model) 1682 interpreter.allocate_tensors() 1683 1684 input_details = interpreter.get_input_details() 1685 self.assertLen(input_details, 1) 1686 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 1687 1688 def testInvalidShapesArray(self): 1689 """Test an invalid shape overriding case, which has a wrong input name.""" 1690 with ops.Graph().as_default(): 1691 in_tensor = array_ops.placeholder( 1692 shape=[None, 16, 16, 3], dtype=dtypes.float32) 1693 _ = in_tensor + in_tensor 1694 sess = session.Session() 1695 1696 # Write graph to file. 1697 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1698 write_graph(sess.graph_def, '', graph_def_file, False) 1699 sess.close() 1700 1701 # Convert model and ensure model is not None. 1702 with self.assertRaises(ValueError): 1703 lite.TFLiteConverter.from_frozen_graph( 1704 graph_def_file, ['Placeholder'], ['add'], 1705 input_shapes={'wrong_input': [2, 16, 16, 3]}) 1706 1707 def testPartialShapesArray(self): 1708 """Test a shape overriding case, with the only one input among two.""" 1709 with ops.Graph().as_default(): 1710 a = array_ops.placeholder( 1711 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a') 1712 b = array_ops.placeholder( 1713 shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b') 1714 _ = math_ops.add(a, b, name='add') 1715 sess = session.Session() 1716 1717 # Write graph to file. 1718 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1719 write_graph(sess.graph_def, '', graph_def_file, False) 1720 sess.close() 1721 1722 # Convert model and ensure model is not None. 1723 converter = lite.TFLiteConverter.from_frozen_graph( 1724 graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]}) 1725 tflite_model = converter.convert() 1726 self.assertIsNotNone(tflite_model) 1727 1728 # Check values from converted model. 1729 interpreter = Interpreter(model_content=tflite_model) 1730 interpreter.allocate_tensors() 1731 1732 input_details = interpreter.get_input_details() 1733 self.assertLen(input_details, 2) 1734 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 1735 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1736 1737 def testFreezeGraph(self): 1738 with ops.Graph().as_default(): 1739 in_tensor = array_ops.placeholder( 1740 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1741 var = variable_scope.get_variable( 1742 'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32) 1743 _ = in_tensor + var 1744 sess = session.Session() 1745 1746 # Write graph to file. 1747 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1748 write_graph(sess.graph_def, '', graph_def_file, False) 1749 sess.close() 1750 1751 # Ensure the graph with variables cannot be converted. 1752 with self.assertRaises(ValueError) as error: 1753 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 1754 ['add']) 1755 self.assertEqual('Please freeze the graph using freeze_graph.py.', 1756 str(error.exception)) 1757 1758 def testPbtxt(self): 1759 with ops.Graph().as_default(): 1760 in_tensor = array_ops.placeholder( 1761 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1762 _ = in_tensor + in_tensor 1763 sess = session.Session() 1764 1765 # Write graph to file. 1766 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt') 1767 write_graph(sess.graph_def, '', graph_def_file, True) 1768 sess.close() 1769 1770 # Convert model and ensure model is not None. 1771 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 1772 ['Placeholder'], ['add']) 1773 tflite_model = converter.convert() 1774 self.assertIsNotNone(tflite_model) 1775 1776 # Check values from converted model. 1777 interpreter = Interpreter(model_content=tflite_model) 1778 interpreter.allocate_tensors() 1779 1780 input_details = interpreter.get_input_details() 1781 self.assertLen(input_details, 1) 1782 self.assertEqual('Placeholder', input_details[0]['name']) 1783 self.assertEqual(np.float32, input_details[0]['dtype']) 1784 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1785 self.assertEqual((0., 0.), input_details[0]['quantization']) 1786 1787 output_details = interpreter.get_output_details() 1788 self.assertLen(output_details, 1) 1789 self.assertEqual('add', output_details[0]['name']) 1790 self.assertEqual(np.float32, output_details[0]['dtype']) 1791 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1792 self.assertEqual((0., 0.), output_details[0]['quantization']) 1793 1794 def testInvalidFileNotFound(self): 1795 with self.assertRaises(IOError) as error: 1796 lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'], 1797 ['add']) 1798 self.assertEqual('File \'invalid_file\' does not exist.', 1799 str(error.exception)) 1800 1801 def testInvalidFileBadData(self): 1802 graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file') 1803 with gfile.Open(graph_def_file, 'wb') as temp_file: 1804 temp_file.write('bad data') 1805 temp_file.flush() 1806 1807 # Attempts to convert the invalid model. 1808 with self.assertRaises(IOError) as error: 1809 lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'], 1810 ['add']) 1811 self.assertEqual( 1812 'Unable to parse input file \'{}\'.'.format(graph_def_file), 1813 str(error.exception)) 1814 1815 def testFloatTocoConverter(self): 1816 with ops.Graph().as_default(): 1817 in_tensor = array_ops.placeholder( 1818 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1819 _ = in_tensor + in_tensor 1820 sess = session.Session() 1821 1822 # Write graph to file. 1823 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1824 write_graph(sess.graph_def, '', graph_def_file, False) 1825 sess.close() 1826 1827 # Convert model and ensure model is not None. 1828 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 1829 ['Placeholder'], ['add']) 1830 tflite_model = converter.convert() 1831 self.assertIsNotNone(tflite_model) 1832 1833 # Ensure the model is able to load. 1834 interpreter = Interpreter(model_content=tflite_model) 1835 interpreter.allocate_tensors() 1836 1837 def testGraphDebugInfo(self): 1838 """Test a frozen graph doesn't have debug info captured.""" 1839 with ops.Graph().as_default(): 1840 in_tensor = array_ops.placeholder( 1841 shape=[1, 16, 16, 3], dtype=dtypes.float32) 1842 _ = in_tensor + in_tensor 1843 sess = session.Session() 1844 1845 # Write graph to file. 1846 graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb') 1847 write_graph(sess.graph_def, '', graph_def_file, False) 1848 sess.close() 1849 1850 # Convert model and ensure model is not None. 1851 converter = lite.TocoConverter.from_frozen_graph(graph_def_file, 1852 ['Placeholder'], ['add']) 1853 converter.convert() 1854 # GraphDebugInfo should be none for frozen graph. 1855 self.assertFalse(converter._debug_info) 1856 1857 1858class FromFrozenGraphObjectDetection(LiteTest): 1859 1860 def _initObjectDetectionArgs(self): 1861 # Initializes the arguments required for the object detection model. 1862 # Looks for the model file which is saved in a different location internally 1863 # and externally. 1864 filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') 1865 if not os.path.exists(filename): 1866 filename = os.path.join( 1867 resource_loader.get_root_dir_with_all_resources(), 1868 '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') 1869 if not os.path.exists(filename): 1870 raise IOError("File '{0}' does not exist.".format(filename)) 1871 1872 self._graph_def_file = filename 1873 self._input_arrays = ['normalized_input_image_tensor'] 1874 self._output_arrays = [ 1875 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1', 1876 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3' 1877 ] 1878 self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]} 1879 1880 def testTFLiteGraphDef(self): 1881 # Tests the object detection model that cannot be loaded in TensorFlow. 1882 self._initObjectDetectionArgs() 1883 1884 converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file, 1885 self._input_arrays, 1886 self._output_arrays, 1887 self._input_shapes) 1888 converter.allow_custom_ops = True 1889 tflite_model = converter.convert() 1890 self.assertIsNotNone(tflite_model) 1891 1892 # Check values from converted model. 1893 interpreter = Interpreter(model_content=tflite_model) 1894 interpreter.allocate_tensors() 1895 1896 input_details = interpreter.get_input_details() 1897 self.assertLen(input_details, 1) 1898 self.assertEqual('normalized_input_image_tensor', input_details[0]['name']) 1899 self.assertEqual(np.float32, input_details[0]['dtype']) 1900 self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape']) 1901 self.assertEqual((0., 0.), input_details[0]['quantization']) 1902 1903 output_details = interpreter.get_output_details() 1904 self.assertLen(output_details, 4) 1905 self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name']) 1906 self.assertEqual(np.float32, output_details[0]['dtype']) 1907 self.assertAllEqual([1, 10, 4], output_details[0]['shape']) 1908 self.assertEqual((0., 0.), output_details[0]['quantization']) 1909 1910 self.assertEqual('TFLite_Detection_PostProcess:1', 1911 output_details[1]['name']) 1912 self.assertAllEqual([1, 10], output_details[1]['shape']) 1913 self.assertEqual('TFLite_Detection_PostProcess:2', 1914 output_details[2]['name']) 1915 self.assertAllEqual([1, 10], output_details[2]['shape']) 1916 self.assertEqual('TFLite_Detection_PostProcess:3', 1917 output_details[3]['name']) 1918 self.assertAllEqual([1], output_details[3]['shape']) 1919 1920 1921class FromSavedModelTest(TestModels): 1922 1923 def _createSavedModel(self, shape): 1924 """Create a simple SavedModel.""" 1925 saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel') 1926 with ops.Graph().as_default(): 1927 with session.Session() as sess: 1928 in_tensor_1 = array_ops.placeholder( 1929 shape=shape, dtype=dtypes.float32, name='inputB') 1930 in_tensor_2 = array_ops.placeholder( 1931 shape=shape, dtype=dtypes.float32, name='inputA') 1932 out_tensor = in_tensor_1 + in_tensor_2 1933 inputs = {'x': in_tensor_1, 'y': in_tensor_2} 1934 outputs = {'z': out_tensor} 1935 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 1936 return saved_model_dir 1937 1938 def testSimpleModel(self): 1939 """Test a SavedModel.""" 1940 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 1941 1942 # Convert model and ensure model is not None. 1943 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 1944 tflite_model = converter.convert() 1945 self.assertIsNotNone(tflite_model) 1946 1947 interpreter = Interpreter(model_content=tflite_model) 1948 interpreter.allocate_tensors() 1949 1950 input_details = interpreter.get_input_details() 1951 self.assertLen(input_details, 2) 1952 self.assertStartsWith(input_details[0]['name'], 'inputA') 1953 self.assertEqual(np.float32, input_details[0]['dtype']) 1954 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 1955 self.assertEqual((0., 0.), input_details[0]['quantization']) 1956 1957 self.assertStartsWith(input_details[1]['name'], 'inputB') 1958 self.assertEqual(np.float32, input_details[1]['dtype']) 1959 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 1960 self.assertEqual((0., 0.), input_details[1]['quantization']) 1961 1962 output_details = interpreter.get_output_details() 1963 self.assertLen(output_details, 1) 1964 self.assertStartsWith(output_details[0]['name'], 'add') 1965 self.assertEqual(np.float32, output_details[0]['dtype']) 1966 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 1967 self.assertEqual((0., 0.), output_details[0]['quantization']) 1968 1969 def testOldConverterWarning(self): 1970 """Test if the warning message when using TOCO is logged.""" 1971 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 1972 log = io.BytesIO() if six.PY2 else io.StringIO() 1973 handler = logging.StreamHandler(log) 1974 logging.root.addHandler(handler) 1975 warning_message = 'Please consider switching to the new converter' 1976 # Convert model and ensure model is not None. 1977 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 1978 converter.experimental_new_converter = False 1979 tflite_model = converter.convert() 1980 self.assertIsNotNone(tflite_model) 1981 self.assertIn(warning_message, log.getvalue()) 1982 logging.root.removeHandler(handler) 1983 1984 def testNewConverterOptOut(self): 1985 """Test if the opt out message when using New converter is logged.""" 1986 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 1987 log = io.BytesIO() if six.PY2 else io.StringIO() 1988 handler = logging.StreamHandler(log) 1989 logging.root.addHandler(handler) 1990 optout_message = ('Using experimental converter: ' 1991 'If you encountered a problem') 1992 # Convert model and ensure model is not None. 1993 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 1994 tflite_model = converter.convert() 1995 self.assertIsNotNone(tflite_model) 1996 self.assertIn(optout_message, log.getvalue()) 1997 logging.root.removeHandler(handler) 1998 1999 def testNoneBatchSize(self): 2000 """Test a SavedModel, with None in input tensor's shape.""" 2001 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 2002 2003 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2004 tflite_model = converter.convert() 2005 self.assertIsNotNone(tflite_model) 2006 2007 # Check values from converted model. 2008 interpreter = Interpreter(model_content=tflite_model) 2009 interpreter.allocate_tensors() 2010 2011 input_details = interpreter.get_input_details() 2012 self.assertLen(input_details, 2) 2013 self.assertStartsWith(input_details[0]['name'], 'inputA') 2014 self.assertEqual(np.float32, input_details[0]['dtype']) 2015 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2016 self.assertEqual((0., 0.), input_details[0]['quantization']) 2017 2018 self.assertStartsWith(input_details[1]['name'], 'inputB') 2019 self.assertEqual(np.float32, input_details[1]['dtype']) 2020 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2021 self.assertEqual((0., 0.), input_details[1]['quantization']) 2022 2023 output_details = interpreter.get_output_details() 2024 self.assertLen(output_details, 1) 2025 self.assertStartsWith(output_details[0]['name'], 'add') 2026 self.assertEqual(np.float32, output_details[0]['dtype']) 2027 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2028 self.assertEqual((0., 0.), output_details[0]['quantization']) 2029 2030 def testOrderInputArrays(self): 2031 """Test a SavedModel ordering of input arrays.""" 2032 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2033 2034 converter = lite.TFLiteConverter.from_saved_model( 2035 saved_model_dir, input_arrays=['inputB', 'inputA']) 2036 tflite_model = converter.convert() 2037 self.assertIsNotNone(tflite_model) 2038 2039 # Check values from converted model. 2040 interpreter = Interpreter(model_content=tflite_model) 2041 interpreter.allocate_tensors() 2042 2043 input_details = interpreter.get_input_details() 2044 self.assertLen(input_details, 2) 2045 self.assertStartsWith(input_details[0]['name'], 'inputA') 2046 self.assertEqual(np.float32, input_details[0]['dtype']) 2047 self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape']) 2048 self.assertEqual((0., 0.), input_details[0]['quantization']) 2049 2050 self.assertStartsWith(input_details[1]['name'], 'inputB') 2051 self.assertEqual(np.float32, input_details[1]['dtype']) 2052 self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape']) 2053 self.assertEqual((0., 0.), input_details[1]['quantization']) 2054 2055 output_details = interpreter.get_output_details() 2056 self.assertLen(output_details, 1) 2057 self.assertStartsWith(output_details[0]['name'], 'add') 2058 self.assertEqual(np.float32, output_details[0]['dtype']) 2059 self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape']) 2060 self.assertEqual((0., 0.), output_details[0]['quantization']) 2061 2062 def testShapeOverriding(self): 2063 """Test a SavedModel with the input_shapes arugment.""" 2064 saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3]) 2065 2066 # Convert model and ensure model is not None. 2067 converter = lite.TFLiteConverter.from_saved_model( 2068 saved_model_dir, 2069 input_shapes={ 2070 'inputA': [2, 16, 16, 3], 2071 'inputB': [2, 16, 16, 3] 2072 }) 2073 tflite_model = converter.convert() 2074 self.assertIsNotNone(tflite_model) 2075 2076 interpreter = Interpreter(model_content=tflite_model) 2077 interpreter.allocate_tensors() 2078 2079 input_details = interpreter.get_input_details() 2080 self.assertLen(input_details, 2) 2081 self.assertStartsWith(input_details[0]['name'], 'inputA') 2082 self.assertEqual(np.float32, input_details[0]['dtype']) 2083 self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape']) 2084 self.assertEqual((0., 0.), input_details[0]['quantization']) 2085 2086 self.assertStartsWith(input_details[1]['name'], 'inputB') 2087 self.assertEqual(np.float32, input_details[1]['dtype']) 2088 self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape']) 2089 self.assertEqual((0., 0.), input_details[1]['quantization']) 2090 2091 output_details = interpreter.get_output_details() 2092 self.assertLen(output_details, 1) 2093 self.assertStartsWith(output_details[0]['name'], 'add') 2094 self.assertEqual(np.float32, output_details[0]['dtype']) 2095 self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape']) 2096 self.assertEqual((0., 0.), output_details[0]['quantization']) 2097 2098 def testWrongInputShapes(self): 2099 """Test a SavedModel with a wrong name in the input_shapes argument.""" 2100 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2101 2102 # Check case where input shape is given. 2103 with self.assertRaises(ValueError): 2104 lite.TFLiteConverter.from_saved_model( 2105 saved_model_dir, 2106 input_arrays=['inputA'], 2107 input_shapes={'wrong_input': [1, 16, 16, 3]}) 2108 2109 def testSubsetInputShaapes(self): 2110 """Test a SavedModel with a subset of the input array names of the model.""" 2111 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2112 2113 # Check case where input shape is given. 2114 converter = lite.TFLiteConverter.from_saved_model( 2115 saved_model_dir, 2116 input_arrays=['inputA'], 2117 input_shapes={'inputA': [1, 16, 16, 3]}) 2118 2119 # Since we only partially specify the input, this is not allowed. 2120 with self.assertRaises(ConverterError): 2121 _ = converter.convert() 2122 2123 # Check case where input shape is None. 2124 converter = lite.TFLiteConverter.from_saved_model( 2125 saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None}) 2126 2127 # Since we only partially specify the input, this is not allowed. 2128 with self.assertRaises(ConverterError): 2129 _ = converter.convert() 2130 2131 def testSimpleModelTocoConverter(self): 2132 """Test a SavedModel with deprecated TocoConverter.""" 2133 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2134 2135 # Convert model and ensure model is not None. 2136 converter = lite.TocoConverter.from_saved_model(saved_model_dir) 2137 tflite_model = converter.convert() 2138 self.assertIsNotNone(tflite_model) 2139 2140 # Ensure the model is able to load. 2141 interpreter = Interpreter(model_content=tflite_model) 2142 interpreter.allocate_tensors() 2143 2144 def testGraphDebugInfo(self): 2145 """Test a SavedModel has debug info captured.""" 2146 saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3]) 2147 converter = lite.TFLiteConverter.from_saved_model(saved_model_dir) 2148 converter.convert() 2149 self.assertValidDebugInfo(converter._debug_info) 2150 2151 2152class MyAddLayer(keras.layers.Layer): 2153 2154 def __init__(self, increment, **kwargs): 2155 super(MyAddLayer, self).__init__(**kwargs) 2156 self._increment = increment 2157 2158 def call(self, inputs): 2159 return inputs + self._increment 2160 2161 def get_config(self): 2162 config = super(MyAddLayer, self).get_config() 2163 config['increment'] = self._increment 2164 return config 2165 2166 2167class FromKerasFile(TestModels, parameterized.TestCase): 2168 2169 def setUp(self): 2170 super(FromKerasFile, self).setUp() 2171 self._keras_file = None 2172 self._custom_objects = None 2173 if not context.executing_eagerly(): 2174 keras.backend.clear_session() 2175 2176 def tearDown(self): 2177 if self._keras_file: 2178 os.remove(self._keras_file) 2179 super(FromKerasFile, self).tearDown() 2180 2181 def _getSequentialModel(self, include_custom_layer=False): 2182 model = keras.models.Sequential() 2183 model.add(keras.layers.Dense(2, input_shape=(3,))) 2184 if include_custom_layer: 2185 model.add(MyAddLayer(1.0)) 2186 model.add(keras.layers.RepeatVector(3)) 2187 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 2188 model.compile( 2189 loss=keras.losses.MSE, 2190 optimizer='sgd', 2191 metrics=[keras.metrics.categorical_accuracy], 2192 sample_weight_mode='temporal') 2193 x = np.random.random((1, 3)) 2194 y = np.random.random((1, 3, 3)) 2195 model.train_on_batch(x, y) 2196 model.predict(x) 2197 2198 try: 2199 fd, self._keras_file = tempfile.mkstemp('.h5') 2200 keras.models.save_model(model, self._keras_file) 2201 finally: 2202 os.close(fd) 2203 2204 if include_custom_layer: 2205 self._custom_objects = {'MyAddLayer': MyAddLayer} 2206 2207 @parameterized.named_parameters(('_graph', context.graph_mode), 2208 ('_eager', context.eager_mode)) 2209 def testSequentialModel(self, test_context): 2210 """Test a Sequential tf.keras model with default inputs.""" 2211 with test_context(): 2212 self._getSequentialModel() 2213 2214 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2215 tflite_model = converter.convert() 2216 self.assertIsNotNone(tflite_model) 2217 2218 # Check tensor details of converted model. 2219 interpreter = Interpreter(model_content=tflite_model) 2220 interpreter.allocate_tensors() 2221 2222 input_details = interpreter.get_input_details() 2223 self.assertLen(input_details, 1) 2224 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2225 self.assertEqual(np.float32, input_details[0]['dtype']) 2226 self.assertAllEqual([1, 3], input_details[0]['shape']) 2227 self.assertEqual((0., 0.), input_details[0]['quantization']) 2228 2229 output_details = interpreter.get_output_details() 2230 self.assertLen(output_details, 1) 2231 self.assertEqual(np.float32, output_details[0]['dtype']) 2232 self.assertAllEqual([1, 3, 3], output_details[0]['shape']) 2233 self.assertEqual((0., 0.), output_details[0]['quantization']) 2234 2235 # Check inference of converted model. 2236 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2237 interpreter.set_tensor(input_details[0]['index'], input_data) 2238 interpreter.invoke() 2239 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2240 2241 keras_model = keras.models.load_model(self._keras_file) 2242 keras_result = keras_model.predict(input_data) 2243 2244 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2245 2246 @parameterized.named_parameters(('_graph', context.graph_mode), 2247 ('_eager', context.eager_mode)) 2248 def testCustomLayer(self, test_context): 2249 """Test a Sequential tf.keras model with default inputs.""" 2250 with test_context(): 2251 self._getSequentialModel(include_custom_layer=True) 2252 2253 converter = lite.TFLiteConverter.from_keras_model_file( 2254 self._keras_file, custom_objects=self._custom_objects) 2255 tflite_model = converter.convert() 2256 self.assertIsNotNone(tflite_model) 2257 2258 # Check tensor details of converted model. 2259 interpreter = Interpreter(model_content=tflite_model) 2260 interpreter.allocate_tensors() 2261 2262 input_details = interpreter.get_input_details() 2263 output_details = interpreter.get_output_details() 2264 2265 # Check inference of converted model. 2266 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2267 interpreter.set_tensor(input_details[0]['index'], input_data) 2268 interpreter.invoke() 2269 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2270 2271 keras_model = keras.models.load_model( 2272 self._keras_file, custom_objects=self._custom_objects) 2273 keras_result = keras_model.predict(input_data) 2274 2275 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2276 2277 def testSequentialModelInputArray(self): 2278 """Test a Sequential tf.keras model testing input arrays argument.""" 2279 ops.disable_eager_execution() 2280 self._getSequentialModel() 2281 2282 # Invalid input array raises error. 2283 with self.assertRaises(ValueError) as error: 2284 lite.TFLiteConverter.from_keras_model_file( 2285 self._keras_file, input_arrays=['invalid-input']) 2286 self.assertEqual("Invalid tensors 'invalid-input' were found.", 2287 str(error.exception)) 2288 2289 # Valid input array. 2290 converter = lite.TFLiteConverter.from_keras_model_file( 2291 self._keras_file, input_arrays=['dense_input']) 2292 tflite_model = converter.convert() 2293 self.assertIsNotNone(tflite_model) 2294 2295 def testSequentialModelInputShape(self): 2296 """Test a Sequential tf.keras model testing input shapes argument.""" 2297 self._getSequentialModel() 2298 2299 # Passing in shape of invalid input array raises error. 2300 with self.assertRaises(ValueError) as error: 2301 converter = lite.TFLiteConverter.from_keras_model_file( 2302 self._keras_file, input_shapes={'invalid-input': [2, 3]}) 2303 self.assertEqual( 2304 "Invalid tensor 'invalid-input' found in tensor shapes map.", 2305 str(error.exception)) 2306 2307 # Passing in shape of valid input array. 2308 converter = lite.TFLiteConverter.from_keras_model_file( 2309 self._keras_file, input_shapes={'dense_input': [2, 3]}) 2310 tflite_model = converter.convert() 2311 self.assertIsNotNone(tflite_model) 2312 2313 # Check input shape from converted model. 2314 interpreter = Interpreter(model_content=tflite_model) 2315 interpreter.allocate_tensors() 2316 2317 input_details = interpreter.get_input_details() 2318 self.assertLen(input_details, 1) 2319 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2320 self.assertAllEqual([2, 3], input_details[0]['shape']) 2321 2322 def testSequentialModelOutputArray(self): 2323 """Test a Sequential tf.keras model testing output arrays argument.""" 2324 ops.disable_eager_execution() 2325 self._getSequentialModel() 2326 2327 # Invalid output array raises error. 2328 with self.assertRaises(ValueError) as error: 2329 lite.TFLiteConverter.from_keras_model_file( 2330 self._keras_file, output_arrays=['invalid-output']) 2331 self.assertEqual("Invalid tensors 'invalid-output' were found.", 2332 str(error.exception)) 2333 2334 # Valid output array. 2335 converter = lite.TFLiteConverter.from_keras_model_file( 2336 self._keras_file, output_arrays=['time_distributed/Reshape_1']) 2337 tflite_model = converter.convert() 2338 self.assertIsNotNone(tflite_model) 2339 2340 @parameterized.named_parameters(('_graph', context.graph_mode), 2341 ('_eager', context.eager_mode)) 2342 def testFunctionalModel(self, test_context): 2343 """Test a Functional tf.keras model with default inputs.""" 2344 with test_context(): 2345 inputs = keras.layers.Input(shape=(3,), name='input') 2346 x = keras.layers.Dense(2)(inputs) 2347 output = keras.layers.Dense(3)(x) 2348 2349 model = keras.models.Model(inputs, output) 2350 model.compile( 2351 loss=keras.losses.MSE, 2352 optimizer='sgd', 2353 metrics=[keras.metrics.categorical_accuracy]) 2354 x = np.random.random((1, 3)) 2355 y = np.random.random((1, 3)) 2356 model.train_on_batch(x, y) 2357 2358 model.predict(x) 2359 fd, self._keras_file = tempfile.mkstemp('.h5') 2360 try: 2361 keras.models.save_model(model, self._keras_file) 2362 finally: 2363 os.close(fd) 2364 2365 # Convert to TFLite model. 2366 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2367 tflite_model = converter.convert() 2368 self.assertIsNotNone(tflite_model) 2369 2370 # Check tensor details of converted model. 2371 interpreter = Interpreter(model_content=tflite_model) 2372 interpreter.allocate_tensors() 2373 2374 input_details = interpreter.get_input_details() 2375 self.assertLen(input_details, 1) 2376 self.assertEqual('input', input_details[0]['name']) 2377 self.assertEqual(np.float32, input_details[0]['dtype']) 2378 self.assertAllEqual([1, 3], input_details[0]['shape']) 2379 self.assertEqual((0., 0.), input_details[0]['quantization']) 2380 2381 output_details = interpreter.get_output_details() 2382 self.assertLen(output_details, 1) 2383 self.assertEqual(np.float32, output_details[0]['dtype']) 2384 self.assertAllEqual([1, 3], output_details[0]['shape']) 2385 self.assertEqual((0., 0.), output_details[0]['quantization']) 2386 2387 # Check inference of converted model. 2388 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2389 interpreter.set_tensor(input_details[0]['index'], input_data) 2390 interpreter.invoke() 2391 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2392 2393 keras_model = keras.models.load_model(self._keras_file) 2394 keras_result = keras_model.predict(input_data) 2395 2396 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2397 2398 def _getFunctionalModelMultipleInputs(self): 2399 a = keras.layers.Input(shape=(3,), name='input_a') 2400 b = keras.layers.Input(shape=(3,), name='input_b') 2401 dense = keras.layers.Dense(4, name='dense') 2402 c = dense(a) 2403 d = dense(b) 2404 e = keras.layers.Dropout(0.5, name='dropout')(c) 2405 2406 model = keras.models.Model([a, b], [d, e]) 2407 model.compile( 2408 loss=keras.losses.MSE, 2409 optimizer='sgd', 2410 metrics=[keras.metrics.mae], 2411 loss_weights=[1., 0.5]) 2412 2413 input_a_np = np.random.random((10, 3)) 2414 input_b_np = np.random.random((10, 3)) 2415 output_d_np = np.random.random((10, 4)) 2416 output_e_np = np.random.random((10, 4)) 2417 model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np]) 2418 2419 model.predict([input_a_np, input_b_np], batch_size=5) 2420 fd, self._keras_file = tempfile.mkstemp('.h5') 2421 try: 2422 keras.models.save_model(model, self._keras_file) 2423 finally: 2424 os.close(fd) 2425 2426 def testFunctionalModelMultipleInputs(self): 2427 """Test a Functional tf.keras model with multiple inputs and outputs.""" 2428 self._getFunctionalModelMultipleInputs() 2429 2430 # Convert to TFLite model. 2431 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2432 tflite_model = converter.convert() 2433 self.assertIsNotNone(tflite_model) 2434 2435 # Check values from converted model. 2436 interpreter = Interpreter(model_content=tflite_model) 2437 interpreter.allocate_tensors() 2438 2439 input_details = interpreter.get_input_details() 2440 self.assertLen(input_details, 2) 2441 self.assertEndsWith(input_details[0]['name'], 'input_a') 2442 self.assertEqual(np.float32, input_details[0]['dtype']) 2443 self.assertAllEqual([1, 3], input_details[0]['shape']) 2444 self.assertEqual((0., 0.), input_details[0]['quantization']) 2445 2446 self.assertEndsWith(input_details[1]['name'], 'input_b') 2447 self.assertEqual(np.float32, input_details[1]['dtype']) 2448 self.assertAllEqual([1, 3], input_details[1]['shape']) 2449 self.assertEqual((0., 0.), input_details[1]['quantization']) 2450 2451 output_details = interpreter.get_output_details() 2452 self.assertLen(output_details, 2) 2453 self.assertEqual(np.float32, output_details[0]['dtype']) 2454 self.assertAllEqual([1, 4], output_details[0]['shape']) 2455 self.assertEqual((0., 0.), output_details[0]['quantization']) 2456 2457 self.assertEqual(np.float32, output_details[1]['dtype']) 2458 self.assertAllEqual([1, 4], output_details[1]['shape']) 2459 self.assertEqual((0., 0.), output_details[1]['quantization']) 2460 2461 def testShapeOverriding(self): 2462 """Test a Functional tf.keras model with input shape overriding.""" 2463 self._getFunctionalModelMultipleInputs() 2464 2465 # Convert to TFLite model. 2466 converter = lite.TFLiteConverter.from_keras_model_file( 2467 self._keras_file, input_shapes={ 2468 'input_a': {2, 3}, 2469 'input_b': {2, 3} 2470 }) 2471 tflite_model = converter.convert() 2472 self.assertIsNotNone(tflite_model) 2473 2474 # Check values from converted model. 2475 interpreter = Interpreter(model_content=tflite_model) 2476 interpreter.allocate_tensors() 2477 2478 input_details = interpreter.get_input_details() 2479 self.assertLen(input_details, 2) 2480 self.assertEndsWith(input_details[0]['name'], 'input_a') 2481 self.assertEqual(np.float32, input_details[0]['dtype']) 2482 self.assertAllEqual([2, 3], input_details[0]['shape']) 2483 self.assertEqual((0., 0.), input_details[0]['quantization']) 2484 2485 self.assertEndsWith(input_details[1]['name'], 'input_b') 2486 self.assertEqual(np.float32, input_details[1]['dtype']) 2487 self.assertAllEqual([2, 3], input_details[1]['shape']) 2488 self.assertEqual((0., 0.), input_details[1]['quantization']) 2489 2490 output_details = interpreter.get_output_details() 2491 self.assertLen(output_details, 2) 2492 self.assertEqual(np.float32, output_details[0]['dtype']) 2493 self.assertAllEqual([2, 4], output_details[0]['shape']) 2494 self.assertEqual((0., 0.), output_details[0]['quantization']) 2495 2496 self.assertEqual(np.float32, output_details[1]['dtype']) 2497 self.assertAllEqual([2, 4], output_details[1]['shape']) 2498 self.assertEqual((0., 0.), output_details[1]['quantization']) 2499 2500 def testPartialShapeOverriding(self): 2501 """Test a Functional tf.keras model with partial input shape overriding.""" 2502 self._getFunctionalModelMultipleInputs() 2503 2504 # Convert to TFLite model. 2505 converter = lite.TFLiteConverter.from_keras_model_file( 2506 self._keras_file, input_shapes={'input_a': {2, 3}}) 2507 tflite_model = converter.convert() 2508 self.assertIsNotNone(tflite_model) 2509 2510 # Check values from converted model. 2511 interpreter = Interpreter(model_content=tflite_model) 2512 interpreter.allocate_tensors() 2513 2514 input_details = interpreter.get_input_details() 2515 self.assertLen(input_details, 2) 2516 self.assertEndsWith(input_details[0]['name'], 'input_a') 2517 self.assertEqual(np.float32, input_details[0]['dtype']) 2518 self.assertAllEqual([2, 3], input_details[0]['shape']) 2519 self.assertEqual((0., 0.), input_details[0]['quantization']) 2520 2521 self.assertEndsWith(input_details[1]['name'], 'input_b') 2522 self.assertEqual(np.float32, input_details[1]['dtype']) 2523 self.assertAllEqual([1, 3], input_details[1]['shape']) 2524 self.assertEqual((0., 0.), input_details[1]['quantization']) 2525 2526 output_details = interpreter.get_output_details() 2527 self.assertLen(output_details, 2) 2528 self.assertEqual(np.float32, output_details[0]['dtype']) 2529 self.assertAllEqual([1, 4], output_details[0]['shape']) 2530 self.assertEqual((0., 0.), output_details[0]['quantization']) 2531 2532 self.assertEqual(np.float32, output_details[1]['dtype']) 2533 self.assertAllEqual([2, 4], output_details[1]['shape']) 2534 self.assertEqual((0., 0.), output_details[1]['quantization']) 2535 2536 def testWrongShapeOverriding(self): 2537 """Test a Functional tf.keras model with wrong input shape overriding.""" 2538 self._getFunctionalModelMultipleInputs() 2539 2540 # Convert to TFLite model. 2541 with self.assertRaises(ValueError): 2542 lite.TFLiteConverter.from_keras_model_file( 2543 self._keras_file, input_shapes={'wrong_input': {2, 3}}) 2544 2545 def testFunctionalSequentialModel(self): 2546 """Test a Functional tf.keras model containing a Sequential model.""" 2547 model = keras.models.Sequential() 2548 model.add(keras.layers.Dense(2, input_shape=(3,))) 2549 model.add(keras.layers.RepeatVector(3)) 2550 model.add(keras.layers.TimeDistributed(keras.layers.Dense(3))) 2551 model = keras.models.Model(model.input, model.output) 2552 2553 model.compile( 2554 loss=keras.losses.MSE, 2555 optimizer='sgd', 2556 metrics=[keras.metrics.categorical_accuracy], 2557 sample_weight_mode='temporal') 2558 x = np.random.random((1, 3)) 2559 y = np.random.random((1, 3, 3)) 2560 model.train_on_batch(x, y) 2561 model.predict(x) 2562 2563 model.predict(x) 2564 fd, self._keras_file = tempfile.mkstemp('.h5') 2565 try: 2566 keras.models.save_model(model, self._keras_file) 2567 finally: 2568 os.close(fd) 2569 2570 # Convert to TFLite model. 2571 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2572 tflite_model = converter.convert() 2573 self.assertIsNotNone(tflite_model) 2574 2575 # Check tensor details of converted model. 2576 interpreter = Interpreter(model_content=tflite_model) 2577 interpreter.allocate_tensors() 2578 2579 input_details = interpreter.get_input_details() 2580 self.assertLen(input_details, 1) 2581 self.assertEndsWith(input_details[0]['name'], 'dense_input') 2582 self.assertEqual(np.float32, input_details[0]['dtype']) 2583 self.assertAllEqual([1, 3], input_details[0]['shape']) 2584 self.assertEqual((0., 0.), input_details[0]['quantization']) 2585 2586 output_details = interpreter.get_output_details() 2587 self.assertLen(output_details, 1) 2588 self.assertEqual(np.float32, output_details[0]['dtype']) 2589 self.assertAllEqual([1, 3, 3], output_details[0]['shape']) 2590 self.assertEqual((0., 0.), output_details[0]['quantization']) 2591 2592 # Check inference of converted model. 2593 input_data = np.array([[1, 2, 3]], dtype=np.float32) 2594 interpreter.set_tensor(input_details[0]['index'], input_data) 2595 interpreter.invoke() 2596 tflite_result = interpreter.get_tensor(output_details[0]['index']) 2597 2598 keras_model = keras.models.load_model(self._keras_file) 2599 keras_result = keras_model.predict(input_data) 2600 2601 np.testing.assert_almost_equal(tflite_result, keras_result, 5) 2602 2603 def testSequentialModelTocoConverter(self): 2604 """Test a Sequential tf.keras model with deprecated TocoConverter.""" 2605 self._getSequentialModel() 2606 2607 converter = lite.TocoConverter.from_keras_model_file(self._keras_file) 2608 tflite_model = converter.convert() 2609 self.assertIsNotNone(tflite_model) 2610 2611 # Ensure the model is able to load. 2612 interpreter = Interpreter(model_content=tflite_model) 2613 interpreter.allocate_tensors() 2614 2615 @parameterized.named_parameters(('_graph', context.graph_mode), 2616 ('_eager', context.eager_mode)) 2617 def testGraphDebugInfo(self, test_context): 2618 """Test a Sequential tf.keras model has debug info captured.""" 2619 with test_context(): 2620 self._getSequentialModel() 2621 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2622 converter.convert() 2623 self.assertValidDebugInfo(converter._debug_info) 2624 2625 def testSparsifyModel(self): 2626 self._getSequentialModel() 2627 2628 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2629 converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY} 2630 tflite_model = converter.convert() 2631 self.assertTrue(tflite_model) 2632 2633 def testSparsifyQuantizedModel(self): 2634 self._getSequentialModel() 2635 2636 converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file) 2637 converter.optimizations = { 2638 lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY 2639 } 2640 tflite_model = converter.convert() 2641 self.assertIsNotNone(tflite_model) 2642 2643 2644class GrapplerTest(TestModels, parameterized.TestCase): 2645 2646 def testConstantFolding(self): 2647 ops.disable_eager_execution() 2648 # Constant folding handles the tf.broadcast_to operation which was not 2649 # supported by the TFLite at the time this test was added. 2650 with ops.Graph().as_default(): 2651 in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32) 2652 y_const = constant_op.constant([1., 2., 3.]) 2653 y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3]) 2654 out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output') 2655 sess = session.Session() 2656 2657 # Convert model. 2658 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2659 [out_tensor]) 2660 tflite_model = converter.convert() 2661 2662 # Check values from converted model. 2663 interpreter = Interpreter(model_content=tflite_model) 2664 interpreter.allocate_tensors() 2665 2666 input_details = interpreter.get_input_details() 2667 self.assertLen(input_details, 1) 2668 self.assertEqual('Placeholder', input_details[0]['name']) 2669 self.assertEqual(np.float32, input_details[0]['dtype']) 2670 self.assertAllEqual([3, 3], input_details[0]['shape']) 2671 2672 output_details = interpreter.get_output_details() 2673 self.assertLen(output_details, 1) 2674 self.assertEqual('output', output_details[0]['name']) 2675 self.assertEqual(np.float32, output_details[0]['dtype']) 2676 self.assertAllEqual([3, 3], output_details[0]['shape']) 2677 2678 @parameterized.named_parameters( 2679 ('EnableMlirConverter', True), # enable mlir 2680 ('DisableMlirConverter', False)) # disable mlir 2681 def testInputNodeIsNotFolded(self, enable_mlir_converter): 2682 ops.disable_eager_execution() 2683 # Constant folding handles the tf.broadcast_to operation which was not 2684 # supported by the TFLite at the time this test was added. 2685 with ops.Graph().as_default(): 2686 in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32) 2687 y_const = constant_op.constant([1., 2., 3.]) 2688 y_add = y_const + y_const 2689 out_tensor = in_tensor * y_add 2690 sess = session.Session() 2691 2692 # Convert model. 2693 converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const], 2694 [out_tensor]) 2695 converter.experimental_new_converter = enable_mlir_converter 2696 tflite_model = converter.convert() 2697 2698 # Check values from converted model. 2699 interpreter = Interpreter(model_content=tflite_model) 2700 interpreter.allocate_tensors() 2701 2702 input_details = interpreter.get_input_details() 2703 self.assertLen(input_details, 2) 2704 self.assertEqual('Placeholder', input_details[0]['name']) 2705 self.assertEqual('Const', input_details[1]['name']) 2706 2707 def testGrapplerConstFolding(self): 2708 # Constant folding converts the following add operation to tf.broadcast_to 2709 # operation which was not supported by the TFLite at the time this test was 2710 # added. 2711 @def_function.function 2712 def plus_placeholder(x, placeholder): 2713 return x + placeholder 2714 2715 with ops.Graph().as_default(): 2716 in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) 2717 out_tensor = plus_placeholder( 2718 array_ops.zeros([2, 2, 2]), 2719 array_ops.reshape(in_tensor, shape=[2, 2])) 2720 sess = session.Session() 2721 2722 # Convert model. 2723 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2724 [out_tensor]) 2725 tflite_model = converter.convert() 2726 2727 # Check values from converted model. 2728 interpreter = Interpreter(model_content=tflite_model) 2729 interpreter.allocate_tensors() 2730 2731 input_details = interpreter.get_input_details() 2732 self.assertLen(input_details, 1) 2733 self.assertEqual('Placeholder', input_details[0]['name']) 2734 2735 2736class ImportOpsUtilTest(LiteTest): 2737 2738 def testGetPotentiallySupportedOps(self): 2739 self.assertIsNotNone(lite.get_potentially_supported_ops()) 2740 2741 2742class DefaultConverterAttrsTest(LiteTest): 2743 2744 def testAttrs(self): 2745 with ops.Graph().as_default(): 2746 in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32) 2747 out_tensor = in_tensor + in_tensor 2748 sess = session.Session() 2749 2750 # Convert model. 2751 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 2752 [out_tensor]) 2753 2754 # Assert output format. 2755 self.assertEqual(converter.output_format, lite_constants.TFLITE) 2756 2757 # Assert the default inference type is float. 2758 self.assertEqual(converter.inference_type, dtypes.float32) 2759 2760 # Assert the default inference type overrides are None. 2761 self.assertIsNone(converter.inference_input_type) 2762 self.assertIsNone(converter.inference_output_type) 2763 2764 # Assert the default quantization options are not set. 2765 self.assertEqual(converter.quantized_input_stats, {}) 2766 self.assertIsNone(converter.default_ranges_stats) 2767 self.assertFalse(converter.reorder_across_fake_quant) 2768 self.assertFalse(converter.change_concat_input_ranges) 2769 2770 # Assert dropping control dependency is enabled by default. 2771 self.assertIsNotNone(converter.drop_control_dependency) 2772 2773 # Assert dumping extra information is disabled by default. 2774 self.assertIsNone(converter.dump_graphviz_dir) 2775 self.assertFalse(converter.dump_graphviz_video) 2776 self.assertIsNone(converter.conversion_summary_dir) 2777 2778 2779class ControlFlowV1OpsTest(LiteTest): 2780 2781 def testConverterErrorOnControlFlowV1Ops(self): 2782 graph_def_file = resource_loader.get_path_to_datafile( 2783 'testdata/control_flow_v1.pbtxt') 2784 input_arrays = ['a', 'b', 'c', 'd'] 2785 output_arrays = ['Merge'] 2786 2787 converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file, 2788 input_arrays, 2789 output_arrays) 2790 with self.assertRaises(ConverterError) as error: 2791 converter.convert() 2792 self.assertIn( 2793 'Failed to functionalize Control Flow V1 ops. Consider using Control ' 2794 'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/' 2795 'tf/compat/v1/enable_control_flow_v2.', str(error.exception)) 2796 2797 2798if __name__ == '__main__': 2799 test.main() 2800