1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for lite.py functionality related to select TF op usage.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.lite.python import lite 22from tensorflow.lite.python.interpreter import Interpreter 23from tensorflow.python.client import session 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import test_util 26from tensorflow.python.ops import array_ops 27from tensorflow.python.platform import test 28 29 30@test_util.run_v1_only('b/120545219') 31class FromSessionTest(test_util.TensorFlowTestCase): 32 33 def testFlexMode(self): 34 in_tensor = array_ops.placeholder( 35 shape=[1, 16, 16, 3], dtype=dtypes.float32) 36 out_tensor = in_tensor + in_tensor 37 sess = session.Session() 38 39 # Convert model and ensure model is not None. 40 converter = lite.TFLiteConverter.from_session(sess, [in_tensor], 41 [out_tensor]) 42 converter.target_ops = set([lite.OpsSet.SELECT_TF_OPS]) 43 tflite_model = converter.convert() 44 self.assertTrue(tflite_model) 45 46 # Ensures the model contains TensorFlow ops. 47 # TODO(nupurgarg): Check values once there is a Python delegate interface. 48 interpreter = Interpreter(model_content=tflite_model) 49 with self.assertRaises(RuntimeError) as error: 50 interpreter.allocate_tensors() 51 self.assertIn( 52 'Regular TensorFlow ops are not supported by this interpreter. Make ' 53 'sure you invoke the Flex delegate before inference.', 54 str(error.exception)) 55 56 57if __name__ == '__main__': 58 test.main() 59