1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utils for writing tests.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.platform import test 22from tensorflow.python.training import coordinator 23from tensorflow.python.training import queue_runner_impl 24 25 26class Base(test.TestCase): 27 """A class with some useful methods for testing.""" 28 29 def eval(self, tensors): 30 with self.cached_session() as sess: 31 coord = coordinator.Coordinator() 32 threads = queue_runner_impl.start_queue_runners(sess=sess, coord=coord) 33 34 try: 35 results = sess.run(tensors) 36 finally: 37 coord.request_stop() 38 coord.join(threads) 39 40 return results 41 42 def assertTensorsEqual(self, tensor_0, tensor_1): 43 [tensor_0_eval, tensor_1_eval] = self.eval([tensor_0, tensor_1]) 44 self.assertAllEqual(tensor_0_eval, tensor_1_eval) 45 46 def assertLabeledTensorsEqual(self, tensor_0, tensor_1): 47 self.assertEqual(tensor_0.axes, tensor_1.axes) 48 self.assertTensorsEqual(tensor_0.tensor, tensor_1.tensor) 49