1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tflite_convert.py.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import os 22 23import numpy as np 24from tensorflow import keras 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.lite.python import test_util as tflite_test_util 28from tensorflow.lite.python import tflite_convert 29from tensorflow.lite.python.convert import register_custom_opdefs 30from tensorflow.python import tf2 31from tensorflow.python.client import session 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import test_util 37from tensorflow.python.framework.importer import import_graph_def 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import random_ops 40from tensorflow.python.platform import gfile 41from tensorflow.python.platform import resource_loader 42from tensorflow.python.platform import test 43from tensorflow.python.saved_model import saved_model 44from tensorflow.python.saved_model.save import save 45from tensorflow.python.training.tracking import tracking 46from tensorflow.python.training.training_util import write_graph 47 48 49class TestModels(test_util.TensorFlowTestCase): 50 51 def _getFilepath(self, filename): 52 return os.path.join(self.get_temp_dir(), filename) 53 54 def _run(self, 55 flags_str, 56 should_succeed, 57 expected_ops_in_converted_model=None, 58 expected_output_shapes=None): 59 output_file = os.path.join(self.get_temp_dir(), 'model.tflite') 60 tflite_bin = resource_loader.get_path_to_datafile('tflite_convert') 61 cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file, 62 flags_str) 63 64 exitcode = os.system(cmdline) 65 if exitcode == 0: 66 with gfile.Open(output_file, 'rb') as model_file: 67 content = model_file.read() 68 self.assertEqual(content is not None, should_succeed) 69 if expected_ops_in_converted_model: 70 op_set = tflite_test_util.get_ops_list(content) 71 for opname in expected_ops_in_converted_model: 72 self.assertIn(opname, op_set) 73 if expected_output_shapes: 74 output_shapes = tflite_test_util.get_output_shapes(content) 75 self.assertEqual(output_shapes, expected_output_shapes) 76 os.remove(output_file) 77 else: 78 self.assertFalse(should_succeed) 79 80 def _getKerasModelFile(self): 81 x = np.array([[1.], [2.]]) 82 y = np.array([[2.], [4.]]) 83 84 model = keras.models.Sequential([ 85 keras.layers.Dropout(0.2, input_shape=(1,)), 86 keras.layers.Dense(1), 87 ]) 88 model.compile(optimizer='sgd', loss='mean_squared_error') 89 model.fit(x, y, epochs=1) 90 91 keras_file = self._getFilepath('model.h5') 92 keras.models.save_model(model, keras_file) 93 return keras_file 94 95 def _getKerasFunctionalModelFile(self): 96 """Returns a functional Keras model with output shapes [[1, 1], [1, 2]].""" 97 input_tensor = keras.layers.Input(shape=(1,)) 98 output1 = keras.layers.Dense(1, name='b')(input_tensor) 99 output2 = keras.layers.Dense(2, name='a')(input_tensor) 100 model = keras.models.Model(inputs=input_tensor, outputs=[output1, output2]) 101 102 keras_file = self._getFilepath('functional_model.h5') 103 keras.models.save_model(model, keras_file) 104 return keras_file 105 106 107class TfLiteConvertV1Test(TestModels): 108 109 def _run(self, 110 flags_str, 111 should_succeed, 112 expected_ops_in_converted_model=None): 113 if tf2.enabled(): 114 flags_str += ' --enable_v1_converter' 115 super(TfLiteConvertV1Test, self)._run(flags_str, should_succeed, 116 expected_ops_in_converted_model) 117 118 def testFrozenGraphDef(self): 119 with ops.Graph().as_default(): 120 in_tensor = array_ops.placeholder( 121 shape=[1, 16, 16, 3], dtype=dtypes.float32) 122 _ = in_tensor + in_tensor 123 sess = session.Session() 124 125 # Write graph to file. 126 graph_def_file = self._getFilepath('model.pb') 127 write_graph(sess.graph_def, '', graph_def_file, False) 128 sess.close() 129 130 flags_str = ('--graph_def_file={0} --input_arrays={1} ' 131 '--output_arrays={2}'.format(graph_def_file, 'Placeholder', 132 'add')) 133 self._run(flags_str, should_succeed=True) 134 os.remove(graph_def_file) 135 136 # Run `tflite_convert` explicitly with the legacy converter. 137 # Before the new converter is enabled by default, this flag has no real 138 # effects. 139 def testFrozenGraphDefWithLegacyConverter(self): 140 with ops.Graph().as_default(): 141 in_tensor = array_ops.placeholder( 142 shape=[1, 16, 16, 3], dtype=dtypes.float32) 143 _ = in_tensor + in_tensor 144 sess = session.Session() 145 146 # Write graph to file. 147 graph_def_file = self._getFilepath('model.pb') 148 write_graph(sess.graph_def, '', graph_def_file, False) 149 sess.close() 150 151 flags_str = ( 152 '--graph_def_file={0} --input_arrays={1} ' 153 '--output_arrays={2} --experimental_new_converter=false'.format( 154 graph_def_file, 'Placeholder', 'add')) 155 self._run(flags_str, should_succeed=True) 156 os.remove(graph_def_file) 157 158 def testFrozenGraphDefNonPlaceholder(self): 159 with ops.Graph().as_default(): 160 in_tensor = random_ops.random_normal(shape=[1, 16, 16, 3], name='random') 161 _ = in_tensor + in_tensor 162 sess = session.Session() 163 164 # Write graph to file. 165 graph_def_file = self._getFilepath('model.pb') 166 write_graph(sess.graph_def, '', graph_def_file, False) 167 sess.close() 168 169 flags_str = ('--graph_def_file={0} --input_arrays={1} ' 170 '--output_arrays={2}'.format(graph_def_file, 'random', 'add')) 171 self._run(flags_str, should_succeed=True) 172 os.remove(graph_def_file) 173 174 def testQATFrozenGraphDefInt8(self): 175 with ops.Graph().as_default(): 176 in_tensor_1 = array_ops.placeholder( 177 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 178 in_tensor_2 = array_ops.placeholder( 179 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 180 _ = array_ops.fake_quant_with_min_max_args( 181 in_tensor_1 + in_tensor_2, min=0., max=1., name='output', 182 num_bits=16) # INT8 inference type works for 16 bits fake quant. 183 sess = session.Session() 184 185 # Write graph to file. 186 graph_def_file = self._getFilepath('model.pb') 187 write_graph(sess.graph_def, '', graph_def_file, False) 188 sess.close() 189 190 flags_str = ('--inference_type=INT8 --std_dev_values=128,128 ' 191 '--mean_values=128,128 ' 192 '--graph_def_file={0} --input_arrays={1},{2} ' 193 '--output_arrays={3}'.format(graph_def_file, 'inputA', 194 'inputB', 'output')) 195 self._run(flags_str, should_succeed=True) 196 os.remove(graph_def_file) 197 198 def testQATFrozenGraphDefUInt8(self): 199 with ops.Graph().as_default(): 200 in_tensor_1 = array_ops.placeholder( 201 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA') 202 in_tensor_2 = array_ops.placeholder( 203 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 204 _ = array_ops.fake_quant_with_min_max_args( 205 in_tensor_1 + in_tensor_2, min=0., max=1., name='output') 206 sess = session.Session() 207 208 # Write graph to file. 209 graph_def_file = self._getFilepath('model.pb') 210 write_graph(sess.graph_def, '', graph_def_file, False) 211 sess.close() 212 213 # Define converter flags 214 flags_str = ('--std_dev_values=128,128 --mean_values=128,128 ' 215 '--graph_def_file={0} --input_arrays={1} ' 216 '--output_arrays={2}'.format(graph_def_file, 'inputA,inputB', 217 'output')) 218 219 # Set inference_type UINT8 and (default) inference_input_type UINT8 220 flags_str_1 = flags_str + ' --inference_type=UINT8' 221 self._run(flags_str_1, should_succeed=True) 222 223 # Set inference_type UINT8 and inference_input_type FLOAT 224 flags_str_2 = flags_str_1 + ' --inference_input_type=FLOAT' 225 self._run(flags_str_2, should_succeed=True) 226 227 os.remove(graph_def_file) 228 229 def testSavedModel(self): 230 saved_model_dir = self._getFilepath('model') 231 with ops.Graph().as_default(): 232 with session.Session() as sess: 233 in_tensor = array_ops.placeholder( 234 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB') 235 out_tensor = in_tensor + in_tensor 236 inputs = {'x': in_tensor} 237 outputs = {'z': out_tensor} 238 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 239 240 flags_str = '--saved_model_dir={}'.format(saved_model_dir) 241 self._run(flags_str, should_succeed=True) 242 243 def _createSavedModelWithCustomOp(self, opname='CustomAdd'): 244 custom_opdefs_str = ( 245 'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} ' 246 'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: ' 247 '\'Output\' type: DT_FLOAT}') 248 249 # Create a graph that has one add op. 250 new_graph = graph_pb2.GraphDef() 251 with ops.Graph().as_default(): 252 with session.Session() as sess: 253 in_tensor = array_ops.placeholder( 254 shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input') 255 out_tensor = in_tensor + in_tensor 256 inputs = {'x': in_tensor} 257 outputs = {'z': out_tensor} 258 259 new_graph.CopyFrom(sess.graph_def) 260 261 # Rename Add op name to opname. 262 for node in new_graph.node: 263 if node.op.startswith('Add'): 264 node.op = opname 265 del node.attr['T'] 266 267 # Register custom op defs to import modified graph def. 268 register_custom_opdefs([custom_opdefs_str]) 269 270 # Store saved model. 271 saved_model_dir = self._getFilepath('model') 272 with ops.Graph().as_default(): 273 with session.Session() as sess: 274 import_graph_def(new_graph, name='') 275 saved_model.simple_save(sess, saved_model_dir, inputs, outputs) 276 return (saved_model_dir, custom_opdefs_str) 277 278 def testEnsureCustomOpdefsFlag(self): 279 saved_model_dir, _ = self._createSavedModelWithCustomOp() 280 281 # Ensure --custom_opdefs. 282 flags_str = ('--saved_model_dir={0} --allow_custom_ops ' 283 '--experimental_new_converter'.format(saved_model_dir)) 284 self._run(flags_str, should_succeed=False) 285 286 def testSavedModelWithCustomOpdefsFlag(self): 287 saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp() 288 289 # Valid conversion. 290 flags_str = ( 291 '--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops ' 292 '--experimental_new_converter'.format(saved_model_dir, 293 custom_opdefs_str)) 294 self._run( 295 flags_str, 296 should_succeed=True, 297 expected_ops_in_converted_model=['CustomAdd']) 298 299 def testSavedModelWithFlex(self): 300 saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp( 301 opname='CustomAdd2') 302 303 # Valid conversion. OpDef already registered. 304 flags_str = ('--saved_model_dir={0} --allow_custom_ops ' 305 '--custom_opdefs="{1}" ' 306 '--experimental_new_converter ' 307 '--experimental_select_user_tf_ops=CustomAdd2 ' 308 '--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS'.format( 309 saved_model_dir, custom_opdefs_str)) 310 self._run( 311 flags_str, 312 should_succeed=True, 313 expected_ops_in_converted_model=['FlexCustomAdd2']) 314 315 def testSavedModelWithInvalidCustomOpdefsFlag(self): 316 saved_model_dir, _ = self._createSavedModelWithCustomOp() 317 318 invalid_custom_opdefs_str = ( 319 'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} ' 320 'output_arg: {name: \'Output\' type: DT_FLOAT}') 321 322 # Valid conversion. 323 flags_str = ( 324 '--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops ' 325 '--experimental_new_converter'.format(saved_model_dir, 326 invalid_custom_opdefs_str)) 327 self._run(flags_str, should_succeed=False) 328 329 def testKerasFile(self): 330 keras_file = self._getKerasModelFile() 331 332 flags_str = '--keras_model_file={}'.format(keras_file) 333 self._run(flags_str, should_succeed=True) 334 os.remove(keras_file) 335 336 def testKerasFileMLIR(self): 337 keras_file = self._getKerasModelFile() 338 339 flags_str = ( 340 '--keras_model_file={} --experimental_new_converter'.format(keras_file)) 341 self._run(flags_str, should_succeed=True) 342 os.remove(keras_file) 343 344 def testConversionSummary(self): 345 keras_file = self._getKerasModelFile() 346 log_dir = self.get_temp_dir() 347 348 flags_str = ('--keras_model_file={} --experimental_new_converter ' 349 '--conversion_summary_dir={}'.format(keras_file, log_dir)) 350 self._run(flags_str, should_succeed=True) 351 os.remove(keras_file) 352 353 num_items_conversion_summary = len(os.listdir(log_dir)) 354 self.assertTrue(num_items_conversion_summary) 355 356 def testConversionSummaryWithOldConverter(self): 357 keras_file = self._getKerasModelFile() 358 log_dir = self.get_temp_dir() 359 360 flags_str = ('--keras_model_file={} --experimental_new_converter=false ' 361 '--conversion_summary_dir={}'.format(keras_file, log_dir)) 362 self._run(flags_str, should_succeed=True) 363 os.remove(keras_file) 364 365 num_items_conversion_summary = len(os.listdir(log_dir)) 366 self.assertEqual(num_items_conversion_summary, 0) 367 368 def _initObjectDetectionArgs(self): 369 # Initializes the arguments required for the object detection model. 370 # Looks for the model file which is saved in a different location internally 371 # and externally. 372 filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb') 373 if not os.path.exists(filename): 374 filename = os.path.join( 375 resource_loader.get_root_dir_with_all_resources(), 376 '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb') 377 if not os.path.exists(filename): 378 raise IOError("File '{0}' does not exist.".format(filename)) 379 380 self._graph_def_file = filename 381 self._input_arrays = 'normalized_input_image_tensor' 382 self._output_arrays = ( 383 'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,' 384 'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3') 385 self._input_shapes = '1,300,300,3' 386 387 def testObjectDetection(self): 388 """Tests object detection model through TOCO.""" 389 self._initObjectDetectionArgs() 390 flags_str = ('--graph_def_file={0} --input_arrays={1} ' 391 '--output_arrays={2} --input_shapes={3} ' 392 '--allow_custom_ops'.format(self._graph_def_file, 393 self._input_arrays, 394 self._output_arrays, 395 self._input_shapes)) 396 self._run(flags_str, should_succeed=True) 397 398 def testObjectDetectionMLIR(self): 399 """Tests object detection model through MLIR converter.""" 400 self._initObjectDetectionArgs() 401 custom_opdefs_str = ( 402 'name: \'TFLite_Detection_PostProcess\' ' 403 'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } ' 404 'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } ' 405 'input_arg: { name: \'anchors\' type: DT_FLOAT } ' 406 'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } ' 407 'output_arg: { name: \'TFLite_Detection_PostProcess:1\' ' 408 'type: DT_FLOAT } ' 409 'output_arg: { name: \'TFLite_Detection_PostProcess:2\' ' 410 'type: DT_FLOAT } ' 411 'output_arg: { name: \'TFLite_Detection_PostProcess:3\' ' 412 'type: DT_FLOAT } ' 413 'attr : { name: \'h_scale\' type: \'float\'} ' 414 'attr : { name: \'max_classes_per_detection\' type: \'int\'} ' 415 'attr : { name: \'max_detections\' type: \'int\'} ' 416 'attr : { name: \'nms_iou_threshold\' type: \'float\'} ' 417 'attr : { name: \'nms_score_threshold\' type: \'float\'} ' 418 'attr : { name: \'num_classes\' type: \'int\'} ' 419 'attr : { name: \'w_scale\' type: \'float\'} ' 420 'attr : { name: \'x_scale\' type: \'float\'} ' 421 'attr : { name: \'y_scale\' type: \'float\'}') 422 423 flags_str = ('--graph_def_file={0} --input_arrays={1} ' 424 '--output_arrays={2} --input_shapes={3} ' 425 '--custom_opdefs="{4}"'.format(self._graph_def_file, 426 self._input_arrays, 427 self._output_arrays, 428 self._input_shapes, 429 custom_opdefs_str)) 430 431 # Ensure --allow_custom_ops. 432 flags_str_final = ('{} --allow_custom_ops').format(flags_str) 433 self._run(flags_str_final, should_succeed=False) 434 435 # Ensure --experimental_new_converter. 436 flags_str_final = ('{} --experimental_new_converter').format(flags_str) 437 self._run(flags_str_final, should_succeed=False) 438 439 # Valid conversion. 440 flags_str_final = ('{} --allow_custom_ops ' 441 '--experimental_new_converter').format(flags_str) 442 self._run( 443 flags_str_final, 444 should_succeed=True, 445 expected_ops_in_converted_model=['TFLite_Detection_PostProcess']) 446 447 def testObjectDetectionMLIRWithFlex(self): 448 """Tests object detection model through MLIR converter.""" 449 self._initObjectDetectionArgs() 450 451 flags_str = ('--graph_def_file={0} --input_arrays={1} ' 452 '--output_arrays={2} --input_shapes={3}'.format( 453 self._graph_def_file, self._input_arrays, 454 self._output_arrays, self._input_shapes)) 455 456 # Valid conversion. 457 flags_str_final = ( 458 '{} --allow_custom_ops ' 459 '--experimental_new_converter ' 460 '--experimental_select_user_tf_ops=TFLite_Detection_PostProcess ' 461 '--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS').format(flags_str) 462 self._run( 463 flags_str_final, 464 should_succeed=True, 465 expected_ops_in_converted_model=['FlexTFLite_Detection_PostProcess']) 466 467 468class TfLiteConvertV2Test(TestModels): 469 470 @test_util.run_v2_only 471 def testSavedModel(self): 472 input_data = constant_op.constant(1., shape=[1]) 473 root = tracking.AutoTrackable() 474 root.f = def_function.function(lambda x: 2. * x) 475 to_save = root.f.get_concrete_function(input_data) 476 477 saved_model_dir = self._getFilepath('model') 478 save(root, saved_model_dir, to_save) 479 480 flags_str = '--saved_model_dir={}'.format(saved_model_dir) 481 self._run(flags_str, should_succeed=True) 482 483 @test_util.run_v2_only 484 def testKerasFile(self): 485 keras_file = self._getKerasModelFile() 486 487 flags_str = '--keras_model_file={}'.format(keras_file) 488 self._run(flags_str, should_succeed=True) 489 os.remove(keras_file) 490 491 @test_util.run_v2_only 492 def testKerasFileMLIR(self): 493 keras_file = self._getKerasModelFile() 494 495 flags_str = ( 496 '--keras_model_file={} --experimental_new_converter'.format(keras_file)) 497 self._run(flags_str, should_succeed=True) 498 os.remove(keras_file) 499 500 @test_util.run_v2_only 501 def testFunctionalKerasModel(self): 502 keras_file = self._getKerasFunctionalModelFile() 503 504 flags_str = '--keras_model_file={}'.format(keras_file) 505 self._run(flags_str, should_succeed=True, 506 expected_output_shapes=[[1, 1], [1, 2]]) 507 os.remove(keras_file) 508 509 @test_util.run_v2_only 510 def testFunctionalKerasModelMLIR(self): 511 keras_file = self._getKerasFunctionalModelFile() 512 513 flags_str = ( 514 '--keras_model_file={} --experimental_new_converter'.format(keras_file)) 515 self._run(flags_str, should_succeed=True, 516 expected_output_shapes=[[1, 1], [1, 2]]) 517 os.remove(keras_file) 518 519 def testMissingRequired(self): 520 self._run('--invalid_args', should_succeed=False) 521 522 def testMutuallyExclusive(self): 523 self._run( 524 '--keras_model_file=model.h5 --saved_model_dir=/tmp/', 525 should_succeed=False) 526 527 528class ArgParserTest(test_util.TensorFlowTestCase): 529 530 def test_without_experimental_new_converter(self): 531 args = [ 532 '--saved_model_dir=/tmp/saved_model/', 533 '--output_file=/tmp/output.tflite', 534 ] 535 536 # Note that when the flag parses to None, the converter uses the default 537 # value, which is True. 538 539 # V1 parser. 540 parser = tflite_convert._get_parser(use_v2_converter=False) 541 parsed_args = parser.parse_args(args) 542 self.assertIsNone(parsed_args.experimental_new_converter) 543 self.assertFalse(parsed_args.experimental_new_quantizer) 544 545 # V2 parser. 546 parser = tflite_convert._get_parser(use_v2_converter=True) 547 parsed_args = parser.parse_args(args) 548 self.assertIsNone(parsed_args.experimental_new_converter) 549 self.assertFalse(parsed_args.experimental_new_quantizer) 550 551 def test_experimental_new_converter(self): 552 args = [ 553 '--saved_model_dir=/tmp/saved_model/', 554 '--output_file=/tmp/output.tflite', 555 '--experimental_new_converter', 556 ] 557 558 # V1 parser. 559 parser = tflite_convert._get_parser(use_v2_converter=False) 560 parsed_args = parser.parse_args(args) 561 self.assertTrue(parsed_args.experimental_new_converter) 562 563 # V2 parser. 564 parser = tflite_convert._get_parser(use_v2_converter=True) 565 parsed_args = parser.parse_args(args) 566 self.assertTrue(parsed_args.experimental_new_converter) 567 568 def test_experimental_new_converter_true(self): 569 args = [ 570 '--saved_model_dir=/tmp/saved_model/', 571 '--output_file=/tmp/output.tflite', 572 '--experimental_new_converter=true', 573 ] 574 575 # V1 parser. 576 parser = tflite_convert._get_parser(False) 577 parsed_args = parser.parse_args(args) 578 self.assertTrue(parsed_args.experimental_new_converter) 579 580 # V2 parser. 581 parser = tflite_convert._get_parser(True) 582 parsed_args = parser.parse_args(args) 583 self.assertTrue(parsed_args.experimental_new_converter) 584 585 def test_experimental_new_converter_false(self): 586 args = [ 587 '--saved_model_dir=/tmp/saved_model/', 588 '--output_file=/tmp/output.tflite', 589 '--experimental_new_converter=false', 590 ] 591 592 # V1 parser. 593 parser = tflite_convert._get_parser(use_v2_converter=False) 594 parsed_args = parser.parse_args(args) 595 self.assertFalse(parsed_args.experimental_new_converter) 596 597 # V2 parser. 598 parser = tflite_convert._get_parser(use_v2_converter=True) 599 parsed_args = parser.parse_args(args) 600 self.assertFalse(parsed_args.experimental_new_converter) 601 602 def test_experimental_new_quantizer(self): 603 args = [ 604 '--saved_model_dir=/tmp/saved_model/', 605 '--output_file=/tmp/output.tflite', 606 '--experimental_new_quantizer', 607 ] 608 609 # V1 parser. 610 parser = tflite_convert._get_parser(use_v2_converter=False) 611 parsed_args = parser.parse_args(args) 612 self.assertTrue(parsed_args.experimental_new_quantizer) 613 614 # V2 parser. 615 parser = tflite_convert._get_parser(use_v2_converter=True) 616 parsed_args = parser.parse_args(args) 617 self.assertTrue(parsed_args.experimental_new_quantizer) 618 619if __name__ == '__main__': 620 test.main() 621