1# Lint as: python2, python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for lite.py functionality related to TensorFlow 2.0."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from absl.testing import parameterized
25from six.moves import zip
26
27from tensorflow.lite.python.interpreter import Interpreter
28from tensorflow.python.eager import def_function
29from tensorflow.python.framework import test_util
30from tensorflow.python.ops import variables
31from tensorflow.python.training.tracking import tracking
32
33
34class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase):
35  """Base test class for TensorFlow Lite 2.x model tests."""
36
37  def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None):
38    """Evaluates the model on the `input_data`.
39
40    Args:
41      tflite_model: TensorFlow Lite model.
42      input_data: List of EagerTensor const ops containing the input data for
43        each input tensor.
44      input_shapes: List of tuples representing the `shape_signature` and the
45        new shape of each input tensor that has unknown dimensions.
46
47    Returns:
48      [np.ndarray]
49    """
50    interpreter = Interpreter(model_content=tflite_model)
51    input_details = interpreter.get_input_details()
52    if input_shapes:
53      for idx, (shape_signature, final_shape) in enumerate(input_shapes):
54        self.assertTrue(
55            (input_details[idx]['shape_signature'] == shape_signature).all())
56        index = input_details[idx]['index']
57        interpreter.resize_tensor_input(index, final_shape, strict=True)
58    interpreter.allocate_tensors()
59
60    output_details = interpreter.get_output_details()
61    input_details = interpreter.get_input_details()
62
63    for input_tensor, tensor_data in zip(input_details, input_data):
64      interpreter.set_tensor(input_tensor['index'], tensor_data.numpy())
65    interpreter.invoke()
66    return [
67        interpreter.get_tensor(details['index']) for details in output_details
68    ]
69
70  def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, method_name,
71                                            inputs):
72    """Evaluates the model on the `inputs`.
73
74    Args:
75      tflite_model: TensorFlow Lite model.
76      method_name: Exported Method name of the SavedModel.
77      inputs: Map from input tensor names in the SignatureDef to tensor value.
78
79    Returns:
80      Dictionary of outputs.
81      Key is the output name in the SignatureDef 'method_name'
82      Value is the output value
83    """
84    interpreter = Interpreter(model_content=tflite_model)
85    signature_runner = interpreter.get_signature_runner(method_name)
86    return signature_runner(**inputs)
87
88  def _getSimpleVariableModel(self):
89    root = tracking.AutoTrackable()
90    root.v1 = variables.Variable(3.)
91    root.v2 = variables.Variable(2.)
92    root.f = def_function.function(lambda x: root.v1 * root.v2 * x)
93    return root
94
95  def _getMultiFunctionModel(self):
96
97    class BasicModel(tracking.AutoTrackable):
98      """Basic model with multiple functions."""
99
100      def __init__(self):
101        self.y = None
102        self.z = None
103
104      @def_function.function
105      def add(self, x):
106        if self.y is None:
107          self.y = variables.Variable(2.)
108        return x + self.y
109
110      @def_function.function
111      def sub(self, x):
112        if self.z is None:
113          self.z = variables.Variable(3.)
114        return x - self.z
115
116      @def_function.function
117      def mul_add(self, x, y):
118        if self.z is None:
119          self.z = variables.Variable(3.)
120        return x * self.z + y
121
122    return BasicModel()
123
124  def _assertValidDebugInfo(self, debug_info):
125    """Verify the DebugInfo is valid."""
126    file_names = set()
127    for file_path in debug_info.files:
128      file_names.add(os.path.basename(file_path))
129    # To make the test independent on how the nodes are created, we only assert
130    # the name of this test file.
131    self.assertIn('lite_v2_test.py', file_names)
132    self.assertNotIn('lite_test.py', file_names)
133