1# Lint as: python2, python3
2# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for lite.py."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import io
23import logging
24import os
25import tempfile
26
27from absl.testing import parameterized
28import numpy as np
29import six
30from six.moves import range
31from tensorflow import keras
32
33from tensorflow.lite.python import lite
34from tensorflow.lite.python import lite_constants
35from tensorflow.lite.python.convert import ConverterError
36from tensorflow.lite.python.convert import mlir_quantize
37from tensorflow.lite.python.interpreter import Interpreter
38from tensorflow.python.client import session
39from tensorflow.python.eager import context
40from tensorflow.python.eager import def_function
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import convert_to_constants
43from tensorflow.python.framework import dtypes
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import test_util
46from tensorflow.python.ops import array_ops
47from tensorflow.python.ops import gen_array_ops
48from tensorflow.python.ops import math_ops
49from tensorflow.python.ops import nn_ops
50from tensorflow.python.ops import random_ops
51from tensorflow.python.ops import variable_scope
52from tensorflow.python.ops import variables
53from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
54from tensorflow.python.platform import gfile
55from tensorflow.python.platform import resource_loader
56from tensorflow.python.platform import test
57from tensorflow.python.saved_model import saved_model
58from tensorflow.python.training.training_util import write_graph
59
60
61class LiteTest(test_util.TensorFlowTestCase):
62  """Base class of all the tests in this module."""
63
64
65class TestModels(LiteTest):
66
67  def assertValidDebugInfo(self, debug_info):
68    """Verify the DebugInfo is valid."""
69    file_names = set()
70    for file_path in debug_info.files:
71      file_names.add(os.path.basename(file_path))
72    # To make the test independent on how the nodes are created, we only assert
73    # the name of this test file.
74    self.assertIn('lite_test.py', file_names)
75    self.assertNotIn('lite_v2_test.py', file_names)
76
77
78class FromConstructor(TestModels):
79
80  # Tests invalid constructors using a dummy value for the GraphDef.
81  def testInvalidConstructor(self):
82    message = ('If input_tensors and output_tensors are None, both '
83               'input_arrays_with_shape and output_arrays must be defined.')
84
85    # `output_arrays` is not defined.
86    with self.assertRaises(ValueError) as error:
87      lite.TFLiteConverter(
88          None, None, [], input_arrays_with_shape=[('input', [3, 9])])
89    self.assertEqual(message, str(error.exception))
90
91    # `input_arrays_with_shape` is not defined.
92    with self.assertRaises(ValueError) as error:
93      lite.TFLiteConverter(None, [], None, output_arrays=['output'])
94    self.assertEqual(message, str(error.exception))
95
96  # Tests valid constructors using a dummy value for the GraphDef.
97  def testValidConstructor(self):
98    converter = lite.TFLiteConverter(
99        None,
100        None,
101        None,
102        input_arrays_with_shape=[('input', [3, 9])],
103        output_arrays=['output'])
104    self.assertFalse(converter._has_valid_tensors())
105    self.assertEqual(converter.get_input_arrays(), ['input'])
106
107    with self.assertRaises(ValueError) as error:
108      converter._set_batch_size(1)
109    self.assertEqual(
110        'The batch size cannot be set for this model. Please use '
111        'input_shapes parameter.', str(error.exception))
112
113    converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
114    self.assertTrue(converter._has_valid_tensors())
115
116  def testRedundantArgumentsWarning(self):
117    """Test if the warning message when there are redundant arguments."""
118    with ops.Graph().as_default():
119      in_tensor = array_ops.placeholder(
120          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
121      out_tensor = math_ops.add(in_tensor, in_tensor, name='add')
122      sess = session.Session()
123
124    frozen_graph_def = (
125        convert_to_constants.convert_variables_to_constants_from_session_graph(
126            sess, sess.graph_def, ['add']))
127
128    # Convert model and ensure model is not None.
129    log = io.BytesIO() if six.PY2 else io.StringIO()
130    handler = logging.StreamHandler(log)
131    logging.root.addHandler(handler)
132    converter = lite.TFLiteConverter(frozen_graph_def, [in_tensor],
133                                     [out_tensor],
134                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
135
136    input_warning_message = 'input_arrays_with_shape will be ignored'
137    output_warning_message = 'output_arrays will be ignored'
138
139    # Convert model and ensure model is not None.
140    tflite_model = converter.convert()
141    self.assertIsNotNone(tflite_model)
142    self.assertIn(input_warning_message, log.getvalue())
143    self.assertIn(output_warning_message, log.getvalue())
144    logging.root.removeHandler(handler)
145
146  def testShapeOverriding(self):
147    """Test a shape overriding case via the constructor."""
148    with ops.Graph().as_default():
149      in_tensor = array_ops.placeholder(
150          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
151      math_ops.add(in_tensor, in_tensor, name='add')
152      sess = session.Session()
153
154    frozen_graph_def = (
155        convert_to_constants.convert_variables_to_constants_from_session_graph(
156            sess, sess.graph_def, ['add']))
157
158    # Convert model and ensure model is not None.
159    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
160                                     [('in_tensor', [2, 16, 16, 3])], ['add'])
161    tflite_model = converter.convert()
162    self.assertIsNotNone(tflite_model)
163
164    # Check values from converted model.
165    interpreter = Interpreter(model_content=tflite_model)
166    interpreter.allocate_tensors()
167
168    input_details = interpreter.get_input_details()
169    self.assertLen(input_details, 1)
170    self.assertEqual('in_tensor', input_details[0]['name'])
171    self.assertEqual(np.float32, input_details[0]['dtype'])
172    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
173    self.assertEqual((0., 0.), input_details[0]['quantization'])
174
175    output_details = interpreter.get_output_details()
176    self.assertLen(output_details, 1)
177    self.assertEqual('add', output_details[0]['name'])
178    self.assertEqual(np.float32, output_details[0]['dtype'])
179    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
180    self.assertEqual((0., 0.), output_details[0]['quantization'])
181
182  def testPartialShapeOverriding(self):
183    """Test a partial shape overriding case via the constructor."""
184    with ops.Graph().as_default():
185      in_tensor_a = array_ops.placeholder(
186          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_a')
187      in_tensor_b = array_ops.placeholder(
188          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor_b')
189      math_ops.add(in_tensor_a, in_tensor_b, name='add')
190      sess = session.Session()
191
192    frozen_graph_def = (
193        convert_to_constants.convert_variables_to_constants_from_session_graph(
194            sess, sess.graph_def, ['add']))
195
196    # Convert model and ensure model is not None.
197    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
198                                     [('in_tensor_a', [2, 16, 16, 3])], ['add'])
199    # There is an unhandled Placeholder op.
200    with self.assertRaises(ConverterError):
201      converter.convert()
202
203  def testInvalidShapeOverriding(self):
204    """Test an invalid shape overriding case via the constructor."""
205    with ops.Graph().as_default():
206      in_tensor = array_ops.placeholder(
207          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='in_tensor')
208      math_ops.add(in_tensor, in_tensor, name='add')
209      sess = session.Session()
210
211    frozen_graph_def = (
212        convert_to_constants.convert_variables_to_constants_from_session_graph(
213            sess, sess.graph_def, ['add']))
214
215    # Convert model and ensure model is not None.
216    converter = lite.TFLiteConverter(frozen_graph_def, None, None,
217                                     [('wrong_tensor', [2, 16, 16, 3])],
218                                     ['add'])
219    with self.assertRaises(ConverterError):
220      converter.convert()
221
222
223class FromSessionTest(TestModels, parameterized.TestCase):
224
225  def testFloatModel(self):
226    with ops.Graph().as_default():
227      in_tensor = array_ops.placeholder(
228          shape=[1, 16, 16, 3], dtype=dtypes.float32)
229      out_tensor = in_tensor + in_tensor
230      sess = session.Session()
231
232    # Convert model and ensure model is not None.
233    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
234                                                  [out_tensor])
235    tflite_model = converter.convert()
236    self.assertIsNotNone(tflite_model)
237
238    # Check values from converted model.
239    interpreter = Interpreter(model_content=tflite_model)
240    interpreter.allocate_tensors()
241
242    input_details = interpreter.get_input_details()
243    self.assertLen(input_details, 1)
244    self.assertEqual('Placeholder', input_details[0]['name'])
245    self.assertEqual(np.float32, input_details[0]['dtype'])
246    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
247    self.assertEqual((0., 0.), input_details[0]['quantization'])
248
249    output_details = interpreter.get_output_details()
250    self.assertLen(output_details, 1)
251    self.assertEqual('add', output_details[0]['name'])
252    self.assertEqual(np.float32, output_details[0]['dtype'])
253    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
254    self.assertEqual((0., 0.), output_details[0]['quantization'])
255
256  def testFloatModelQuantizedInput(self):
257    with ops.Graph().as_default():
258      in_tensor = array_ops.placeholder(
259          shape=[1, 16, 16, 3], dtype=dtypes.float32)
260      out_tensor = in_tensor + in_tensor
261      sess = session.Session()
262
263    # Convert model and ensure model is not None.
264    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
265                                                  [out_tensor])
266    converter.inference_input_type = dtypes.uint8
267    converter.inference_type = dtypes.float32
268    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
269    tflite_model = converter.convert()
270    self.assertIsNotNone(tflite_model)
271
272    # Check values from converted model.
273    interpreter = Interpreter(model_content=tflite_model)
274    interpreter.allocate_tensors()
275
276    input_details = interpreter.get_input_details()
277    self.assertLen(input_details, 1)
278    self.assertEqual('Placeholder', input_details[0]['name'])
279    self.assertEqual(np.uint8, input_details[0]['dtype'])
280    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
281    self.assertEqual((1., 0.), input_details[0]['quantization'])
282
283    output_details = interpreter.get_output_details()
284    self.assertLen(output_details, 1)
285    self.assertEqual('add', output_details[0]['name'])
286    self.assertEqual(np.float32, output_details[0]['dtype'])
287    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
288    self.assertEqual((0., 0.), output_details[0]['quantization'])  # float
289
290  def testForgottenCallToAllocateTensors(self):
291    with ops.Graph().as_default():
292      in_tensor = array_ops.placeholder(
293          shape=[1, 16, 16, 3], dtype=dtypes.float32)
294      out_tensor = in_tensor + in_tensor
295      sess = session.Session()
296    # Convert model and ensure model is not None.
297    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
298                                                  [out_tensor])
299    tflite_model = converter.convert()
300    self.assertIsNotNone(tflite_model)
301
302    # Check values from converted model.
303    interpreter = Interpreter(model_content=tflite_model)
304    input_index = interpreter.get_input_details()[0]['index']
305    dummy_tensor = np.ones(shape=[1, 16, 16, 3], dtype=np.float32)
306    with self.assertRaises(ValueError):
307      interpreter.set_tensor(input_index, dummy_tensor)
308
309  @parameterized.named_parameters(
310      ('_INT8InputOutput', False, False, dtypes.int8),
311      ('_UINT8InputOutput', False, False, dtypes.uint8),
312      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
313      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
314      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
315      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
316      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
317      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
318  def testIntegerQuantizationWithUnsupportedOps(self,
319                                                is_int_only,
320                                                is_int16_quantize,
321                                                inference_input_output_type,
322                                                enable_mlir_quantizer=False):
323    with ops.Graph().as_default():
324      in_tensor_a = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
325      in_tensor_b = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
326      # ceil kernel does not support int8 nor int16 types neither.
327      left = math_ops.ceil(in_tensor_a)
328      out_tensor_b = math_ops.tanh(in_tensor_b)
329      add = math_ops.add(left, out_tensor_b)
330      # ceil kernel does not support int8 nor int16 types neither.
331      out_tensor_a = math_ops.ceil(add)
332      sess = session.Session()
333
334    def calibration_gen():
335      for _ in range(5):
336        yield [
337            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
338            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
339        ]
340
341    quantized_converter = lite.TFLiteConverter.from_session(
342        sess, [in_tensor_a, in_tensor_b], [out_tensor_a, out_tensor_b])
343    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
344    quantized_converter.representative_dataset = calibration_gen
345    if is_int_only:
346      if is_int16_quantize:
347        quantized_converter.target_spec.supported_ops = [
348            lite.OpsSet.\
349            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
350            lite.OpsSet.TFLITE_BUILTINS
351        ]
352      else:
353        quantized_converter.target_spec.supported_ops = [
354            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
355        ]
356    else:
357      if is_int16_quantize:
358        quantized_converter.target_spec.supported_ops = [
359            lite.OpsSet.\
360            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
361            lite.OpsSet.TFLITE_BUILTINS
362        ]
363      else:
364        quantized_converter.target_spec.supported_ops = [
365            lite.OpsSet.TFLITE_BUILTINS
366        ]
367
368    quantized_converter.inference_input_type = inference_input_output_type
369    quantized_converter.inference_output_type = inference_input_output_type
370    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
371    quantized_tflite_model = quantized_converter.convert()
372    self.assertIsNotNone(quantized_tflite_model)
373
374    expected_dtype = inference_input_output_type.as_numpy_dtype
375    # Allow float32 for fallback on non-quantizable op.
376    expected_ceil_dtype = (
377        expected_dtype if enable_mlir_quantizer else dtypes.float32)
378
379    interpreter = Interpreter(model_content=quantized_tflite_model)
380    interpreter.allocate_tensors()
381    input_details = interpreter.get_input_details()
382    self.assertLen(input_details, 2)
383    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
384    self.assertEqual(input_details[1]['dtype'], expected_dtype)
385    output_details = interpreter.get_output_details()
386    self.assertLen(output_details, 2)
387    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
388    self.assertEqual(output_details[1]['dtype'], expected_dtype)
389
390  @parameterized.named_parameters(
391      ('EnableMlirConverter', True),  # enable mlir
392      ('DisableMlirConverter', False))  # disable mlir
393  def testString(self, enable_mlir_converter):
394    with ops.Graph().as_default():
395      in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
396      out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
397      sess = session.Session()
398
399    # Convert model and ensure model is not None.
400    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
401                                                  [out_tensor])
402    converter.experimental_new_converter = enable_mlir_converter
403    tflite_model = converter.convert()
404    self.assertIsNotNone(tflite_model)
405
406    # Check values from converted model.
407    interpreter = Interpreter(model_content=tflite_model)
408    interpreter.allocate_tensors()
409
410    input_details = interpreter.get_input_details()
411    self.assertLen(input_details, 1)
412    self.assertEqual('Placeholder', input_details[0]['name'])
413    self.assertEqual(np.string_, input_details[0]['dtype'])
414    self.assertAllEqual([4], input_details[0]['shape'])
415
416    output_details = interpreter.get_output_details()
417    self.assertLen(output_details, 1)
418    self.assertEqual('Reshape', output_details[0]['name'])
419    self.assertEqual(np.string_, output_details[0]['dtype'])
420    self.assertAllEqual([2, 2], output_details[0]['shape'])
421    # TODO(b/122659643): Test setting/getting string data via the python
422    # interpreter API after support has been added.
423
424  def testIntermediateInputArray(self):
425    """Convert a model from an intermediate input array."""
426    with ops.Graph().as_default():
427      in_tensor_init = array_ops.placeholder(
428          shape=[1, 16, 16, 3], dtype=dtypes.float32)
429      in_tensor_final = in_tensor_init + in_tensor_init
430      out_tensor = in_tensor_final + in_tensor_final
431      sess = session.Session()
432
433    # Convert model and ensure model is not None.
434    converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
435                                                  [out_tensor])
436    tflite_model = converter.convert()
437    self.assertIsNotNone(tflite_model)
438
439    # Check values from converted model.
440    interpreter = Interpreter(model_content=tflite_model)
441    interpreter.allocate_tensors()
442
443    input_details = interpreter.get_input_details()
444    self.assertLen(input_details, 1)
445    self.assertEqual('add', input_details[0]['name'])
446    self.assertEqual(np.float32, input_details[0]['dtype'])
447    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
448    self.assertEqual((0., 0.), input_details[0]['quantization'])
449
450    output_details = interpreter.get_output_details()
451    self.assertLen(output_details, 1)
452    self.assertEqual('add_1', output_details[0]['name'])
453    self.assertEqual(np.float32, output_details[0]['dtype'])
454    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
455    self.assertEqual((0., 0.), output_details[0]['quantization'])
456
457  def testSizeNoneInvalid(self):
458    with ops.Graph().as_default():
459      in_tensor = array_ops.placeholder(dtype=dtypes.float32)
460      out_tensor = in_tensor + in_tensor
461      sess = session.Session()
462
463    # Test None as shape when dynamic shapes are disabled. Run with TOCO in
464    # order to invoke shape checking code.
465    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
466                                                  [out_tensor])
467    converter.experimental_new_converter = False
468    with self.assertRaises(ValueError) as error:
469      converter.convert()
470    self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
471                     str(error.exception))
472
473  @parameterized.named_parameters(
474      ('EnableMlirConverter', True),  # enable mlir
475      ('DisableMlirConverter', False))  # disable mlir
476  def testScalarValid(self, enable_mlir_converter):
477    # Construct a graph using a scalar (empty shape) input.
478    with ops.Graph().as_default():
479      in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
480      out_tensor = in_tensor + in_tensor
481      sess = session.Session()
482
483    # Test conversion with the scalar input shape.
484    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
485                                                  [out_tensor])
486    converter.experimental_new_converter = enable_mlir_converter
487    tflite_model = converter.convert()
488    self.assertIsNotNone(tflite_model)
489
490    # Check values from converted model.
491    interpreter = Interpreter(model_content=tflite_model)
492    interpreter.allocate_tensors()
493
494    input_details = interpreter.get_input_details()
495    self.assertLen(input_details, 1)
496    self.assertEqual('Placeholder', input_details[0]['name'])
497    self.assertEqual(np.float32, input_details[0]['dtype'])
498    self.assertEmpty(input_details[0]['shape'])
499
500    output_details = interpreter.get_output_details()
501    self.assertLen(output_details, 1)
502    self.assertEqual('add', output_details[0]['name'])
503    self.assertEqual(np.float32, output_details[0]['dtype'])
504    self.assertEmpty(input_details[0]['shape'])
505
506    # Validate inference using the scalar inputs/outputs.
507    test_input = np.array(4.0, dtype=np.float32)
508    expected_output = np.array(8.0, dtype=np.float32)
509    interpreter.set_tensor(input_details[0]['index'], test_input)
510    interpreter.invoke()
511
512    output_data = interpreter.get_tensor(output_details[0]['index'])
513    self.assertEqual(expected_output, output_data)
514
515  def testSizeInvalid(self):
516    with ops.Graph().as_default():
517      in_tensor = array_ops.placeholder(
518          shape=[1, None, 16, 3], dtype=dtypes.float32)
519      out_tensor = in_tensor + in_tensor
520      sess = session.Session()
521
522    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
523    # invoke shape checking code.
524    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
525                                                  [out_tensor])
526    converter.experimental_new_converter = False
527    with self.assertRaises(ValueError) as error:
528      converter.convert()
529    self.assertEqual(
530        'None is only supported in the 1st dimension. Tensor '
531        '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
532        str(error.exception))
533
534  def testSizeNone(self):
535    with ops.Graph().as_default():
536      in_tensor = array_ops.placeholder(
537          shape=[1, None, 16, 3], dtype=dtypes.float32)
538      out_tensor = in_tensor + in_tensor
539      sess = session.Session()
540
541    # Test None after 1st dimension.
542    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
543                                                  [out_tensor])
544    tflite_model = converter.convert()
545
546    # Check values from converted model.
547    interpreter = Interpreter(model_content=tflite_model)
548    input_details = interpreter.get_input_details()
549    self.assertLen(input_details, 1)
550    self.assertEqual('Placeholder', input_details[0]['name'])
551    self.assertEqual(np.float32, input_details[0]['dtype'])
552    self.assertAllEqual([1, 1, 16, 3], input_details[0]['shape'])
553    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
554    self.assertEqual((0., 0.), input_details[0]['quantization'])
555
556    # Resize tensor with strict checking.
557    with self.assertRaises(RuntimeError) as error:
558      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
559    self.assertIn(
560        'ResizeInputTensorStrict only allows mutating unknown dimensions '
561        'identified by -1.', str(error.exception))
562
563    # Resize tensor and invoke.
564    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
565    interpreter.allocate_tensors()
566    interpreter.invoke()
567
568    input_details = interpreter.get_input_details()
569    self.assertLen(input_details, 1)
570    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
571    self.assertAllEqual([1, -1, 16, 3], input_details[0]['shape_signature'])
572
573    output_details = interpreter.get_output_details()
574    self.assertAllEqual([1, -1, 16, 3], output_details[0]['shape_signature'])
575
576  def testResizeTensorInputStrict(self):
577    # Ensures that resize_tensor_input(strict=True) works as expected.
578    with ops.Graph().as_default():
579      in_tensor = array_ops.placeholder(
580          shape=[1, 16, 16, 3], dtype=dtypes.float32)
581      out_tensor = in_tensor + in_tensor
582      sess = session.Session()
583
584    # Convert model and ensure model is not None.
585    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
586                                                  [out_tensor])
587    tflite_model = converter.convert()
588    self.assertIsNotNone(tflite_model)
589
590    # Check values from converted model.
591    interpreter = Interpreter(model_content=tflite_model)
592
593    # Resize incorrect value.
594    with self.assertRaises(RuntimeError) as error:
595      interpreter.resize_tensor_input(0, [3, 16, 16, 3], strict=True)
596    self.assertIn(
597        'ResizeInputTensorStrict only allows mutating unknown dimensions '
598        'identified by -1.', str(error.exception))
599
600    # Resize correct value.
601    interpreter.resize_tensor_input(0, [1, 16, 16, 3], strict=True)
602    interpreter.allocate_tensors()
603
604  def testBatchSizeValid(self):
605    with ops.Graph().as_default():
606      in_tensor = array_ops.placeholder(
607          shape=[None, 16, 16, 3], dtype=dtypes.float32)
608      out_tensor = in_tensor + in_tensor
609      sess = session.Session()
610
611    # Convert model and ensure model is not None.
612    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
613                                                  [out_tensor])
614    tflite_model = converter.convert()
615    self.assertIsNotNone(tflite_model)
616
617    # Check values from converted model.
618    interpreter = Interpreter(model_content=tflite_model)
619    interpreter.allocate_tensors()
620
621    input_details = interpreter.get_input_details()
622    self.assertLen(input_details, 1)
623    self.assertEqual('Placeholder', input_details[0]['name'])
624    self.assertEqual(np.float32, input_details[0]['dtype'])
625    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
626    self.assertEqual((0., 0.), input_details[0]['quantization'])
627
628    output_details = interpreter.get_output_details()
629    self.assertLen(output_details, 1)
630    self.assertEqual('add', output_details[0]['name'])
631    self.assertEqual(np.float32, output_details[0]['dtype'])
632    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
633    self.assertEqual((0., 0.), output_details[0]['quantization'])
634
635  def testBatchSizeNonZero(self):
636    with ops.Graph().as_default():
637      in_tensor_1 = array_ops.placeholder(
638          shape=[None, 4], dtype=dtypes.float32, name='input1')
639      in_tensor_2 = array_ops.placeholder(
640          shape=[4, 10], dtype=dtypes.float32, name='input2')
641      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2)
642      sess = session.Session()
643
644    # Convert model and ensure model is not None.
645    converter = lite.TFLiteConverter.from_session(sess,
646                                                  [in_tensor_1, in_tensor_2],
647                                                  [out_tensor])
648    tflite_model = converter.convert()
649    self.assertIsNotNone(tflite_model)
650
651    # Check values from converted model.
652    interpreter = Interpreter(model_content=tflite_model)
653    interpreter.allocate_tensors()
654
655    input_details = interpreter.get_input_details()
656    self.assertLen(input_details, 2)
657    self.assertEqual('input1', input_details[0]['name'])
658    self.assertAllEqual([1, 4], input_details[0]['shape'])
659    self.assertEqual('input2', input_details[1]['name'])
660    self.assertAllEqual([4, 10], input_details[1]['shape'])
661
662  def testFreezeGraph(self):
663    with ops.Graph().as_default():
664      in_tensor = array_ops.placeholder(
665          shape=[1, 16, 16, 3], dtype=dtypes.float32)
666      var = variable_scope.get_variable(
667          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
668      # Get the second output to ensure freezing properly processes tensor names
669      # like 'X:1'.
670      out_tensor = nn_ops.top_k(in_tensor + var, name='top_k')[1]
671      sess = session.Session()
672      sess.run(_global_variables_initializer())
673
674    # Convert model and ensure model is not None.
675    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
676                                                  [out_tensor])
677    tflite_model = converter.convert()
678    self.assertIsNotNone(tflite_model)
679
680    # Check values from converted model.
681    interpreter = Interpreter(model_content=tflite_model)
682    interpreter.allocate_tensors()
683
684    input_details = interpreter.get_input_details()
685    self.assertLen(input_details, 1)
686    self.assertEqual('Placeholder', input_details[0]['name'])
687    self.assertEqual(np.float32, input_details[0]['dtype'])
688    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
689    self.assertEqual((0., 0.), input_details[0]['quantization'])
690
691    output_details = interpreter.get_output_details()
692    self.assertLen(output_details, 1)
693    self.assertEqual('top_k:1', output_details[0]['name'])
694    self.assertEqual(np.int32, output_details[0]['dtype'])
695    self.assertAllEqual([1, 16, 16, 1], output_details[0]['shape'])
696    self.assertEqual((0., 0.), output_details[0]['quantization'])
697
698  def testGraphviz(self):
699    with ops.Graph().as_default():
700      in_tensor = array_ops.placeholder(
701          shape=[1, 16, 16, 3], dtype=dtypes.float32)
702      out_tensor = in_tensor + in_tensor
703      sess = session.Session()
704
705    # Convert model and ensure model is not None.
706    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
707                                                  [out_tensor])
708    converter.output_format = lite_constants.GRAPHVIZ_DOT
709    graphviz_output = converter.convert()
710    self.assertIsNotNone(graphviz_output)
711
712  @parameterized.named_parameters(
713      ('EnableMlirConverter', True),  # enable mlir
714      ('DisableMlirConverter', False))  # disable mlir
715  def testDumpGraphviz(self, enable_mlir_converter):
716    with ops.Graph().as_default():
717      in_tensor = array_ops.placeholder(
718          shape=[1, 16, 16, 3], dtype=dtypes.float32)
719      out_tensor = in_tensor + in_tensor
720      sess = session.Session()
721
722    # Convert model and ensure model is not None.
723    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
724                                                  [out_tensor])
725    converter.experimental_new_converter = enable_mlir_converter
726    graphviz_dir = self.get_temp_dir()
727    converter.dump_graphviz_dir = graphviz_dir
728    tflite_model = converter.convert()
729    self.assertIsNotNone(tflite_model)
730
731    # Ensure interpreter is able to allocate and check graphviz data.
732    interpreter = Interpreter(model_content=tflite_model)
733    interpreter.allocate_tensors()
734
735    num_items_graphviz = len(os.listdir(graphviz_dir))
736    self.assertIsNotNone(num_items_graphviz)
737    self.assertIsNotNone(
738        os.path.exists(os.path.join(graphviz_dir, 'toco_AT_IMPORT.dot')))
739    self.assertIsNotNone(
740        os.path.exists(
741            os.path.join(graphviz_dir, 'toco_AFTER_TRANSFORMATIONS.dot')))
742
743    # new converter doesn't support `dump_graphviz_video` flag
744    if not enable_mlir_converter:
745      # Convert model and ensure model is not None.
746      converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
747                                                    [out_tensor])
748      converter.experimental_new_converter = enable_mlir_converter
749      graphviz_dir = self.get_temp_dir()
750      converter.dump_graphviz_dir = graphviz_dir
751      converter.dump_graphviz_video = True
752      tflite_model = converter.convert()
753      self.assertIsNotNone(tflite_model)
754
755      # Ensure graphviz folder has more data after using video flag.
756      num_items_graphviz_video = len(os.listdir(graphviz_dir))
757      self.assertGreater(num_items_graphviz_video, num_items_graphviz)
758
759  def testDumpConversionSummary(self):
760    with ops.Graph().as_default():
761      in_tensor = array_ops.placeholder(
762          shape=[1, 16, 16, 3], dtype=dtypes.float32)
763      out_tensor = in_tensor + in_tensor
764      sess = session.Session()
765
766    # Convert model and ensure model is not None.
767    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
768                                                  [out_tensor])
769    log_dir = self.get_temp_dir()
770    converter.conversion_summary_dir = log_dir
771    tflite_model = converter.convert()
772    self.assertIsNotNone(tflite_model)
773
774    self.assertNotEmpty(os.listdir(log_dir))
775
776  def testDumpConversionSummaryWithOldConverter(self):
777    with ops.Graph().as_default():
778      in_tensor = array_ops.placeholder(
779          shape=[1, 16, 16, 3], dtype=dtypes.float32)
780      out_tensor = in_tensor + in_tensor
781      sess = session.Session()
782
783    # Convert model and ensure model is not None.
784    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
785                                                  [out_tensor])
786    converter.experimental_new_converter = False
787    log_dir = self.get_temp_dir()
788    converter.conversion_summary_dir = log_dir
789    tflite_model = converter.convert()
790    self.assertIsNotNone(tflite_model)
791    # Check nothing is generated under the conversion summary path.
792    num_items_conversion_summary = len(os.listdir(log_dir))
793    self.assertEqual(num_items_conversion_summary, 0)
794
795  @parameterized.named_parameters(
796      ('EnableMlirConverter', True),  # enable mlir
797      ('DisableMlirConverter', False))  # disable mlir
798  def testQuantizeDynamicRange(self, enable_mlir_converter):
799    np.random.seed(0)
800    with ops.Graph().as_default():
801      # We need the tensor to have more than 1024 elements for quantize_weights
802      # to kick in. Thus, the [33, 33] shape.
803      in_tensor_1 = array_ops.placeholder(
804          shape=[33, 33], dtype=dtypes.float32, name='inputA')
805      in_tensor_2 = constant_op.constant(
806          np.random.uniform(low=-10., high=10., size=(33, 33)),
807          shape=[33, 33],
808          dtype=dtypes.float32,
809          name='inputB')
810      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
811      sess = session.Session()
812
813    # Convert float model.
814    float_converter = lite.TFLiteConverter.from_session(
815        sess, [in_tensor_1], [out_tensor])
816    float_converter.experimental_new_converter = enable_mlir_converter
817    float_tflite_model = float_converter.convert()
818    self.assertIsNotNone(float_tflite_model)
819
820    # Convert quantized weights model.
821    quantized_converter = lite.TFLiteConverter.from_session(
822        sess, [in_tensor_1], [out_tensor])
823
824    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
825    quantized_converter.experimental_new_converter = enable_mlir_converter
826    quantized_tflite_model = quantized_converter.convert()
827    self.assertIsNotNone(quantized_tflite_model)
828
829    # Ensure that the quantized weights tflite model is smaller.
830    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
831
832  @parameterized.named_parameters(
833      ('EnableMlirConverter', True),  # enable mlir
834      ('DisableMlirConverter', False))  # disable mlir
835  def testQuantizeDynamicRangeDeprecatedPostTrainingQuantizeAttribute(
836      self, enable_mlir_converter):
837    with ops.Graph().as_default():
838      in_tensor_1 = array_ops.placeholder(
839          shape=[33, 33], dtype=dtypes.float32, name='inputA')
840      in_tensor_2 = constant_op.constant(
841          np.random.uniform(low=-10., high=10., size=(33, 33)),
842          shape=[33, 33],
843          dtype=dtypes.float32,
844          name='inputB')
845      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
846      sess = session.Session()
847
848    quantized_converter = lite.TFLiteConverter.from_session(
849        sess, [in_tensor_1], [out_tensor])
850    self.assertFalse(quantized_converter.post_training_quantize)
851    quantized_converter.experimental_new_converter = enable_mlir_converter
852
853    quantized_converter.post_training_quantize = True
854    self.assertTrue(quantized_converter.post_training_quantize)
855    self.assertEqual(quantized_converter.optimizations, [lite.Optimize.DEFAULT])
856
857    quantized_tflite_model = quantized_converter.convert()
858    self.assertIsNotNone(quantized_tflite_model)
859
860  def _getIntegerQuantizeModel(self):
861    np.random.seed(0)
862    inp = array_ops.placeholder(
863        dtype=dtypes.float32, shape=(1, 5, 5, 3), name='input')
864    conv = nn_ops.conv2d(
865        inp,
866        filter=array_ops.ones([3, 3, 3, 16]),
867        strides=[1, 1, 1, 1],
868        padding='SAME')
869    output = nn_ops.relu(conv, name='output')
870
871    def calibration_gen():
872      for _ in range(5):
873        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
874
875    return (inp, output, calibration_gen)
876
877  @parameterized.named_parameters(
878      ('EnableMlirConverter', True),  # enable mlir
879      ('DisableMlirConverter', False))  # disable mlir
880  def testQuantizeInt8AllowFloat(self, enable_mlir_converter):
881    with ops.Graph().as_default():
882      inp, output, calibration_gen = self._getIntegerQuantizeModel()
883      sess = session.Session()
884
885    # Convert float model.
886    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
887    float_tflite_model = float_converter.convert()
888    self.assertIsNotNone(float_tflite_model)
889
890    # Convert quantized model.
891    quantized_converter = lite.TFLiteConverter.from_session(
892        sess, [inp], [output])
893    quantized_converter.experimental_new_converter = enable_mlir_converter
894    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
895    quantized_converter.representative_dataset = calibration_gen
896    quantized_tflite_model = quantized_converter.convert()
897    self.assertIsNotNone(quantized_tflite_model)
898
899    # The default input and output types should be float.
900    interpreter = Interpreter(model_content=quantized_tflite_model)
901    interpreter.allocate_tensors()
902    input_details = interpreter.get_input_details()
903    self.assertLen(input_details, 1)
904    self.assertEqual(np.float32, input_details[0]['dtype'])
905    output_details = interpreter.get_output_details()
906    self.assertLen(output_details, 1)
907    self.assertEqual(np.float32, output_details[0]['dtype'])
908
909    # Ensure that the quantized weights tflite model is smaller.
910    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
911
912  @parameterized.named_parameters(
913      # Quantize model to Int8: with enable mlir
914      ('UseTfliteBuiltinsIntEnableMLIR',
915       [lite.OpsSet.TFLITE_BUILTINS_INT8], True),
916      # Quantize model to Int8: with disable mlir
917      ('UseTfliteBuiltinsIntDisableMLIR',
918       [lite.OpsSet.TFLITE_BUILTINS_INT8], False),
919      # Quantize model to Int16: with disable mlir
920      ('UseTfliteBuiltinsInt16DisableMLIR',
921       [lite.OpsSet.\
922       EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8],
923       False),
924      ('UseTfliteBuiltinsInt16EnableMLIR',
925       [lite.OpsSet.\
926       EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8],
927       True))
928  def testQuantizeInt8And16x8(self, supported_ops, enable_mlir_converter):
929    with ops.Graph().as_default():
930      inp, output, calibration_gen = self._getIntegerQuantizeModel()
931      sess = session.Session()
932
933    # Convert float model.
934    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
935    float_converter.experimental_new_converter = enable_mlir_converter
936    float_tflite_model = float_converter.convert()
937    self.assertIsNotNone(float_tflite_model)
938
939    # Convert model by specifying target spec (instead of optimizations), since
940    # when targeting an integer only backend, quantization is mandatory.
941    quantized_converter = lite.TFLiteConverter.from_session(
942        sess, [inp], [output])
943    quantized_converter.experimental_new_converter = enable_mlir_converter
944    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
945    quantized_converter.target_spec.supported_ops = supported_ops
946    quantized_converter.representative_dataset = calibration_gen
947    quantized_tflite_model = quantized_converter.convert()
948    self.assertIsNotNone(quantized_tflite_model)
949
950    # The default input and output types should be float.
951    interpreter = Interpreter(model_content=quantized_tflite_model)
952    interpreter.allocate_tensors()
953    input_details = interpreter.get_input_details()
954    self.assertLen(input_details, 1)
955    self.assertEqual(np.float32, input_details[0]['dtype'])
956    output_details = interpreter.get_output_details()
957    self.assertLen(output_details, 1)
958    self.assertEqual(np.float32, output_details[0]['dtype'])
959
960    # Ensure that the quantized weights tflite model is smaller.
961    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
962
963  @parameterized.named_parameters(
964      ('EnableMlirConverter', True),  # enable mlir
965      ('DisableMlirConverter', False))  # disable mlir
966  def testQuantizeInt8InputOutput(self, enable_mlir_converter):
967    with ops.Graph().as_default():
968      inp, output, calibration_gen = self._getIntegerQuantizeModel()
969      sess = session.Session()
970
971    # Convert float model.
972    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
973    float_converter.experimental_new_converter = enable_mlir_converter
974    float_tflite_model = float_converter.convert()
975    self.assertIsNotNone(float_tflite_model)
976
977    # Convert quantized weights model.
978    quantized_converter = lite.TFLiteConverter.from_session(
979        sess, [inp], [output])
980    quantized_converter.experimental_new_converter = enable_mlir_converter
981    quantized_converter.inference_input_type = dtypes.int8
982    quantized_converter.inference_output_type = dtypes.int8
983    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
984    quantized_converter.representative_dataset = calibration_gen
985    quantized_tflite_model = quantized_converter.convert()
986    self.assertIsNotNone(quantized_tflite_model)
987
988    # The input and output types should be int8.
989    interpreter = Interpreter(model_content=quantized_tflite_model)
990    interpreter.allocate_tensors()
991    input_details = interpreter.get_input_details()
992    self.assertLen(input_details, 1)
993    self.assertEqual(np.int8, input_details[0]['dtype'])
994    output_details = interpreter.get_output_details()
995    self.assertLen(output_details, 1)
996    self.assertEqual(np.int8, output_details[0]['dtype'])
997
998    # Ensure that the quantized weights tflite model is smaller.
999    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1000
1001  @parameterized.named_parameters(
1002      ('EnableMlirConverter', True),  # enable mlir
1003      ('DisableMlirConverter', False))  # disable mlir
1004  def testInvalidQuantizeInt8(self, enable_mlir_converter):
1005    np.random.seed(0)
1006    with ops.Graph().as_default():
1007      # We need the tensor to have more than 1024 elements for quantize_weights
1008      # to kick in. Thus, the [33, 33] shape.
1009      in_tensor_1 = array_ops.placeholder(
1010          shape=[33, 33], dtype=dtypes.float32, name='inputA')
1011      in_tensor_2 = constant_op.constant(
1012          np.random.uniform(low=-10., high=10., size=(33, 33)),
1013          shape=[33, 33],
1014          dtype=dtypes.float32,
1015          name='inputB')
1016      out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
1017      sess = session.Session()
1018
1019    # Attempt to convert to quantized weights model.
1020    quantized_converter = lite.TFLiteConverter.from_session(
1021        sess, [in_tensor_1], [out_tensor])
1022    quantized_converter.experimental_new_converter = enable_mlir_converter
1023    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1024    # Restricting to int8 type only
1025    quantized_converter.target_spec.supported_types = [dtypes.int8]
1026    # A representative dataset is required for full fixed point quantization.
1027    with self.assertRaises(ValueError) as error:
1028      quantized_converter.convert()
1029    self.assertEqual(
1030        'representative_dataset is required when specifying '
1031        'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
1032
1033  @parameterized.named_parameters(
1034      ('EnableMlirConverter', True),  # enable mlir
1035      ('DisableMlirConverter', False))  # disable mlir
1036  def testQuantizeUInt8(self, enable_mlir_converter):
1037    with ops.Graph().as_default():
1038      in_tensor_1 = array_ops.placeholder(
1039          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
1040      in_tensor_2 = array_ops.placeholder(
1041          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
1042      out_tensor = array_ops.fake_quant_with_min_max_args(
1043          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
1044      sess = session.Session()
1045
1046    # Convert model and ensure model is not None.
1047    converter = lite.TFLiteConverter.from_session(sess,
1048                                                  [in_tensor_1, in_tensor_2],
1049                                                  [out_tensor])
1050    converter.inference_type = dtypes.uint8
1051    converter.quantized_input_stats = {
1052        'inputA': (0., 1.),
1053        'inputB': (0., 1.)
1054    }  # mean, std_dev
1055    converter.experimental_new_converter = enable_mlir_converter
1056    tflite_model = converter.convert()
1057    self.assertIsNotNone(tflite_model)
1058
1059    # Check values from converted model.
1060    interpreter = Interpreter(model_content=tflite_model)
1061    interpreter.allocate_tensors()
1062
1063    input_details = interpreter.get_input_details()
1064    self.assertLen(input_details, 2)
1065    self.assertEqual('inputA', input_details[0]['name'])
1066    self.assertEqual(np.uint8, input_details[0]['dtype'])
1067    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1068    self.assertEqual((1., 0.), input_details[0]['quantization'])
1069
1070    self.assertEqual('inputB', input_details[1]['name'])
1071    self.assertEqual(np.uint8, input_details[1]['dtype'])
1072    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1073    self.assertEqual((1., 0.), input_details[1]['quantization'])
1074
1075    output_details = interpreter.get_output_details()
1076    self.assertLen(output_details, 1)
1077    self.assertEqual(np.uint8, output_details[0]['dtype'])
1078    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1079    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale
1080
1081  def testQuantizeUInt8UsingDefaultRangeStats(self):
1082    with ops.Graph().as_default():
1083      in_tensor = array_ops.placeholder(
1084          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1085      out_tensor = in_tensor + in_tensor
1086      sess = session.Session()
1087
1088    # Convert model and ensure model is not None.
1089    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1090                                                  [out_tensor])
1091    converter.inference_type = dtypes.uint8
1092    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
1093    converter.default_ranges_stats = (0, 6)  # min, max
1094    tflite_model = converter.convert()
1095    self.assertIsNotNone(tflite_model)
1096
1097    # Check values from converted model.
1098    interpreter = Interpreter(model_content=tflite_model)
1099    interpreter.allocate_tensors()
1100
1101    input_details = interpreter.get_input_details()
1102    self.assertLen(input_details, 1)
1103    self.assertEqual('Placeholder', input_details[0]['name'])
1104    self.assertEqual(np.uint8, input_details[0]['dtype'])
1105    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1106    self.assertEqual((1., 0.), input_details[0]['quantization'])
1107
1108    output_details = interpreter.get_output_details()
1109    self.assertLen(output_details, 1)
1110    self.assertEqual('add', output_details[0]['name'])
1111    self.assertEqual(np.uint8, output_details[0]['dtype'])
1112    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1113    self.assertGreater(output_details[0]['quantization'][0], 0)  # scale
1114
1115  @parameterized.named_parameters(
1116      # Quantize to Float16 even if rep data provided.
1117      ('UseRepresentativeData', True, False, True, False, False, False, False),
1118      # Quantize to Float16 if no rep data provided.
1119      ('NoRepresentativeData', False, False, True, False, False, False, False),
1120      # Post training quantization if both rep data and int8 included.
1121      ('UseSampleDataIncludeInt8', True, True, False, False, True, False, False
1122      ),
1123      # Quantize to Float16 even if rep data provided with mlir.
1124      ('UseRepresentativeDataMlir', True, False, True, False, False, True, False
1125      ),
1126      # Quantize to Float16 if no rep data provided with mlir.
1127      ('NoRepresentativeDataMlir', False, False, True, False, False, True, False
1128      ),
1129      # Post training quantization if both rep data and int8 included with mlir.
1130      ('SampleDataIncludeInt8Mlir', True, True, False, False, True, True, False
1131      ),
1132      # Same as above, but using MLIR quantizer
1133      ('SampleDataIncludeInt8MlirQuant', True, True, False, False, True, True,
1134       True))
1135  def testQuantizeFloat16(self, use_rep_data, include_int8,
1136                          is_float16_quantized, is_error,
1137                          is_post_training_quantized, enable_mlir_converter,
1138                          enable_mlir_quantizer):
1139    with ops.Graph().as_default():
1140      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1141      sess = session.Session()
1142
1143    bias_idx = 1 if enable_mlir_converter else 0
1144    bias_name = 'Conv2D' if enable_mlir_converter else 'Conv2D_bias'
1145
1146    # Convert float model.
1147    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1148    float_converter.experimental_new_converter = enable_mlir_converter
1149    float_tflite_model = float_converter.convert()
1150    self.assertIsNotNone(float_tflite_model)
1151    interpreter = Interpreter(model_content=float_tflite_model)
1152    interpreter.allocate_tensors()
1153    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
1154                     bias_name)
1155    self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1156                     dtypes.float32)
1157
1158    # MLIR quantizer has different bias index.
1159    if enable_mlir_quantizer:
1160      bias_idx = 2
1161
1162    # Convert model to quantized version
1163    quantized_converter = lite.TFLiteConverter.from_session(
1164        sess, [inp], [output])
1165    quantized_converter.experimental_new_converter = enable_mlir_converter
1166    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
1167    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1168    quantized_converter.target_spec.supported_types = [dtypes.float16]
1169    if include_int8:
1170      quantized_converter.target_spec.supported_types.append(dtypes.int8)
1171    if use_rep_data:
1172      quantized_converter.representative_dataset = calibration_gen
1173
1174    if is_error:
1175      with self.assertRaises(ValueError) as error:
1176        quantized_converter.convert()
1177      self.assertEqual(
1178          'representative_dataset is required when specifying '
1179          'TFLITE_BUILTINS_INT8 or INT8 supported types.', str(error.exception))
1180
1181    else:
1182      quantized_tflite_model = quantized_converter.convert()
1183      self.assertIsNotNone(quantized_tflite_model)
1184      interpreter = Interpreter(model_content=quantized_tflite_model)
1185      interpreter.allocate_tensors()
1186      self.assertEqual(interpreter.get_tensor_details()[bias_idx]['name'],
1187                       bias_name)
1188
1189      if is_float16_quantized:
1190        # Verify that bias constant is float16 type.
1191        self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1192                         dtypes.float16)
1193      elif is_post_training_quantized:
1194        # Verify that bias constants is int32 type.
1195        self.assertEqual(interpreter.get_tensor_details()[bias_idx]['dtype'],
1196                         dtypes.int32)
1197      else:
1198        raise ValueError('Invalid test options.')
1199
1200  @parameterized.named_parameters(
1201      ('EnableMlirConverter', True),  # enable mlir
1202      ('DisableMlirConverter', False))  # disable mlir
1203  def testInvalidQuantizeFloat16(self, enable_mlir_converter):
1204    with ops.Graph().as_default():
1205      inp, output, _ = self._getIntegerQuantizeModel()
1206      sess = session.Session()
1207
1208    # Specify float16 quantization
1209    quantized_converter = lite.TFLiteConverter.from_session(
1210        sess, [inp], [output])
1211    quantized_converter.experimental_new_converter = enable_mlir_converter
1212    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1213    quantized_converter.target_spec.supported_types = [dtypes.float16]
1214    # Specify only int8 builtin ops
1215    quantized_converter.target_spec.supported_ops = [
1216        lite.OpsSet.TFLITE_BUILTINS_INT8
1217    ]
1218    with self.assertRaises(ValueError) as error:
1219      quantized_converter.convert()
1220    self.assertEqual(
1221        'TFLITE_BUILTINS_INT8 requires smallest supported type to be INT8.',
1222        str(error.exception))
1223
1224  @parameterized.named_parameters(
1225      ('InferenceType_INT8', dtypes.int8),
1226      ('InferenceType_UINT8', dtypes.uint8))
1227  def testInvalidQuantizeQATModelRequiresInputStats(self, quantized_type):
1228    with ops.Graph().as_default():
1229      in_tensor = array_ops.placeholder(
1230          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1231      out_tensor = array_ops.fake_quant_with_min_max_args(
1232          in_tensor + in_tensor, min=0., max=1.)
1233      sess = session.Session()
1234
1235    quantized_converter = lite.TFLiteConverter.from_session(
1236        sess, [in_tensor], [out_tensor])
1237
1238    with self.assertRaises(ValueError) as error:
1239      quantized_converter.inference_type = quantized_type
1240      quantized_converter.convert()
1241    self.assertEqual(
1242        'The `quantized_input_stats` flag must be defined when either '
1243        '`inference_type` flag or `inference_input_type` flag is set to '
1244        'tf.int8 or tf.uint8. Currently, `inference_type=tf.{}` and '
1245        '`inference_input_type=None`.'.format(quantized_type.name),
1246        str(error.exception))
1247
1248    with self.assertRaises(ValueError) as error:
1249      quantized_converter.inference_type = dtypes.float32
1250      quantized_converter.inference_input_type = quantized_type
1251      quantized_converter.convert()
1252    self.assertEqual(
1253        'The `quantized_input_stats` flag must be defined when either '
1254        '`inference_type` flag or `inference_input_type` flag is set to '
1255        'tf.int8 or tf.uint8. Currently, `inference_type=tf.float32` and '
1256        '`inference_input_type=tf.{}`.'.format(quantized_type.name),
1257        str(error.exception))
1258
1259    quantized_converter.inference_type = quantized_type
1260    quantized_converter.inference_input_type = quantized_type
1261
1262    input_arrays = quantized_converter.get_input_arrays()
1263    quantized_converter.quantized_input_stats = {
1264        input_arrays[0]: (0., 1.)
1265    }
1266    quantized_converter.convert()
1267
1268  def testInvalidQuantizeQATModelMissingInputStats(self):
1269    with ops.Graph().as_default():
1270      in_tensor_1 = array_ops.placeholder(
1271          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
1272      in_tensor_2 = array_ops.placeholder(
1273          shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
1274      out_tensor = array_ops.fake_quant_with_min_max_args(
1275          in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
1276      sess = session.Session()
1277
1278    # Convert model and ensure model is not None.
1279    converter = lite.TFLiteConverter.from_session(sess,
1280                                                  [in_tensor_1, in_tensor_2],
1281                                                  [out_tensor])
1282    converter.inference_type = dtypes.uint8
1283    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
1284    with self.assertRaises(ValueError) as error:
1285      converter.convert()
1286    self.assertEqual(
1287        'Quantization input stats are not available for input tensors '
1288        '\'inputB\'.', str(error.exception))
1289
1290  def testTrainingTimeAndPostTrainingCalibrateAndQuantize(self):
1291    with ops.Graph().as_default():
1292      inp, output, calibration_gen = self._getIntegerQuantizeModel()
1293      sess = session.Session()
1294
1295    # Convert float model.
1296    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1297    float_tflite_model = float_converter.convert()
1298    self.assertIsNotNone(float_tflite_model)
1299
1300    converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
1301
1302    # extra flags to trigger training time quantization conversion
1303    converter.inference_type = dtypes.int8
1304    converter.inference_input_type = dtypes.float32
1305    converter.inference_output_type = dtypes.float32
1306    input_arrays = converter.get_input_arrays()
1307    converter.quantized_input_stats = {
1308        input_arrays[0]: (0., 1.)
1309    }
1310    # trigger post-training quantization
1311    converter.optimizations = [lite.Optimize.DEFAULT]
1312    converter.representative_dataset = calibration_gen
1313    converter.experimental_new_quantizer = True
1314    quantized_tflite_model = converter.convert()
1315    self.assertIsNotNone(quantized_tflite_model)
1316    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1317
1318    # calibration only api
1319    converter._experimental_calibrate_only = True
1320    calibrated_tflite = converter.convert()
1321    quantized_tflite_model = mlir_quantize(
1322        calibrated_tflite, fully_quantize=True)
1323    interpreter = Interpreter(model_content=quantized_tflite_model)
1324    interpreter.allocate_tensors()
1325    input_details = interpreter.get_input_details()
1326    self.assertEqual(np.int8, input_details[0]['dtype'])
1327    self.assertEqual((1., 0.), input_details[0]['quantization'])
1328
1329    output_details = interpreter.get_output_details()
1330    self.assertEqual(np.int8, output_details[0]['dtype'])
1331
1332  def testFloatTocoConverter(self):
1333    """Tests deprecated test TocoConverter."""
1334    with ops.Graph().as_default():
1335      in_tensor = array_ops.placeholder(
1336          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1337      out_tensor = in_tensor + in_tensor
1338      sess = session.Session()
1339
1340    # Convert model and ensure model is not None.
1341    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
1342    tflite_model = converter.convert()
1343    self.assertIsNotNone(tflite_model)
1344
1345    # Ensure the interpreter is able to load.
1346    interpreter = Interpreter(model_content=tflite_model)
1347    interpreter.allocate_tensors()
1348
1349  def testMultipleOutputNodeNames(self):
1350    """Tests converting a graph with an op that have multiple outputs."""
1351    with ops.Graph().as_default():
1352      input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
1353      out0, out1, out2, out3 = array_ops.split(
1354          input_tensor, [1, 1, 1, 1], axis=0)
1355      sess = session.Session()
1356
1357    # Convert model and ensure model is not None.
1358    converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
1359                                                  [out0, out1, out2, out3])
1360    tflite_model = converter.convert()
1361    self.assertIsNotNone(tflite_model)
1362
1363    # Check values from converted model.
1364    interpreter = Interpreter(model_content=tflite_model)
1365    interpreter.allocate_tensors()
1366
1367    input_details = interpreter.get_input_details()
1368    self.assertLen(input_details, 1)
1369    interpreter.set_tensor(input_details[0]['index'],
1370                           np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
1371    interpreter.invoke()
1372
1373    output_details = interpreter.get_output_details()
1374    self.assertLen(output_details, 4)
1375    self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
1376    self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
1377    self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
1378    self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
1379
1380  @parameterized.named_parameters(
1381      ('EnableMlirConverter', True),  # enable mlir
1382      ('DisableMlirConverter', False))  # disable mlir
1383  @test_util.run_in_graph_and_eager_modes
1384  def testFunctions(self, enable_mlir_converter):
1385    """Tests tf.function in 1.X."""
1386
1387    @def_function.function
1388    def plus_placeholder(x, placeholder):
1389      return x + placeholder
1390
1391    with ops.Graph().as_default():
1392      placeholder = array_ops.placeholder(
1393          dtype=dtypes.float32, shape=[1], name='input')
1394      variable_node = variables.Variable(1.0, name='variable_node')
1395      defun_node = plus_placeholder(variable_node, placeholder)
1396      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
1397
1398      # Initialize variables in the model.
1399      sess = session.Session()
1400      sess.run(variables.variables_initializer([variable_node]))
1401
1402    # Convert model and ensure model is not None.
1403    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
1404                                                  [output_node])
1405    converter.experimental_new_converter = enable_mlir_converter
1406    tflite_model = converter.convert()
1407    self.assertIsNotNone(tflite_model)
1408
1409    # Check values from converted model.
1410    interpreter = Interpreter(model_content=tflite_model)
1411    interpreter.allocate_tensors()
1412
1413    input_details = interpreter.get_input_details()
1414    self.assertLen(input_details, 1)
1415    self.assertEqual('input', input_details[0]['name'])
1416    self.assertEqual(np.float32, input_details[0]['dtype'])
1417    self.assertAllEqual([1], input_details[0]['shape'])
1418    self.assertEqual((0., 0.), input_details[0]['quantization'])
1419
1420    output_details = interpreter.get_output_details()
1421    self.assertLen(output_details, 1)
1422    self.assertEqual('output_node', output_details[0]['name'])
1423    self.assertEqual(np.float32, output_details[0]['dtype'])
1424    self.assertAllEqual([1], output_details[0]['shape'])
1425    self.assertEqual((0., 0.), output_details[0]['quantization'])
1426
1427  def testInferenceInputOutputTypeFloatDefault(self):
1428    with ops.Graph().as_default():
1429      in_tensor = array_ops.placeholder(
1430          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1431      out_tensor = in_tensor + in_tensor
1432      sess = session.Session()
1433
1434    # Convert model and ensure model is not None.
1435    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1436                                                  [out_tensor])
1437    tflite_model = converter.convert()
1438    self.assertIsNotNone(tflite_model)
1439
1440    # Check values from converted model.
1441    interpreter = Interpreter(model_content=tflite_model)
1442    interpreter.allocate_tensors()
1443
1444    input_details = interpreter.get_input_details()
1445    self.assertLen(input_details, 1)
1446    self.assertEqual('Placeholder', input_details[0]['name'])
1447    self.assertEqual(np.float32, input_details[0]['dtype'])
1448    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1449
1450    output_details = interpreter.get_output_details()
1451    self.assertLen(output_details, 1)
1452    self.assertEqual('add', output_details[0]['name'])
1453    self.assertEqual(np.float32, output_details[0]['dtype'])
1454    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1455
1456  def testInferenceInputOutputTypeQuantizedUint8Default(self):
1457    with ops.Graph().as_default():
1458      in_tensor = array_ops.placeholder(
1459          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1460      out_tensor = array_ops.fake_quant_with_min_max_args(
1461          in_tensor + in_tensor, min=0., max=1., name='output')
1462      sess = session.Session()
1463
1464    # Convert model and ensure model is not None.
1465    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1466                                                  [out_tensor])
1467    converter.inference_type = dtypes.uint8
1468    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
1469    tflite_model = converter.convert()
1470    self.assertIsNotNone(tflite_model)
1471
1472    # Check values from converted model.
1473    interpreter = Interpreter(model_content=tflite_model)
1474    interpreter.allocate_tensors()
1475
1476    input_details = interpreter.get_input_details()
1477    self.assertLen(input_details, 1)
1478    self.assertEqual('Placeholder', input_details[0]['name'])
1479    self.assertEqual(np.uint8, input_details[0]['dtype'])
1480    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1481
1482    output_details = interpreter.get_output_details()
1483    self.assertLen(output_details, 1)
1484    self.assertEqual('output', output_details[0]['name'])
1485    self.assertEqual(np.uint8, output_details[0]['dtype'])
1486    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1487
1488  def testReusingConverterWithDifferentPostTrainingQuantization(self):
1489    with ops.Graph().as_default():
1490      in_tensor = array_ops.placeholder(
1491          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1492      out_tensor = array_ops.fake_quant_with_min_max_args(
1493          in_tensor + in_tensor, min=0., max=1., name='output')
1494      sess = session.Session()
1495
1496    # Convert model and ensure model is not None.
1497    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1498                                                  [out_tensor])
1499
1500    converter.post_training_quantize = True
1501    tflite_model = converter.convert()
1502    self.assertIsNotNone(tflite_model)
1503
1504    converter.post_training_quantize = False
1505    tflite_model = converter.convert()
1506    self.assertIsNotNone(tflite_model)
1507
1508  def testResizeWithShape(self):
1509    with ops.Graph().as_default():
1510      # Construct a graph with a dynamically shapped input and an internal node
1511      # that relies on the output of that input's shape.
1512      in_tensor = array_ops.placeholder(
1513          shape=[None, None], dtype=dtypes.float32)
1514      in_tensor2 = [[1, 2], [3, 4]]
1515      out_tensor = array_ops.reshape(in_tensor2, array_ops.shape(in_tensor))
1516      sess = session.Session()
1517
1518    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
1519                                                  [out_tensor])
1520    tflite_model = converter.convert()
1521
1522    # Check values from converted model.
1523    interpreter = Interpreter(model_content=tflite_model)
1524    input_details = interpreter.get_input_details()
1525    self.assertLen(input_details, 1)
1526    self.assertAllEqual([1, 1], input_details[0]['shape'])
1527    self.assertAllEqual([-1, -1], input_details[0]['shape_signature'])
1528
1529    # Resize tensor and invoke.
1530    interpreter.resize_tensor_input(0, [4])
1531    interpreter.allocate_tensors()
1532    interpreter.invoke()
1533
1534    # The output should be reshaped properly according to the resized input.
1535    output_details = interpreter.get_output_details()
1536    self.assertLen(output_details, 1)
1537    self.assertEqual(np.int32, output_details[0]['dtype'])
1538    self.assertAllEqual([4], output_details[0]['shape'])
1539    output_data = interpreter.get_tensor(output_details[0]['index'])
1540    self.assertAllEqual([1, 2, 3, 4], output_data)
1541
1542  def testResizingIntermediateDynamicTensor(self):
1543    # This is a regression test for the case where shape of dynamic output
1544    # tensors changes between invocations.
1545    # See also https://github.com/tensorflow/tensorflow/issues/26549
1546    with ops.Graph().as_default():
1547      input_tensor = array_ops.placeholder(shape=[1, 1], dtype=dtypes.float32)
1548      input2_tensor = array_ops.placeholder(shape=[1], dtype=dtypes.float32)
1549
1550      # The bug is triggered only when dynamic tensor is intermediate. Putting
1551      # some other ops around it.
1552      neg = math_ops.negative(input2_tensor)
1553      padding = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int32)
1554      output_tensor = array_ops.pad(input_tensor, padding) + neg
1555
1556      sess = session.Session()
1557
1558    converter = lite.TFLiteConverter.from_session(
1559        sess, [input_tensor, padding, input2_tensor], [output_tensor])
1560    tflite_model = converter.convert()
1561
1562    interpreter = Interpreter(model_content=tflite_model)
1563    interpreter.allocate_tensors()
1564
1565    input_details = interpreter.get_input_details()
1566    interpreter.set_tensor(input_details[1]['index'],
1567                           np.array([[1, 1], [1, 1]], dtype=np.int32))
1568    interpreter.invoke()
1569
1570    # Without the fix, invocation will fail when changing the shape of
1571    # intermediate dynamic tensors.
1572    interpreter.set_tensor(input_details[1]['index'],
1573                           np.array([[2, 2], [2, 2]], dtype=np.int32))
1574    interpreter.invoke()
1575
1576  def testGraphDebugInfo(self):
1577    """Test a session has debug info captured."""
1578
1579    @def_function.function
1580    def plus_placeholder(x, placeholder):
1581      return x + placeholder
1582
1583    with ops.Graph().as_default():
1584      placeholder = array_ops.placeholder(
1585          dtype=dtypes.float32, shape=[1], name='input')
1586      variable_node = variables.Variable(1.0, name='variable_node')
1587      defun_node = plus_placeholder(variable_node, placeholder)
1588      output_node = math_ops.multiply(defun_node, 2.0, name='output_node')
1589
1590      # Initialize variables in the model.
1591      sess = session.Session()
1592      sess.run(variables.variables_initializer([variable_node]))
1593
1594    converter = lite.TFLiteConverter.from_session(sess, [placeholder],
1595                                                  [output_node])
1596    converter.convert()
1597    self.assertValidDebugInfo(converter._debug_info)
1598
1599    # Check the add node in the inlined function is included.
1600    func = sess.graph.as_graph_def().library.function[0].signature.name
1601    self.assertIn(('add@' + six.ensure_str(func)), converter._debug_info.traces)
1602
1603  def testOutputOnlyModel(self):
1604    with ops.Graph().as_default():
1605      out_tensor = random_ops.random_normal(shape=[3])
1606      sess = session.Session()
1607
1608    # Convert model and ensure model is not None.
1609    converter = lite.TFLiteConverter.from_session(sess, [], [out_tensor])
1610    converter.target_spec.supported_ops = [
1611        lite.OpsSet.TFLITE_BUILTINS,
1612        lite.OpsSet.SELECT_TF_OPS,
1613    ]
1614
1615    # Empty input array is a valid input.
1616    self.assertTrue(converter._has_valid_tensors())
1617
1618    tflite_model = converter.convert()
1619    self.assertIsNotNone(tflite_model)
1620
1621
1622class FromFrozenGraphFile(LiteTest):
1623
1624  def testFloat(self):
1625    with ops.Graph().as_default():
1626      in_tensor = array_ops.placeholder(
1627          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1628      _ = in_tensor + in_tensor
1629      sess = session.Session()
1630
1631    # Write graph to file.
1632    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1633    write_graph(sess.graph_def, '', graph_def_file, False)
1634    sess.close()
1635
1636    # Convert model and ensure model is not None.
1637    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1638                                                       ['Placeholder'], ['add'])
1639    tflite_model = converter.convert()
1640    self.assertIsNotNone(tflite_model)
1641
1642    # Check values from converted model.
1643    interpreter = Interpreter(model_content=tflite_model)
1644    interpreter.allocate_tensors()
1645
1646    input_details = interpreter.get_input_details()
1647    self.assertLen(input_details, 1)
1648    self.assertEqual('Placeholder', input_details[0]['name'])
1649    self.assertEqual(np.float32, input_details[0]['dtype'])
1650    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1651    self.assertEqual((0., 0.), input_details[0]['quantization'])
1652
1653    output_details = interpreter.get_output_details()
1654    self.assertLen(output_details, 1)
1655    self.assertEqual('add', output_details[0]['name'])
1656    self.assertEqual(np.float32, output_details[0]['dtype'])
1657    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1658    self.assertEqual((0., 0.), output_details[0]['quantization'])
1659
1660  def testFloatWithShapesArray(self):
1661    """Test a shape overriding case."""
1662    with ops.Graph().as_default():
1663      in_tensor = array_ops.placeholder(
1664          shape=[None, 16, 16, 3], dtype=dtypes.float32)
1665      _ = in_tensor + in_tensor
1666      sess = session.Session()
1667
1668    # Write graph to file.
1669    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1670    write_graph(sess.graph_def, '', graph_def_file, False)
1671    sess.close()
1672
1673    # Convert model and ensure model is not None.
1674    converter = lite.TFLiteConverter.from_frozen_graph(
1675        graph_def_file, ['Placeholder'], ['add'],
1676        input_shapes={'Placeholder': [2, 16, 16, 3]})
1677    tflite_model = converter.convert()
1678    self.assertIsNotNone(tflite_model)
1679
1680    # Check values from converted model.
1681    interpreter = Interpreter(model_content=tflite_model)
1682    interpreter.allocate_tensors()
1683
1684    input_details = interpreter.get_input_details()
1685    self.assertLen(input_details, 1)
1686    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
1687
1688  def testInvalidShapesArray(self):
1689    """Test an invalid shape overriding case, which has a wrong input name."""
1690    with ops.Graph().as_default():
1691      in_tensor = array_ops.placeholder(
1692          shape=[None, 16, 16, 3], dtype=dtypes.float32)
1693      _ = in_tensor + in_tensor
1694      sess = session.Session()
1695
1696    # Write graph to file.
1697    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1698    write_graph(sess.graph_def, '', graph_def_file, False)
1699    sess.close()
1700
1701    # Convert model and ensure model is not None.
1702    with self.assertRaises(ValueError):
1703      lite.TFLiteConverter.from_frozen_graph(
1704          graph_def_file, ['Placeholder'], ['add'],
1705          input_shapes={'wrong_input': [2, 16, 16, 3]})
1706
1707  def testPartialShapesArray(self):
1708    """Test a shape overriding case, with the only one input among two."""
1709    with ops.Graph().as_default():
1710      a = array_ops.placeholder(
1711          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='a')
1712      b = array_ops.placeholder(
1713          shape=[None, 16, 16, 3], dtype=dtypes.float32, name='b')
1714      _ = math_ops.add(a, b, name='add')
1715      sess = session.Session()
1716
1717    # Write graph to file.
1718    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1719    write_graph(sess.graph_def, '', graph_def_file, False)
1720    sess.close()
1721
1722    # Convert model and ensure model is not None.
1723    converter = lite.TFLiteConverter.from_frozen_graph(
1724        graph_def_file, ['a', 'b'], ['add'], input_shapes={'a': [2, 16, 16, 3]})
1725    tflite_model = converter.convert()
1726    self.assertIsNotNone(tflite_model)
1727
1728    # Check values from converted model.
1729    interpreter = Interpreter(model_content=tflite_model)
1730    interpreter.allocate_tensors()
1731
1732    input_details = interpreter.get_input_details()
1733    self.assertLen(input_details, 2)
1734    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
1735    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1736
1737  def testFreezeGraph(self):
1738    with ops.Graph().as_default():
1739      in_tensor = array_ops.placeholder(
1740          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1741      var = variable_scope.get_variable(
1742          'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
1743      _ = in_tensor + var
1744      sess = session.Session()
1745
1746    # Write graph to file.
1747    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1748    write_graph(sess.graph_def, '', graph_def_file, False)
1749    sess.close()
1750
1751    # Ensure the graph with variables cannot be converted.
1752    with self.assertRaises(ValueError) as error:
1753      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
1754                                             ['add'])
1755    self.assertEqual('Please freeze the graph using freeze_graph.py.',
1756                     str(error.exception))
1757
1758  def testPbtxt(self):
1759    with ops.Graph().as_default():
1760      in_tensor = array_ops.placeholder(
1761          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1762      _ = in_tensor + in_tensor
1763      sess = session.Session()
1764
1765    # Write graph to file.
1766    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
1767    write_graph(sess.graph_def, '', graph_def_file, True)
1768    sess.close()
1769
1770    # Convert model and ensure model is not None.
1771    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
1772                                                       ['Placeholder'], ['add'])
1773    tflite_model = converter.convert()
1774    self.assertIsNotNone(tflite_model)
1775
1776    # Check values from converted model.
1777    interpreter = Interpreter(model_content=tflite_model)
1778    interpreter.allocate_tensors()
1779
1780    input_details = interpreter.get_input_details()
1781    self.assertLen(input_details, 1)
1782    self.assertEqual('Placeholder', input_details[0]['name'])
1783    self.assertEqual(np.float32, input_details[0]['dtype'])
1784    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1785    self.assertEqual((0., 0.), input_details[0]['quantization'])
1786
1787    output_details = interpreter.get_output_details()
1788    self.assertLen(output_details, 1)
1789    self.assertEqual('add', output_details[0]['name'])
1790    self.assertEqual(np.float32, output_details[0]['dtype'])
1791    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1792    self.assertEqual((0., 0.), output_details[0]['quantization'])
1793
1794  def testInvalidFileNotFound(self):
1795    with self.assertRaises(IOError) as error:
1796      lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
1797                                             ['add'])
1798    self.assertEqual('File \'invalid_file\' does not exist.',
1799                     str(error.exception))
1800
1801  def testInvalidFileBadData(self):
1802    graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
1803    with gfile.Open(graph_def_file, 'wb') as temp_file:
1804      temp_file.write('bad data')
1805      temp_file.flush()
1806
1807    # Attempts to convert the invalid model.
1808    with self.assertRaises(IOError) as error:
1809      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
1810                                             ['add'])
1811    self.assertEqual(
1812        'Unable to parse input file \'{}\'.'.format(graph_def_file),
1813        str(error.exception))
1814
1815  def testFloatTocoConverter(self):
1816    with ops.Graph().as_default():
1817      in_tensor = array_ops.placeholder(
1818          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1819      _ = in_tensor + in_tensor
1820      sess = session.Session()
1821
1822    # Write graph to file.
1823    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1824    write_graph(sess.graph_def, '', graph_def_file, False)
1825    sess.close()
1826
1827    # Convert model and ensure model is not None.
1828    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
1829                                                     ['Placeholder'], ['add'])
1830    tflite_model = converter.convert()
1831    self.assertIsNotNone(tflite_model)
1832
1833    # Ensure the model is able to load.
1834    interpreter = Interpreter(model_content=tflite_model)
1835    interpreter.allocate_tensors()
1836
1837  def testGraphDebugInfo(self):
1838    """Test a frozen graph doesn't have debug info captured."""
1839    with ops.Graph().as_default():
1840      in_tensor = array_ops.placeholder(
1841          shape=[1, 16, 16, 3], dtype=dtypes.float32)
1842      _ = in_tensor + in_tensor
1843      sess = session.Session()
1844
1845    # Write graph to file.
1846    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
1847    write_graph(sess.graph_def, '', graph_def_file, False)
1848    sess.close()
1849
1850    # Convert model and ensure model is not None.
1851    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
1852                                                     ['Placeholder'], ['add'])
1853    converter.convert()
1854    # GraphDebugInfo should be none for frozen graph.
1855    self.assertFalse(converter._debug_info)
1856
1857
1858class FromFrozenGraphObjectDetection(LiteTest):
1859
1860  def _initObjectDetectionArgs(self):
1861    # Initializes the arguments required for the object detection model.
1862    # Looks for the model file which is saved in a different location internally
1863    # and externally.
1864    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
1865    if not os.path.exists(filename):
1866      filename = os.path.join(
1867          resource_loader.get_root_dir_with_all_resources(),
1868          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
1869      if not os.path.exists(filename):
1870        raise IOError("File '{0}' does not exist.".format(filename))
1871
1872    self._graph_def_file = filename
1873    self._input_arrays = ['normalized_input_image_tensor']
1874    self._output_arrays = [
1875        'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
1876        'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
1877    ]
1878    self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
1879
1880  def testTFLiteGraphDef(self):
1881    # Tests the object detection model that cannot be loaded in TensorFlow.
1882    self._initObjectDetectionArgs()
1883
1884    converter = lite.TFLiteConverter.from_frozen_graph(self._graph_def_file,
1885                                                       self._input_arrays,
1886                                                       self._output_arrays,
1887                                                       self._input_shapes)
1888    converter.allow_custom_ops = True
1889    tflite_model = converter.convert()
1890    self.assertIsNotNone(tflite_model)
1891
1892    # Check values from converted model.
1893    interpreter = Interpreter(model_content=tflite_model)
1894    interpreter.allocate_tensors()
1895
1896    input_details = interpreter.get_input_details()
1897    self.assertLen(input_details, 1)
1898    self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
1899    self.assertEqual(np.float32, input_details[0]['dtype'])
1900    self.assertAllEqual([1, 300, 300, 3], input_details[0]['shape'])
1901    self.assertEqual((0., 0.), input_details[0]['quantization'])
1902
1903    output_details = interpreter.get_output_details()
1904    self.assertLen(output_details, 4)
1905    self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
1906    self.assertEqual(np.float32, output_details[0]['dtype'])
1907    self.assertAllEqual([1, 10, 4], output_details[0]['shape'])
1908    self.assertEqual((0., 0.), output_details[0]['quantization'])
1909
1910    self.assertEqual('TFLite_Detection_PostProcess:1',
1911                     output_details[1]['name'])
1912    self.assertAllEqual([1, 10], output_details[1]['shape'])
1913    self.assertEqual('TFLite_Detection_PostProcess:2',
1914                     output_details[2]['name'])
1915    self.assertAllEqual([1, 10], output_details[2]['shape'])
1916    self.assertEqual('TFLite_Detection_PostProcess:3',
1917                     output_details[3]['name'])
1918    self.assertAllEqual([1], output_details[3]['shape'])
1919
1920
1921class FromSavedModelTest(TestModels):
1922
1923  def _createSavedModel(self, shape):
1924    """Create a simple SavedModel."""
1925    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
1926    with ops.Graph().as_default():
1927      with session.Session() as sess:
1928        in_tensor_1 = array_ops.placeholder(
1929            shape=shape, dtype=dtypes.float32, name='inputB')
1930        in_tensor_2 = array_ops.placeholder(
1931            shape=shape, dtype=dtypes.float32, name='inputA')
1932        out_tensor = in_tensor_1 + in_tensor_2
1933        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
1934        outputs = {'z': out_tensor}
1935        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1936    return saved_model_dir
1937
1938  def testSimpleModel(self):
1939    """Test a SavedModel."""
1940    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
1941
1942    # Convert model and ensure model is not None.
1943    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
1944    tflite_model = converter.convert()
1945    self.assertIsNotNone(tflite_model)
1946
1947    interpreter = Interpreter(model_content=tflite_model)
1948    interpreter.allocate_tensors()
1949
1950    input_details = interpreter.get_input_details()
1951    self.assertLen(input_details, 2)
1952    self.assertStartsWith(input_details[0]['name'], 'inputA')
1953    self.assertEqual(np.float32, input_details[0]['dtype'])
1954    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
1955    self.assertEqual((0., 0.), input_details[0]['quantization'])
1956
1957    self.assertStartsWith(input_details[1]['name'], 'inputB')
1958    self.assertEqual(np.float32, input_details[1]['dtype'])
1959    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
1960    self.assertEqual((0., 0.), input_details[1]['quantization'])
1961
1962    output_details = interpreter.get_output_details()
1963    self.assertLen(output_details, 1)
1964    self.assertStartsWith(output_details[0]['name'], 'add')
1965    self.assertEqual(np.float32, output_details[0]['dtype'])
1966    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
1967    self.assertEqual((0., 0.), output_details[0]['quantization'])
1968
1969  def testOldConverterWarning(self):
1970    """Test if the warning message when using TOCO is logged."""
1971    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
1972    log = io.BytesIO() if six.PY2 else io.StringIO()
1973    handler = logging.StreamHandler(log)
1974    logging.root.addHandler(handler)
1975    warning_message = 'Please consider switching to the new converter'
1976    # Convert model and ensure model is not None.
1977    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
1978    converter.experimental_new_converter = False
1979    tflite_model = converter.convert()
1980    self.assertIsNotNone(tflite_model)
1981    self.assertIn(warning_message, log.getvalue())
1982    logging.root.removeHandler(handler)
1983
1984  def testNewConverterOptOut(self):
1985    """Test if the opt out message when using New converter is logged."""
1986    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
1987    log = io.BytesIO() if six.PY2 else io.StringIO()
1988    handler = logging.StreamHandler(log)
1989    logging.root.addHandler(handler)
1990    optout_message = ('Using experimental converter: '
1991                      'If you encountered a problem')
1992    # Convert model and ensure model is not None.
1993    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
1994    tflite_model = converter.convert()
1995    self.assertIsNotNone(tflite_model)
1996    self.assertIn(optout_message, log.getvalue())
1997    logging.root.removeHandler(handler)
1998
1999  def testNoneBatchSize(self):
2000    """Test a SavedModel, with None in input tensor's shape."""
2001    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
2002
2003    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2004    tflite_model = converter.convert()
2005    self.assertIsNotNone(tflite_model)
2006
2007    # Check values from converted model.
2008    interpreter = Interpreter(model_content=tflite_model)
2009    interpreter.allocate_tensors()
2010
2011    input_details = interpreter.get_input_details()
2012    self.assertLen(input_details, 2)
2013    self.assertStartsWith(input_details[0]['name'], 'inputA')
2014    self.assertEqual(np.float32, input_details[0]['dtype'])
2015    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2016    self.assertEqual((0., 0.), input_details[0]['quantization'])
2017
2018    self.assertStartsWith(input_details[1]['name'], 'inputB')
2019    self.assertEqual(np.float32, input_details[1]['dtype'])
2020    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2021    self.assertEqual((0., 0.), input_details[1]['quantization'])
2022
2023    output_details = interpreter.get_output_details()
2024    self.assertLen(output_details, 1)
2025    self.assertStartsWith(output_details[0]['name'], 'add')
2026    self.assertEqual(np.float32, output_details[0]['dtype'])
2027    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2028    self.assertEqual((0., 0.), output_details[0]['quantization'])
2029
2030  def testOrderInputArrays(self):
2031    """Test a SavedModel ordering of input arrays."""
2032    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2033
2034    converter = lite.TFLiteConverter.from_saved_model(
2035        saved_model_dir, input_arrays=['inputB', 'inputA'])
2036    tflite_model = converter.convert()
2037    self.assertIsNotNone(tflite_model)
2038
2039    # Check values from converted model.
2040    interpreter = Interpreter(model_content=tflite_model)
2041    interpreter.allocate_tensors()
2042
2043    input_details = interpreter.get_input_details()
2044    self.assertLen(input_details, 2)
2045    self.assertStartsWith(input_details[0]['name'], 'inputA')
2046    self.assertEqual(np.float32, input_details[0]['dtype'])
2047    self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
2048    self.assertEqual((0., 0.), input_details[0]['quantization'])
2049
2050    self.assertStartsWith(input_details[1]['name'], 'inputB')
2051    self.assertEqual(np.float32, input_details[1]['dtype'])
2052    self.assertAllEqual([1, 16, 16, 3], input_details[1]['shape'])
2053    self.assertEqual((0., 0.), input_details[1]['quantization'])
2054
2055    output_details = interpreter.get_output_details()
2056    self.assertLen(output_details, 1)
2057    self.assertStartsWith(output_details[0]['name'], 'add')
2058    self.assertEqual(np.float32, output_details[0]['dtype'])
2059    self.assertAllEqual([1, 16, 16, 3], output_details[0]['shape'])
2060    self.assertEqual((0., 0.), output_details[0]['quantization'])
2061
2062  def testShapeOverriding(self):
2063    """Test a SavedModel with the input_shapes arugment."""
2064    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
2065
2066    # Convert model and ensure model is not None.
2067    converter = lite.TFLiteConverter.from_saved_model(
2068        saved_model_dir,
2069        input_shapes={
2070            'inputA': [2, 16, 16, 3],
2071            'inputB': [2, 16, 16, 3]
2072        })
2073    tflite_model = converter.convert()
2074    self.assertIsNotNone(tflite_model)
2075
2076    interpreter = Interpreter(model_content=tflite_model)
2077    interpreter.allocate_tensors()
2078
2079    input_details = interpreter.get_input_details()
2080    self.assertLen(input_details, 2)
2081    self.assertStartsWith(input_details[0]['name'], 'inputA')
2082    self.assertEqual(np.float32, input_details[0]['dtype'])
2083    self.assertAllEqual([2, 16, 16, 3], input_details[0]['shape'])
2084    self.assertEqual((0., 0.), input_details[0]['quantization'])
2085
2086    self.assertStartsWith(input_details[1]['name'], 'inputB')
2087    self.assertEqual(np.float32, input_details[1]['dtype'])
2088    self.assertAllEqual([2, 16, 16, 3], input_details[1]['shape'])
2089    self.assertEqual((0., 0.), input_details[1]['quantization'])
2090
2091    output_details = interpreter.get_output_details()
2092    self.assertLen(output_details, 1)
2093    self.assertStartsWith(output_details[0]['name'], 'add')
2094    self.assertEqual(np.float32, output_details[0]['dtype'])
2095    self.assertAllEqual([2, 16, 16, 3], output_details[0]['shape'])
2096    self.assertEqual((0., 0.), output_details[0]['quantization'])
2097
2098  def testWrongInputShapes(self):
2099    """Test a SavedModel with a wrong name in the input_shapes argument."""
2100    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2101
2102    # Check case where input shape is given.
2103    with self.assertRaises(ValueError):
2104      lite.TFLiteConverter.from_saved_model(
2105          saved_model_dir,
2106          input_arrays=['inputA'],
2107          input_shapes={'wrong_input': [1, 16, 16, 3]})
2108
2109  def testSubsetInputShaapes(self):
2110    """Test a SavedModel with a subset of the input array names of the model."""
2111    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2112
2113    # Check case where input shape is given.
2114    converter = lite.TFLiteConverter.from_saved_model(
2115        saved_model_dir,
2116        input_arrays=['inputA'],
2117        input_shapes={'inputA': [1, 16, 16, 3]})
2118
2119    # Since we only partially specify the input, this is not allowed.
2120    with self.assertRaises(ConverterError):
2121      _ = converter.convert()
2122
2123    # Check case where input shape is None.
2124    converter = lite.TFLiteConverter.from_saved_model(
2125        saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
2126
2127    # Since we only partially specify the input, this is not allowed.
2128    with self.assertRaises(ConverterError):
2129      _ = converter.convert()
2130
2131  def testSimpleModelTocoConverter(self):
2132    """Test a SavedModel with deprecated TocoConverter."""
2133    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2134
2135    # Convert model and ensure model is not None.
2136    converter = lite.TocoConverter.from_saved_model(saved_model_dir)
2137    tflite_model = converter.convert()
2138    self.assertIsNotNone(tflite_model)
2139
2140    # Ensure the model is able to load.
2141    interpreter = Interpreter(model_content=tflite_model)
2142    interpreter.allocate_tensors()
2143
2144  def testGraphDebugInfo(self):
2145    """Test a SavedModel has debug info captured."""
2146    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
2147    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
2148    converter.convert()
2149    self.assertValidDebugInfo(converter._debug_info)
2150
2151
2152class MyAddLayer(keras.layers.Layer):
2153
2154  def __init__(self, increment, **kwargs):
2155    super(MyAddLayer, self).__init__(**kwargs)
2156    self._increment = increment
2157
2158  def call(self, inputs):
2159    return inputs + self._increment
2160
2161  def get_config(self):
2162    config = super(MyAddLayer, self).get_config()
2163    config['increment'] = self._increment
2164    return config
2165
2166
2167class FromKerasFile(TestModels, parameterized.TestCase):
2168
2169  def setUp(self):
2170    super(FromKerasFile, self).setUp()
2171    self._keras_file = None
2172    self._custom_objects = None
2173    if not context.executing_eagerly():
2174      keras.backend.clear_session()
2175
2176  def tearDown(self):
2177    if self._keras_file:
2178      os.remove(self._keras_file)
2179    super(FromKerasFile, self).tearDown()
2180
2181  def _getSequentialModel(self, include_custom_layer=False):
2182    model = keras.models.Sequential()
2183    model.add(keras.layers.Dense(2, input_shape=(3,)))
2184    if include_custom_layer:
2185      model.add(MyAddLayer(1.0))
2186    model.add(keras.layers.RepeatVector(3))
2187    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
2188    model.compile(
2189        loss=keras.losses.MSE,
2190        optimizer='sgd',
2191        metrics=[keras.metrics.categorical_accuracy],
2192        sample_weight_mode='temporal')
2193    x = np.random.random((1, 3))
2194    y = np.random.random((1, 3, 3))
2195    model.train_on_batch(x, y)
2196    model.predict(x)
2197
2198    try:
2199      fd, self._keras_file = tempfile.mkstemp('.h5')
2200      keras.models.save_model(model, self._keras_file)
2201    finally:
2202      os.close(fd)
2203
2204    if include_custom_layer:
2205      self._custom_objects = {'MyAddLayer': MyAddLayer}
2206
2207  @parameterized.named_parameters(('_graph', context.graph_mode),
2208                                  ('_eager', context.eager_mode))
2209  def testSequentialModel(self, test_context):
2210    """Test a Sequential tf.keras model with default inputs."""
2211    with test_context():
2212      self._getSequentialModel()
2213
2214      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2215      tflite_model = converter.convert()
2216      self.assertIsNotNone(tflite_model)
2217
2218    # Check tensor details of converted model.
2219    interpreter = Interpreter(model_content=tflite_model)
2220    interpreter.allocate_tensors()
2221
2222    input_details = interpreter.get_input_details()
2223    self.assertLen(input_details, 1)
2224    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2225    self.assertEqual(np.float32, input_details[0]['dtype'])
2226    self.assertAllEqual([1, 3], input_details[0]['shape'])
2227    self.assertEqual((0., 0.), input_details[0]['quantization'])
2228
2229    output_details = interpreter.get_output_details()
2230    self.assertLen(output_details, 1)
2231    self.assertEqual(np.float32, output_details[0]['dtype'])
2232    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
2233    self.assertEqual((0., 0.), output_details[0]['quantization'])
2234
2235    # Check inference of converted model.
2236    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2237    interpreter.set_tensor(input_details[0]['index'], input_data)
2238    interpreter.invoke()
2239    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2240
2241    keras_model = keras.models.load_model(self._keras_file)
2242    keras_result = keras_model.predict(input_data)
2243
2244    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2245
2246  @parameterized.named_parameters(('_graph', context.graph_mode),
2247                                  ('_eager', context.eager_mode))
2248  def testCustomLayer(self, test_context):
2249    """Test a Sequential tf.keras model with default inputs."""
2250    with test_context():
2251      self._getSequentialModel(include_custom_layer=True)
2252
2253      converter = lite.TFLiteConverter.from_keras_model_file(
2254          self._keras_file, custom_objects=self._custom_objects)
2255      tflite_model = converter.convert()
2256      self.assertIsNotNone(tflite_model)
2257
2258    # Check tensor details of converted model.
2259    interpreter = Interpreter(model_content=tflite_model)
2260    interpreter.allocate_tensors()
2261
2262    input_details = interpreter.get_input_details()
2263    output_details = interpreter.get_output_details()
2264
2265    # Check inference of converted model.
2266    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2267    interpreter.set_tensor(input_details[0]['index'], input_data)
2268    interpreter.invoke()
2269    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2270
2271    keras_model = keras.models.load_model(
2272        self._keras_file, custom_objects=self._custom_objects)
2273    keras_result = keras_model.predict(input_data)
2274
2275    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2276
2277  def testSequentialModelInputArray(self):
2278    """Test a Sequential tf.keras model testing input arrays argument."""
2279    ops.disable_eager_execution()
2280    self._getSequentialModel()
2281
2282    # Invalid input array raises error.
2283    with self.assertRaises(ValueError) as error:
2284      lite.TFLiteConverter.from_keras_model_file(
2285          self._keras_file, input_arrays=['invalid-input'])
2286    self.assertEqual("Invalid tensors 'invalid-input' were found.",
2287                     str(error.exception))
2288
2289    # Valid input array.
2290    converter = lite.TFLiteConverter.from_keras_model_file(
2291        self._keras_file, input_arrays=['dense_input'])
2292    tflite_model = converter.convert()
2293    self.assertIsNotNone(tflite_model)
2294
2295  def testSequentialModelInputShape(self):
2296    """Test a Sequential tf.keras model testing input shapes argument."""
2297    self._getSequentialModel()
2298
2299    # Passing in shape of invalid input array raises error.
2300    with self.assertRaises(ValueError) as error:
2301      converter = lite.TFLiteConverter.from_keras_model_file(
2302          self._keras_file, input_shapes={'invalid-input': [2, 3]})
2303    self.assertEqual(
2304        "Invalid tensor 'invalid-input' found in tensor shapes map.",
2305        str(error.exception))
2306
2307    # Passing in shape of valid input array.
2308    converter = lite.TFLiteConverter.from_keras_model_file(
2309        self._keras_file, input_shapes={'dense_input': [2, 3]})
2310    tflite_model = converter.convert()
2311    self.assertIsNotNone(tflite_model)
2312
2313    # Check input shape from converted model.
2314    interpreter = Interpreter(model_content=tflite_model)
2315    interpreter.allocate_tensors()
2316
2317    input_details = interpreter.get_input_details()
2318    self.assertLen(input_details, 1)
2319    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2320    self.assertAllEqual([2, 3], input_details[0]['shape'])
2321
2322  def testSequentialModelOutputArray(self):
2323    """Test a Sequential tf.keras model testing output arrays argument."""
2324    ops.disable_eager_execution()
2325    self._getSequentialModel()
2326
2327    # Invalid output array raises error.
2328    with self.assertRaises(ValueError) as error:
2329      lite.TFLiteConverter.from_keras_model_file(
2330          self._keras_file, output_arrays=['invalid-output'])
2331    self.assertEqual("Invalid tensors 'invalid-output' were found.",
2332                     str(error.exception))
2333
2334    # Valid output array.
2335    converter = lite.TFLiteConverter.from_keras_model_file(
2336        self._keras_file, output_arrays=['time_distributed/Reshape_1'])
2337    tflite_model = converter.convert()
2338    self.assertIsNotNone(tflite_model)
2339
2340  @parameterized.named_parameters(('_graph', context.graph_mode),
2341                                  ('_eager', context.eager_mode))
2342  def testFunctionalModel(self, test_context):
2343    """Test a Functional tf.keras model with default inputs."""
2344    with test_context():
2345      inputs = keras.layers.Input(shape=(3,), name='input')
2346      x = keras.layers.Dense(2)(inputs)
2347      output = keras.layers.Dense(3)(x)
2348
2349      model = keras.models.Model(inputs, output)
2350      model.compile(
2351          loss=keras.losses.MSE,
2352          optimizer='sgd',
2353          metrics=[keras.metrics.categorical_accuracy])
2354      x = np.random.random((1, 3))
2355      y = np.random.random((1, 3))
2356      model.train_on_batch(x, y)
2357
2358      model.predict(x)
2359      fd, self._keras_file = tempfile.mkstemp('.h5')
2360      try:
2361        keras.models.save_model(model, self._keras_file)
2362      finally:
2363        os.close(fd)
2364
2365      # Convert to TFLite model.
2366      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2367      tflite_model = converter.convert()
2368      self.assertIsNotNone(tflite_model)
2369
2370    # Check tensor details of converted model.
2371    interpreter = Interpreter(model_content=tflite_model)
2372    interpreter.allocate_tensors()
2373
2374    input_details = interpreter.get_input_details()
2375    self.assertLen(input_details, 1)
2376    self.assertEqual('input', input_details[0]['name'])
2377    self.assertEqual(np.float32, input_details[0]['dtype'])
2378    self.assertAllEqual([1, 3], input_details[0]['shape'])
2379    self.assertEqual((0., 0.), input_details[0]['quantization'])
2380
2381    output_details = interpreter.get_output_details()
2382    self.assertLen(output_details, 1)
2383    self.assertEqual(np.float32, output_details[0]['dtype'])
2384    self.assertAllEqual([1, 3], output_details[0]['shape'])
2385    self.assertEqual((0., 0.), output_details[0]['quantization'])
2386
2387    # Check inference of converted model.
2388    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2389    interpreter.set_tensor(input_details[0]['index'], input_data)
2390    interpreter.invoke()
2391    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2392
2393    keras_model = keras.models.load_model(self._keras_file)
2394    keras_result = keras_model.predict(input_data)
2395
2396    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2397
2398  def _getFunctionalModelMultipleInputs(self):
2399    a = keras.layers.Input(shape=(3,), name='input_a')
2400    b = keras.layers.Input(shape=(3,), name='input_b')
2401    dense = keras.layers.Dense(4, name='dense')
2402    c = dense(a)
2403    d = dense(b)
2404    e = keras.layers.Dropout(0.5, name='dropout')(c)
2405
2406    model = keras.models.Model([a, b], [d, e])
2407    model.compile(
2408        loss=keras.losses.MSE,
2409        optimizer='sgd',
2410        metrics=[keras.metrics.mae],
2411        loss_weights=[1., 0.5])
2412
2413    input_a_np = np.random.random((10, 3))
2414    input_b_np = np.random.random((10, 3))
2415    output_d_np = np.random.random((10, 4))
2416    output_e_np = np.random.random((10, 4))
2417    model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
2418
2419    model.predict([input_a_np, input_b_np], batch_size=5)
2420    fd, self._keras_file = tempfile.mkstemp('.h5')
2421    try:
2422      keras.models.save_model(model, self._keras_file)
2423    finally:
2424      os.close(fd)
2425
2426  def testFunctionalModelMultipleInputs(self):
2427    """Test a Functional tf.keras model with multiple inputs and outputs."""
2428    self._getFunctionalModelMultipleInputs()
2429
2430    # Convert to TFLite model.
2431    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2432    tflite_model = converter.convert()
2433    self.assertIsNotNone(tflite_model)
2434
2435    # Check values from converted model.
2436    interpreter = Interpreter(model_content=tflite_model)
2437    interpreter.allocate_tensors()
2438
2439    input_details = interpreter.get_input_details()
2440    self.assertLen(input_details, 2)
2441    self.assertEndsWith(input_details[0]['name'], 'input_a')
2442    self.assertEqual(np.float32, input_details[0]['dtype'])
2443    self.assertAllEqual([1, 3], input_details[0]['shape'])
2444    self.assertEqual((0., 0.), input_details[0]['quantization'])
2445
2446    self.assertEndsWith(input_details[1]['name'], 'input_b')
2447    self.assertEqual(np.float32, input_details[1]['dtype'])
2448    self.assertAllEqual([1, 3], input_details[1]['shape'])
2449    self.assertEqual((0., 0.), input_details[1]['quantization'])
2450
2451    output_details = interpreter.get_output_details()
2452    self.assertLen(output_details, 2)
2453    self.assertEqual(np.float32, output_details[0]['dtype'])
2454    self.assertAllEqual([1, 4], output_details[0]['shape'])
2455    self.assertEqual((0., 0.), output_details[0]['quantization'])
2456
2457    self.assertEqual(np.float32, output_details[1]['dtype'])
2458    self.assertAllEqual([1, 4], output_details[1]['shape'])
2459    self.assertEqual((0., 0.), output_details[1]['quantization'])
2460
2461  def testShapeOverriding(self):
2462    """Test a Functional tf.keras model with input shape overriding."""
2463    self._getFunctionalModelMultipleInputs()
2464
2465    # Convert to TFLite model.
2466    converter = lite.TFLiteConverter.from_keras_model_file(
2467        self._keras_file, input_shapes={
2468            'input_a': {2, 3},
2469            'input_b': {2, 3}
2470        })
2471    tflite_model = converter.convert()
2472    self.assertIsNotNone(tflite_model)
2473
2474    # Check values from converted model.
2475    interpreter = Interpreter(model_content=tflite_model)
2476    interpreter.allocate_tensors()
2477
2478    input_details = interpreter.get_input_details()
2479    self.assertLen(input_details, 2)
2480    self.assertEndsWith(input_details[0]['name'], 'input_a')
2481    self.assertEqual(np.float32, input_details[0]['dtype'])
2482    self.assertAllEqual([2, 3], input_details[0]['shape'])
2483    self.assertEqual((0., 0.), input_details[0]['quantization'])
2484
2485    self.assertEndsWith(input_details[1]['name'], 'input_b')
2486    self.assertEqual(np.float32, input_details[1]['dtype'])
2487    self.assertAllEqual([2, 3], input_details[1]['shape'])
2488    self.assertEqual((0., 0.), input_details[1]['quantization'])
2489
2490    output_details = interpreter.get_output_details()
2491    self.assertLen(output_details, 2)
2492    self.assertEqual(np.float32, output_details[0]['dtype'])
2493    self.assertAllEqual([2, 4], output_details[0]['shape'])
2494    self.assertEqual((0., 0.), output_details[0]['quantization'])
2495
2496    self.assertEqual(np.float32, output_details[1]['dtype'])
2497    self.assertAllEqual([2, 4], output_details[1]['shape'])
2498    self.assertEqual((0., 0.), output_details[1]['quantization'])
2499
2500  def testPartialShapeOverriding(self):
2501    """Test a Functional tf.keras model with partial input shape overriding."""
2502    self._getFunctionalModelMultipleInputs()
2503
2504    # Convert to TFLite model.
2505    converter = lite.TFLiteConverter.from_keras_model_file(
2506        self._keras_file, input_shapes={'input_a': {2, 3}})
2507    tflite_model = converter.convert()
2508    self.assertIsNotNone(tflite_model)
2509
2510    # Check values from converted model.
2511    interpreter = Interpreter(model_content=tflite_model)
2512    interpreter.allocate_tensors()
2513
2514    input_details = interpreter.get_input_details()
2515    self.assertLen(input_details, 2)
2516    self.assertEndsWith(input_details[0]['name'], 'input_a')
2517    self.assertEqual(np.float32, input_details[0]['dtype'])
2518    self.assertAllEqual([2, 3], input_details[0]['shape'])
2519    self.assertEqual((0., 0.), input_details[0]['quantization'])
2520
2521    self.assertEndsWith(input_details[1]['name'], 'input_b')
2522    self.assertEqual(np.float32, input_details[1]['dtype'])
2523    self.assertAllEqual([1, 3], input_details[1]['shape'])
2524    self.assertEqual((0., 0.), input_details[1]['quantization'])
2525
2526    output_details = interpreter.get_output_details()
2527    self.assertLen(output_details, 2)
2528    self.assertEqual(np.float32, output_details[0]['dtype'])
2529    self.assertAllEqual([1, 4], output_details[0]['shape'])
2530    self.assertEqual((0., 0.), output_details[0]['quantization'])
2531
2532    self.assertEqual(np.float32, output_details[1]['dtype'])
2533    self.assertAllEqual([2, 4], output_details[1]['shape'])
2534    self.assertEqual((0., 0.), output_details[1]['quantization'])
2535
2536  def testWrongShapeOverriding(self):
2537    """Test a Functional tf.keras model with wrong input shape overriding."""
2538    self._getFunctionalModelMultipleInputs()
2539
2540    # Convert to TFLite model.
2541    with self.assertRaises(ValueError):
2542      lite.TFLiteConverter.from_keras_model_file(
2543          self._keras_file, input_shapes={'wrong_input': {2, 3}})
2544
2545  def testFunctionalSequentialModel(self):
2546    """Test a Functional tf.keras model containing a Sequential model."""
2547    model = keras.models.Sequential()
2548    model.add(keras.layers.Dense(2, input_shape=(3,)))
2549    model.add(keras.layers.RepeatVector(3))
2550    model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
2551    model = keras.models.Model(model.input, model.output)
2552
2553    model.compile(
2554        loss=keras.losses.MSE,
2555        optimizer='sgd',
2556        metrics=[keras.metrics.categorical_accuracy],
2557        sample_weight_mode='temporal')
2558    x = np.random.random((1, 3))
2559    y = np.random.random((1, 3, 3))
2560    model.train_on_batch(x, y)
2561    model.predict(x)
2562
2563    model.predict(x)
2564    fd, self._keras_file = tempfile.mkstemp('.h5')
2565    try:
2566      keras.models.save_model(model, self._keras_file)
2567    finally:
2568      os.close(fd)
2569
2570    # Convert to TFLite model.
2571    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2572    tflite_model = converter.convert()
2573    self.assertIsNotNone(tflite_model)
2574
2575    # Check tensor details of converted model.
2576    interpreter = Interpreter(model_content=tflite_model)
2577    interpreter.allocate_tensors()
2578
2579    input_details = interpreter.get_input_details()
2580    self.assertLen(input_details, 1)
2581    self.assertEndsWith(input_details[0]['name'], 'dense_input')
2582    self.assertEqual(np.float32, input_details[0]['dtype'])
2583    self.assertAllEqual([1, 3], input_details[0]['shape'])
2584    self.assertEqual((0., 0.), input_details[0]['quantization'])
2585
2586    output_details = interpreter.get_output_details()
2587    self.assertLen(output_details, 1)
2588    self.assertEqual(np.float32, output_details[0]['dtype'])
2589    self.assertAllEqual([1, 3, 3], output_details[0]['shape'])
2590    self.assertEqual((0., 0.), output_details[0]['quantization'])
2591
2592    # Check inference of converted model.
2593    input_data = np.array([[1, 2, 3]], dtype=np.float32)
2594    interpreter.set_tensor(input_details[0]['index'], input_data)
2595    interpreter.invoke()
2596    tflite_result = interpreter.get_tensor(output_details[0]['index'])
2597
2598    keras_model = keras.models.load_model(self._keras_file)
2599    keras_result = keras_model.predict(input_data)
2600
2601    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
2602
2603  def testSequentialModelTocoConverter(self):
2604    """Test a Sequential tf.keras model with deprecated TocoConverter."""
2605    self._getSequentialModel()
2606
2607    converter = lite.TocoConverter.from_keras_model_file(self._keras_file)
2608    tflite_model = converter.convert()
2609    self.assertIsNotNone(tflite_model)
2610
2611    # Ensure the model is able to load.
2612    interpreter = Interpreter(model_content=tflite_model)
2613    interpreter.allocate_tensors()
2614
2615  @parameterized.named_parameters(('_graph', context.graph_mode),
2616                                  ('_eager', context.eager_mode))
2617  def testGraphDebugInfo(self, test_context):
2618    """Test a Sequential tf.keras model has debug info captured."""
2619    with test_context():
2620      self._getSequentialModel()
2621      converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2622      converter.convert()
2623      self.assertValidDebugInfo(converter._debug_info)
2624
2625  def testSparsifyModel(self):
2626    self._getSequentialModel()
2627
2628    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2629    converter.optimizations = {lite.Optimize.EXPERIMENTAL_SPARSITY}
2630    tflite_model = converter.convert()
2631    self.assertTrue(tflite_model)
2632
2633  def testSparsifyQuantizedModel(self):
2634    self._getSequentialModel()
2635
2636    converter = lite.TFLiteConverter.from_keras_model_file(self._keras_file)
2637    converter.optimizations = {
2638        lite.Optimize.DEFAULT, lite.Optimize.EXPERIMENTAL_SPARSITY
2639    }
2640    tflite_model = converter.convert()
2641    self.assertIsNotNone(tflite_model)
2642
2643
2644class GrapplerTest(TestModels, parameterized.TestCase):
2645
2646  def testConstantFolding(self):
2647    ops.disable_eager_execution()
2648    # Constant folding handles the tf.broadcast_to operation which was not
2649    # supported by the TFLite at the time this test was added.
2650    with ops.Graph().as_default():
2651      in_tensor = array_ops.placeholder(shape=[3, 3], dtype=dtypes.float32)
2652      y_const = constant_op.constant([1., 2., 3.])
2653      y_broadcast = gen_array_ops.broadcast_to(y_const, [3, 3])
2654      out_tensor = math_ops.matmul(in_tensor, y_broadcast, name='output')
2655      sess = session.Session()
2656
2657    # Convert model.
2658    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2659                                                  [out_tensor])
2660    tflite_model = converter.convert()
2661
2662    # Check values from converted model.
2663    interpreter = Interpreter(model_content=tflite_model)
2664    interpreter.allocate_tensors()
2665
2666    input_details = interpreter.get_input_details()
2667    self.assertLen(input_details, 1)
2668    self.assertEqual('Placeholder', input_details[0]['name'])
2669    self.assertEqual(np.float32, input_details[0]['dtype'])
2670    self.assertAllEqual([3, 3], input_details[0]['shape'])
2671
2672    output_details = interpreter.get_output_details()
2673    self.assertLen(output_details, 1)
2674    self.assertEqual('output', output_details[0]['name'])
2675    self.assertEqual(np.float32, output_details[0]['dtype'])
2676    self.assertAllEqual([3, 3], output_details[0]['shape'])
2677
2678  @parameterized.named_parameters(
2679      ('EnableMlirConverter', True),  # enable mlir
2680      ('DisableMlirConverter', False))  # disable mlir
2681  def testInputNodeIsNotFolded(self, enable_mlir_converter):
2682    ops.disable_eager_execution()
2683    # Constant folding handles the tf.broadcast_to operation which was not
2684    # supported by the TFLite at the time this test was added.
2685    with ops.Graph().as_default():
2686      in_tensor = array_ops.placeholder(shape=[3], dtype=dtypes.float32)
2687      y_const = constant_op.constant([1., 2., 3.])
2688      y_add = y_const + y_const
2689      out_tensor = in_tensor * y_add
2690      sess = session.Session()
2691
2692    # Convert model.
2693    converter = lite.TFLiteConverter.from_session(sess, [in_tensor, y_const],
2694                                                  [out_tensor])
2695    converter.experimental_new_converter = enable_mlir_converter
2696    tflite_model = converter.convert()
2697
2698    # Check values from converted model.
2699    interpreter = Interpreter(model_content=tflite_model)
2700    interpreter.allocate_tensors()
2701
2702    input_details = interpreter.get_input_details()
2703    self.assertLen(input_details, 2)
2704    self.assertEqual('Placeholder', input_details[0]['name'])
2705    self.assertEqual('Const', input_details[1]['name'])
2706
2707  def testGrapplerConstFolding(self):
2708    # Constant folding converts the following add operation to tf.broadcast_to
2709    # operation which was not supported by the TFLite at the time this test was
2710    # added.
2711    @def_function.function
2712    def plus_placeholder(x, placeholder):
2713      return x + placeholder
2714
2715    with ops.Graph().as_default():
2716      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
2717      out_tensor = plus_placeholder(
2718          array_ops.zeros([2, 2, 2]),
2719          array_ops.reshape(in_tensor, shape=[2, 2]))
2720      sess = session.Session()
2721
2722    # Convert model.
2723    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2724                                                  [out_tensor])
2725    tflite_model = converter.convert()
2726
2727    # Check values from converted model.
2728    interpreter = Interpreter(model_content=tflite_model)
2729    interpreter.allocate_tensors()
2730
2731    input_details = interpreter.get_input_details()
2732    self.assertLen(input_details, 1)
2733    self.assertEqual('Placeholder', input_details[0]['name'])
2734
2735
2736class ImportOpsUtilTest(LiteTest):
2737
2738  def testGetPotentiallySupportedOps(self):
2739    self.assertIsNotNone(lite.get_potentially_supported_ops())
2740
2741
2742class DefaultConverterAttrsTest(LiteTest):
2743
2744  def testAttrs(self):
2745    with ops.Graph().as_default():
2746      in_tensor = array_ops.placeholder(shape=[2, 2], dtype=dtypes.float32)
2747      out_tensor = in_tensor + in_tensor
2748      sess = session.Session()
2749
2750    # Convert model.
2751    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
2752                                                  [out_tensor])
2753
2754    # Assert output format.
2755    self.assertEqual(converter.output_format, lite_constants.TFLITE)
2756
2757    # Assert the default inference type is float.
2758    self.assertEqual(converter.inference_type, dtypes.float32)
2759
2760    # Assert the default inference type overrides are None.
2761    self.assertIsNone(converter.inference_input_type)
2762    self.assertIsNone(converter.inference_output_type)
2763
2764    # Assert the default quantization options are not set.
2765    self.assertEqual(converter.quantized_input_stats, {})
2766    self.assertIsNone(converter.default_ranges_stats)
2767    self.assertFalse(converter.reorder_across_fake_quant)
2768    self.assertFalse(converter.change_concat_input_ranges)
2769
2770    # Assert dropping control dependency is enabled by default.
2771    self.assertIsNotNone(converter.drop_control_dependency)
2772
2773    # Assert dumping extra information is disabled by default.
2774    self.assertIsNone(converter.dump_graphviz_dir)
2775    self.assertFalse(converter.dump_graphviz_video)
2776    self.assertIsNone(converter.conversion_summary_dir)
2777
2778
2779class ControlFlowV1OpsTest(LiteTest):
2780
2781  def testConverterErrorOnControlFlowV1Ops(self):
2782    graph_def_file = resource_loader.get_path_to_datafile(
2783        'testdata/control_flow_v1.pbtxt')
2784    input_arrays = ['a', 'b', 'c', 'd']
2785    output_arrays = ['Merge']
2786
2787    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
2788                                                       input_arrays,
2789                                                       output_arrays)
2790    with self.assertRaises(ConverterError) as error:
2791      converter.convert()
2792    self.assertIn(
2793        'Failed to functionalize Control Flow V1 ops. Consider using Control '
2794        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
2795        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
2796
2797
2798if __name__ == '__main__':
2799  test.main()
2800