1# Lint as: python2, python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for lite.py functionality related to TensorFlow 2.0.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24from absl.testing import parameterized 25from six.moves import zip 26 27from tensorflow.lite.python.interpreter import Interpreter 28from tensorflow.python.eager import def_function 29from tensorflow.python.framework import test_util 30from tensorflow.python.ops import variables 31from tensorflow.python.training.tracking import tracking 32 33 34class ModelTest(test_util.TensorFlowTestCase, parameterized.TestCase): 35 """Base test class for TensorFlow Lite 2.x model tests.""" 36 37 def _evaluateTFLiteModel(self, tflite_model, input_data, input_shapes=None): 38 """Evaluates the model on the `input_data`. 39 40 Args: 41 tflite_model: TensorFlow Lite model. 42 input_data: List of EagerTensor const ops containing the input data for 43 each input tensor. 44 input_shapes: List of tuples representing the `shape_signature` and the 45 new shape of each input tensor that has unknown dimensions. 46 47 Returns: 48 [np.ndarray] 49 """ 50 interpreter = Interpreter(model_content=tflite_model) 51 input_details = interpreter.get_input_details() 52 if input_shapes: 53 for idx, (shape_signature, final_shape) in enumerate(input_shapes): 54 self.assertTrue( 55 (input_details[idx]['shape_signature'] == shape_signature).all()) 56 index = input_details[idx]['index'] 57 interpreter.resize_tensor_input(index, final_shape, strict=True) 58 interpreter.allocate_tensors() 59 60 output_details = interpreter.get_output_details() 61 input_details = interpreter.get_input_details() 62 63 for input_tensor, tensor_data in zip(input_details, input_data): 64 interpreter.set_tensor(input_tensor['index'], tensor_data.numpy()) 65 interpreter.invoke() 66 return [ 67 interpreter.get_tensor(details['index']) for details in output_details 68 ] 69 70 def _evaluateTFLiteModelUsingSignatureDef(self, tflite_model, method_name, 71 inputs): 72 """Evaluates the model on the `inputs`. 73 74 Args: 75 tflite_model: TensorFlow Lite model. 76 method_name: Exported Method name of the SavedModel. 77 inputs: Map from input tensor names in the SignatureDef to tensor value. 78 79 Returns: 80 Dictionary of outputs. 81 Key is the output name in the SignatureDef 'method_name' 82 Value is the output value 83 """ 84 interpreter = Interpreter(model_content=tflite_model) 85 signature_runner = interpreter.get_signature_runner(method_name) 86 return signature_runner(**inputs) 87 88 def _getSimpleVariableModel(self): 89 root = tracking.AutoTrackable() 90 root.v1 = variables.Variable(3.) 91 root.v2 = variables.Variable(2.) 92 root.f = def_function.function(lambda x: root.v1 * root.v2 * x) 93 return root 94 95 def _getMultiFunctionModel(self): 96 97 class BasicModel(tracking.AutoTrackable): 98 """Basic model with multiple functions.""" 99 100 def __init__(self): 101 self.y = None 102 self.z = None 103 104 @def_function.function 105 def add(self, x): 106 if self.y is None: 107 self.y = variables.Variable(2.) 108 return x + self.y 109 110 @def_function.function 111 def sub(self, x): 112 if self.z is None: 113 self.z = variables.Variable(3.) 114 return x - self.z 115 116 @def_function.function 117 def mul_add(self, x, y): 118 if self.z is None: 119 self.z = variables.Variable(3.) 120 return x * self.z + y 121 122 return BasicModel() 123 124 def _assertValidDebugInfo(self, debug_info): 125 """Verify the DebugInfo is valid.""" 126 file_names = set() 127 for file_path in debug_info.files: 128 file_names.add(os.path.basename(file_path)) 129 # To make the test independent on how the nodes are created, we only assert 130 # the name of this test file. 131 self.assertIn('lite_v2_test.py', file_names) 132 self.assertNotIn('lite_test.py', file_names) 133