1# Lint as: python2, python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for lite.py functionality related to TensorFlow 2.0."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from absl.testing import parameterized
25import numpy as np
26from six.moves import range
27from six.moves import zip
28import tensorflow as tf
29
30from tensorflow.lite.kernels.hashtable import pywrap_hashtable_ops as hashtable_ops_registerer
31from tensorflow.lite.python import convert
32from tensorflow.lite.python import lite
33from tensorflow.lite.python import lite_v2_test_util
34from tensorflow.lite.python.convert import mlir_quantize
35from tensorflow.lite.python.interpreter import Interpreter
36from tensorflow.lite.python.interpreter import InterpreterWithCustomOps
37from tensorflow.lite.toco import types_pb2 as _types_pb2
38from tensorflow.python.framework import dtypes
39from tensorflow.python.framework import ops
40from tensorflow.python.framework import test_util
41from tensorflow.python.lib.io import file_io
42from tensorflow.python.platform import resource_loader
43from tensorflow.python.platform import test
44from tensorflow.python.saved_model import save_options
45from tensorflow.python.saved_model import saved_model
46from tensorflow.python.saved_model.loader_impl import parse_saved_model
47from tensorflow.python.saved_model.save import save
48from tensorflow.python.training.tracking import tracking
49
50
51class FromConcreteFunctionTest(lite_v2_test_util.ModelTest):
52
53  @test_util.run_v2_only
54  def testTypeInvalid(self):
55    root = self._getSimpleVariableModel()
56    with self.assertRaises(ValueError) as error:
57      _ = lite.TFLiteConverterV2.from_concrete_functions([root.f])
58    self.assertIn('call get_concrete_function', str(error.exception))
59
60  @parameterized.named_parameters(
61      ('EnableMlirConverter', True),  # enable mlir
62      ('DisableMlirConverter', False))  # disable mlir
63  @test_util.run_v2_only
64  def testFloat(self, enable_mlir_converter):
65    root = self._getSimpleVariableModel()
66    input_data = tf.constant(1., shape=[1])
67    concrete_func = root.f.get_concrete_function(input_data)
68
69    # Convert model.
70    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
71    converter.experimental_new_converter = enable_mlir_converter
72    tflite_model = converter.convert()
73
74    # Check output value from converted model.
75    expected_value = root.f(input_data)
76    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
77    self.assertEqual(expected_value.numpy(), actual_value)
78
79  @parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
80                                  ('_UINT8InputOutput', dtypes.uint8),
81                                  ('_INT16InputOutput', dtypes.int16))
82  @test_util.run_v2_only
83  def testInvalidFloat(self, inference_input_output_type):
84    root = self._getSimpleVariableModel()
85    input_data = tf.constant(1., shape=[1])
86    concrete_func = root.f.get_concrete_function(input_data)
87
88    # Convert model.
89    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
90    with self.assertRaises(ValueError) as error:
91      converter.inference_input_type = inference_input_output_type
92      converter.inference_output_type = inference_input_output_type
93      converter.convert()
94    self.assertEqual(
95        'The inference_input_type and inference_output_type '
96        'must be tf.float32.', str(error.exception))
97
98  @test_util.run_v2_only
99  def testScalarInput(self):
100    root = self._getSimpleVariableModel()
101    input_data = tf.constant(1., shape=[])
102    concrete_func = root.f.get_concrete_function(input_data)
103
104    # Convert model.
105    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
106    tflite_model = converter.convert()
107
108    # Check values from converted model.
109    expected_value = root.f(input_data)
110    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
111    self.assertEqual(expected_value.numpy(), actual_value)
112
113  @test_util.run_v2_only
114  def testMultiFunctionModel(self):
115    """Convert a single model in a multi-functional model."""
116    root = self._getMultiFunctionModel()
117    input_data = tf.constant(1., shape=[1])
118    concrete_func = root.add.get_concrete_function(input_data)
119
120    # Convert model and ensure model is not None.
121    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
122    tflite_model = converter.convert()
123
124    # Check values from converted model.
125    expected_value = root.add(input_data)
126    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
127    self.assertEqual(expected_value.numpy(), actual_value)
128
129  @test_util.run_v2_only
130  def testConvertMultipleFunctions(self):
131    """Convert multiple functions in a multi-functional model."""
132    root = self._getMultiFunctionModel()
133    input_data = tf.constant(1., shape=[1])
134    add_func = root.add.get_concrete_function(input_data)
135    sub_func = root.sub.get_concrete_function(input_data)
136
137    # Try converting multiple functions.
138    converter = lite.TFLiteConverterV2.from_concrete_functions(
139        [add_func, sub_func])
140    with self.assertRaises(ValueError) as error:
141      _ = converter.convert()
142    self.assertIn('can only convert a single ConcreteFunction',
143                  str(error.exception))
144
145  def _getIntegerQuantizeModel(self):
146    np.random.seed(0)
147
148    root = tracking.AutoTrackable()
149
150    @tf.function(
151        input_signature=[tf.TensorSpec(shape=[1, 5, 5, 3], dtype=tf.float32)])
152    def func(inp):
153      conv = tf.nn.conv2d(
154          inp, tf.ones([3, 3, 3, 16]), strides=[1, 1, 1, 1], padding='SAME')
155      output = tf.nn.relu(conv, name='output')
156      return output
157
158    def calibration_gen():
159      for _ in range(5):
160        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
161
162    root.f = func
163    to_save = root.f.get_concrete_function()
164    return (to_save, calibration_gen)
165
166  @parameterized.named_parameters(
167      ('EnableMlirQuantizer', True),  # enable mlir quantizer
168      ('DisableMlirQuantizer', False))  # disable mlir quantizer
169  def testPostTrainingCalibrateAndQuantize(self, mlir_quantizer):
170    func, calibration_gen = self._getIntegerQuantizeModel()
171
172    # Convert float model.
173    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
174    float_tflite_model = float_converter.convert()
175    self.assertIsNotNone(float_tflite_model)
176
177    # Convert quantized model.
178    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
179    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
180    quantized_converter.representative_dataset = calibration_gen
181    quantized_converter.experimental_new_quantizer = mlir_quantizer
182    quantized_tflite_model = quantized_converter.convert()
183    self.assertIsNotNone(quantized_tflite_model)
184
185    # The default input and output types should be float.
186    interpreter = Interpreter(model_content=quantized_tflite_model)
187    interpreter.allocate_tensors()
188    input_details = interpreter.get_input_details()
189    self.assertLen(input_details, 1)
190    self.assertEqual(np.float32, input_details[0]['dtype'])
191    output_details = interpreter.get_output_details()
192    self.assertLen(output_details, 1)
193    self.assertEqual(np.float32, output_details[0]['dtype'])
194
195    # Ensure that the quantized weights tflite model is smaller.
196    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
197
198  @parameterized.named_parameters(('_INT8InputOutput', dtypes.int8),
199                                  ('_UINT8InputOutput', dtypes.uint8),
200                                  ('_INT16InputOutput', dtypes.int16))
201  @test_util.run_v2_only
202  def testInvalidPostTrainingDynamicRangeQuantization(
203      self, inference_input_output_type):
204    func, _ = self._getIntegerQuantizeModel()
205
206    # Convert float model.
207    converter = lite.TFLiteConverterV2.from_concrete_functions([func])
208    tflite_model = converter.convert()
209    self.assertTrue(tflite_model)
210
211    # Convert quantized model.
212    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
213    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
214    with self.assertRaises(ValueError) as error:
215      quantized_converter.inference_input_type = inference_input_output_type
216      quantized_converter.inference_output_type = inference_input_output_type
217      quantized_converter.convert()
218    self.assertEqual(
219        'The inference_input_type and inference_output_type '
220        'must be tf.float32.', str(error.exception))
221
222  @parameterized.named_parameters(
223      ('_Default', False, False, dtypes.float32),
224      ('_INT8InputOutput', False, False, dtypes.int8),
225      ('_UINT8InputOutput', False, False, dtypes.uint8),
226      ('_INT16Quantize', False, True, dtypes.float32),
227      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
228      ('_IntOnly', True, False, dtypes.float32),
229      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
230      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
231      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
232      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
233  def testIntegerQuantization(self, is_int_only, is_int16_quantize,
234                              inference_input_output_type):
235    func, calibration_gen = self._getIntegerQuantizeModel()
236
237    # Convert float model.
238    converter = lite.TFLiteConverterV2.from_concrete_functions([func])
239    tflite_model = converter.convert()
240    self.assertTrue(tflite_model)
241
242    # Convert quantized model.
243    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
244    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
245    quantized_converter.representative_dataset = calibration_gen
246    if is_int_only:
247      if is_int16_quantize:
248        quantized_converter.target_spec.supported_ops = [
249            lite.OpsSet.\
250            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8
251        ]
252      else:
253        quantized_converter.target_spec.supported_ops = [
254            lite.OpsSet.TFLITE_BUILTINS_INT8
255        ]
256    else:
257      if is_int16_quantize:
258        quantized_converter.target_spec.supported_ops = [
259            lite.OpsSet.\
260            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
261            lite.OpsSet.TFLITE_BUILTINS
262        ]
263    quantized_converter.inference_input_type = inference_input_output_type
264    quantized_converter.inference_output_type = inference_input_output_type
265    quantized_tflite_model = quantized_converter.convert()
266    self.assertIsNotNone(quantized_tflite_model)
267
268    interpreter = Interpreter(model_content=quantized_tflite_model)
269    interpreter.allocate_tensors()
270    input_details = interpreter.get_input_details()
271    self.assertLen(input_details, 1)
272    self.assertEqual(inference_input_output_type.as_numpy_dtype,
273                     input_details[0]['dtype'])
274    output_details = interpreter.get_output_details()
275    self.assertLen(output_details, 1)
276    self.assertEqual(inference_input_output_type.as_numpy_dtype,
277                     output_details[0]['dtype'])
278
279    # Ensure that the quantized tflite model is smaller.
280    self.assertLess(len(quantized_tflite_model), len(tflite_model))
281
282  @parameterized.named_parameters(
283      ('_INT16Quantize_INT8InputOutput', True, dtypes.int8))
284  def testInvalidIntegerQuantization(self, is_int16_quantize,
285                                     inference_input_output_type):
286    func, calibration_gen = self._getIntegerQuantizeModel()
287
288    # Convert quantized model.
289    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
290    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
291    quantized_converter.representative_dataset = calibration_gen
292    if is_int16_quantize:
293      quantized_converter.target_spec.supported_ops = [
294          lite.OpsSet.\
295          EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
296          lite.OpsSet.TFLITE_BUILTINS
297      ]
298    with self.assertRaises(ValueError) as error:
299      quantized_converter.inference_input_type = dtypes.int8
300      quantized_converter.inference_output_type = dtypes.int8
301      quantized_converter.convert()
302    self.assertEqual(
303        'The inference_input_type and inference_output_type '
304        "must be in ['tf.float32', 'tf.int16'].", str(error.exception))
305
306  def testCalibrateAndQuantizeBuiltinInt16(self):
307    func, calibration_gen = self._getIntegerQuantizeModel()
308
309    # Convert float model.
310    float_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
311    float_tflite_model = float_converter.convert()
312    self.assertIsNotNone(float_tflite_model)
313
314    converter = lite.TFLiteConverterV2.from_concrete_functions([func])
315    # TODO(b/156309549): We should add INT16 to the builtin types.
316    converter.optimizations = [lite.Optimize.DEFAULT]
317    converter.target_spec.supported_ops = [lite.OpsSet.TFLITE_BUILTINS_INT8]
318    converter.representative_dataset = calibration_gen
319    converter._experimental_calibrate_only = True
320    calibrated_tflite = converter.convert()
321    quantized_tflite_model = mlir_quantize(
322        calibrated_tflite, inference_type=_types_pb2.QUANTIZED_INT16)
323
324    self.assertIsNotNone(quantized_tflite_model)
325
326    # The default input and output types should be float.
327    interpreter = Interpreter(model_content=quantized_tflite_model)
328    interpreter.allocate_tensors()
329    input_details = interpreter.get_input_details()
330    self.assertLen(input_details, 1)
331    self.assertEqual(np.float32, input_details[0]['dtype'])
332    output_details = interpreter.get_output_details()
333    self.assertLen(output_details, 1)
334    self.assertEqual(np.float32, output_details[0]['dtype'])
335
336    # Ensure that the quantized weights tflite model is smaller.
337    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
338
339  def _getTrainingTimeQuantizedModel(self):
340
341    class QLinear(tf.keras.layers.Layer):
342
343      def __init__(self, units=3, **kwargs):
344        super(QLinear, self).__init__(**kwargs)
345        self.units = units
346
347      def build(self, input_shape):
348        self.w = self.add_weight(
349            'weight',
350            shape=(input_shape[-1], self.units),
351            initializer='random_normal',
352            trainable=True)
353        self.min_var = self.add_weight(
354            'min',
355            initializer=tf.keras.initializers.Constant(-6.0),
356            trainable=False)
357        self.max_var = self.add_weight(
358            'max',
359            initializer=tf.keras.initializers.Constant(6.0),
360            trainable=False)
361
362      def call(self, inputs):
363        x = tf.quantization.fake_quant_with_min_max_vars(
364            inputs, self.min_var, self.max_var)
365
366        w_fq = tf.quantization.fake_quant_with_min_max_vars(
367            self.w, self.min_var, self.max_var)
368        x = tf.matmul(x, w_fq)
369
370        x = tf.quantization.fake_quant_with_min_max_vars(
371            x, self.min_var, self.max_var)
372
373        return x
374
375    return tf.keras.Sequential(QLinear(3, input_shape=(2,)))
376
377  @parameterized.named_parameters(
378      ('_DefaultFLOAT32InputOutput', dtypes.float32),
379      ('_INT8InputOutput', dtypes.int8), ('_UINT8InputOutput', dtypes.uint8))
380  @test_util.run_v2_only
381  def testTrainingTimeQuantization(self, inference_input_output_type):
382    model = self._getTrainingTimeQuantizedModel()
383
384    float_converter = lite.TFLiteConverterV2.from_keras_model(model)
385    float_tflite_model = float_converter.convert()
386    self.assertIsNotNone(float_tflite_model)
387
388    quantized_converter = lite.TFLiteConverterV2.from_keras_model(model)
389    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
390    quantized_converter.inference_input_type = inference_input_output_type
391    quantized_converter.inference_output_type = inference_input_output_type
392    quantized_tflite_model = quantized_converter.convert()
393    self.assertIsNotNone(quantized_tflite_model)
394
395    interpreter = Interpreter(model_content=quantized_tflite_model)
396    interpreter.allocate_tensors()
397    input_details = interpreter.get_input_details()
398    self.assertLen(input_details, 1)
399    self.assertEqual(inference_input_output_type.as_numpy_dtype,
400                     input_details[0]['dtype'])
401    output_details = interpreter.get_output_details()
402    self.assertLen(output_details, 1)
403    self.assertEqual(inference_input_output_type.as_numpy_dtype,
404                     output_details[0]['dtype'])
405
406    # Ensure that the quantized tflite model is smaller.
407    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
408
409  @test_util.run_v2_only
410  def testNewQuantizer(self):
411    """Test the model quantized by the new converter."""
412    func, calibration_gen = self._getIntegerQuantizeModel()
413
414    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
415    quantized_converter.target_spec.supported_ops = [
416        lite.OpsSet.TFLITE_BUILTINS_INT8
417    ]
418    quantized_converter.representative_dataset = calibration_gen
419
420    # default quantizer
421    quantized_converter.experimental_new_quantizer = False
422    old_tflite = quantized_converter.convert()
423
424    # new quantizer
425    quantized_converter.experimental_new_quantizer = True
426    new_tflite = quantized_converter.convert()
427
428    for _ in range(5):
429      input_data = tf.constant(
430          np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
431      old_value = self._evaluateTFLiteModel(old_tflite, [input_data])
432      new_value = self._evaluateTFLiteModel(new_tflite, [input_data])
433      self.assertAllClose(old_value, new_value, atol=1e-01)
434
435  @parameterized.named_parameters(
436      ('EnableMlirConverter', True),  # enable mlir
437      ('DisableMlirConverter', False))  # disable mlir
438  @test_util.run_v2_only
439  def testEmbeddings(self, enable_mlir_converter):
440    """Test model with embeddings."""
441    input_data = tf.constant(
442        np.array(np.random.random_sample((20)), dtype=np.int32))
443
444    class EmbeddingModel(tf.keras.Model):
445
446      def __init__(self):
447        super(EmbeddingModel, self).__init__()
448        self.shared_weights = self.add_weight(
449            'weights',
450            shape=(2000, 300),
451            dtype=tf.float32,
452            initializer=tf.random_normal_initializer(
453                mean=0.0, stddev=300**(-0.5)))
454
455      @tf.function(input_signature=[tf.TensorSpec(shape=(20), dtype=tf.int32)])
456      def func(self, x):
457        return tf.gather(self.shared_weights, x)
458
459    # Building the model.
460    root = EmbeddingModel()
461    concrete_func = root.func.get_concrete_function()
462
463    # Convert model.
464    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
465    converter.experimental_new_converter = enable_mlir_converter
466    tflite_model = converter.convert()
467
468    # Check values from converted model.
469    expected_value = root.func(input_data)
470    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
471    self.assertAllClose(expected_value.numpy(), actual_value[0], atol=1e-05)
472
473  @test_util.run_v2_only
474  def testGraphDebugInfo(self):
475    """Test a concrete function has debug info captured."""
476    root = tracking.AutoTrackable()
477    root.v1 = tf.Variable(3.)
478    root.f = tf.function(lambda x: root.v1 * x)
479    input_data = tf.constant(1., shape=[1])
480    concrete_func = root.f.get_concrete_function(input_data)
481
482    # Convert model.
483    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
484    converter.convert()
485    self._assertValidDebugInfo(converter._debug_info)
486
487  def _getIntegerQuantizationModelWithFlexOp(self):
488    np.random.seed(0)
489
490    root = tracking.AutoTrackable()
491
492    @tf.function(input_signature=[
493        tf.TensorSpec(shape=[3, 3, 3, 3, 3], dtype=tf.float32)
494    ])
495    def func(inp):
496      tanh = tf.math.tanh(inp)
497      # Flex delegate will merge the consecutive conv3d and erf ops into one
498      # Delegate node.
499      conv3d = tf.nn.conv3d(
500          tanh,
501          tf.ones([3, 3, 3, 3, 3]),
502          strides=[1, 1, 1, 1, 1],
503          padding='SAME')
504      erf = tf.math.erf(conv3d)
505      output = tf.math.tanh(erf)
506      return output
507
508    def calibration_gen():
509      for _ in range(5):
510        yield [
511            np.random.uniform(-1, 1, size=(3, 3, 3, 3, 3)).astype(np.float32)
512        ]
513
514    root.f = func
515    return (root.f.get_concrete_function(), calibration_gen)
516
517  @parameterized.named_parameters(
518      ('_Default', False, False, dtypes.float32),
519      ('_INT8InputOutput', False, False, dtypes.int8),
520      ('_UINT8InputOutput', False, False, dtypes.uint8),
521      ('_INT16Quantize', False, True, dtypes.float32),
522      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
523      ('_IntOnly', True, False, dtypes.float32),
524      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
525      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
526      ('_IntOnly_INT16Quantize', True, True, dtypes.float32),
527      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16))
528  @test_util.run_v2_only
529  def testIntegerQuantizationWithFlexOp(self, is_int_only, is_int16_quantize,
530                                        inference_input_output_type):
531    func, calibration_gen = self._getIntegerQuantizationModelWithFlexOp()
532
533    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
534        [func])
535    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
536    quantized_converter.representative_dataset = calibration_gen
537    if is_int_only:
538      if is_int16_quantize:
539        quantized_converter.target_spec.supported_ops = [
540            lite.OpsSet.\
541            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
542            lite.OpsSet.SELECT_TF_OPS
543        ]
544      else:
545        quantized_converter.target_spec.supported_ops = [
546            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.SELECT_TF_OPS
547        ]
548    else:
549      if is_int16_quantize:
550        quantized_converter.target_spec.supported_ops = [
551            lite.OpsSet.\
552            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
553            lite.OpsSet.TFLITE_BUILTINS,
554            lite.OpsSet.SELECT_TF_OPS
555        ]
556      else:
557        quantized_converter.target_spec.supported_ops = [
558            lite.OpsSet.TFLITE_BUILTINS, lite.OpsSet.SELECT_TF_OPS
559        ]
560
561    quantized_converter.inference_input_type = inference_input_output_type
562    quantized_converter.inference_output_type = inference_input_output_type
563    quantized_tflite_model = quantized_converter.convert()
564    self.assertIsNotNone(quantized_tflite_model)
565
566    interpreter = Interpreter(model_content=quantized_tflite_model)
567    interpreter.allocate_tensors()
568    input_details = interpreter.get_input_details()
569    self.assertLen(input_details, 1)
570    self.assertEqual(inference_input_output_type.as_numpy_dtype,
571                     input_details[0]['dtype'])
572    output_details = interpreter.get_output_details()
573    self.assertLen(output_details, 1)
574    self.assertEqual(inference_input_output_type.as_numpy_dtype,
575                     output_details[0]['dtype'])
576
577  def _getIntegerQuantizationModelWithUnsupportedOps(self):
578    np.random.seed(0)
579
580    root = tracking.AutoTrackable()
581
582    @tf.function(input_signature=[
583        tf.TensorSpec(shape=[3], dtype=tf.float32),
584        tf.TensorSpec(shape=[3], dtype=tf.float32)
585    ])
586    def func(a, b):
587      # ceil kernel does not support int8 nor int16 types neither.
588      left = tf.math.ceil(a)
589      right = tf.nn.tanh(b)
590      add = tf.math.add(left, right)
591      # ceil kernel does not support int8 nor int16 types neither.
592      output = tf.math.ceil(add)
593      return (output, right)
594
595    def calibration_gen():
596      for _ in range(5):
597        yield [
598            np.random.uniform(-1, 1, size=(3)).astype(np.float32),
599            np.random.uniform(-1, 1, size=(3)).astype(np.float32)
600        ]
601
602    root.f = func
603    return (root.f.get_concrete_function(), calibration_gen)
604
605  @parameterized.named_parameters(
606      ('_INT8InputOutput', False, False, dtypes.int8),
607      ('_UINT8InputOutput', False, False, dtypes.uint8),
608      ('_INT16Quantize_INT16InputOutput', False, True, dtypes.int16),
609      ('_IntOnly_INT8InputOutput', True, False, dtypes.int8),
610      ('_IntOnly_UINT8InputOutput', True, False, dtypes.uint8),
611      ('_IntOnly_INT16Quantize_INT16InputOutput', True, True, dtypes.int16),
612      ('_IntOnly_INT8InputOutputMlirQuant', True, False, dtypes.int8, True),
613      ('_IntOnly_UINT8InputOutputMlirQuant', True, False, dtypes.uint8, True))
614  @test_util.run_v2_only
615  def testIntegerQuantizationWithUnsupportedOps(self,
616                                                is_int_only,
617                                                is_int16_quantize,
618                                                inference_input_output_type,
619                                                enable_mlir_quantizer=False):
620    func, calib_gen = self._getIntegerQuantizationModelWithUnsupportedOps()
621
622    quantized_converter = tf.lite.TFLiteConverter.from_concrete_functions(
623        [func])
624    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
625    quantized_converter.representative_dataset = calib_gen
626    if is_int_only:
627      if is_int16_quantize:
628        quantized_converter.target_spec.supported_ops = [
629            lite.OpsSet.\
630            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
631            lite.OpsSet.TFLITE_BUILTINS
632        ]
633      else:
634        quantized_converter.target_spec.supported_ops = [
635            lite.OpsSet.TFLITE_BUILTINS_INT8, lite.OpsSet.TFLITE_BUILTINS
636        ]
637    else:
638      if is_int16_quantize:
639        quantized_converter.target_spec.supported_ops = [
640            lite.OpsSet.\
641            EXPERIMENTAL_TFLITE_BUILTINS_ACTIVATIONS_INT16_WEIGHTS_INT8,
642            lite.OpsSet.TFLITE_BUILTINS
643        ]
644      else:
645        quantized_converter.target_spec.supported_ops = [
646            lite.OpsSet.TFLITE_BUILTINS
647        ]
648
649    quantized_converter.inference_input_type = inference_input_output_type
650    quantized_converter.inference_output_type = inference_input_output_type
651    quantized_converter.experimental_new_quantizer = enable_mlir_quantizer
652    quantized_tflite_model = quantized_converter.convert()
653    self.assertIsNotNone(quantized_tflite_model)
654
655    expected_dtype = inference_input_output_type.as_numpy_dtype
656    # Allow float32 for fallback on non-quantizable op.
657    expected_ceil_dtype = (
658        expected_dtype if enable_mlir_quantizer else dtypes.float32)
659
660    interpreter = Interpreter(model_content=quantized_tflite_model)
661    interpreter.allocate_tensors()
662    input_details = interpreter.get_input_details()
663    self.assertLen(input_details, 2)
664    self.assertEqual(input_details[0]['dtype'], expected_ceil_dtype)
665    self.assertEqual(input_details[1]['dtype'], expected_dtype)
666    output_details = interpreter.get_output_details()
667    self.assertLen(output_details, 2)
668    self.assertEqual(output_details[0]['dtype'], expected_ceil_dtype)
669    self.assertEqual(output_details[1]['dtype'], expected_dtype)
670
671  @test_util.run_v2_only
672  def testNewQuantizerNumericVerificationDebugMode(self):
673    """Test the model quantized by the new converter with numeric verify ops."""
674    func, calibration_gen = self._getIntegerQuantizeModel()
675
676    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions([func])
677    quantized_converter.target_spec.supported_ops = [
678        lite.OpsSet.TFLITE_BUILTINS_INT8
679    ]
680    quantized_converter.representative_dataset = calibration_gen
681
682    # Create a TFLite model with new quantizer.
683    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
684    quantized_converter.experimental_new_quantizer = True
685    production_tflite = quantized_converter.convert()
686    # Create a TFLite model with new quantizer and numeric verify ops.
687    quantized_converter._experimental_calibrate_only = True
688    calibrated = quantized_converter.convert()
689    debug_mode_tflite = mlir_quantize(calibrated, enable_numeric_verify=True)
690
691    # Check if adding debug mode should output a different flatbuffer.
692    self.assertNotEqual(production_tflite, debug_mode_tflite)
693
694    # Check if newly added ops are numeric verify ops.
695    input_data = tf.constant(
696        np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32))
697
698    def examine_tflite_model(tflite_content, input_data):
699      interpreter = Interpreter(model_content=tflite_content)
700      interpreter.allocate_tensors()
701      input_details = interpreter.get_input_details()
702      interpreter.set_tensor(input_details[0]['index'], input_data.numpy())
703      interpreter.invoke()
704      tensor_details = interpreter.get_tensor_details()
705      return {
706          details['name']: interpreter.get_tensor(details['index'])
707          for details in interpreter.get_tensor_details()
708      }, tensor_details
709
710    tflite_result, _ = examine_tflite_model(production_tflite, input_data)
711    debug_mode_tflite_result, debug_tensor_details = examine_tflite_model(
712        debug_mode_tflite, input_data)
713
714    # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
715    num_production_quantize_ops = len([
716        None for output_tensor_name in tflite_result
717        if 'tfl.quantize' in output_tensor_name
718    ])
719    self.assertEqual(num_production_quantize_ops, 1)
720    # MLIR-based quantizer should output flatbuffer model with `tfl.quantize`.
721    num_debug_quantize_ops = len([
722        None for output_tensor_name in debug_mode_tflite_result
723        if 'tfl.quantize' in output_tensor_name
724    ])
725    # Two numbers should be equal.
726    self.assertEqual(num_production_quantize_ops, num_debug_quantize_ops)
727    # DebugMode TFLite flatbuffer should have NumericVerifyOps more than zero.
728    # The name has the prefix "NumericVerify/{name}:{id}
729    # where {name} is the tensor name of the original quantized op's activation,
730    # and {id} is its tensor id.
731    num_debug_ops = 0
732    for output_tensor_name in debug_mode_tflite_result:
733      if 'NumericVerify' in output_tensor_name:
734        pos_end_prefix = len('NumericVerify/')
735        pos_colon = output_tensor_name.rfind(':')
736        self.assertEqual('NumericVerify/',
737                         output_tensor_name[:pos_end_prefix])
738        tensor_id = int(output_tensor_name[pos_colon+1:])
739        original_tensor_name = output_tensor_name[pos_end_prefix:pos_colon]
740        self.assertEqual(original_tensor_name,
741                         debug_tensor_details[tensor_id]['name'])
742        num_debug_ops += 1
743    self.assertEqual(num_debug_ops, 1)
744    # The number of debug ops should be equal to that of quantized ops.
745    self.assertEqual(num_debug_ops, num_debug_quantize_ops)
746
747
748class FromSavedModelTest(lite_v2_test_util.ModelTest):
749
750  def _createV1SavedModel(self, shape):
751    """Create a simple SavedModel."""
752    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
753    with tf.Graph().as_default():
754      with tf.compat.v1.Session() as sess:
755        in_tensor_1 = tf.compat.v1.placeholder(
756            shape=shape, dtype=tf.float32, name='inputB')
757        in_tensor_2 = tf.compat.v1.placeholder(
758            shape=shape, dtype=tf.float32, name='inputA')
759        variable_node = tf.Variable(1.0, name='variable_node')
760        out_tensor = in_tensor_1 + in_tensor_2 * variable_node
761        inputs = {'x': in_tensor_1, 'y': in_tensor_2}
762        outputs = {'z': out_tensor}
763        sess.run(tf.compat.v1.variables_initializer([variable_node]))
764        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
765    return saved_model_dir
766
767  @test_util.run_v2_only
768  def testV1SimpleModel(self):
769    """Test a SavedModel."""
770    with tf.Graph().as_default():
771      saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
772
773      # Convert model and ensure model is not None.
774      converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
775      tflite_model = converter.convert()
776      self.assertTrue(tflite_model)
777
778      interpreter = Interpreter(model_content=tflite_model)
779      interpreter.allocate_tensors()
780
781      input_details = interpreter.get_input_details()
782      self.assertLen(input_details, 2)
783      self.assertStartsWith(input_details[0]['name'], 'inputA')
784      self.assertEqual(np.float32, input_details[0]['dtype'])
785      self.assertAllEqual([1, 16, 16, 3], input_details[0]['shape'])
786      self.assertEqual((0., 0.), input_details[0]['quantization'])
787
788      self.assertStartsWith(
789          input_details[1]['name'],
790          'inputB',
791      )
792      self.assertEqual(np.float32, input_details[1]['dtype'])
793      self.assertTrue([1, 16, 16, 3], input_details[1]['shape'])
794      self.assertEqual((0., 0.), input_details[1]['quantization'])
795
796      output_details = interpreter.get_output_details()
797      self.assertLen(output_details, 1)
798      self.assertStartsWith(output_details[0]['name'], 'add')
799      self.assertEqual(np.float32, output_details[0]['dtype'])
800      self.assertTrue([1, 16, 16, 3], output_details[0]['shape'])
801      self.assertEqual((0., 0.), output_details[0]['quantization'])
802
803  @test_util.run_v2_only
804  def testTF1HubFormattedModel(self):
805    """Test a TF1 hub formatted model."""
806    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
807
808    # TF1 hub model is based on V1 saved model and they omit the saved model
809    # schema version setting.
810    saved_model_proto = parse_saved_model(saved_model_dir)
811    saved_model_proto.saved_model_schema_version = 0
812
813    saved_model_pb_file_path = os.path.join(saved_model_dir, 'saved_model.pb')
814    with file_io.FileIO(saved_model_pb_file_path, 'wb') as writer:
815      writer.write(saved_model_proto.SerializeToString())
816
817    # Convert model and ensure model is not None.
818    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
819    tflite_model = converter.convert()
820    self.assertTrue(tflite_model)
821
822  def _createV1ModelWithHashTableInitializer(self):
823    # Create a v1 saved model with hash table initializers.
824    tf.compat.v1.disable_eager_execution()
825    saved_model_dir = os.path.join(self.get_temp_dir(),
826                                   'savedmodel_with_hashtable')
827
828    table_initializer = tf.lookup.KeyValueTensorInitializer(
829        keys=['a', 'b', 'c', 'd'],
830        values=[1, 2, 3, 4],
831        key_dtype=tf.string,
832        value_dtype=tf.int64)
833    table = tf.lookup.StaticHashTable(
834        table_initializer, default_value=tf.constant(-1, dtype=tf.int64))
835
836    x = tf.compat.v1.placeholder(tf.string, shape=(), name='input')
837    y = table.lookup(x)
838
839    tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
840    tensor_info_y = tf.compat.v1.saved_model.utils.build_tensor_info(y)
841
842    signature_def_map, init_op, assets_collection = {
843        'serving_default':
844            (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
845                inputs={'x': tensor_info_x},
846                outputs={'y': tensor_info_y},
847                method_name='some_function'))
848    }, tf.compat.v1.tables_initializer(), None
849
850    sess = tf.compat.v1.Session()
851    sess.run(tf.compat.v1.initializers.global_variables())
852
853    builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(
854        saved_model_dir)
855    builder.add_meta_graph_and_variables(
856        sess, [tf.compat.v1.saved_model.tag_constants.SERVING],
857        signature_def_map,
858        main_op=init_op,
859        assets_collection=assets_collection,
860        strip_default_attrs=True)
861    builder.save()
862
863    # Restore TF v2 behavior.
864    tf.compat.v1.reset_default_graph()
865    tf.compat.v1.enable_eager_execution()
866    return saved_model_dir
867
868  @test_util.run_v2_only
869  def testModelWithHashTableInitializer(self):
870    """Test a model with saved_model's session initializer for hash tables."""
871    saved_model_dir = self._createV1ModelWithHashTableInitializer()
872
873    # Convert model and ensure model is not None.
874    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
875    converter.allow_custom_ops = True
876    tflite_model = converter.convert()
877
878    # Check values from converted model.
879    interpreter = InterpreterWithCustomOps(
880        model_content=tflite_model,
881        custom_op_registerers=[hashtable_ops_registerer.HashtableOpsRegisterer])
882    input_details = interpreter.get_input_details()
883    output_details = interpreter.get_output_details()
884
885    input_data = np.array(['a', 'b', 'c', 'z'], dtype=np.string_)
886    interpreter.resize_tensor_input(
887        input_details[0]['index'], [4], strict=False)
888    interpreter.allocate_tensors()
889
890    interpreter.set_tensor(input_details[0]['index'], input_data)
891
892    # Invoke multiple times to ensure the initializer graph runs only once.
893    interpreter.invoke()
894    actual_value = interpreter.get_tensor(output_details[0]['index'])
895    self.assertEqual([1, 2, 3, -1], list(actual_value))
896
897    interpreter.invoke()
898    actual_value = interpreter.get_tensor(output_details[0]['index'])
899    self.assertEqual([1, 2, 3, -1], list(actual_value))
900
901    interpreter.invoke()
902    actual_value = interpreter.get_tensor(output_details[0]['index'])
903    self.assertEqual([1, 2, 3, -1], list(actual_value))
904
905  @test_util.run_v2_only
906  def testConstModel(self):
907    """Test a basic model with functions to make sure functions are inlined."""
908    input_data = tf.constant(1., shape=[1])
909    root = tracking.AutoTrackable()
910    root.f = tf.function(lambda x: 2. * x)
911    to_save = root.f.get_concrete_function(input_data)
912
913    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
914    save(root, save_dir, to_save)
915
916    # Convert model and ensure model is not None.
917    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
918    tflite_model = converter.convert()
919
920    # Check values from converted model.
921    expected_value = root.f(input_data)
922    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
923    self.assertEqual(expected_value.numpy(), actual_value)
924
925  @test_util.run_v2_only
926  def testVariableModel(self):
927    """Test a basic model with Variables with saving/loading the SavedModel."""
928    root = self._getSimpleVariableModel()
929    input_data = tf.constant(1., shape=[1])
930    to_save = root.f.get_concrete_function(input_data)
931
932    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
933    save(root, save_dir, to_save)
934
935    # Convert model and ensure model is not None.
936    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
937    tflite_model = converter.convert()
938
939    # Check values from converted model.
940    expected_value = root.f(input_data)
941    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
942    self.assertEqual(expected_value.numpy(), actual_value)
943
944  @test_util.run_v2_only
945  def testSignatures(self):
946    """Test values for `signature_keys` argument."""
947    root = self._getSimpleVariableModel()
948    input_data = tf.constant(1., shape=[1])
949    to_save = root.f.get_concrete_function(input_data)
950
951    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
952    save(root, save_dir, to_save)
953
954    # Convert model with invalid `signature_keys`.
955    with self.assertRaises(ValueError) as error:
956      _ = lite.TFLiteConverterV2.from_saved_model(
957          save_dir, signature_keys=['INVALID'])
958    self.assertIn("Invalid signature key 'INVALID'", str(error.exception))
959
960    # Convert model with empty `signature_keys`.
961    converter = lite.TFLiteConverterV2.from_saved_model(
962        save_dir, signature_keys=[])
963    tflite_model = converter.convert()
964
965    # Check values from converted model.
966    expected_value = root.f(input_data)
967    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
968    self.assertEqual(expected_value.numpy(), actual_value)
969
970  @test_util.run_v2_only
971  def testSignatureDefs(self):
972    """Test converting SignatureDef is correct and uses SignatureDef API."""
973    root = self._getMultiFunctionModel()
974    input_data_0 = tf.constant(1., shape=[1])
975    input_data_1 = tf.constant(3., shape=[1])
976    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
977                                                      input_data_0)
978
979    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
980    save(root, save_dir, {'mul_add': mul_add_func})
981
982    converter = lite.TFLiteConverterV2.from_saved_model(
983        save_dir, signature_keys=['mul_add'])
984    tflite_model = converter.convert()
985
986    # Check values from converted model.
987    expected_value = root.mul_add(input_data_1, input_data_0)
988    interpreter = Interpreter(model_content=tflite_model)
989    signature_defs = interpreter.get_signature_list()
990    results = self._evaluateTFLiteModelUsingSignatureDef(
991        tflite_model, 'mul_add', {
992            'y': input_data_0,
993            'x': input_data_1
994        })
995    self.assertEqual(list(results.keys()), ['output_0'])
996    self.assertEqual(expected_value.numpy(), results['output_0'])
997
998    # Verify the SignatureDef structure returned is as expected.
999    self.assertEqual(len(signature_defs), 1)
1000    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1001    self.assertEqual(len(signature_defs.values()), 1)
1002    self.assertEqual(
1003        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1004    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1005    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1006
1007  @test_util.run_v2_only
1008  def testSignatureDefsWithDefaultValue(self):
1009    """Test converting SignatureDef is correct and uses SignatureDef API.
1010
1011    This test uses None as method_name to test default behavior.
1012    """
1013    root = self._getMultiFunctionModel()
1014    input_data_0 = tf.constant(1., shape=[1])
1015    input_data_1 = tf.constant(3., shape=[1])
1016    mul_add_func = root.mul_add.get_concrete_function(input_data_1,
1017                                                      input_data_0)
1018
1019    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1020    save(root, save_dir, {'mul_add': mul_add_func})
1021
1022    converter = lite.TFLiteConverterV2.from_saved_model(
1023        save_dir, signature_keys=['mul_add'])
1024    tflite_model = converter.convert()
1025
1026    # Check values from converted model.
1027    expected_value = root.mul_add(input_data_1, input_data_0)
1028    interpreter = Interpreter(model_content=tflite_model)
1029    signature_defs = interpreter.get_signature_list()
1030    results = self._evaluateTFLiteModelUsingSignatureDef(
1031        tflite_model, None, {
1032            'y': input_data_0,
1033            'x': input_data_1
1034        })
1035    self.assertEqual(list(results.keys()), ['output_0'])
1036    self.assertEqual(expected_value.numpy(), results['output_0'])
1037
1038    # Verify the SignatureDef structure returned is as expected.
1039    self.assertEqual(len(signature_defs), 1)
1040    self.assertEqual(list(signature_defs.keys()), ['mul_add'])
1041    self.assertEqual(len(signature_defs.values()), 1)
1042    self.assertEqual(
1043        list(signature_defs['mul_add'].keys()), ['inputs', 'outputs'])
1044    self.assertCountEqual(signature_defs['mul_add']['inputs'], ['x', 'y'])
1045    self.assertEqual(list(signature_defs['mul_add']['outputs']), ['output_0'])
1046
1047  @test_util.run_v2_only
1048  def testMultipleFunctionModel(self):
1049    """Convert multiple functions in a multi-functional model."""
1050    root = self._getMultiFunctionModel()
1051    input_data = tf.constant(1., shape=[1])
1052    add_func = root.add.get_concrete_function(input_data)
1053    sub_func = root.sub.get_concrete_function(input_data)
1054
1055    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1056    save(root, save_dir, {'add': add_func, 'sub': sub_func})
1057
1058    # Try converting multiple functions.
1059    with self.assertRaises(ValueError) as error:
1060      _ = lite.TFLiteConverterV2.from_saved_model(save_dir)
1061    self.assertIn('Only support a single signature key.', str(error.exception))
1062
1063  @test_util.run_v2_only
1064  def testNoConcreteFunctionModel(self):
1065    root = self._getMultiFunctionModel()
1066
1067    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1068    save(root, save_dir)
1069
1070    with self.assertRaises(ValueError) as error:
1071      _ = lite.TFLiteConverterV2.from_saved_model(save_dir)
1072    self.assertIn('Only support a single signature key.', str(error.exception))
1073
1074  @test_util.run_v2_only
1075  def testKerasSequentialModel(self):
1076    """Test a simple sequential tf.Keras model."""
1077    input_data = tf.constant(1., shape=[1, 1])
1078
1079    x = np.array([[1.], [2.]])
1080    y = np.array([[2.], [4.]])
1081
1082    model = tf.keras.models.Sequential([
1083        tf.keras.layers.Dropout(0.2),
1084        tf.keras.layers.Dense(1),
1085    ])
1086    model.compile(optimizer='sgd', loss='mean_squared_error')
1087    model.fit(x, y, epochs=1)
1088
1089    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1090    save(model, save_dir)
1091
1092    # Convert model and ensure model is not None.
1093    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1094    tflite_model = converter.convert()
1095
1096    # Check values from converted model.
1097    expected_value = model.predict(input_data)
1098    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1099    self.assertEqual(expected_value, actual_value)
1100
1101  @test_util.run_v2_only
1102  def testGraphDebugInfo(self):
1103    """Test a SavedModel has debug info captured."""
1104    input_data = tf.constant(1., shape=[1])
1105    root = tracking.AutoTrackable()
1106    root.f = tf.function(lambda x: 2. * x)
1107    to_save = root.f.get_concrete_function(input_data)
1108    options = save_options.SaveOptions(save_debug_info=True)
1109    save_dir = os.path.join(self.get_temp_dir(), 'saved_model')
1110    save(root, save_dir, to_save, options)
1111
1112    # Convert model and ensure model is not None.
1113    converter = lite.TFLiteConverterV2.from_saved_model(save_dir)
1114    converter.convert()
1115    self._assertValidDebugInfo(converter._debug_info)
1116
1117  @test_util.run_v2_only
1118  def testFallbackPath(self):
1119    """Test a SavedModel fallback path using old converter."""
1120    saved_model_dir = self._createV1SavedModel(shape=[1, 16, 16, 3])
1121
1122    # Convert model and ensure model is not None.
1123    converter = lite.TFLiteConverterV2.from_saved_model(saved_model_dir)
1124    converter.experimental_new_converter = False
1125    tflite_model = converter.convert()
1126
1127    self.assertTrue(tflite_model)
1128
1129  @test_util.run_v2_only
1130  def testNonStatefulConvLSTM2D(self):
1131    """Test saved model with non stateful ConvLSTM2D keras layer."""
1132    # Create keras model
1133    model = tf.keras.Sequential([
1134        tf.keras.layers.ConvLSTM2D(
1135            32, (3, 3),
1136            padding='same',
1137            return_sequences=True,
1138            stateful=False,
1139            batch_input_shape=(1, 1, 10, 10, 1))
1140    ])
1141    model.compile()
1142
1143    # Export the keras model to saved model.
1144    saved_model_dir = os.path.join(self.get_temp_dir(), 'conv_lstm_2d')
1145    model.save(saved_model_dir, save_format='tf', include_optimizer=False)
1146
1147    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1148    converter.target_spec.supported_ops = [
1149        tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS
1150    ]
1151    tflite_model = converter.convert()
1152    self.assertTrue(tflite_model)
1153
1154  def _createUnknownInputShapeModel(self):
1155    """Create a simple SavedModel with unknown input."""
1156    saved_model_dir = os.path.join(self.get_temp_dir(), 'unknown_input_shape')
1157    with tf.Graph().as_default():
1158      with tf.compat.v1.Session() as sess:
1159        unknown_shape = tf.TensorShape(None)
1160        in_tensor = tf.compat.v1.placeholder(
1161            shape=unknown_shape, dtype=tf.float32, name='input')
1162        out_tensor = in_tensor + in_tensor
1163        inputs = {'input': in_tensor}
1164        outputs = {'output': out_tensor}
1165        saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
1166    return saved_model_dir
1167
1168  @test_util.run_v2_only
1169  def testUnknownInputShapeModel(self):
1170    """Test a SavedModel with an unknown input shape."""
1171    saved_model_dir = self._createUnknownInputShapeModel()
1172
1173    converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
1174    tflite_model = converter.convert()
1175    self.assertTrue(tflite_model)
1176
1177    # Check values from converted model.
1178    interpreter = Interpreter(model_content=tflite_model)
1179    input_details = interpreter.get_input_details()
1180    output_details = interpreter.get_output_details()
1181
1182    input_data = np.array([1., 2., 3.], dtype=np.float32)
1183    interpreter.resize_tensor_input(
1184        input_details[0]['index'], [3], strict=False)
1185    interpreter.allocate_tensors()
1186
1187    interpreter.set_tensor(input_details[0]['index'], input_data)
1188    interpreter.invoke()
1189    actual_value = interpreter.get_tensor(output_details[0]['index'])
1190    self.assertEqual([2., 4., 6.], list(actual_value))
1191
1192
1193class FromKerasModelTest(lite_v2_test_util.ModelTest):
1194
1195  @test_util.run_v2_only
1196  def testSequentialModel(self):
1197    """Test a simple sequential tf.Keras model."""
1198    input_data = tf.constant(1., shape=[1, 1])
1199
1200    # Create a simple Keras model.
1201    x = np.array([[1.], [2.]])
1202    y = np.array([[2.], [4.]])
1203
1204    model = tf.keras.models.Sequential([
1205        tf.keras.layers.Dropout(0.2),
1206        tf.keras.layers.Dense(units=1, input_shape=[1])
1207    ])
1208    model.compile(optimizer='sgd', loss='mean_squared_error')
1209    model.fit(x, y, epochs=1)
1210
1211    # Convert model and ensure model is not None.
1212    converter = lite.TFLiteConverterV2.from_keras_model(model)
1213    tflite_model = converter.convert()
1214
1215    # Check values from converted model.
1216    expected_value = model.predict(input_data)
1217    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1218    self.assertEqual(expected_value, actual_value)
1219
1220  @test_util.run_v2_only
1221  def testSequentialMultiInputOutputModel(self):
1222    """Test a tf.Keras model with multiple inputs and outputs."""
1223    left_input_data = tf.constant(1., shape=[1, 3])
1224    right_input_data = tf.constant(1., shape=[1, 3])
1225
1226    # Create a simple Keras model.
1227    input_a_np = np.random.random((10, 3))
1228    input_b_np = np.random.random((10, 3))
1229    output_c_np = np.random.random((10, 3))
1230    output_d_np = np.random.random((10, 2))
1231
1232    input_a = tf.keras.layers.Input(shape=(3,), name='input_a')
1233    input_b = tf.keras.layers.Input(shape=(3,), name='input_b')
1234
1235    dense = tf.keras.layers.Dense(8, name='dense_1')
1236    interm_a = dense(input_a)
1237    interm_b = dense(input_b)
1238    merged = tf.keras.layers.concatenate([interm_a, interm_b], name='merge')
1239
1240    output_c = tf.keras.layers.Dense(
1241        3, activation='softmax', name='dense_2')(
1242            merged)
1243    output_d = tf.keras.layers.Dense(
1244        2, activation='softmax', name='dense_3')(
1245            merged)
1246
1247    model = tf.keras.models.Model(
1248        inputs=[input_a, input_b], outputs=[output_c, output_d])
1249    model.compile(optimizer='sgd', loss='mean_squared_error')
1250    model.fit([input_a_np, input_b_np], [output_c_np, output_d_np], epochs=1)
1251
1252    # Convert model and ensure model is not None.
1253    converter = lite.TFLiteConverterV2.from_keras_model(model)
1254    tflite_model = converter.convert()
1255
1256    # Check values from converted model.
1257    input_data = [left_input_data, right_input_data]
1258    expected_value = model.predict(input_data)
1259    actual_value = self._evaluateTFLiteModel(tflite_model, input_data)
1260    for tf_result, tflite_result in zip(expected_value, actual_value):
1261      self.assertAllClose(tf_result, tflite_result, atol=1e-05)
1262
1263  @test_util.run_v2_only
1264  def testGraphDebugInfo(self):
1265    """Test a tf.Keras model has debug info captured."""
1266    # Create a simple Keras model.
1267    x = [-1, 0, 1, 2, 3, 4]
1268    y = [-3, -1, 1, 3, 5, 7]
1269    model = tf.keras.models.Sequential(
1270        [tf.keras.layers.Dense(units=1, input_shape=[1])])
1271    model.compile(optimizer='sgd', loss='mean_squared_error')
1272    model.fit(x, y, epochs=1)
1273    converter = lite.TFLiteConverterV2.from_keras_model(model)
1274    converter.convert()
1275    self._assertValidDebugInfo(converter._debug_info)
1276
1277  @test_util.run_v2_only
1278  def testKerasFallbackPath(self):
1279    """Test keras model which failed when exporting to the saved model."""
1280    input_data = tf.constant(
1281        np.array(np.random.random_sample((20)), dtype=np.float32))
1282
1283    class Model(tf.keras.Model):
1284
1285      def __init__(self):
1286        super(Model, self).__init__()
1287        # A None name will cause a failure in exporting to a saved model.
1288        self.shared_weights = self.add_weight(
1289            name=None,
1290            shape=(20, 1),
1291            dtype=tf.float32,
1292            initializer=tf.random_normal_initializer(
1293                mean=0.0, stddev=300**(-0.5)))
1294
1295      def call(self, x):
1296        return tf.add(self.shared_weights, x)
1297
1298    # Building the model.
1299    model = Model()
1300    model.compile(optimizer='sgd', loss='mean_squared_error')
1301    model.fit(input_data, input_data, epochs=1)
1302
1303    # Convert model.
1304    converter = lite.TFLiteConverterV2.from_keras_model(model)
1305    tflite_model = converter.convert()
1306    self.assertTrue(tflite_model)
1307
1308
1309class ControlFlowTest(lite_v2_test_util.ModelTest):
1310
1311  @test_util.run_v2_only
1312  def testCond(self):
1313    input_data = {
1314        'x': tf.constant([1., 2.], shape=[1, 2]),
1315        'b': tf.constant(True)
1316    }
1317
1318    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
1319
1320    def true_fn(x):
1321      return tf.matmul(x, weights)
1322
1323    def false_fn(x):
1324      return tf.add(x, weights)
1325
1326    @tf.function(input_signature=[
1327        tf.TensorSpec(shape=[1, 2], dtype=tf.float32),
1328        tf.TensorSpec(shape=(), dtype=tf.bool)
1329    ])
1330    def model(x, b):
1331      return tf.cond(
1332          b, true_fn=lambda: true_fn(x), false_fn=lambda: false_fn(x))
1333
1334    concrete_func = model.get_concrete_function()
1335
1336    # Convert model.
1337    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1338    tflite_model = converter.convert()
1339
1340    # Check values from converted model.
1341    expected_value = concrete_func(**input_data)
1342    actual_value = self._evaluateTFLiteModel(
1343        tflite_model, [input_data['x'], input_data['b']])[0]
1344    self.assertAllClose(expected_value, actual_value)
1345
1346  @test_util.run_v2_only
1347  def testConverterErrorOnControlFlowV1Ops(self):
1348    filename = resource_loader.get_path_to_datafile(
1349        'testdata/control_flow_v1_saved_model')
1350    converter = lite.TFLiteConverterV2.from_saved_model(filename)
1351    with self.assertRaises(convert.ConverterError) as error:
1352      converter.convert()
1353    self.assertIn(
1354        'Failed to functionalize Control Flow V1 ops. Consider using Control '
1355        'Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/'
1356        'tf/compat/v1/enable_control_flow_v2.', str(error.exception))
1357
1358  @test_util.run_v2_only
1359  def testStaticRnn(self):
1360    input_data = tf.constant(
1361        np.array(np.random.random_sample((3, 10)), dtype=np.float32))
1362
1363    cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10)
1364
1365    @tf.function(
1366        input_signature=[tf.TensorSpec(shape=[3, 10], dtype=tf.float32)])
1367    def model(x):
1368      seq = tf.split(x, 3, 0)
1369      return tf.compat.v1.nn.static_rnn(
1370          cell, seq, dtype=tf.float32, sequence_length=[1])
1371
1372    concrete_func = model.get_concrete_function()
1373
1374    # Convert model.
1375    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1376    tflite_model = converter.convert()
1377
1378    # Check values from converted model.
1379    expected_value = concrete_func(input_data)[0]
1380    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1381    for expected, actual in zip(expected_value, actual_value):
1382      self.assertAllClose(expected, actual)
1383
1384  @test_util.run_v2_only
1385  def testWhileLoop(self):
1386    input_data = tf.constant([1., 2., 3., 4.], shape=[2, 2])
1387
1388    weights = tf.Variable([[0.1, 0.2], [0.3, 0.4]], dtype=tf.float32)
1389
1390    def condition(x):
1391      return tf.reduce_sum(x) < 100
1392
1393    def body(x):
1394      return tf.add(x, weights)
1395
1396    @tf.function(
1397        input_signature=[tf.TensorSpec(shape=[2, 2], dtype=tf.float32)])
1398    def model(x):
1399      return tf.while_loop(condition, body, [x])
1400
1401    concrete_func = model.get_concrete_function()
1402
1403    # Convert model.
1404    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1405    tflite_model = converter.convert()
1406
1407    # Check values from converted model.
1408    expected_value = concrete_func(input_data)[0]
1409    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1410    self.assertAllClose(expected_value, actual_value)
1411
1412  @test_util.run_v2_only
1413  def testDynamicRnn(self):
1414    input_data = tf.constant(
1415        np.array(np.random.random_sample((3, 10, 10)), dtype=np.float32))
1416
1417    cell = tf.compat.v1.nn.rnn_cell.LSTMCell(10)
1418
1419    @tf.function(
1420        input_signature=[tf.TensorSpec(shape=[3, 10, 10], dtype=tf.float32)])
1421    def model(x):
1422      return tf.compat.v1.nn.dynamic_rnn(cell, x, dtype=tf.float32)
1423
1424    concrete_func = model.get_concrete_function()
1425
1426    # Convert model.
1427    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1428    tflite_model = converter.convert()
1429
1430    # Check values from converted model.
1431    expected_value = concrete_func(input_data)
1432    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])
1433    for expected, actual in zip(expected_value, actual_value):
1434      if not isinstance(expected, ops.EagerTensor):
1435        expected = expected.c
1436      self.assertAllClose(expected, actual)
1437
1438  @parameterized.named_parameters(
1439      ('LSTM_BatchSize_None', tf.keras.layers.LSTM, None),
1440      ('SimpleRNN_BatchSize_None', tf.keras.layers.SimpleRNN, None),
1441      ('GRU_BatchSize_None', tf.keras.layers.GRU, None),
1442      ('LSTM_BatchSize_One', tf.keras.layers.LSTM, 1),
1443      ('SimpleRNN_BatchSize_One', tf.keras.layers.SimpleRNN, 1),
1444      ('GRU_BatchSize_One', tf.keras.layers.GRU, 1))
1445  @test_util.run_v2_only
1446  def testKerasRNN(self, rnn_layer, batch_size):
1447    # This test will run with `batch_size=1` and `batch_size=None`.
1448    # When `batch_size=1`, the model will convert to fused RNN, and when
1449    # `batch_size=None`, it will convert to unfused RNN
1450    # (similar for tests below).
1451    input_data = tf.constant(
1452        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
1453    rnn_obj = rnn_layer(units=10, input_shape=(10, 10))
1454    model = tf.keras.models.Sequential([
1455        tf.keras.layers.Input(
1456            batch_size=batch_size, shape=(10, 10), name='input'),
1457        rnn_obj,
1458    ])
1459
1460    # Convert model.
1461    converter = lite.TFLiteConverterV2.from_keras_model(model)
1462    tflite_model = converter.convert()
1463    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1464
1465    # Check values from converted model.
1466    expected_value = model.predict(input_data)
1467    self.assertAllClose(expected_value, actual_value, atol=1e-05)
1468
1469  @parameterized.named_parameters(('LSTM', tf.keras.layers.LSTM),
1470                                  ('SimpleRNN', tf.keras.layers.SimpleRNN),
1471                                  ('GRU', tf.keras.layers.GRU))
1472  @test_util.run_v2_only
1473  def testKerasRNNMultiBatches(self, rnn_layer):
1474    input_data = tf.constant(
1475        np.array(np.random.random_sample((4, 10, 10)), dtype=np.float32))
1476    # Specify a fixed batch size(4) for the test model.
1477    x = tf.keras.layers.Input(batch_shape=(4, 10, 10))
1478    y = rnn_layer(units=10, input_shape=(10, 10))(x)
1479    model = tf.keras.Model(inputs=[x], outputs=[y])
1480
1481    # Convert model.
1482    converter = lite.TFLiteConverterV2.from_keras_model(model)
1483    tflite_model = converter.convert()
1484    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1485
1486    # Check values from converted model.
1487    expected_value = model.predict(input_data)
1488    self.assertAllClose(expected_value, actual_value, atol=1e-05)
1489
1490  @parameterized.named_parameters(('BatchSize_None', None),
1491                                  ('BatchSize_One', 1))
1492  @test_util.run_v2_only
1493  def testKerasBidirectionalRNNReturnSequence(self, batch_size):
1494    input_data = tf.constant(
1495        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
1496    model = tf.keras.models.Sequential()
1497    model.add(
1498        tf.keras.layers.Input(
1499            batch_size=batch_size, shape=(10, 10), name='input'))
1500    model.add(
1501        tf.keras.layers.Bidirectional(
1502            tf.keras.layers.LSTM(units=10, return_sequences=True),
1503            input_shape=(10, 10)))
1504    model.add(tf.keras.layers.Flatten())
1505    model.add(tf.keras.layers.Dense(5))
1506    model.add(tf.keras.layers.Activation('softmax'))
1507
1508    # Convert model.
1509    converter = lite.TFLiteConverterV2.from_keras_model(model)
1510    tflite_model = converter.convert()
1511    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1512
1513    # Check values from converted model.
1514    expected_value = model.predict(input_data)
1515    self.assertAllClose(expected_value, actual_value, atol=1e-05)
1516
1517  @parameterized.named_parameters(('BatchSize_None', None),
1518                                  ('BatchSize_One', 1))
1519  @test_util.run_v2_only
1520  def testKerasBidirectionalRNN(self, batch_size):
1521    input_data = tf.constant(
1522        np.array(np.random.random_sample((1, 10, 10)), dtype=np.float32))
1523    model = tf.keras.models.Sequential()
1524    model.add(
1525        tf.keras.layers.Input(
1526            batch_size=batch_size, shape=(10, 10), name='input'))
1527    model.add(tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=10)))
1528    model.add(tf.keras.layers.Dense(5))
1529    model.add(tf.keras.layers.Activation('softmax'))
1530
1531    # Convert model.
1532    converter = lite.TFLiteConverterV2.from_keras_model(model)
1533    tflite_model = converter.convert()
1534    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1535
1536    # Check values from converted model.
1537    expected_value = model.predict(input_data)
1538    self.assertAllClose(expected_value, actual_value, atol=1e-05)
1539
1540
1541class GrapplerTest(lite_v2_test_util.ModelTest):
1542
1543  @test_util.run_v2_only
1544  def testConstantFolding(self):
1545    # Constant folding handles the tf.broadcast_to operation which was not
1546    # supported by the TFLite at the time this test was added.
1547    input_data = tf.constant([1., 2., 3., 4., 5., 6., 7., 8., 9.], shape=[3, 3])
1548
1549    @tf.function
1550    def func(x):
1551      y_const = tf.constant([1., 2., 3.])
1552      y_broadcast = tf.broadcast_to(y_const, [3, 3])
1553      return tf.matmul(x, y_broadcast)
1554
1555    root = tracking.AutoTrackable()
1556    root.f = func
1557    concrete_func = root.f.get_concrete_function(input_data)
1558
1559    # Convert model.
1560    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1561    tflite_model = converter.convert()
1562
1563    # Check values from converted model.
1564    expected_value = root.f(input_data)
1565    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1566    self.assertAllClose(expected_value, actual_value)
1567
1568    # Enable hybrid quantization, same result
1569    converter.optimizations = [lite.Optimize.DEFAULT]
1570    tflite_model = converter.convert()
1571    actual_value = self._evaluateTFLiteModel(tflite_model, [input_data])[0]
1572    self.assertAllClose(expected_value, actual_value)
1573
1574
1575class UnknownShapes(lite_v2_test_util.ModelTest):
1576
1577  @test_util.run_v2_only
1578  def testMatMul(self):
1579    input_data = tf.constant(
1580        np.array(np.random.random_sample((10, 4)), dtype=np.float32))
1581
1582    @tf.function(
1583        input_signature=[tf.TensorSpec(shape=[None, 4], dtype=tf.float32)])
1584    def model(in_tensor):
1585      shape = tf.shape(in_tensor)
1586      fill = tf.transpose(tf.fill(shape, 1.))
1587      return tf.matmul(fill, in_tensor)
1588
1589    concrete_func = model.get_concrete_function()
1590
1591    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1592    tflite_model = converter.convert()
1593
1594    # Check values from converted model.
1595    expected_value = concrete_func(input_data)
1596    actual_value = self._evaluateTFLiteModel(
1597        tflite_model, [input_data], input_shapes=[([-1, 4], [10, 4])])[0]
1598    self.assertAllClose(expected_value, actual_value, atol=1e-06)
1599
1600  def _getIntegerQuantizeModelWithUnknownShapes(self):
1601    np.random.seed(0)
1602
1603    @tf.function(
1604        input_signature=[tf.TensorSpec(shape=[None, 33], dtype=tf.float32)])
1605    def model(input_tensor):
1606      """Define a model with tf.MatMul and unknown shapes."""
1607      # We need the tensor to have more than 1024 elements for quantize_weights
1608      # to kick in. Thus, the [33, 33] shape.
1609      const_tensor = tf.constant(
1610          np.random.uniform(low=-10., high=10., size=[33, 33]),
1611          shape=[33, 33],
1612          dtype=tf.float32,
1613          name='inputB')
1614
1615      shape = tf.shape(input_tensor)
1616      fill = tf.transpose(tf.fill(shape, 1.))
1617      mult = tf.matmul(fill, input_tensor)
1618      return tf.matmul(mult, const_tensor)
1619
1620    root = tracking.AutoTrackable()
1621    root.f = model
1622    concrete_func = root.f.get_concrete_function()
1623
1624    def calibration_gen():
1625      for batch in range(5, 20, 5):
1626        for _ in range(5):
1627          yield [np.random.uniform(-1, 1, size=(batch, 33)).astype(np.float32)]
1628
1629    return concrete_func, calibration_gen
1630
1631  @test_util.run_v2_only
1632  def testMatMulQuantize(self):
1633    concrete_func, _ = self._getIntegerQuantizeModelWithUnknownShapes()
1634    float_converter = lite.TFLiteConverterV2.from_concrete_functions(
1635        [concrete_func])
1636    float_tflite_model = float_converter.convert()
1637
1638    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
1639        [concrete_func])
1640    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1641    quantized_tflite_model = quantized_converter.convert()
1642
1643    # The default input and output types should be float.
1644    quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
1645    quantized_interpreter.allocate_tensors()
1646    input_details = quantized_interpreter.get_input_details()
1647    self.assertLen(input_details, 1)
1648    self.assertEqual(np.float32, input_details[0]['dtype'])
1649    self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
1650
1651    # Ensure that the quantized weights tflite model is smaller.
1652    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1653
1654  @test_util.run_v2_only
1655  def testMatMulCalibrateAndQuantize(self):
1656    concrete_func, calibration_gen = \
1657        self._getIntegerQuantizeModelWithUnknownShapes()
1658    float_converter = lite.TFLiteConverterV2.from_concrete_functions(
1659        [concrete_func])
1660    float_tflite_model = float_converter.convert()
1661
1662    quantized_converter = lite.TFLiteConverterV2.from_concrete_functions(
1663        [concrete_func])
1664    quantized_converter.optimizations = [lite.Optimize.DEFAULT]
1665    quantized_converter.representative_dataset = calibration_gen
1666    quantized_tflite_model = quantized_converter.convert()
1667
1668    # The default input and output types should be float.
1669    quantized_interpreter = Interpreter(model_content=quantized_tflite_model)
1670    quantized_interpreter.allocate_tensors()
1671    input_details = quantized_interpreter.get_input_details()
1672    self.assertLen(input_details, 1)
1673    self.assertEqual(np.float32, input_details[0]['dtype'])
1674    self.assertAllEqual([-1, 33], input_details[0]['shape_signature'])
1675
1676    # Ensure that the quantized weights tflite model is smaller.
1677    self.assertLess(len(quantized_tflite_model), len(float_tflite_model))
1678
1679  def testBatchMatMul(self):
1680    input_data_1 = tf.constant(
1681        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
1682    input_data_2 = tf.constant(
1683        np.array(np.random.random_sample((1, 256, 256)), dtype=np.float32))
1684
1685    @tf.function(input_signature=[
1686        tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32),
1687        tf.TensorSpec(shape=[None, 256, 256], dtype=tf.float32)
1688    ])
1689    def model(in_tensor_1, in_tensor_2):
1690      return tf.matmul(in_tensor_1, in_tensor_2)
1691
1692    concrete_func = model.get_concrete_function()
1693
1694    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1695    tflite_model = converter.convert()
1696
1697    # Check values from converted model.
1698    expected_value = concrete_func(input_data_1, input_data_2)
1699    actual_value = self._evaluateTFLiteModel(
1700        tflite_model, [input_data_1, input_data_2],
1701        input_shapes=[([-1, 256, 256], [1, 256, 256])])[0]
1702    self.assertAllClose(expected_value, actual_value, atol=4)
1703
1704  def testSizeInvalid(self):
1705
1706    @tf.function(input_signature=[
1707        tf.TensorSpec(shape=[1, None, 16, 3], dtype=tf.float32)
1708    ])
1709    def model(in_tensor):
1710      return in_tensor + in_tensor
1711
1712    concrete_func = model.get_concrete_function()
1713
1714    # Test invalid shape. None after 1st dimension. Run with TOCO in order to
1715    # invoke shape checking code.
1716    converter = lite.TFLiteConverterV2.from_concrete_functions([concrete_func])
1717    converter.experimental_new_converter = False
1718    with self.assertRaises(ValueError) as error:
1719      converter.convert()
1720    self.assertEqual(
1721        'None is only supported in the 1st dimension. Tensor '
1722        '\'in_tensor\' has invalid shape \'[1, None, 16, 3]\'.',
1723        str(error.exception))
1724
1725
1726if __name__ == '__main__':
1727  test.main()
1728