1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tflite_convert.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22
23import numpy as np
24from tensorflow import keras
25
26from tensorflow.core.framework import graph_pb2
27from tensorflow.lite.python import test_util as tflite_test_util
28from tensorflow.lite.python import tflite_convert
29from tensorflow.lite.python.convert import register_custom_opdefs
30from tensorflow.python import tf2
31from tensorflow.python.client import session
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import test_util
37from tensorflow.python.framework.importer import import_graph_def
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import random_ops
40from tensorflow.python.platform import gfile
41from tensorflow.python.platform import resource_loader
42from tensorflow.python.platform import test
43from tensorflow.python.saved_model import saved_model
44from tensorflow.python.saved_model.save import save
45from tensorflow.python.training.tracking import tracking
46from tensorflow.python.training.training_util import write_graph
47
48
49class TestModels(test_util.TensorFlowTestCase):
50
51  def _getFilepath(self, filename):
52    return os.path.join(self.get_temp_dir(), filename)
53
54  def _run(self,
55           flags_str,
56           should_succeed,
57           expected_ops_in_converted_model=None,
58           expected_output_shapes=None):
59    output_file = os.path.join(self.get_temp_dir(), 'model.tflite')
60    tflite_bin = resource_loader.get_path_to_datafile('tflite_convert')
61    cmdline = '{0} --output_file={1} {2}'.format(tflite_bin, output_file,
62                                                 flags_str)
63
64    exitcode = os.system(cmdline)
65    if exitcode == 0:
66      with gfile.Open(output_file, 'rb') as model_file:
67        content = model_file.read()
68      self.assertEqual(content is not None, should_succeed)
69      if expected_ops_in_converted_model:
70        op_set = tflite_test_util.get_ops_list(content)
71        for opname in expected_ops_in_converted_model:
72          self.assertIn(opname, op_set)
73      if expected_output_shapes:
74        output_shapes = tflite_test_util.get_output_shapes(content)
75        self.assertEqual(output_shapes, expected_output_shapes)
76      os.remove(output_file)
77    else:
78      self.assertFalse(should_succeed)
79
80  def _getKerasModelFile(self):
81    x = np.array([[1.], [2.]])
82    y = np.array([[2.], [4.]])
83
84    model = keras.models.Sequential([
85        keras.layers.Dropout(0.2, input_shape=(1,)),
86        keras.layers.Dense(1),
87    ])
88    model.compile(optimizer='sgd', loss='mean_squared_error')
89    model.fit(x, y, epochs=1)
90
91    keras_file = self._getFilepath('model.h5')
92    keras.models.save_model(model, keras_file)
93    return keras_file
94
95  def _getKerasFunctionalModelFile(self):
96    """Returns a functional Keras model with output shapes [[1, 1], [1, 2]]."""
97    input_tensor = keras.layers.Input(shape=(1,))
98    output1 = keras.layers.Dense(1, name='b')(input_tensor)
99    output2 = keras.layers.Dense(2, name='a')(input_tensor)
100    model = keras.models.Model(inputs=input_tensor, outputs=[output1, output2])
101
102    keras_file = self._getFilepath('functional_model.h5')
103    keras.models.save_model(model, keras_file)
104    return keras_file
105
106
107class TfLiteConvertV1Test(TestModels):
108
109  def _run(self,
110           flags_str,
111           should_succeed,
112           expected_ops_in_converted_model=None):
113    if tf2.enabled():
114      flags_str += ' --enable_v1_converter'
115    super(TfLiteConvertV1Test, self)._run(flags_str, should_succeed,
116                                          expected_ops_in_converted_model)
117
118  def testFrozenGraphDef(self):
119    with ops.Graph().as_default():
120      in_tensor = array_ops.placeholder(
121          shape=[1, 16, 16, 3], dtype=dtypes.float32)
122      _ = in_tensor + in_tensor
123      sess = session.Session()
124
125    # Write graph to file.
126    graph_def_file = self._getFilepath('model.pb')
127    write_graph(sess.graph_def, '', graph_def_file, False)
128    sess.close()
129
130    flags_str = ('--graph_def_file={0} --input_arrays={1} '
131                 '--output_arrays={2}'.format(graph_def_file, 'Placeholder',
132                                              'add'))
133    self._run(flags_str, should_succeed=True)
134    os.remove(graph_def_file)
135
136  # Run `tflite_convert` explicitly with the legacy converter.
137  # Before the new converter is enabled by default, this flag has no real
138  # effects.
139  def testFrozenGraphDefWithLegacyConverter(self):
140    with ops.Graph().as_default():
141      in_tensor = array_ops.placeholder(
142          shape=[1, 16, 16, 3], dtype=dtypes.float32)
143      _ = in_tensor + in_tensor
144      sess = session.Session()
145
146    # Write graph to file.
147    graph_def_file = self._getFilepath('model.pb')
148    write_graph(sess.graph_def, '', graph_def_file, False)
149    sess.close()
150
151    flags_str = (
152        '--graph_def_file={0} --input_arrays={1} '
153        '--output_arrays={2} --experimental_new_converter=false'.format(
154            graph_def_file, 'Placeholder', 'add'))
155    self._run(flags_str, should_succeed=True)
156    os.remove(graph_def_file)
157
158  def testFrozenGraphDefNonPlaceholder(self):
159    with ops.Graph().as_default():
160      in_tensor = random_ops.random_normal(shape=[1, 16, 16, 3], name='random')
161      _ = in_tensor + in_tensor
162      sess = session.Session()
163
164    # Write graph to file.
165    graph_def_file = self._getFilepath('model.pb')
166    write_graph(sess.graph_def, '', graph_def_file, False)
167    sess.close()
168
169    flags_str = ('--graph_def_file={0} --input_arrays={1} '
170                 '--output_arrays={2}'.format(graph_def_file, 'random', 'add'))
171    self._run(flags_str, should_succeed=True)
172    os.remove(graph_def_file)
173
174  def testQATFrozenGraphDefInt8(self):
175    with ops.Graph().as_default():
176      in_tensor_1 = array_ops.placeholder(
177          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
178      in_tensor_2 = array_ops.placeholder(
179          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
180      _ = array_ops.fake_quant_with_min_max_args(
181          in_tensor_1 + in_tensor_2, min=0., max=1., name='output',
182          num_bits=16)  # INT8 inference type works for 16 bits fake quant.
183      sess = session.Session()
184
185    # Write graph to file.
186    graph_def_file = self._getFilepath('model.pb')
187    write_graph(sess.graph_def, '', graph_def_file, False)
188    sess.close()
189
190    flags_str = ('--inference_type=INT8 --std_dev_values=128,128 '
191                 '--mean_values=128,128 '
192                 '--graph_def_file={0} --input_arrays={1},{2} '
193                 '--output_arrays={3}'.format(graph_def_file, 'inputA',
194                                              'inputB', 'output'))
195    self._run(flags_str, should_succeed=True)
196    os.remove(graph_def_file)
197
198  def testQATFrozenGraphDefUInt8(self):
199    with ops.Graph().as_default():
200      in_tensor_1 = array_ops.placeholder(
201          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
202      in_tensor_2 = array_ops.placeholder(
203          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
204      _ = array_ops.fake_quant_with_min_max_args(
205          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
206      sess = session.Session()
207
208    # Write graph to file.
209    graph_def_file = self._getFilepath('model.pb')
210    write_graph(sess.graph_def, '', graph_def_file, False)
211    sess.close()
212
213    # Define converter flags
214    flags_str = ('--std_dev_values=128,128 --mean_values=128,128 '
215                 '--graph_def_file={0} --input_arrays={1} '
216                 '--output_arrays={2}'.format(graph_def_file, 'inputA,inputB',
217                                              'output'))
218
219    # Set inference_type UINT8 and (default) inference_input_type UINT8
220    flags_str_1 = flags_str + ' --inference_type=UINT8'
221    self._run(flags_str_1, should_succeed=True)
222
223    # Set inference_type UINT8 and inference_input_type FLOAT
224    flags_str_2 = flags_str_1 + ' --inference_input_type=FLOAT'
225    self._run(flags_str_2, should_succeed=True)
226
227    os.remove(graph_def_file)
228
229  def testSavedModel(self):
230    saved_model_dir = self._getFilepath('model')
231    with ops.Graph().as_default():
232      with session.Session() as sess:
233        in_tensor = array_ops.placeholder(
234            shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
235        out_tensor = in_tensor + in_tensor
236        inputs = {'x': in_tensor}
237        outputs = {'z': out_tensor}
238        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
239
240    flags_str = '--saved_model_dir={}'.format(saved_model_dir)
241    self._run(flags_str, should_succeed=True)
242
243  def _createSavedModelWithCustomOp(self, opname='CustomAdd'):
244    custom_opdefs_str = (
245        'name: \'' + opname + '\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
246        'input_arg: {name: \'Input2\' type: DT_FLOAT} output_arg: {name: '
247        '\'Output\' type: DT_FLOAT}')
248
249    # Create a graph that has one add op.
250    new_graph = graph_pb2.GraphDef()
251    with ops.Graph().as_default():
252      with session.Session() as sess:
253        in_tensor = array_ops.placeholder(
254            shape=[1, 16, 16, 3], dtype=dtypes.float32, name='input')
255        out_tensor = in_tensor + in_tensor
256        inputs = {'x': in_tensor}
257        outputs = {'z': out_tensor}
258
259        new_graph.CopyFrom(sess.graph_def)
260
261    # Rename Add op name to opname.
262    for node in new_graph.node:
263      if node.op.startswith('Add'):
264        node.op = opname
265        del node.attr['T']
266
267    # Register custom op defs to import modified graph def.
268    register_custom_opdefs([custom_opdefs_str])
269
270    # Store saved model.
271    saved_model_dir = self._getFilepath('model')
272    with ops.Graph().as_default():
273      with session.Session() as sess:
274        import_graph_def(new_graph, name='')
275        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
276    return (saved_model_dir, custom_opdefs_str)
277
278  def testEnsureCustomOpdefsFlag(self):
279    saved_model_dir, _ = self._createSavedModelWithCustomOp()
280
281    # Ensure --custom_opdefs.
282    flags_str = ('--saved_model_dir={0} --allow_custom_ops '
283                 '--experimental_new_converter'.format(saved_model_dir))
284    self._run(flags_str, should_succeed=False)
285
286  def testSavedModelWithCustomOpdefsFlag(self):
287    saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp()
288
289    # Valid conversion.
290    flags_str = (
291        '--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
292        '--experimental_new_converter'.format(saved_model_dir,
293                                              custom_opdefs_str))
294    self._run(
295        flags_str,
296        should_succeed=True,
297        expected_ops_in_converted_model=['CustomAdd'])
298
299  def testSavedModelWithFlex(self):
300    saved_model_dir, custom_opdefs_str = self._createSavedModelWithCustomOp(
301        opname='CustomAdd2')
302
303    # Valid conversion. OpDef already registered.
304    flags_str = ('--saved_model_dir={0} --allow_custom_ops '
305                 '--custom_opdefs="{1}" '
306                 '--experimental_new_converter '
307                 '--experimental_select_user_tf_ops=CustomAdd2 '
308                 '--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS'.format(
309                     saved_model_dir, custom_opdefs_str))
310    self._run(
311        flags_str,
312        should_succeed=True,
313        expected_ops_in_converted_model=['FlexCustomAdd2'])
314
315  def testSavedModelWithInvalidCustomOpdefsFlag(self):
316    saved_model_dir, _ = self._createSavedModelWithCustomOp()
317
318    invalid_custom_opdefs_str = (
319        'name: \'CustomAdd\' input_arg: {name: \'Input1\' type: DT_FLOAT} '
320        'output_arg: {name: \'Output\' type: DT_FLOAT}')
321
322    # Valid conversion.
323    flags_str = (
324        '--saved_model_dir={0} --custom_opdefs="{1}" --allow_custom_ops '
325        '--experimental_new_converter'.format(saved_model_dir,
326                                              invalid_custom_opdefs_str))
327    self._run(flags_str, should_succeed=False)
328
329  def testKerasFile(self):
330    keras_file = self._getKerasModelFile()
331
332    flags_str = '--keras_model_file={}'.format(keras_file)
333    self._run(flags_str, should_succeed=True)
334    os.remove(keras_file)
335
336  def testKerasFileMLIR(self):
337    keras_file = self._getKerasModelFile()
338
339    flags_str = (
340        '--keras_model_file={} --experimental_new_converter'.format(keras_file))
341    self._run(flags_str, should_succeed=True)
342    os.remove(keras_file)
343
344  def testConversionSummary(self):
345    keras_file = self._getKerasModelFile()
346    log_dir = self.get_temp_dir()
347
348    flags_str = ('--keras_model_file={} --experimental_new_converter  '
349                 '--conversion_summary_dir={}'.format(keras_file, log_dir))
350    self._run(flags_str, should_succeed=True)
351    os.remove(keras_file)
352
353    num_items_conversion_summary = len(os.listdir(log_dir))
354    self.assertTrue(num_items_conversion_summary)
355
356  def testConversionSummaryWithOldConverter(self):
357    keras_file = self._getKerasModelFile()
358    log_dir = self.get_temp_dir()
359
360    flags_str = ('--keras_model_file={} --experimental_new_converter=false '
361                 '--conversion_summary_dir={}'.format(keras_file, log_dir))
362    self._run(flags_str, should_succeed=True)
363    os.remove(keras_file)
364
365    num_items_conversion_summary = len(os.listdir(log_dir))
366    self.assertEqual(num_items_conversion_summary, 0)
367
368  def _initObjectDetectionArgs(self):
369    # Initializes the arguments required for the object detection model.
370    # Looks for the model file which is saved in a different location internally
371    # and externally.
372    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
373    if not os.path.exists(filename):
374      filename = os.path.join(
375          resource_loader.get_root_dir_with_all_resources(),
376          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
377      if not os.path.exists(filename):
378        raise IOError("File '{0}' does not exist.".format(filename))
379
380    self._graph_def_file = filename
381    self._input_arrays = 'normalized_input_image_tensor'
382    self._output_arrays = (
383        'TFLite_Detection_PostProcess,TFLite_Detection_PostProcess:1,'
384        'TFLite_Detection_PostProcess:2,TFLite_Detection_PostProcess:3')
385    self._input_shapes = '1,300,300,3'
386
387  def testObjectDetection(self):
388    """Tests object detection model through TOCO."""
389    self._initObjectDetectionArgs()
390    flags_str = ('--graph_def_file={0} --input_arrays={1} '
391                 '--output_arrays={2} --input_shapes={3} '
392                 '--allow_custom_ops'.format(self._graph_def_file,
393                                             self._input_arrays,
394                                             self._output_arrays,
395                                             self._input_shapes))
396    self._run(flags_str, should_succeed=True)
397
398  def testObjectDetectionMLIR(self):
399    """Tests object detection model through MLIR converter."""
400    self._initObjectDetectionArgs()
401    custom_opdefs_str = (
402        'name: \'TFLite_Detection_PostProcess\' '
403        'input_arg: { name: \'raw_outputs/box_encodings\' type: DT_FLOAT } '
404        'input_arg: { name: \'raw_outputs/class_predictions\' type: DT_FLOAT } '
405        'input_arg: { name: \'anchors\' type: DT_FLOAT } '
406        'output_arg: { name: \'TFLite_Detection_PostProcess\' type: DT_FLOAT } '
407        'output_arg: { name: \'TFLite_Detection_PostProcess:1\' '
408        'type: DT_FLOAT } '
409        'output_arg: { name: \'TFLite_Detection_PostProcess:2\' '
410        'type: DT_FLOAT } '
411        'output_arg: { name: \'TFLite_Detection_PostProcess:3\' '
412        'type: DT_FLOAT } '
413        'attr : { name: \'h_scale\' type: \'float\'} '
414        'attr : { name: \'max_classes_per_detection\' type: \'int\'} '
415        'attr : { name: \'max_detections\' type: \'int\'} '
416        'attr : { name: \'nms_iou_threshold\' type: \'float\'} '
417        'attr : { name: \'nms_score_threshold\' type: \'float\'} '
418        'attr : { name: \'num_classes\' type: \'int\'} '
419        'attr : { name: \'w_scale\' type: \'float\'} '
420        'attr : { name: \'x_scale\' type: \'float\'} '
421        'attr : { name: \'y_scale\' type: \'float\'}')
422
423    flags_str = ('--graph_def_file={0} --input_arrays={1} '
424                 '--output_arrays={2} --input_shapes={3} '
425                 '--custom_opdefs="{4}"'.format(self._graph_def_file,
426                                                self._input_arrays,
427                                                self._output_arrays,
428                                                self._input_shapes,
429                                                custom_opdefs_str))
430
431    # Ensure --allow_custom_ops.
432    flags_str_final = ('{} --allow_custom_ops').format(flags_str)
433    self._run(flags_str_final, should_succeed=False)
434
435    # Ensure --experimental_new_converter.
436    flags_str_final = ('{} --experimental_new_converter').format(flags_str)
437    self._run(flags_str_final, should_succeed=False)
438
439    # Valid conversion.
440    flags_str_final = ('{} --allow_custom_ops '
441                       '--experimental_new_converter').format(flags_str)
442    self._run(
443        flags_str_final,
444        should_succeed=True,
445        expected_ops_in_converted_model=['TFLite_Detection_PostProcess'])
446
447  def testObjectDetectionMLIRWithFlex(self):
448    """Tests object detection model through MLIR converter."""
449    self._initObjectDetectionArgs()
450
451    flags_str = ('--graph_def_file={0} --input_arrays={1} '
452                 '--output_arrays={2} --input_shapes={3}'.format(
453                     self._graph_def_file, self._input_arrays,
454                     self._output_arrays, self._input_shapes))
455
456    # Valid conversion.
457    flags_str_final = (
458        '{} --allow_custom_ops '
459        '--experimental_new_converter '
460        '--experimental_select_user_tf_ops=TFLite_Detection_PostProcess '
461        '--target_ops=TFLITE_BUILTINS,SELECT_TF_OPS').format(flags_str)
462    self._run(
463        flags_str_final,
464        should_succeed=True,
465        expected_ops_in_converted_model=['FlexTFLite_Detection_PostProcess'])
466
467
468class TfLiteConvertV2Test(TestModels):
469
470  @test_util.run_v2_only
471  def testSavedModel(self):
472    input_data = constant_op.constant(1., shape=[1])
473    root = tracking.AutoTrackable()
474    root.f = def_function.function(lambda x: 2. * x)
475    to_save = root.f.get_concrete_function(input_data)
476
477    saved_model_dir = self._getFilepath('model')
478    save(root, saved_model_dir, to_save)
479
480    flags_str = '--saved_model_dir={}'.format(saved_model_dir)
481    self._run(flags_str, should_succeed=True)
482
483  @test_util.run_v2_only
484  def testKerasFile(self):
485    keras_file = self._getKerasModelFile()
486
487    flags_str = '--keras_model_file={}'.format(keras_file)
488    self._run(flags_str, should_succeed=True)
489    os.remove(keras_file)
490
491  @test_util.run_v2_only
492  def testKerasFileMLIR(self):
493    keras_file = self._getKerasModelFile()
494
495    flags_str = (
496        '--keras_model_file={} --experimental_new_converter'.format(keras_file))
497    self._run(flags_str, should_succeed=True)
498    os.remove(keras_file)
499
500  @test_util.run_v2_only
501  def testFunctionalKerasModel(self):
502    keras_file = self._getKerasFunctionalModelFile()
503
504    flags_str = '--keras_model_file={}'.format(keras_file)
505    self._run(flags_str, should_succeed=True,
506              expected_output_shapes=[[1, 1], [1, 2]])
507    os.remove(keras_file)
508
509  @test_util.run_v2_only
510  def testFunctionalKerasModelMLIR(self):
511    keras_file = self._getKerasFunctionalModelFile()
512
513    flags_str = (
514        '--keras_model_file={} --experimental_new_converter'.format(keras_file))
515    self._run(flags_str, should_succeed=True,
516              expected_output_shapes=[[1, 1], [1, 2]])
517    os.remove(keras_file)
518
519  def testMissingRequired(self):
520    self._run('--invalid_args', should_succeed=False)
521
522  def testMutuallyExclusive(self):
523    self._run(
524        '--keras_model_file=model.h5 --saved_model_dir=/tmp/',
525        should_succeed=False)
526
527
528class ArgParserTest(test_util.TensorFlowTestCase):
529
530  def test_without_experimental_new_converter(self):
531    args = [
532        '--saved_model_dir=/tmp/saved_model/',
533        '--output_file=/tmp/output.tflite',
534    ]
535
536    # Note that when the flag parses to None, the converter uses the default
537    # value, which is True.
538
539    # V1 parser.
540    parser = tflite_convert._get_parser(use_v2_converter=False)
541    parsed_args = parser.parse_args(args)
542    self.assertIsNone(parsed_args.experimental_new_converter)
543    self.assertFalse(parsed_args.experimental_new_quantizer)
544
545    # V2 parser.
546    parser = tflite_convert._get_parser(use_v2_converter=True)
547    parsed_args = parser.parse_args(args)
548    self.assertIsNone(parsed_args.experimental_new_converter)
549    self.assertFalse(parsed_args.experimental_new_quantizer)
550
551  def test_experimental_new_converter(self):
552    args = [
553        '--saved_model_dir=/tmp/saved_model/',
554        '--output_file=/tmp/output.tflite',
555        '--experimental_new_converter',
556    ]
557
558    # V1 parser.
559    parser = tflite_convert._get_parser(use_v2_converter=False)
560    parsed_args = parser.parse_args(args)
561    self.assertTrue(parsed_args.experimental_new_converter)
562
563    # V2 parser.
564    parser = tflite_convert._get_parser(use_v2_converter=True)
565    parsed_args = parser.parse_args(args)
566    self.assertTrue(parsed_args.experimental_new_converter)
567
568  def test_experimental_new_converter_true(self):
569    args = [
570        '--saved_model_dir=/tmp/saved_model/',
571        '--output_file=/tmp/output.tflite',
572        '--experimental_new_converter=true',
573    ]
574
575    # V1 parser.
576    parser = tflite_convert._get_parser(False)
577    parsed_args = parser.parse_args(args)
578    self.assertTrue(parsed_args.experimental_new_converter)
579
580    # V2 parser.
581    parser = tflite_convert._get_parser(True)
582    parsed_args = parser.parse_args(args)
583    self.assertTrue(parsed_args.experimental_new_converter)
584
585  def test_experimental_new_converter_false(self):
586    args = [
587        '--saved_model_dir=/tmp/saved_model/',
588        '--output_file=/tmp/output.tflite',
589        '--experimental_new_converter=false',
590    ]
591
592    # V1 parser.
593    parser = tflite_convert._get_parser(use_v2_converter=False)
594    parsed_args = parser.parse_args(args)
595    self.assertFalse(parsed_args.experimental_new_converter)
596
597    # V2 parser.
598    parser = tflite_convert._get_parser(use_v2_converter=True)
599    parsed_args = parser.parse_args(args)
600    self.assertFalse(parsed_args.experimental_new_converter)
601
602  def test_experimental_new_quantizer(self):
603    args = [
604        '--saved_model_dir=/tmp/saved_model/',
605        '--output_file=/tmp/output.tflite',
606        '--experimental_new_quantizer',
607    ]
608
609    # V1 parser.
610    parser = tflite_convert._get_parser(use_v2_converter=False)
611    parsed_args = parser.parse_args(args)
612    self.assertTrue(parsed_args.experimental_new_quantizer)
613
614    # V2 parser.
615    parser = tflite_convert._get_parser(use_v2_converter=True)
616    parsed_args = parser.parse_args(args)
617    self.assertTrue(parsed_args.experimental_new_quantizer)
618
619if __name__ == '__main__':
620  test.main()
621