1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for lite.py."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import os
22import tempfile
23import numpy as np
24
25from tensorflow.lite.python import lite
26from tensorflow.lite.python import lite_constants
27from tensorflow.lite.python.interpreter import Interpreter
28from tensorflow.python import keras
29from tensorflow.python.client import session
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import test_util
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import math_ops
35from tensorflow.python.ops import nn_ops
36from tensorflow.python.ops import variable_scope
37from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
38from tensorflow.python.platform import gfile
39from tensorflow.python.platform import resource_loader
40from tensorflow.python.platform import test
41from tensorflow.python.saved_model import saved_model
42from tensorflow.python.training.training_util import write_graph
43
44
45class FromConstructor(test_util.TensorFlowTestCase):
46
47  # Tests invalid constructors using a dummy value for the GraphDef.
48  def testInvalidConstructor(self):
49    message = ('If input_tensors and output_tensors are None, both '
50               'input_arrays_with_shape and output_arrays must be defined.')
51
52    # `output_arrays` is not defined.
53    with self.assertRaises(ValueError) as error:
54      lite.TFLiteConverter(
55          None, None, [], input_arrays_with_shape=[('input', [3, 9])])
56    self.assertEqual(message, str(error.exception))
57
58    # `input_arrays_with_shape` is not defined.
59    with self.assertRaises(ValueError) as error:
60      lite.TFLiteConverter(None, [], None, output_arrays=['output'])
61    self.assertEqual(message, str(error.exception))
62
63  # Tests valid constructors using a dummy value for the GraphDef.
64  def testValidConstructor(self):
65    converter = lite.TFLiteConverter(
66        None,
67        None,
68        None,
69        input_arrays_with_shape=[('input', [3, 9])],
70        output_arrays=['output'])
71    self.assertFalse(converter._has_valid_tensors())
72    self.assertEqual(converter.get_input_arrays(), ['input'])
73
74    with self.assertRaises(ValueError) as error:
75      converter._set_batch_size(1)
76    self.assertEqual(
77        'The batch size cannot be set for this model. Please use '
78        'input_shapes parameter.', str(error.exception))
79
80    converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
81    self.assertTrue(converter._has_valid_tensors())
82
83
84@test_util.run_v1_only('b/120545219')
85class FromSessionTest(test_util.TensorFlowTestCase):
86
87  def testFloat(self):
88    in_tensor = array_ops.placeholder(
89        shape=[1, 16, 16, 3], dtype=dtypes.float32)
90    out_tensor = in_tensor + in_tensor
91    sess = session.Session()
92
93    # Convert model and ensure model is not None.
94    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
95                                                  [out_tensor])
96    tflite_model = converter.convert()
97    self.assertTrue(tflite_model)
98
99    # Check values from converted model.
100    interpreter = Interpreter(model_content=tflite_model)
101    interpreter.allocate_tensors()
102
103    input_details = interpreter.get_input_details()
104    self.assertEqual(1, len(input_details))
105    self.assertEqual('Placeholder', input_details[0]['name'])
106    self.assertEqual(np.float32, input_details[0]['dtype'])
107    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
108    self.assertEqual((0., 0.), input_details[0]['quantization'])
109
110    output_details = interpreter.get_output_details()
111    self.assertEqual(1, len(output_details))
112    self.assertEqual('add', output_details[0]['name'])
113    self.assertEqual(np.float32, output_details[0]['dtype'])
114    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
115    self.assertEqual((0., 0.), output_details[0]['quantization'])
116
117  def testString(self):
118    in_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.string)
119    out_tensor = array_ops.reshape(in_tensor, shape=[2, 2])
120    sess = session.Session()
121
122    # Convert model and ensure model is not None.
123    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
124                                                  [out_tensor])
125    tflite_model = converter.convert()
126    self.assertTrue(tflite_model)
127
128    # Check values from converted model.
129    interpreter = Interpreter(model_content=tflite_model)
130    interpreter.allocate_tensors()
131
132    input_details = interpreter.get_input_details()
133    self.assertEqual(1, len(input_details))
134    self.assertEqual('Placeholder', input_details[0]['name'])
135    self.assertEqual(np.string_, input_details[0]['dtype'])
136    self.assertTrue(([4] == input_details[0]['shape']).all())
137
138    output_details = interpreter.get_output_details()
139    self.assertEqual(1, len(output_details))
140    self.assertEqual('Reshape', output_details[0]['name'])
141    self.assertEqual(np.string_, output_details[0]['dtype'])
142    self.assertTrue(([2, 2] == output_details[0]['shape']).all())
143    # TODO(b/122659643): Test setting/getting string data via the python
144    # interpreter API after support has been added.
145
146  def testQuantization(self):
147    in_tensor_1 = array_ops.placeholder(
148        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
149    in_tensor_2 = array_ops.placeholder(
150        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
151    out_tensor = array_ops.fake_quant_with_min_max_args(
152        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
153    sess = session.Session()
154
155    # Convert model and ensure model is not None.
156    converter = lite.TFLiteConverter.from_session(
157        sess, [in_tensor_1, in_tensor_2], [out_tensor])
158    converter.inference_type = lite_constants.QUANTIZED_UINT8
159    converter.quantized_input_stats = {
160        'inputA': (0., 1.),
161        'inputB': (0., 1.)
162    }  # mean, std_dev
163    tflite_model = converter.convert()
164    self.assertTrue(tflite_model)
165
166    # Check values from converted model.
167    interpreter = Interpreter(model_content=tflite_model)
168    interpreter.allocate_tensors()
169
170    input_details = interpreter.get_input_details()
171    self.assertEqual(2, len(input_details))
172    self.assertEqual('inputA', input_details[0]['name'])
173    self.assertEqual(np.uint8, input_details[0]['dtype'])
174    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
175    self.assertEqual((1., 0.),
176                     input_details[0]['quantization'])  # scale, zero_point
177
178    self.assertEqual('inputB', input_details[1]['name'])
179    self.assertEqual(np.uint8, input_details[1]['dtype'])
180    self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
181    self.assertEqual((1., 0.),
182                     input_details[1]['quantization'])  # scale, zero_point
183
184    output_details = interpreter.get_output_details()
185    self.assertEqual(1, len(output_details))
186    self.assertEqual('output', output_details[0]['name'])
187    self.assertEqual(np.uint8, output_details[0]['dtype'])
188    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
189    self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
190
191  def testQuantizationInvalid(self):
192    in_tensor_1 = array_ops.placeholder(
193        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
194    in_tensor_2 = array_ops.placeholder(
195        shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
196    out_tensor = array_ops.fake_quant_with_min_max_args(
197        in_tensor_1 + in_tensor_2, min=0., max=1., name='output')
198    sess = session.Session()
199
200    # Convert model and ensure model is not None.
201    converter = lite.TFLiteConverter.from_session(
202        sess, [in_tensor_1, in_tensor_2], [out_tensor])
203    converter.inference_type = lite_constants.QUANTIZED_UINT8
204    converter.quantized_input_stats = {'inputA': (0., 1.)}  # mean, std_dev
205    with self.assertRaises(ValueError) as error:
206      converter.convert()
207    self.assertEqual(
208        'Quantization input stats are not available for input tensors '
209        '\'inputB\'.', str(error.exception))
210
211  def testIntermediateInputArray(self):
212    """Convert a model from an intermediate input array."""
213    in_tensor_init = array_ops.placeholder(
214        shape=[1, 16, 16, 3], dtype=dtypes.float32)
215    in_tensor_final = in_tensor_init + in_tensor_init
216    out_tensor = in_tensor_final + in_tensor_final
217    sess = session.Session()
218
219    # Convert model and ensure model is not None.
220    converter = lite.TFLiteConverter.from_session(sess, [in_tensor_final],
221                                                  [out_tensor])
222    tflite_model = converter.convert()
223    self.assertTrue(tflite_model)
224
225    # Check values from converted model.
226    interpreter = Interpreter(model_content=tflite_model)
227    interpreter.allocate_tensors()
228
229    input_details = interpreter.get_input_details()
230    self.assertEqual(1, len(input_details))
231    self.assertEqual('add', input_details[0]['name'])
232    self.assertEqual(np.float32, input_details[0]['dtype'])
233    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
234    self.assertEqual((0., 0.), input_details[0]['quantization'])
235
236    output_details = interpreter.get_output_details()
237    self.assertEqual(1, len(output_details))
238    self.assertEqual('add_1', output_details[0]['name'])
239    self.assertEqual(np.float32, output_details[0]['dtype'])
240    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
241    self.assertEqual((0., 0.), output_details[0]['quantization'])
242
243  def testSizeNoneInvalid(self):
244    in_tensor = array_ops.placeholder(dtype=dtypes.float32)
245    out_tensor = in_tensor + in_tensor
246    sess = session.Session()
247
248    # Test None as shape.
249    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
250                                                  [out_tensor])
251    with self.assertRaises(ValueError) as error:
252      converter.convert()
253    self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
254                     str(error.exception))
255
256  def testScalarValid(self):
257    # Construct a graph using a scalar (empty shape) input.
258    in_tensor = array_ops.placeholder(dtype=dtypes.float32, shape=[])
259    out_tensor = in_tensor + in_tensor
260    sess = session.Session()
261
262    # Test conversion with the scalar input shape.
263    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
264                                                  [out_tensor])
265    tflite_model = converter.convert()
266    self.assertTrue(tflite_model)
267
268    # Check values from converted model.
269    interpreter = Interpreter(model_content=tflite_model)
270    interpreter.allocate_tensors()
271
272    input_details = interpreter.get_input_details()
273    self.assertEqual(1, len(input_details))
274    self.assertEqual('Placeholder', input_details[0]['name'])
275    self.assertEqual(np.float32, input_details[0]['dtype'])
276    self.assertTrue(([] == input_details[0]['shape']).all())
277
278    output_details = interpreter.get_output_details()
279    self.assertEqual(1, len(output_details))
280    self.assertEqual('add', output_details[0]['name'])
281    self.assertEqual(np.float32, output_details[0]['dtype'])
282    self.assertTrue(([] == input_details[0]['shape']).all())
283
284    # Validate inference using the scalar inputs/outputs.
285    test_input = np.array(4.0, dtype=np.float32)
286    expected_output = np.array(8.0, dtype=np.float32)
287    interpreter.set_tensor(input_details[0]['index'], test_input)
288    interpreter.invoke()
289
290    output_data = interpreter.get_tensor(output_details[0]['index'])
291    self.assertTrue((expected_output == output_data).all())
292
293  def testSizeInvalid(self):
294    in_tensor = array_ops.placeholder(
295        shape=[1, None, 16, 3], dtype=dtypes.float32)
296    out_tensor = in_tensor + in_tensor
297    sess = session.Session()
298
299    # Test invalid shape. None after 1st dimension.
300    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
301                                                  [out_tensor])
302    with self.assertRaises(ValueError) as error:
303      converter.convert()
304    self.assertEqual(
305        'None is only supported in the 1st dimension. Tensor '
306        '\'Placeholder\' has invalid shape \'[1, None, 16, 3]\'.',
307        str(error.exception))
308
309  def testBatchSizeValid(self):
310    in_tensor = array_ops.placeholder(
311        shape=[None, 16, 16, 3], dtype=dtypes.float32)
312    out_tensor = in_tensor + in_tensor
313    sess = session.Session()
314
315    # Convert model and ensure model is not None.
316    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
317                                                  [out_tensor])
318    tflite_model = converter.convert()
319    self.assertTrue(tflite_model)
320
321    # Check values from converted model.
322    interpreter = Interpreter(model_content=tflite_model)
323    interpreter.allocate_tensors()
324
325    input_details = interpreter.get_input_details()
326    self.assertEqual(1, len(input_details))
327    self.assertEqual('Placeholder', input_details[0]['name'])
328    self.assertEqual(np.float32, input_details[0]['dtype'])
329    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
330    self.assertEqual((0., 0.), input_details[0]['quantization'])
331
332    output_details = interpreter.get_output_details()
333    self.assertEqual(1, len(output_details))
334    self.assertEqual('add', output_details[0]['name'])
335    self.assertEqual(np.float32, output_details[0]['dtype'])
336    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
337    self.assertEqual((0., 0.), output_details[0]['quantization'])
338
339  def testFreezeGraph(self):
340    in_tensor = array_ops.placeholder(
341        shape=[1, 16, 16, 3], dtype=dtypes.float32)
342    var = variable_scope.get_variable(
343        'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
344    out_tensor = in_tensor + var
345    sess = session.Session()
346    sess.run(_global_variables_initializer())
347
348    # Convert model and ensure model is not None.
349    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
350                                                  [out_tensor])
351    tflite_model = converter.convert()
352    self.assertTrue(tflite_model)
353
354    # Check values from converted model.
355    interpreter = Interpreter(model_content=tflite_model)
356    interpreter.allocate_tensors()
357
358    input_details = interpreter.get_input_details()
359    self.assertEqual(1, len(input_details))
360    self.assertEqual('Placeholder', input_details[0]['name'])
361    self.assertEqual(np.float32, input_details[0]['dtype'])
362    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
363    self.assertEqual((0., 0.), input_details[0]['quantization'])
364
365    output_details = interpreter.get_output_details()
366    self.assertEqual(1, len(output_details))
367    self.assertEqual('add', output_details[0]['name'])
368    self.assertEqual(np.float32, output_details[0]['dtype'])
369    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
370    self.assertEqual((0., 0.), output_details[0]['quantization'])
371
372  # TODO(nupurgarg): Verify value of contents in GraphViz.
373  def testGraphviz(self):
374    in_tensor = array_ops.placeholder(
375        shape=[1, 16, 16, 3], dtype=dtypes.float32)
376    out_tensor = in_tensor + in_tensor
377    sess = session.Session()
378
379    # Convert model and ensure model is not None.
380    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
381                                                  [out_tensor])
382    converter.output_format = lite_constants.GRAPHVIZ_DOT
383    graphviz_output = converter.convert()
384    self.assertTrue(graphviz_output)
385
386  # TODO(nupurgarg): Verify value of contents in GraphViz.
387  def testDumpGraphviz(self):
388    in_tensor = array_ops.placeholder(
389        shape=[1, 16, 16, 3], dtype=dtypes.float32)
390    out_tensor = in_tensor + in_tensor
391    sess = session.Session()
392
393    # Convert model and ensure model is not None.
394    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
395                                                  [out_tensor])
396    graphviz_dir = self.get_temp_dir()
397    converter.dump_graphviz_dir = graphviz_dir
398    tflite_model = converter.convert()
399    self.assertTrue(tflite_model)
400
401    # Ensure interpreter is able to allocate and check graphviz data.
402    interpreter = Interpreter(model_content=tflite_model)
403    interpreter.allocate_tensors()
404
405    num_items_graphviz = len(os.listdir(graphviz_dir))
406    self.assertTrue(num_items_graphviz)
407
408    # Convert model and ensure model is not None.
409    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
410                                                  [out_tensor])
411    graphviz_dir = self.get_temp_dir()
412    converter.dump_graphviz_dir = graphviz_dir
413    converter.dump_graphviz_video = True
414    tflite_model = converter.convert()
415    self.assertTrue(tflite_model)
416
417    # Ensure graphviz folder has more data after using video flag.
418    num_items_graphviz_video = len(os.listdir(graphviz_dir))
419    self.assertTrue(num_items_graphviz_video > num_items_graphviz)
420
421  def testInferenceInputType(self):
422    in_tensor = array_ops.placeholder(
423        shape=[1, 16, 16, 3], dtype=dtypes.float32)
424    out_tensor = in_tensor + in_tensor
425    sess = session.Session()
426
427    # Convert model and ensure model is not None.
428    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
429                                                  [out_tensor])
430    converter.inference_input_type = lite_constants.QUANTIZED_UINT8
431    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
432    tflite_model = converter.convert()
433    self.assertTrue(tflite_model)
434
435    # Check values from converted model.
436    interpreter = Interpreter(model_content=tflite_model)
437    interpreter.allocate_tensors()
438
439    input_details = interpreter.get_input_details()
440    self.assertEqual(1, len(input_details))
441    self.assertEqual('Placeholder', input_details[0]['name'])
442    self.assertEqual(np.uint8, input_details[0]['dtype'])
443    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
444    self.assertEqual((1., 0.), input_details[0]['quantization'])
445
446    output_details = interpreter.get_output_details()
447    self.assertEqual(1, len(output_details))
448    self.assertEqual('add', output_details[0]['name'])
449    self.assertEqual(np.float32, output_details[0]['dtype'])
450    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
451
452  def testDefaultRangesStats(self):
453    in_tensor = array_ops.placeholder(
454        shape=[1, 16, 16, 3], dtype=dtypes.float32)
455    out_tensor = in_tensor + in_tensor
456    sess = session.Session()
457
458    # Convert model and ensure model is not None.
459    converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
460                                                  [out_tensor])
461    converter.inference_type = lite_constants.QUANTIZED_UINT8
462    converter.quantized_input_stats = {'Placeholder': (0., 1.)}  # mean, std_dev
463    converter.default_ranges_stats = (0, 6)  # min, max
464    tflite_model = converter.convert()
465    self.assertTrue(tflite_model)
466
467    # Check values from converted model.
468    interpreter = Interpreter(model_content=tflite_model)
469    interpreter.allocate_tensors()
470
471    input_details = interpreter.get_input_details()
472    self.assertEqual(1, len(input_details))
473    self.assertEqual('Placeholder', input_details[0]['name'])
474    self.assertEqual(np.uint8, input_details[0]['dtype'])
475    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
476    self.assertEqual((1., 0.), input_details[0]['quantization'])
477
478    output_details = interpreter.get_output_details()
479    self.assertEqual(1, len(output_details))
480    self.assertEqual('add', output_details[0]['name'])
481    self.assertEqual(np.uint8, output_details[0]['dtype'])
482    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
483    self.assertTrue(output_details[0]['quantization'][0] > 0)  # scale
484
485  def testPostTrainingQuantizeDeprecatedAttribute(self):
486    in_tensor_1 = array_ops.placeholder(
487        shape=[33, 33], dtype=dtypes.float32, name='inputA')
488    in_tensor_2 = constant_op.constant(
489        np.random.uniform(low=-10., high=10., size=(33, 33)),
490        shape=[33, 33],
491        dtype=dtypes.float32,
492        name='inputB')
493    out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
494    sess = session.Session()
495
496    quantized_converter = lite.TFLiteConverter.from_session(
497        sess, [in_tensor_1], [out_tensor])
498    self.assertFalse(quantized_converter.post_training_quantize)
499
500    quantized_converter.post_training_quantize = True
501    self.assertTrue(quantized_converter.post_training_quantize)
502    self.assertEqual(quantized_converter.optimizations,
503                     [lite.Optimize.OPTIMIZE_FOR_SIZE])
504
505    quantized_tflite = quantized_converter.convert()
506    self.assertTrue(quantized_tflite)
507
508  def testPostTrainingQuantize(self):
509    np.random.seed(0)
510    # We need the tensor to have more than 1024 elements for quantize_weights
511    # to kick in. Thus, the [33, 33] shape.
512    in_tensor_1 = array_ops.placeholder(
513        shape=[33, 33], dtype=dtypes.float32, name='inputA')
514    in_tensor_2 = constant_op.constant(
515        np.random.uniform(low=-10., high=10., size=(33, 33)),
516        shape=[33, 33],
517        dtype=dtypes.float32,
518        name='inputB')
519    out_tensor = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
520    sess = session.Session()
521
522    # Convert float model.
523    float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
524                                                        [out_tensor])
525    float_tflite = float_converter.convert()
526    self.assertTrue(float_tflite)
527
528    # Convert quantized weights model.
529    quantized_converter = lite.TFLiteConverter.from_session(
530        sess, [in_tensor_1], [out_tensor])
531    quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE]
532    quantized_tflite = quantized_converter.convert()
533    self.assertTrue(quantized_tflite)
534
535    # Ensure that the quantized weights tflite model is smaller.
536    self.assertTrue(len(quantized_tflite) < len(float_tflite))
537
538  def testPostTrainingCalibrateAndQuantize(self):
539    np.random.seed(0)
540    # Create a mobilenet like model.
541    output_channel = 16
542    depth_multiplier = 1
543    inp = array_ops.placeholder(dtype=dtypes.float32, shape=(1, 5, 5, 3))
544    conv = nn_ops.conv2d(
545        inp,
546        filter=array_ops.zeros([3, 3, 3, output_channel]),
547        strides=[1, 1, 1, 1],
548        padding='SAME')
549    dconv = nn_ops.depthwise_conv2d_native(
550        conv,
551        filter=array_ops.zeros(
552            [16, 16, output_channel, output_channel * depth_multiplier]),
553        strides=[1, 1, 1, 1],
554        padding='SAME')
555    pool = nn_ops.pool(
556        dconv, window_shape=[2, 2], pooling_type='AVG', padding='SAME')
557    max_pool = nn_ops.pool(
558        pool, window_shape=[2, 2], pooling_type='MAX', padding='SAME')
559    output = nn_ops.softmax(max_pool)
560
561    def calibration_gen():
562      for _ in range(10):
563        yield [np.random.uniform(-1, 1, size=(1, 5, 5, 3)).astype(np.float32)]
564
565    sess = session.Session()
566
567    # Convert float model.
568    float_converter = lite.TFLiteConverter.from_session(sess, [inp], [output])
569    float_tflite = float_converter.convert()
570    self.assertTrue(float_tflite)
571
572    # Convert quantized weights model.
573    quantized_converter = lite.TFLiteConverter.from_session(
574        sess, [inp], [output])
575    quantized_converter.optimizations = [lite.Optimize.OPTIMIZE_FOR_SIZE]
576    quantized_converter.representative_dataset = lite.RepresentativeDataset(
577        calibration_gen)
578    quantized_tflite = quantized_converter.convert()
579    self.assertTrue(quantized_tflite)
580
581    # Ensure that the quantized weights tflite model is smaller.
582    self.assertTrue(len(quantized_tflite) < len(float_tflite))
583
584  def testFloatTocoConverter(self):
585    """Tests deprecated test TocoConverter."""
586    in_tensor = array_ops.placeholder(
587        shape=[1, 16, 16, 3], dtype=dtypes.float32)
588    out_tensor = in_tensor + in_tensor
589    sess = session.Session()
590
591    # Convert model and ensure model is not None.
592    converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
593    tflite_model = converter.convert()
594    self.assertTrue(tflite_model)
595
596    # Ensure the interpreter is able to load.
597    interpreter = Interpreter(model_content=tflite_model)
598    interpreter.allocate_tensors()
599
600  def testMultipleOutputNodeNames(self):
601    """Tests converting a graph with an op that have multiple outputs."""
602    input_tensor = array_ops.placeholder(shape=[4], dtype=dtypes.float32)
603    out0, out1, out2, out3 = array_ops.split(input_tensor, [1, 1, 1, 1], axis=0)
604    sess = session.Session()
605
606    # Convert model and ensure model is not None.
607    converter = lite.TFLiteConverter.from_session(sess, [input_tensor],
608                                                  [out0, out1, out2, out3])
609    tflite_model = converter.convert()
610    self.assertTrue(tflite_model)
611
612    # Check values from converted model.
613    interpreter = Interpreter(model_content=tflite_model)
614    interpreter.allocate_tensors()
615
616    input_details = interpreter.get_input_details()
617    self.assertEqual(1, len(input_details))
618    interpreter.set_tensor(input_details[0]['index'],
619                           np.asarray([1.0, 2.0, 3.0, 4.0], dtype=np.float32))
620    interpreter.invoke()
621
622    output_details = interpreter.get_output_details()
623    self.assertEqual(4, len(output_details))
624    self.assertEqual(1.0, interpreter.get_tensor(output_details[0]['index']))
625    self.assertEqual(2.0, interpreter.get_tensor(output_details[1]['index']))
626    self.assertEqual(3.0, interpreter.get_tensor(output_details[2]['index']))
627    self.assertEqual(4.0, interpreter.get_tensor(output_details[3]['index']))
628
629
630@test_util.run_v1_only('b/120545219')
631class FromFrozenGraphFile(test_util.TensorFlowTestCase):
632
633  def testFloat(self):
634    in_tensor = array_ops.placeholder(
635        shape=[1, 16, 16, 3], dtype=dtypes.float32)
636    _ = in_tensor + in_tensor
637    sess = session.Session()
638
639    # Write graph to file.
640    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
641    write_graph(sess.graph_def, '', graph_def_file, False)
642    sess.close()
643
644    # Convert model and ensure model is not None.
645    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
646                                                       ['Placeholder'], ['add'])
647    tflite_model = converter.convert()
648    self.assertTrue(tflite_model)
649
650    # Check values from converted model.
651    interpreter = Interpreter(model_content=tflite_model)
652    interpreter.allocate_tensors()
653
654    input_details = interpreter.get_input_details()
655    self.assertEqual(1, len(input_details))
656    self.assertEqual('Placeholder', input_details[0]['name'])
657    self.assertEqual(np.float32, input_details[0]['dtype'])
658    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
659    self.assertEqual((0., 0.), input_details[0]['quantization'])
660
661    output_details = interpreter.get_output_details()
662    self.assertEqual(1, len(output_details))
663    self.assertEqual('add', output_details[0]['name'])
664    self.assertEqual(np.float32, output_details[0]['dtype'])
665    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
666    self.assertEqual((0., 0.), output_details[0]['quantization'])
667
668  def testFloatWithShapesArray(self):
669    in_tensor = array_ops.placeholder(
670        shape=[1, 16, 16, 3], dtype=dtypes.float32)
671    _ = in_tensor + in_tensor
672    sess = session.Session()
673
674    # Write graph to file.
675    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
676    write_graph(sess.graph_def, '', graph_def_file, False)
677    sess.close()
678
679    # Convert model and ensure model is not None.
680    converter = lite.TFLiteConverter.from_frozen_graph(
681        graph_def_file, ['Placeholder'], ['add'],
682        input_shapes={'Placeholder': [1, 16, 16, 3]})
683    tflite_model = converter.convert()
684    self.assertTrue(tflite_model)
685
686    # Check values from converted model.
687    interpreter = Interpreter(model_content=tflite_model)
688    interpreter.allocate_tensors()
689
690    input_details = interpreter.get_input_details()
691    self.assertEqual(1, len(input_details))
692    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
693
694  def testFreezeGraph(self):
695    in_tensor = array_ops.placeholder(
696        shape=[1, 16, 16, 3], dtype=dtypes.float32)
697    var = variable_scope.get_variable(
698        'weights', shape=[1, 16, 16, 3], dtype=dtypes.float32)
699    _ = in_tensor + var
700    sess = session.Session()
701
702    # Write graph to file.
703    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
704    write_graph(sess.graph_def, '', graph_def_file, False)
705    sess.close()
706
707    # Ensure the graph with variables cannot be converted.
708    with self.assertRaises(ValueError) as error:
709      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
710                                             ['add'])
711    self.assertEqual('Please freeze the graph using freeze_graph.py.',
712                     str(error.exception))
713
714  def testPbtxt(self):
715    in_tensor = array_ops.placeholder(
716        shape=[1, 16, 16, 3], dtype=dtypes.float32)
717    _ = in_tensor + in_tensor
718    sess = session.Session()
719
720    # Write graph to file.
721    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pbtxt')
722    write_graph(sess.graph_def, '', graph_def_file, True)
723    sess.close()
724
725    # Convert model and ensure model is not None.
726    converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
727                                                       ['Placeholder'], ['add'])
728    tflite_model = converter.convert()
729    self.assertTrue(tflite_model)
730
731    # Check values from converted model.
732    interpreter = Interpreter(model_content=tflite_model)
733    interpreter.allocate_tensors()
734
735    input_details = interpreter.get_input_details()
736    self.assertEqual(1, len(input_details))
737    self.assertEqual('Placeholder', input_details[0]['name'])
738    self.assertEqual(np.float32, input_details[0]['dtype'])
739    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
740    self.assertEqual((0., 0.), input_details[0]['quantization'])
741
742    output_details = interpreter.get_output_details()
743    self.assertEqual(1, len(output_details))
744    self.assertEqual('add', output_details[0]['name'])
745    self.assertEqual(np.float32, output_details[0]['dtype'])
746    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
747    self.assertEqual((0., 0.), output_details[0]['quantization'])
748
749  def testInvalidFileNotFound(self):
750    with self.assertRaises(IOError) as error:
751      lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
752                                             ['add'])
753    self.assertEqual('File \'invalid_file\' does not exist.',
754                     str(error.exception))
755
756  def testInvalidFileBadData(self):
757    graph_def_file = os.path.join(self.get_temp_dir(), 'invalid_file')
758    with gfile.Open(graph_def_file, 'wb') as temp_file:
759      temp_file.write('bad data')
760      temp_file.flush()
761
762    # Attempts to convert the invalid model.
763    with self.assertRaises(IOError) as error:
764      lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
765                                             ['add'])
766    self.assertEqual(
767        'Unable to parse input file \'{}\'.'.format(graph_def_file),
768        str(error.exception))
769
770  # TODO(nupurgarg): Test model loading in open source.
771  def _initObjectDetectionArgs(self):
772    # Initializes the arguments required for the object detection model.
773    # Looks for the model file which is saved in a different location internally
774    # and externally.
775    filename = resource_loader.get_path_to_datafile('testdata/tflite_graph.pb')
776    if not os.path.exists(filename):
777      filename = os.path.join(
778          resource_loader.get_root_dir_with_all_resources(),
779          '../tflite_mobilenet_ssd_quant_protobuf/tflite_graph.pb')
780      if not os.path.exists(filename):
781        raise IOError("File '{0}' does not exist.".format(filename))
782
783    self._graph_def_file = filename
784    self._input_arrays = ['normalized_input_image_tensor']
785    self._output_arrays = [
786        'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
787        'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
788    ]
789    self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
790
791  def testTFLiteGraphDef(self):
792    # Tests the object detection model that cannot be loaded in TensorFlow.
793    self._initObjectDetectionArgs()
794
795    converter = lite.TFLiteConverter.from_frozen_graph(
796        self._graph_def_file, self._input_arrays, self._output_arrays,
797        self._input_shapes)
798    converter.allow_custom_ops = True
799    tflite_model = converter.convert()
800    self.assertTrue(tflite_model)
801
802    # Check values from converted model.
803    interpreter = Interpreter(model_content=tflite_model)
804    interpreter.allocate_tensors()
805
806    input_details = interpreter.get_input_details()
807    self.assertEqual(1, len(input_details))
808    self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
809    self.assertEqual(np.float32, input_details[0]['dtype'])
810    self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
811    self.assertEqual((0., 0.), input_details[0]['quantization'])
812
813    output_details = interpreter.get_output_details()
814    self.assertEqual(4, len(output_details))
815    self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
816    self.assertEqual(np.float32, output_details[0]['dtype'])
817    self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
818    self.assertEqual((0., 0.), output_details[0]['quantization'])
819
820    self.assertEqual('TFLite_Detection_PostProcess:1',
821                     output_details[1]['name'])
822    self.assertTrue(([1, 10] == output_details[1]['shape']).all())
823    self.assertEqual('TFLite_Detection_PostProcess:2',
824                     output_details[2]['name'])
825    self.assertTrue(([1, 10] == output_details[2]['shape']).all())
826    self.assertEqual('TFLite_Detection_PostProcess:3',
827                     output_details[3]['name'])
828    self.assertTrue(([1] == output_details[3]['shape']).all())
829
830  def testTFLiteGraphDefMissingShape(self):
831    # Tests invalid cases for the model that cannot be loaded in TensorFlow.
832    self._initObjectDetectionArgs()
833
834    # Missing `input_shapes`.
835    with self.assertRaises(ValueError) as error:
836      lite.TFLiteConverter.from_frozen_graph(
837          self._graph_def_file, self._input_arrays, self._output_arrays)
838    self.assertEqual('input_shapes must be defined for this model.',
839                     str(error.exception))
840
841  def testTFLiteGraphDefInvalidShape(self):
842    # Tests invalid cases for the model that cannot be loaded in TensorFlow.
843    self._initObjectDetectionArgs()
844
845    # `input_shapes` does not contain the names in `input_arrays`.
846    with self.assertRaises(ValueError) as error:
847      lite.TFLiteConverter.from_frozen_graph(
848          self._graph_def_file,
849          self._input_arrays,
850          self._output_arrays,
851          input_shapes={'invalid-value': [1, 19]})
852    self.assertEqual(
853        'input_shapes must contain a value for each item in input_array.',
854        str(error.exception))
855
856  def testFloatTocoConverter(self):
857    in_tensor = array_ops.placeholder(
858        shape=[1, 16, 16, 3], dtype=dtypes.float32)
859    _ = in_tensor + in_tensor
860    sess = session.Session()
861
862    # Write graph to file.
863    graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
864    write_graph(sess.graph_def, '', graph_def_file, False)
865    sess.close()
866
867    # Convert model and ensure model is not None.
868    converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
869                                                     ['Placeholder'], ['add'])
870    tflite_model = converter.convert()
871    self.assertTrue(tflite_model)
872
873    # Ensure the model is able to load.
874    interpreter = Interpreter(model_content=tflite_model)
875    interpreter.allocate_tensors()
876
877
878@test_util.run_v1_only('b/120545219')
879class FromSavedModelTest(test_util.TensorFlowTestCase):
880
881  def _createSavedModel(self, shape):
882    """Create a simple SavedModel."""
883    saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
884    with session.Session() as sess:
885      in_tensor_1 = array_ops.placeholder(
886          shape=shape, dtype=dtypes.float32, name='inputB')
887      in_tensor_2 = array_ops.placeholder(
888          shape=shape, dtype=dtypes.float32, name='inputA')
889      out_tensor = in_tensor_1 + in_tensor_2
890      inputs = {'x': in_tensor_1, 'y': in_tensor_2}
891      outputs = {'z': out_tensor}
892      saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
893    return saved_model_dir
894
895  def testSimpleModel(self):
896    """Test a SavedModel."""
897    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
898
899    # Convert model and ensure model is not None.
900    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
901    tflite_model = converter.convert()
902    self.assertTrue(tflite_model)
903
904    interpreter = Interpreter(model_content=tflite_model)
905    interpreter.allocate_tensors()
906
907    input_details = interpreter.get_input_details()
908    self.assertEqual(2, len(input_details))
909    self.assertEqual('inputA', input_details[0]['name'])
910    self.assertEqual(np.float32, input_details[0]['dtype'])
911    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
912    self.assertEqual((0., 0.), input_details[0]['quantization'])
913
914    self.assertEqual('inputB', input_details[1]['name'])
915    self.assertEqual(np.float32, input_details[1]['dtype'])
916    self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
917    self.assertEqual((0., 0.), input_details[1]['quantization'])
918
919    output_details = interpreter.get_output_details()
920    self.assertEqual(1, len(output_details))
921    self.assertEqual('add', output_details[0]['name'])
922    self.assertEqual(np.float32, output_details[0]['dtype'])
923    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
924    self.assertEqual((0., 0.), output_details[0]['quantization'])
925
926  def testNoneBatchSize(self):
927    """Test a SavedModel, with None in input tensor's shape."""
928    saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
929
930    converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
931    tflite_model = converter.convert()
932    self.assertTrue(tflite_model)
933
934    # Check values from converted model.
935    interpreter = Interpreter(model_content=tflite_model)
936    interpreter.allocate_tensors()
937
938    input_details = interpreter.get_input_details()
939    self.assertEqual(2, len(input_details))
940    self.assertEqual('inputA', input_details[0]['name'])
941    self.assertEqual(np.float32, input_details[0]['dtype'])
942    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
943    self.assertEqual((0., 0.), input_details[0]['quantization'])
944
945    self.assertEqual('inputB', input_details[1]['name'])
946    self.assertEqual(np.float32, input_details[1]['dtype'])
947    self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
948    self.assertEqual((0., 0.), input_details[1]['quantization'])
949
950    output_details = interpreter.get_output_details()
951    self.assertEqual(1, len(output_details))
952    self.assertEqual('add', output_details[0]['name'])
953    self.assertEqual(np.float32, output_details[0]['dtype'])
954    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
955    self.assertEqual((0., 0.), output_details[0]['quantization'])
956
957  def testOrderInputArrays(self):
958    """Test a SavedModel ordering of input arrays."""
959    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
960
961    converter = lite.TFLiteConverter.from_saved_model(
962        saved_model_dir, input_arrays=['inputB', 'inputA'])
963    tflite_model = converter.convert()
964    self.assertTrue(tflite_model)
965
966    # Check values from converted model.
967    interpreter = Interpreter(model_content=tflite_model)
968    interpreter.allocate_tensors()
969
970    input_details = interpreter.get_input_details()
971    self.assertEqual(2, len(input_details))
972    self.assertEqual('inputA', input_details[0]['name'])
973    self.assertEqual(np.float32, input_details[0]['dtype'])
974    self.assertTrue(([1, 16, 16, 3] == input_details[0]['shape']).all())
975    self.assertEqual((0., 0.), input_details[0]['quantization'])
976
977    self.assertEqual('inputB', input_details[1]['name'])
978    self.assertEqual(np.float32, input_details[1]['dtype'])
979    self.assertTrue(([1, 16, 16, 3] == input_details[1]['shape']).all())
980    self.assertEqual((0., 0.), input_details[1]['quantization'])
981
982    output_details = interpreter.get_output_details()
983    self.assertEqual(1, len(output_details))
984    self.assertEqual('add', output_details[0]['name'])
985    self.assertEqual(np.float32, output_details[0]['dtype'])
986    self.assertTrue(([1, 16, 16, 3] == output_details[0]['shape']).all())
987    self.assertEqual((0., 0.), output_details[0]['quantization'])
988
989  def testSubsetInputArrays(self):
990    """Test a SavedModel with a subset of the input array names of the model."""
991    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
992
993    # Check case where input shape is given.
994    converter = lite.TFLiteConverter.from_saved_model(
995        saved_model_dir,
996        input_arrays=['inputA'],
997        input_shapes={'inputA': [1, 16, 16, 3]})
998
999    tflite_model = converter.convert()
1000    self.assertTrue(tflite_model)
1001
1002    # Check case where input shape is None.
1003    converter = lite.TFLiteConverter.from_saved_model(
1004        saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
1005
1006    tflite_model = converter.convert()
1007    self.assertTrue(tflite_model)
1008
1009  def testSimpleModelTocoConverter(self):
1010    """Test a SavedModel with deprecated TocoConverter."""
1011    saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
1012
1013    # Convert model and ensure model is not None.
1014    converter = lite.TocoConverter.from_saved_model(saved_model_dir)
1015    tflite_model = converter.convert()
1016    self.assertTrue(tflite_model)
1017
1018    # Ensure the model is able to load.
1019    interpreter = Interpreter(model_content=tflite_model)
1020    interpreter.allocate_tensors()
1021
1022
1023@test_util.run_v1_only('b/120545219')
1024class FromKerasFile(test_util.TensorFlowTestCase):
1025
1026  def setUp(self):
1027    keras.backend.clear_session()
1028
1029  def _getSequentialModel(self):
1030    with session.Session().as_default():
1031      model = keras.models.Sequential()
1032      model.add(keras.layers.Dense(2, input_shape=(3,)))
1033      model.add(keras.layers.RepeatVector(3))
1034      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
1035      model.compile(
1036          loss=keras.losses.MSE,
1037          optimizer=keras.optimizers.RMSprop(),
1038          metrics=[keras.metrics.categorical_accuracy],
1039          sample_weight_mode='temporal')
1040      x = np.random.random((1, 3))
1041      y = np.random.random((1, 3, 3))
1042      model.train_on_batch(x, y)
1043      model.predict(x)
1044
1045      try:
1046        fd, keras_file = tempfile.mkstemp('.h5')
1047        keras.models.save_model(model, keras_file)
1048      finally:
1049        os.close(fd)
1050      return keras_file
1051
1052  def testSequentialModel(self):
1053    """Test a Sequential tf.keras model with default inputs."""
1054    keras_file = self._getSequentialModel()
1055
1056    converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
1057    tflite_model = converter.convert()
1058    self.assertTrue(tflite_model)
1059
1060    # Check tensor details of converted model.
1061    interpreter = Interpreter(model_content=tflite_model)
1062    interpreter.allocate_tensors()
1063
1064    input_details = interpreter.get_input_details()
1065    self.assertEqual(1, len(input_details))
1066    self.assertEqual('dense_input', input_details[0]['name'])
1067    self.assertEqual(np.float32, input_details[0]['dtype'])
1068    self.assertTrue(([1, 3] == input_details[0]['shape']).all())
1069    self.assertEqual((0., 0.), input_details[0]['quantization'])
1070
1071    output_details = interpreter.get_output_details()
1072    self.assertEqual(1, len(output_details))
1073    self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
1074    self.assertEqual(np.float32, output_details[0]['dtype'])
1075    self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
1076    self.assertEqual((0., 0.), output_details[0]['quantization'])
1077
1078    # Check inference of converted model.
1079    input_data = np.array([[1, 2, 3]], dtype=np.float32)
1080    interpreter.set_tensor(input_details[0]['index'], input_data)
1081    interpreter.invoke()
1082    tflite_result = interpreter.get_tensor(output_details[0]['index'])
1083
1084    keras_model = keras.models.load_model(keras_file)
1085    keras_result = keras_model.predict(input_data)
1086
1087    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
1088    os.remove(keras_file)
1089
1090  def testSequentialModelInputArray(self):
1091    """Test a Sequential tf.keras model testing input arrays argument."""
1092    keras_file = self._getSequentialModel()
1093
1094    # Invalid input array raises error.
1095    with self.assertRaises(ValueError) as error:
1096      lite.TFLiteConverter.from_keras_model_file(
1097          keras_file, input_arrays=['invalid-input'])
1098    self.assertEqual("Invalid tensors 'invalid-input' were found.",
1099                     str(error.exception))
1100
1101    # Valid input array.
1102    converter = lite.TFLiteConverter.from_keras_model_file(
1103        keras_file, input_arrays=['dense_input'])
1104    tflite_model = converter.convert()
1105    os.remove(keras_file)
1106    self.assertTrue(tflite_model)
1107
1108  def testSequentialModelInputShape(self):
1109    """Test a Sequential tf.keras model testing input shapes argument."""
1110    keras_file = self._getSequentialModel()
1111
1112    # Passing in shape of invalid input array raises error.
1113    with self.assertRaises(ValueError) as error:
1114      converter = lite.TFLiteConverter.from_keras_model_file(
1115          keras_file, input_shapes={'invalid-input': [2, 3]})
1116    self.assertEqual(
1117        "Invalid tensor 'invalid-input' found in tensor shapes map.",
1118        str(error.exception))
1119
1120    # Passing in shape of valid input array.
1121    converter = lite.TFLiteConverter.from_keras_model_file(
1122        keras_file, input_shapes={'dense_input': [2, 3]})
1123    tflite_model = converter.convert()
1124    os.remove(keras_file)
1125    self.assertTrue(tflite_model)
1126
1127    # Check input shape from converted model.
1128    interpreter = Interpreter(model_content=tflite_model)
1129    interpreter.allocate_tensors()
1130
1131    input_details = interpreter.get_input_details()
1132    self.assertEqual(1, len(input_details))
1133    self.assertEqual('dense_input', input_details[0]['name'])
1134    self.assertTrue(([2, 3] == input_details[0]['shape']).all())
1135
1136  def testSequentialModelOutputArray(self):
1137    """Test a Sequential tf.keras model testing output arrays argument."""
1138    keras_file = self._getSequentialModel()
1139
1140    # Invalid output array raises error.
1141    with self.assertRaises(ValueError) as error:
1142      lite.TFLiteConverter.from_keras_model_file(
1143          keras_file, output_arrays=['invalid-output'])
1144    self.assertEqual("Invalid tensors 'invalid-output' were found.",
1145                     str(error.exception))
1146
1147    # Valid output array.
1148    converter = lite.TFLiteConverter.from_keras_model_file(
1149        keras_file, output_arrays=['time_distributed/Reshape_1'])
1150    tflite_model = converter.convert()
1151    os.remove(keras_file)
1152    self.assertTrue(tflite_model)
1153
1154  def testFunctionalModel(self):
1155    """Test a Functional tf.keras model with default inputs."""
1156    with session.Session().as_default():
1157      inputs = keras.layers.Input(shape=(3,), name='input')
1158      x = keras.layers.Dense(2)(inputs)
1159      output = keras.layers.Dense(3)(x)
1160
1161      model = keras.models.Model(inputs, output)
1162      model.compile(
1163          loss=keras.losses.MSE,
1164          optimizer=keras.optimizers.RMSprop(),
1165          metrics=[keras.metrics.categorical_accuracy])
1166      x = np.random.random((1, 3))
1167      y = np.random.random((1, 3))
1168      model.train_on_batch(x, y)
1169
1170      model.predict(x)
1171      fd, keras_file = tempfile.mkstemp('.h5')
1172      try:
1173        keras.models.save_model(model, keras_file)
1174      finally:
1175        os.close(fd)
1176
1177    # Convert to TFLite model.
1178    converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
1179    tflite_model = converter.convert()
1180    self.assertTrue(tflite_model)
1181
1182    # Check tensor details of converted model.
1183    interpreter = Interpreter(model_content=tflite_model)
1184    interpreter.allocate_tensors()
1185
1186    input_details = interpreter.get_input_details()
1187    self.assertEqual(1, len(input_details))
1188    self.assertEqual('input', input_details[0]['name'])
1189    self.assertEqual(np.float32, input_details[0]['dtype'])
1190    self.assertTrue(([1, 3] == input_details[0]['shape']).all())
1191    self.assertEqual((0., 0.), input_details[0]['quantization'])
1192
1193    output_details = interpreter.get_output_details()
1194    self.assertEqual(1, len(output_details))
1195    self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
1196    self.assertEqual(np.float32, output_details[0]['dtype'])
1197    self.assertTrue(([1, 3] == output_details[0]['shape']).all())
1198    self.assertEqual((0., 0.), output_details[0]['quantization'])
1199
1200    # Check inference of converted model.
1201    input_data = np.array([[1, 2, 3]], dtype=np.float32)
1202    interpreter.set_tensor(input_details[0]['index'], input_data)
1203    interpreter.invoke()
1204    tflite_result = interpreter.get_tensor(output_details[0]['index'])
1205
1206    keras_model = keras.models.load_model(keras_file)
1207    keras_result = keras_model.predict(input_data)
1208
1209    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
1210    os.remove(keras_file)
1211
1212  def testFunctionalModelMultipleInputs(self):
1213    """Test a Functional tf.keras model with multiple inputs and outputs."""
1214    with session.Session().as_default():
1215      a = keras.layers.Input(shape=(3,), name='input_a')
1216      b = keras.layers.Input(shape=(3,), name='input_b')
1217      dense = keras.layers.Dense(4, name='dense')
1218      c = dense(a)
1219      d = dense(b)
1220      e = keras.layers.Dropout(0.5, name='dropout')(c)
1221
1222      model = keras.models.Model([a, b], [d, e])
1223      model.compile(
1224          loss=keras.losses.MSE,
1225          optimizer=keras.optimizers.RMSprop(),
1226          metrics=[keras.metrics.mae],
1227          loss_weights=[1., 0.5])
1228
1229      input_a_np = np.random.random((10, 3))
1230      input_b_np = np.random.random((10, 3))
1231      output_d_np = np.random.random((10, 4))
1232      output_e_np = np.random.random((10, 4))
1233      model.train_on_batch([input_a_np, input_b_np], [output_d_np, output_e_np])
1234
1235      model.predict([input_a_np, input_b_np], batch_size=5)
1236      fd, keras_file = tempfile.mkstemp('.h5')
1237      try:
1238        keras.models.save_model(model, keras_file)
1239      finally:
1240        os.close(fd)
1241
1242    # Convert to TFLite model.
1243    converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
1244    tflite_model = converter.convert()
1245    self.assertTrue(tflite_model)
1246
1247    os.remove(keras_file)
1248
1249    # Check values from converted model.
1250    interpreter = Interpreter(model_content=tflite_model)
1251    interpreter.allocate_tensors()
1252
1253    input_details = interpreter.get_input_details()
1254    self.assertEqual(2, len(input_details))
1255    self.assertEqual('input_a', input_details[0]['name'])
1256    self.assertEqual(np.float32, input_details[0]['dtype'])
1257    self.assertTrue(([1, 3] == input_details[0]['shape']).all())
1258    self.assertEqual((0., 0.), input_details[0]['quantization'])
1259
1260    self.assertEqual('input_b', input_details[1]['name'])
1261    self.assertEqual(np.float32, input_details[1]['dtype'])
1262    self.assertTrue(([1, 3] == input_details[1]['shape']).all())
1263    self.assertEqual((0., 0.), input_details[1]['quantization'])
1264
1265    output_details = interpreter.get_output_details()
1266    self.assertEqual(2, len(output_details))
1267    self.assertEqual('dense_1/BiasAdd', output_details[0]['name'])
1268    self.assertEqual(np.float32, output_details[0]['dtype'])
1269    self.assertTrue(([1, 4] == output_details[0]['shape']).all())
1270    self.assertEqual((0., 0.), output_details[0]['quantization'])
1271
1272    self.assertEqual('dropout/Identity', output_details[1]['name'])
1273    self.assertEqual(np.float32, output_details[1]['dtype'])
1274    self.assertTrue(([1, 4] == output_details[1]['shape']).all())
1275    self.assertEqual((0., 0.), output_details[1]['quantization'])
1276
1277  def testFunctionalSequentialModel(self):
1278    """Test a Functional tf.keras model containing a Sequential model."""
1279    with session.Session().as_default():
1280      model = keras.models.Sequential()
1281      model.add(keras.layers.Dense(2, input_shape=(3,)))
1282      model.add(keras.layers.RepeatVector(3))
1283      model.add(keras.layers.TimeDistributed(keras.layers.Dense(3)))
1284      model = keras.models.Model(model.input, model.output)
1285
1286      model.compile(
1287          loss=keras.losses.MSE,
1288          optimizer=keras.optimizers.RMSprop(),
1289          metrics=[keras.metrics.categorical_accuracy],
1290          sample_weight_mode='temporal')
1291      x = np.random.random((1, 3))
1292      y = np.random.random((1, 3, 3))
1293      model.train_on_batch(x, y)
1294      model.predict(x)
1295
1296      model.predict(x)
1297      fd, keras_file = tempfile.mkstemp('.h5')
1298      try:
1299        keras.models.save_model(model, keras_file)
1300      finally:
1301        os.close(fd)
1302
1303    # Convert to TFLite model.
1304    converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
1305    tflite_model = converter.convert()
1306    self.assertTrue(tflite_model)
1307
1308    # Check tensor details of converted model.
1309    interpreter = Interpreter(model_content=tflite_model)
1310    interpreter.allocate_tensors()
1311
1312    input_details = interpreter.get_input_details()
1313    self.assertEqual(1, len(input_details))
1314    self.assertEqual('dense_input', input_details[0]['name'])
1315    self.assertEqual(np.float32, input_details[0]['dtype'])
1316    self.assertTrue(([1, 3] == input_details[0]['shape']).all())
1317    self.assertEqual((0., 0.), input_details[0]['quantization'])
1318
1319    output_details = interpreter.get_output_details()
1320    self.assertEqual(1, len(output_details))
1321    self.assertEqual('time_distributed/Reshape_1', output_details[0]['name'])
1322    self.assertEqual(np.float32, output_details[0]['dtype'])
1323    self.assertTrue(([1, 3, 3] == output_details[0]['shape']).all())
1324    self.assertEqual((0., 0.), output_details[0]['quantization'])
1325
1326    # Check inference of converted model.
1327    input_data = np.array([[1, 2, 3]], dtype=np.float32)
1328    interpreter.set_tensor(input_details[0]['index'], input_data)
1329    interpreter.invoke()
1330    tflite_result = interpreter.get_tensor(output_details[0]['index'])
1331
1332    keras_model = keras.models.load_model(keras_file)
1333    keras_result = keras_model.predict(input_data)
1334
1335    np.testing.assert_almost_equal(tflite_result, keras_result, 5)
1336    os.remove(keras_file)
1337
1338  def testSequentialModelTocoConverter(self):
1339    """Test a Sequential tf.keras model with deprecated TocoConverter."""
1340    keras_file = self._getSequentialModel()
1341
1342    converter = lite.TocoConverter.from_keras_model_file(keras_file)
1343    tflite_model = converter.convert()
1344    self.assertTrue(tflite_model)
1345
1346    # Ensure the model is able to load.
1347    interpreter = Interpreter(model_content=tflite_model)
1348    interpreter.allocate_tensors()
1349
1350
1351if __name__ == '__main__':
1352  test.main()
1353