1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.python.client.session.Session.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import collections 21import os 22import random 23import sys 24import threading 25import time 26import warnings 27 28import numpy as np 29import six 30from six.moves import xrange # pylint: disable=redefined-builtin 31 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.core.lib.core import error_codes_pb2 34from tensorflow.core.protobuf import config_pb2 35from tensorflow.python.client import session 36from tensorflow.python.eager import context 37from tensorflow.python.eager import def_function 38from tensorflow.python.framework import config 39from tensorflow.python.framework import constant_op 40from tensorflow.python.framework import device as framework_device_lib 41from tensorflow.python.framework import dtypes 42from tensorflow.python.framework import errors 43from tensorflow.python.framework import function 44from tensorflow.python.framework import importer 45from tensorflow.python.framework import ops 46from tensorflow.python.framework import sparse_tensor 47from tensorflow.python.framework import tensor_util 48from tensorflow.python.framework import test_util 49from tensorflow.python.framework import versions 50from tensorflow.python.ops import array_ops 51from tensorflow.python.ops import control_flow_ops 52from tensorflow.python.ops import data_flow_ops 53from tensorflow.python.ops import gen_control_flow_ops 54# Import gradients to resolve circular imports 55from tensorflow.python.ops import gradients # pylint: disable=unused-import 56from tensorflow.python.ops import gradients_impl 57from tensorflow.python.ops import math_ops 58# Import resource_variable_ops for the variables-to-tensor implicit conversion. 59from tensorflow.python.ops import resource_variable_ops # pylint: disable=unused-import 60from tensorflow.python.ops import state_ops 61from tensorflow.python.ops import variables 62from tensorflow.python.platform import googletest 63from tensorflow.python.training import server_lib 64from tensorflow.python.util import compat 65 66try: 67 import attr # pylint:disable=g-import-not-at-top 68except ImportError: 69 attr = None 70 71try: 72 from frozendict import frozendict # pylint:disable=g-import-not-at-top 73except ImportError: 74 frozendict = dict # pylint:disable=invalid-name 75 76defaultdict = collections.defaultdict # pylint:disable=invalid-name 77 78 79class SessionTest(test_util.TensorFlowTestCase): 80 81 def setUp(self): 82 super(SessionTest, self).setUp() 83 warnings.simplefilter('always') 84 85 def testUseExistingGraph(self): 86 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 87 a = constant_op.constant(6.0, shape=[1, 1]) 88 b = constant_op.constant(7.0, shape=[1, 1]) 89 c = math_ops.matmul(a, b, name='matmul') 90 with session.Session(graph=g): 91 result = c.eval() 92 self.assertAllEqual(result, [[42.0]]) 93 94 def testUseDefaultGraph(self): 95 with ops.Graph().as_default(), ops.device('/cpu:0'): 96 a = constant_op.constant(6.0, shape=[1, 1]) 97 b = constant_op.constant(7.0, shape=[1, 1]) 98 c = math_ops.matmul(a, b, name='matmul') 99 with session.Session(): 100 result = c.eval() 101 self.assertAllEqual(result, [[42.0]]) 102 103 def testCreate(self): 104 with session.Session(): 105 inp = constant_op.constant(10.0, shape=[2, 3], name='W1') 106 copy = array_ops.identity(inp) 107 # Test with feed. 108 # TODO(mrry): Investigate why order='F' didn't work. 109 arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C') 110 copy_val = copy.eval({'W1:0': arr}) 111 self.assertAllEqual(arr, copy_val) 112 # Test without feed. 113 copy_val = copy.eval() 114 self.assertAllEqual( 115 np.asarray( 116 [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32), 117 copy_val) 118 119 def testManyCPUs(self): 120 with session.Session( 121 config=config_pb2.ConfigProto(device_count={ 122 'CPU': 2, 'GPU': 0 123 })) as sess: 124 inp = constant_op.constant(10.0, name='W1') 125 self.assertAllEqual(inp, 10.0) 126 127 num_cpu_devices = 0 128 num_gpu_devices = 0 129 for device in sess.list_devices(): 130 device_type = framework_device_lib.DeviceSpec.from_string( 131 device.name).device_type 132 if device_type == 'CPU': 133 num_cpu_devices += 1 134 elif device_type == 'GPU': 135 num_gpu_devices += 1 136 self.assertEqual(2, num_cpu_devices) 137 self.assertEqual(0, num_gpu_devices) 138 139 def testPerSessionThreads(self): 140 with session.Session( 141 config=config_pb2.ConfigProto(use_per_session_threads=True)): 142 inp = constant_op.constant(10.0, name='W1') 143 self.assertAllEqual(inp, 10.0) 144 145 def testSessionInterOpThreadPool(self): 146 config_pb = config_pb2.ConfigProto() 147 pool = config_pb.session_inter_op_thread_pool.add() 148 with session.Session(config=config_pb) as s: 149 inp = constant_op.constant(10.0, name='W1') 150 results = s.run([inp]) 151 self.assertAllEqual([10.0], results) 152 153 pool = config_pb.session_inter_op_thread_pool.add() 154 pool.num_threads = 1 155 with session.Session(config=config_pb) as s: 156 inp = constant_op.constant(20.0, name='W2') 157 results = s.run([inp]) 158 self.assertAllEqual([20.0], results) 159 160 pool = config_pb.session_inter_op_thread_pool.add() 161 pool.num_threads = 1 162 pool.global_name = 't1' 163 run_options = config_pb2.RunOptions() 164 run_options.inter_op_thread_pool = ( 165 len(config_pb.session_inter_op_thread_pool) - 1) 166 with session.Session(config=config_pb) as s: 167 inp = constant_op.constant(30.0, name='W2') 168 results = s.run([inp], options=run_options) 169 self.assertAllEqual([30.0], results) 170 171 def testErrorsReported(self): 172 with session.Session() as s: 173 constant_op.constant(10.0, name='W1') 174 with self.assertRaises(ValueError): 175 s.run('foo:0') 176 177 def testErrorPayload(self): 178 with session.Session(): 179 a = array_ops.placeholder(dtypes.float32) 180 with self.assertRaisesOpError(lambda e: e.op == a.op): 181 a.eval() 182 183 def testErrorCodeWithNoNodeDef(self): 184 with session.Session() as s: 185 a = array_ops.placeholder(dtypes.float32, shape=[]) 186 b = array_ops.placeholder(dtypes.float32, shape=[]) 187 r1 = math_ops.add(a, b) 188 189 def exc_predicate(e): 190 return (e.op is None and e.node_def is None and 191 e.error_code == error_codes_pb2.INVALID_ARGUMENT) 192 193 with self.assertRaisesOpError(exc_predicate): 194 # Run with a bogus handle. 195 s.partial_run('foo', r1, feed_dict={a: 1, b: 2}) 196 197 def testErrorBasedOn(self): 198 with session.Session() as sess: 199 a = constant_op.constant(0.0, shape=[2, 3]) 200 # NOTE(mrry): The original_op is nonsense, but used here to test that the 201 # errors are reported correctly. 202 with sess.graph._original_op(a.op): 203 b = array_ops.identity(a, name='id') 204 with sess.graph._original_op(b.op): 205 c = array_ops.placeholder(dtypes.float32) 206 207 def exc_predicate(e): 208 return (e.op == c.op and e.op._original_op == b.op and 209 e.op._original_op._original_op == a.op) 210 211 with self.assertRaisesOpError(exc_predicate): 212 c.eval() 213 214 def testFetchNone(self): 215 with session.Session() as s: 216 a = constant_op.constant(1.0) 217 with self.assertRaises(TypeError): 218 s.run(None) 219 with self.assertRaises(TypeError): 220 s.run([None]) 221 with self.assertRaises(TypeError): 222 s.run({'b': None}) 223 with self.assertRaises(TypeError): 224 s.run({'a': a, 'b': None}) 225 226 def testFetchSingleton(self): 227 with session.Session() as sess: 228 a = constant_op.constant(42.0) 229 res = sess.run(a) 230 self.assertEqual(42.0, res) 231 res = sess.run(a.op) # An op, not a tensor. 232 self.assertIsNone(res) 233 tensor_runner = sess.make_callable(a) 234 res = tensor_runner() 235 self.assertEqual(42.0, res) 236 op_runner = sess.make_callable(a.op) 237 res = op_runner() 238 self.assertIsNone(res) 239 240 def testFetchSingletonByName(self): 241 with session.Session() as sess: 242 a = constant_op.constant(42.0) 243 res = sess.run(a.name) 244 self.assertEqual(42.0, res) 245 res = sess.run(a.op) # An op, not a tensor. 246 self.assertIsNone(res) 247 248 def testFetchList(self): 249 with session.Session() as sess: 250 a = constant_op.constant(42.0) 251 b = control_flow_ops.no_op() # An op, not a tensor. 252 c = constant_op.constant(44.0) 253 v = variables.Variable([54.0]) 254 assign = v.assign([63.0]) 255 res = sess.run([a, b, c, a.name, assign.op]) 256 self.assertIsInstance(res, list) 257 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 258 list_runner = sess.make_callable([a, b, c, a.name, assign.op]) 259 res = list_runner() 260 self.assertIsInstance(res, list) 261 self.assertEqual([42.0, None, 44.0, 42.0, None], res) 262 263 def testFetchTuple(self): 264 with session.Session() as sess: 265 a = constant_op.constant(42.0) 266 b = control_flow_ops.no_op() # An op, not a tensor. 267 c = constant_op.constant(44.0) 268 res = sess.run((a, b, c, a.name)) 269 self.assertIsInstance(res, tuple) 270 self.assertEqual((42.0, None, 44.0, 42.0), res) 271 tuple_runner = sess.make_callable((a, b, c, a.name)) 272 res = tuple_runner() 273 self.assertIsInstance(res, tuple) 274 self.assertEqual((42.0, None, 44.0, 42.0), res) 275 276 def testFetchNamedTuple(self): 277 # pylint: disable=invalid-name 278 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 279 # pylint: enable=invalid-name 280 with session.Session() as sess: 281 a = constant_op.constant(42.0) 282 b = control_flow_ops.no_op() # An op, not a tensor. 283 c = constant_op.constant(44.0) 284 res = sess.run(ABC(a, b, c)) 285 self.assertIsInstance(res, ABC) 286 self.assertEqual(42.0, res.a) 287 self.assertIsNone(res.b) 288 self.assertEqual(44.0, res.c) 289 namedtuple_runner = sess.make_callable(ABC(a, b, c)) 290 res = namedtuple_runner() 291 self.assertIsInstance(res, ABC) 292 self.assertEqual(42.0, res.a) 293 self.assertIsNone(res.b) 294 self.assertEqual(44.0, res.c) 295 296 def testFetchDict(self): 297 with session.Session() as sess: 298 a = constant_op.constant(42.0) 299 b = control_flow_ops.no_op() # An op, not a tensor. 300 c = constant_op.constant(44.0) 301 res = sess.run({'a': a, 'b': b, 'c': c}) 302 self.assertIsInstance(res, dict) 303 self.assertEqual(42.0, res['a']) 304 self.assertIsNone(res['b']) 305 self.assertEqual(44.0, res['c']) 306 307 def testFetchOrderedDict(self): 308 with session.Session() as sess: 309 a = constant_op.constant(42.0) 310 b = control_flow_ops.no_op() # An op, not a tensor. 311 c = constant_op.constant(44.0) 312 res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)])) 313 self.assertIsInstance(res, collections.OrderedDict) 314 self.assertEqual([3, 2, 1], list(res.keys())) 315 self.assertEqual(42.0, res[3]) 316 self.assertIsNone(res[2]) 317 self.assertEqual(44.0, res[1]) 318 319 @test_util.run_v1_only('b/120545219') 320 def testFetchAttrs(self): 321 if attr is None: 322 self.skipTest('attr module is unavailable.') 323 324 @attr.s 325 class SampleAttr(object): 326 field1 = attr.ib() 327 field2 = attr.ib() 328 329 val1 = np.array([1.2, 3.4, 5.6]) 330 val2 = np.array([[1, 2], [4, 3]]) 331 val3 = np.array([10, 20, 30]) 332 333 t1 = constant_op.constant(val1) 334 t2 = constant_op.constant(val2) 335 336 sample = SampleAttr(t1, t2) 337 with session.Session() as sess: 338 result = sess.run(sample) 339 self.assertIsInstance(result, SampleAttr) 340 self.assertAllEqual(val1, result.field1) 341 self.assertAllEqual(val2, result.field2) 342 343 result = sess.run(sample, feed_dict={sample.field1: val3}) 344 self.assertIsInstance(result, SampleAttr) 345 self.assertAllEqual(val3, result.field1) 346 self.assertAllEqual(val2, result.field2) 347 348 @test_util.run_v1_only('b/120545219') 349 def testFetchNestedAttrs(self): 350 if attr is None: 351 self.skipTest('attr module is unavailable.') 352 353 @attr.s 354 class SampleAttr(object): 355 field0 = attr.ib() 356 field1 = attr.ib() 357 358 v1 = 10 359 v2 = 20 360 v3 = np.float32(1.2) 361 v4 = np.float32(3.4) 362 v5 = np.float64(100.001) 363 v6 = np.float64(-23.451) 364 arr1 = np.array([1.2, 6.7, 3.4]) 365 arr2 = np.array([7, 11, 3]) 366 sample = SampleAttr( 367 SampleAttr( 368 SampleAttr(constant_op.constant(v1), constant_op.constant(v2)), 369 SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))), 370 {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)), 371 'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]}) 372 373 with session.Session() as sess: 374 result = sess.run(sample) 375 self.assertIsInstance(result, SampleAttr) 376 self.assertIsInstance(result.field0, SampleAttr) 377 self.assertIsInstance(result.field0.field0, SampleAttr) 378 self.assertIsInstance(result.field0.field1, SampleAttr) 379 self.assertIsInstance(result.field0.field1.field0, np.ndarray) 380 self.assertAllEqual(arr1, result.field0.field1.field0) 381 self.assertIsInstance(result.field0.field1.field1, np.ndarray) 382 self.assertAllEqual(arr2, result.field0.field1.field1) 383 self.assertIsInstance(result.field1, dict) 384 self.assertIn('A', result.field1) 385 self.assertIn('B', result.field1) 386 self.assertIsInstance(result.field1['A'], SampleAttr) 387 self.assertAllEqual( 388 [v3, v4], 389 [result.field1['A'].field0, result.field1['A'].field1]) 390 self.assertIsInstance(result.field1['B'], list) 391 self.assertEqual(1, len(result.field1['B'])) 392 self.assertIsInstance(result.field1['B'][0], SampleAttr) 393 self.assertAllEqual( 394 [v5, v6], 395 [result.field1['B'][0].field0, result.field1['B'][0].field1]) 396 397 def testFetchNestingEmptyOneLevel(self): 398 with session.Session() as sess: 399 a_val = 11.0 400 a = constant_op.constant(a_val) 401 402 res = sess.run([[], tuple(), {}]) 403 self.assertIsInstance(res, list) 404 self.assertEqual(3, len(res)) 405 self.assertIsInstance(res[0], list) 406 self.assertEqual(0, len(res[0])) 407 self.assertIsInstance(res[1], tuple) 408 self.assertEqual(0, len(res[1])) 409 self.assertIsInstance(res[2], dict) 410 self.assertEqual(0, len(res[2])) 411 412 res = sess.run([[], tuple(), {}, a]) 413 self.assertIsInstance(res, list) 414 self.assertEqual(4, len(res)) 415 self.assertIsInstance(res[0], list) 416 self.assertEqual(0, len(res[0])) 417 self.assertIsInstance(res[1], tuple) 418 self.assertEqual(0, len(res[1])) 419 self.assertIsInstance(res[2], dict) 420 self.assertEqual(0, len(res[2])) 421 self.assertEqual(a_val, res[3]) 422 423 def testFetchNestingOneLevel(self): 424 with session.Session() as sess: 425 # pylint: disable=invalid-name 426 ABC = collections.namedtuple('ABC', ['a', 'b', 'c']) 427 DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i']) 428 # pylint: enable=invalid-name 429 a_val = 42.0 430 b_val = None 431 c_val = 44.0 432 a = constant_op.constant(a_val) 433 b = control_flow_ops.no_op() # An op, not a tensor. 434 c = constant_op.constant(c_val) 435 test_dct = {'a': a.name, 'c': c, 'b': b} 436 test_dct_types = [dict, frozendict, defaultdict] 437 # List of lists, tuples, namedtuple, dict, frozendict, and defaultdict 438 res = sess.run([ 439 [a, b, c], 440 (a, b, c), 441 ABC(a=a, b=b, c=c), 442 dict(test_dct), 443 frozendict(test_dct), 444 defaultdict(str, test_dct), 445 ]) 446 self.assertIsInstance(res, list) 447 self.assertEqual(6, len(res)) 448 self.assertIsInstance(res[0], list) 449 self.assertEqual(3, len(res[0])) 450 self.assertEqual(a_val, res[0][0]) 451 self.assertEqual(b_val, res[0][1]) 452 self.assertEqual(c_val, res[0][2]) 453 self.assertIsInstance(res[1], tuple) 454 self.assertEqual(3, len(res[1])) 455 self.assertEqual(a_val, res[1][0]) 456 self.assertEqual(b_val, res[1][1]) 457 self.assertEqual(c_val, res[1][2]) 458 self.assertIsInstance(res[2], ABC) 459 self.assertEqual(a_val, res[2].a) 460 self.assertEqual(b_val, res[2].b) 461 self.assertEqual(c_val, res[2].c) 462 for expected_type, r in zip(test_dct_types, res[3:]): 463 self.assertIsInstance(r, expected_type) 464 self.assertEqual(3, len(r)) 465 self.assertEqual(a_val, r['a']) 466 self.assertEqual(b_val, r['b']) 467 self.assertEqual(c_val, r['c']) 468 self.assertEqual(res[5].default_factory, str) 469 # Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict 470 res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b, 471 c=c), dict(test_dct), 472 frozendict(test_dct), defaultdict(str, test_dct))) 473 self.assertIsInstance(res, tuple) 474 self.assertEqual(6, len(res)) 475 self.assertIsInstance(res[0], list) 476 self.assertEqual(3, len(res[0])) 477 self.assertEqual(a_val, res[0][0]) 478 self.assertEqual(b_val, res[0][1]) 479 self.assertEqual(c_val, res[0][2]) 480 self.assertIsInstance(res[1], tuple) 481 self.assertEqual(3, len(res[1])) 482 self.assertEqual(a_val, res[1][0]) 483 self.assertEqual(b_val, res[1][1]) 484 self.assertEqual(c_val, res[1][2]) 485 self.assertIsInstance(res[2], ABC) 486 self.assertEqual(a_val, res[2].a) 487 self.assertEqual(b_val, res[2].b) 488 self.assertEqual(c_val, res[2].c) 489 for expected_type, r in zip(test_dct_types, res[3:]): 490 self.assertIsInstance(r, expected_type) 491 self.assertEqual(3, len(r)) 492 self.assertEqual(a_val, r['a']) 493 self.assertEqual(b_val, r['b']) 494 self.assertEqual(c_val, r['c']) 495 self.assertEqual(res[5].default_factory, str) 496 497 # Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict 498 res = sess.run( 499 DEFGHI( 500 d=[a, b, c], 501 e=(a, b, c), 502 f=ABC(a=a.name, b=b, c=c), 503 g=dict(test_dct), 504 h=frozendict(test_dct), 505 i=defaultdict(str, test_dct))) 506 self.assertIsInstance(res, DEFGHI) 507 self.assertIsInstance(res.d, list) 508 self.assertEqual(3, len(res.d)) 509 self.assertEqual(a_val, res.d[0]) 510 self.assertEqual(b_val, res.d[1]) 511 self.assertEqual(c_val, res.d[2]) 512 self.assertIsInstance(res.e, tuple) 513 self.assertEqual(3, len(res.e)) 514 self.assertEqual(a_val, res.e[0]) 515 self.assertEqual(b_val, res.e[1]) 516 self.assertEqual(c_val, res.e[2]) 517 self.assertIsInstance(res.f, ABC) 518 self.assertEqual(a_val, res.f.a) 519 self.assertEqual(b_val, res.f.b) 520 self.assertEqual(c_val, res.f.c) 521 self.assertIsInstance(res.g, dict) 522 self.assertEqual(3, len(res.g)) 523 self.assertEqual(a_val, res.g['a']) 524 self.assertEqual(b_val, res.g['b']) 525 self.assertEqual(c_val, res.g['c']) 526 self.assertIsInstance(res.h, frozendict) 527 self.assertEqual(3, len(res.h)) 528 self.assertEqual(a_val, res.h['a']) 529 self.assertEqual(b_val, res.h['b']) 530 self.assertEqual(c_val, res.h['c']) 531 self.assertIsInstance(res.i, defaultdict) 532 self.assertEqual(3, len(res.i)) 533 self.assertEqual(a_val, res.i['a']) 534 self.assertEqual(b_val, res.i['b']) 535 self.assertEqual(c_val, res.i['c']) 536 self.assertEqual(res.i.default_factory, str) 537 # Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict 538 res = sess.run({ 539 'd': [a, b, c], 540 'e': (a, b, c), 541 'f': ABC(a=a, b=b, c=c), 542 'g': dict(test_dct), 543 'h': frozendict(test_dct), 544 'i': defaultdict(str, test_dct), 545 }) 546 self.assertIsInstance(res, dict) 547 self.assertEqual(6, len(res)) 548 self.assertIsInstance(res['d'], list) 549 self.assertEqual(3, len(res['d'])) 550 self.assertEqual(a_val, res['d'][0]) 551 self.assertEqual(b_val, res['d'][1]) 552 self.assertEqual(c_val, res['d'][2]) 553 self.assertIsInstance(res['e'], tuple) 554 self.assertEqual(3, len(res['e'])) 555 self.assertEqual(a_val, res['e'][0]) 556 self.assertEqual(b_val, res['e'][1]) 557 self.assertEqual(c_val, res['e'][2]) 558 self.assertIsInstance(res['f'], ABC) 559 self.assertEqual(a_val, res['f'].a) 560 self.assertEqual(b_val, res['f'].b) 561 self.assertEqual(c_val, res['f'].c) 562 for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')): 563 r = res[r_key] 564 self.assertIsInstance(r, expected_type) 565 self.assertEqual(3, len(r)) 566 self.assertEqual(a_val, r['a']) 567 self.assertEqual(b_val, r['b']) 568 self.assertEqual(c_val, r['c']) 569 self.assertEqual(res['i'].default_factory, str) 570 571 def testFetchTensorObject(self): 572 with session.Session() as s: 573 a = constant_op.constant(1.0, shape=[1, 2]) 574 b = constant_op.constant(2.0, shape=[2, 3]) 575 c = math_ops.matmul(a, b) 576 results_with_list = s.run([c]) 577 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0]) 578 results_with_single = s.run(c) 579 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single) 580 results_with_get = c.eval() 581 self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get) 582 a_val, b_val = s.run([a, b]) # Test multiple fetches. 583 self.assertAllEqual([[1.0, 1.0]], a_val) 584 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val) 585 results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]}) 586 self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0]) 587 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 588 results_with_dict['b']) 589 self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0]) 590 self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1]) 591 592 # Test nested structures 593 results_with_nested_list = s.run([[[a, b], b], a, [a, b]]) 594 self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0]) 595 self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], 596 results_with_nested_list[0][0][1]) 597 self.assertAllEqual(results_with_nested_list[0][0][0], 598 results_with_nested_list[1]) 599 self.assertAllEqual(results_with_nested_list[1], 600 results_with_nested_list[2][0]) 601 self.assertAllEqual(results_with_nested_list[0][0][1], 602 results_with_nested_list[0][1]) 603 self.assertAllEqual(results_with_nested_list[0][1], 604 results_with_nested_list[2][1]) 605 606 def testFetchScalar(self): 607 with session.Session() as s: 608 for scalar in np.int32, np.int64, np.float16, np.float32, np.float64: 609 x = scalar(7) 610 y = scalar(8) 611 tf_x = constant_op.constant(x, shape=[]) 612 tf_y = constant_op.constant(y) 613 tf_xy = math_ops.add(tf_x, tf_y) 614 # Single fetch 615 xy = s.run(tf_xy) 616 self.assertEqual(scalar, type(xy)) 617 self.assertEqual(x + y, xy) 618 # List fetch 619 xy, = s.run([tf_xy]) 620 self.assertEqual(scalar, type(xy)) 621 self.assertEqual(x + y, xy) 622 # Dict fetch 623 xy = s.run({'xy': tf_xy})['xy'] 624 self.assertEqual(scalar, type(xy)) 625 self.assertEqual(x + y, xy) 626 # Nested list fetch 627 xy = s.run([[[tf_xy]], tf_xy, [tf_xy]]) 628 self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]]) 629 self.assertEqual(scalar, type(xy[0][0][0])) 630 self.assertEqual(scalar, type(xy[1])) 631 self.assertEqual(scalar, type(xy[2][0])) 632 633 def testFetchOperationObject(self): 634 with session.Session() as s: 635 a = constant_op.constant(1.0, shape=[1, 2]) 636 v = variables.Variable(a, name='testFetchOperationObject_v') 637 s.run(v.initializer) 638 v_val = s.run(v) 639 self.assertAllEqual([[1.0, 1.0]], v_val) 640 641 def testFetchSparseTensor(self): 642 with session.Session() as s: 643 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 644 values = np.array([1.0, 2.0]).astype(np.float32) 645 shape = np.array([7, 9, 2]).astype(np.int64) 646 sp = sparse_tensor.SparseTensor( 647 constant_op.constant(indices), constant_op.constant(values), 648 constant_op.constant(shape)) 649 # Single fetch, use as tuple 650 sp_out = s.run(sp) 651 indices_out, values_out, shape_out = sp_out 652 self.assertAllEqual(indices_out, indices) 653 self.assertAllEqual(values_out, values) 654 self.assertAllEqual(shape_out, shape) 655 # Single fetch, use as SparseTensorValue 656 sp_out = s.run(sp) 657 self.assertAllEqual(sp_out.indices, indices) 658 self.assertAllEqual(sp_out.values, values) 659 self.assertAllEqual(sp_out.dense_shape, shape) 660 # Tuple fetch, use as tuple 661 indices_out, values_out, shape_out = s.run(sp) 662 self.assertAllEqual(indices_out, indices) 663 self.assertAllEqual(values_out, values) 664 self.assertAllEqual(shape_out, shape) 665 # List fetch, use as tuple 666 (indices_out, values_out, shape_out), = s.run([sp]) 667 self.assertAllEqual(indices_out, indices) 668 self.assertAllEqual(values_out, values) 669 self.assertAllEqual(shape_out, shape) 670 # List fetch, use as SparseTensorValue 671 sp_out, = s.run([sp]) 672 self.assertAllEqual(sp_out.indices, indices) 673 self.assertAllEqual(sp_out.values, values) 674 self.assertAllEqual(sp_out.dense_shape, shape) 675 # Dict fetch (single value), use as tuple 676 indices_out, values_out, shape_out = s.run({'sp': sp})['sp'] 677 self.assertAllEqual(indices_out, indices) 678 self.assertAllEqual(values_out, values) 679 self.assertAllEqual(shape_out, shape) 680 # Dict fetch (list value), use as tuple 681 (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp'] 682 self.assertAllEqual(indices_out, indices) 683 self.assertAllEqual(values_out, values) 684 self.assertAllEqual(shape_out, shape) 685 # Dict fetch, use as SparseTensorValue 686 sp_out = s.run({'sp': sp})['sp'] 687 self.assertAllEqual(sp_out.indices, indices) 688 self.assertAllEqual(sp_out.values, values) 689 self.assertAllEqual(sp_out.dense_shape, shape) 690 # Nested list fetch use as tuple 691 sp_out = s.run([[[sp]], sp]) 692 indices_out, values_out, shape_out = sp_out[0][0][0] 693 self.assertAllEqual(indices_out, indices) 694 self.assertAllEqual(values_out, values) 695 self.assertAllEqual(shape_out, shape) 696 indices_out, values_out, shape_out = sp_out[1] 697 self.assertAllEqual(indices_out, indices) 698 self.assertAllEqual(values_out, values) 699 self.assertAllEqual(shape_out, shape) 700 # Nested list fetch, use as SparseTensorValue 701 sp_out = s.run([[[sp]], sp]) 702 self.assertAllEqual(sp_out[0][0][0].indices, indices) 703 self.assertAllEqual(sp_out[0][0][0].values, values) 704 self.assertAllEqual(sp_out[0][0][0].dense_shape, shape) 705 self.assertAllEqual(sp_out[1].indices, indices) 706 self.assertAllEqual(sp_out[1].values, values) 707 self.assertAllEqual(sp_out[1].dense_shape, shape) 708 709 def testFeedSparseTensor(self): 710 with session.Session() as s: 711 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 712 values = np.array([1.0, 2.0]).astype(np.float32) 713 shape = np.array([7, 9, 2]).astype(np.int64) 714 sp = sparse_tensor.SparseTensor( 715 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 716 array_ops.placeholder(dtype=np.float32, shape=(2,)), 717 array_ops.placeholder(dtype=np.int64, shape=(3,)), 718 ) 719 sp_indices = array_ops.identity(sp.indices) 720 sp_values = array_ops.identity(sp.values) 721 sp_shape = array_ops.identity(sp.dense_shape) 722 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 723 # Feed with tuple 724 indices_out, values_out, shape_out = s.run( 725 [sp_indices, sp_values, sp_shape], { 726 sp: (indices, values, shape) 727 }) 728 self.assertAllEqual(indices_out, indices) 729 self.assertAllEqual(values_out, values) 730 self.assertAllEqual(shape_out, shape) 731 # Feed with tuple, fetch sp directly 732 sp_out = s.run(sp, {sp: (indices, values, shape)}) 733 self.assertAllEqual(sp_out.indices, indices) 734 self.assertAllEqual(sp_out.values, values) 735 self.assertAllEqual(sp_out.dense_shape, shape) 736 # Feed with SparseTensorValue 737 indices_out, values_out, shape_out = s.run( 738 [sp_indices, sp_values, sp_shape], { 739 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 740 }) 741 self.assertAllEqual(indices_out, indices) 742 self.assertAllEqual(values_out, values) 743 self.assertAllEqual(shape_out, shape) 744 # Feed with SparseTensorValue, fetch SparseTensorValue 745 sp2_out = s.run(sp2, { 746 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 747 }) 748 self.assertAllEqual(sp2_out.indices, indices) 749 self.assertAllEqual(sp2_out.values, values) 750 self.assertAllEqual(sp2_out.dense_shape, shape) 751 # Feed SparseTensorValue and fetch sp directly. 752 sp_out = s.run(sp, { 753 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 754 }) 755 self.assertAllEqual(sp_out.indices, indices) 756 self.assertAllEqual(sp_out.values, values) 757 self.assertAllEqual(sp_out.dense_shape, shape) 758 759 def testFeedSparsePlaceholder(self): 760 with session.Session() as s: 761 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 762 values = np.array([1.0, 2.0]).astype(np.float32) 763 shape = np.array([7, 9, 2]).astype(np.int64) 764 sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1') 765 sp_indices = array_ops.identity(sp.indices) 766 sp_values = array_ops.identity(sp.values) 767 sp_shape = array_ops.identity(sp.dense_shape) 768 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 769 # Feed with tuple 770 indices_out, values_out, shape_out = s.run( 771 [sp_indices, sp_values, sp_shape], { 772 sp: (indices, values, shape) 773 }) 774 self.assertAllEqual(indices_out, indices) 775 self.assertAllEqual(values_out, values) 776 self.assertAllEqual(shape_out, shape) 777 # Feed with SparseTensorValue 778 indices_out, values_out, shape_out = s.run( 779 [sp_indices, sp_values, sp_shape], { 780 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 781 }) 782 self.assertAllEqual(indices_out, indices) 783 self.assertAllEqual(values_out, values) 784 self.assertAllEqual(shape_out, shape) 785 # Feed with SparseTensorValue, fetch SparseTensorValue 786 sp2_out = s.run(sp2, { 787 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 788 }) 789 self.assertAllEqual(sp2_out.indices, indices) 790 self.assertAllEqual(sp2_out.values, values) 791 self.assertAllEqual(sp2_out.dense_shape, shape) 792 793 def testFeedSparsePlaceholderPartialShape(self): 794 with session.Session() as s: 795 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 796 values = np.array([1.0, 2.0]).astype(np.float32) 797 shape = np.array([7, 9, 2]).astype(np.int64) 798 sp = array_ops.sparse_placeholder( 799 shape=[None, 9, 2], dtype=np.float32, name='placeholder1') 800 sp_indices = array_ops.identity(sp.indices) 801 sp_values = array_ops.identity(sp.values) 802 sp_shape = array_ops.identity(sp.dense_shape) 803 sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape) 804 # Feed with tuple 805 indices_out, values_out, shape_out = s.run( 806 [sp_indices, sp_values, sp_shape], { 807 sp: (indices, values, shape) 808 }) 809 self.assertAllEqual(indices_out, indices) 810 self.assertAllEqual(values_out, values) 811 self.assertAllEqual(shape_out, shape) 812 # Feed with SparseTensorValue 813 indices_out, values_out, shape_out = s.run( 814 [sp_indices, sp_values, sp_shape], { 815 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 816 }) 817 self.assertAllEqual(indices_out, indices) 818 self.assertAllEqual(values_out, values) 819 self.assertAllEqual(shape_out, shape) 820 # Feed with SparseTensorValue, fetch SparseTensorValue 821 sp2_out = s.run(sp2, { 822 sp: sparse_tensor.SparseTensorValue(indices, values, shape) 823 }) 824 self.assertAllEqual(sp2_out.indices, indices) 825 self.assertAllEqual(sp2_out.values, values) 826 self.assertAllEqual(sp2_out.dense_shape, shape) 827 828 def testFeedSparsePlaceholderConstantShape(self): 829 with session.Session() as s: 830 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 831 values = np.array([1.0, 2.0]).astype(np.float32) 832 shape = np.array([7, 9, 2]).astype(np.int64) 833 sp = array_ops.sparse_placeholder( 834 dtype=np.float32, shape=shape, name='placeholder1') 835 self.assertAllEqual(sp.dense_shape.eval(session=s), shape) 836 self.assertAllEqual(tensor_util.constant_value(sp.shape), shape) 837 sp_indices = array_ops.identity(sp.indices) 838 sp_values = array_ops.identity(sp.values) 839 sp_shape = array_ops.identity(sp.dense_shape) 840 # Feed with tuple 841 indices_out, values_out, shape_out = s.run( 842 [sp_indices, sp_values, sp_shape], { 843 sp: (indices, values) 844 }) 845 self.assertAllEqual(indices_out, indices) 846 self.assertAllEqual(values_out, values) 847 self.assertAllEqual(shape_out, shape) 848 849 def testFetchIndexedSlices(self): 850 with session.Session() as s: 851 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 852 values = np.array([1.0, 2.0]).astype(np.float32) 853 dense_shape = np.array([7, 9, 2]).astype(np.int64) 854 ind = ops.IndexedSlices( 855 constant_op.constant(values), constant_op.constant(indices), 856 constant_op.constant(dense_shape)) 857 # Single fetch, use as tuple 858 ind_out = s.run(ind) 859 values_out, indices_out, dense_shape_out = ind_out 860 self.assertAllEqual(values_out, values) 861 self.assertAllEqual(indices_out, indices) 862 self.assertAllEqual(dense_shape_out, dense_shape) 863 # Single fetch, use as IndexedSlicesValue 864 ind_out = s.run(ind) 865 self.assertAllEqual(ind_out.values, values) 866 self.assertAllEqual(ind_out.indices, indices) 867 self.assertAllEqual(ind_out.dense_shape, dense_shape) 868 # Tuple fetch, use as tuple 869 values_out, indices_out, dense_shape_out = s.run(ind) 870 self.assertAllEqual(values_out, values) 871 self.assertAllEqual(indices_out, indices) 872 self.assertAllEqual(dense_shape_out, dense_shape) 873 # List fetch, use as tuple 874 (values_out, indices_out, dense_shape_out), = s.run([ind]) 875 self.assertAllEqual(values_out, values) 876 self.assertAllEqual(indices_out, indices) 877 self.assertAllEqual(dense_shape_out, dense_shape) 878 # List fetch, use as IndexedSlicesValue 879 ind_out, = s.run([ind]) 880 self.assertAllEqual(ind_out.values, values) 881 self.assertAllEqual(ind_out.indices, indices) 882 self.assertAllEqual(ind_out.dense_shape, dense_shape) 883 884 def testFeedIndexedSlices(self): 885 with session.Session() as s: 886 values = np.array([1.0, 2.0]).astype(np.float32) 887 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 888 dense_shape = np.array([7, 9, 2]).astype(np.int64) 889 ind = ops.IndexedSlices( 890 array_ops.placeholder(dtype=np.float32, shape=(2,)), 891 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), 892 array_ops.placeholder(dtype=np.int64, shape=(3,)), 893 ) 894 ind_values = array_ops.identity(ind.values) 895 ind_indices = array_ops.identity(ind.indices) 896 ind_dense_shape = array_ops.identity(ind.dense_shape) 897 ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape) 898 # Feed with tuple 899 values_out, indices_out, dense_shape_out = s.run( 900 [ind_values, ind_indices, ind_dense_shape], { 901 ind: (values, indices, dense_shape) 902 }) 903 self.assertAllEqual(values_out, values) 904 self.assertAllEqual(indices_out, indices) 905 self.assertAllEqual(dense_shape_out, dense_shape) 906 # Feed with IndexedSlicesValue 907 values_out, indices_out, dense_shape_out = s.run( 908 [ind_values, ind_indices, ind_dense_shape], { 909 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 910 }) 911 self.assertAllEqual(values_out, values) 912 self.assertAllEqual(indices_out, indices) 913 self.assertAllEqual(dense_shape_out, dense_shape) 914 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 915 ind2_out = s.run(ind2, { 916 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 917 }) 918 self.assertAllEqual(ind2_out.values, values) 919 self.assertAllEqual(ind2_out.indices, indices) 920 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 921 922 def testFetchIndexedSlicesWithoutDenseShape(self): 923 with session.Session() as s: 924 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 925 values = np.array([1.0, 2.0]).astype(np.float32) 926 dense_shape = None 927 ind = ops.IndexedSlices( 928 constant_op.constant(values), constant_op.constant(indices), None) 929 # Single fetch, use as tuple 930 ind_out = s.run(ind) 931 values_out, indices_out, dense_shape_out = ind_out 932 self.assertAllEqual(values_out, values) 933 self.assertAllEqual(indices_out, indices) 934 self.assertAllEqual(dense_shape_out, dense_shape) 935 # Single fetch, use as IndexedSlicesValue 936 ind_out = s.run(ind) 937 self.assertAllEqual(ind_out.values, values) 938 self.assertAllEqual(ind_out.indices, indices) 939 self.assertAllEqual(ind_out.dense_shape, dense_shape) 940 # Tuple fetch, use as tuple 941 values_out, indices_out, dense_shape_out = s.run(ind) 942 self.assertAllEqual(values_out, values) 943 self.assertAllEqual(indices_out, indices) 944 self.assertAllEqual(dense_shape_out, dense_shape) 945 # List fetch, use as tuple 946 (values_out, indices_out, dense_shape_out), = s.run([ind]) 947 self.assertAllEqual(values_out, values) 948 self.assertAllEqual(indices_out, indices) 949 self.assertAllEqual(dense_shape_out, dense_shape) 950 # List fetch, use as IndexedSlicesValue 951 ind_out, = s.run([ind]) 952 self.assertAllEqual(ind_out.values, values) 953 self.assertAllEqual(ind_out.indices, indices) 954 self.assertAllEqual(ind_out.dense_shape, dense_shape) 955 956 def testFeedIndexedSlicesWithoutDenseShape(self): 957 with session.Session() as s: 958 values = np.array([1.0, 2.0]).astype(np.float32) 959 indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64) 960 dense_shape = None 961 ind = ops.IndexedSlices( 962 array_ops.placeholder(dtype=np.float32, shape=(2,)), 963 array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None) 964 ind_values = array_ops.identity(ind.values) 965 ind_indices = array_ops.identity(ind.indices) 966 ind2 = ops.IndexedSlices(ind_values, ind_indices) 967 # Feed with tuple 968 values_out, indices_out = s.run([ind_values, ind_indices], { 969 ind: (values, indices) 970 }) 971 self.assertAllEqual(values_out, values) 972 self.assertAllEqual(indices_out, indices) 973 # Feed with IndexedSlicesValue 974 values_out, indices_out = s.run([ind_values, ind_indices], { 975 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 976 }) 977 self.assertAllEqual(values_out, values) 978 self.assertAllEqual(indices_out, indices) 979 # Feed with IndexedSlicesValue, fetch IndexedSlicesValue 980 ind2_out = s.run(ind2, { 981 ind: ops.IndexedSlicesValue(values, indices, dense_shape) 982 }) 983 self.assertAllEqual(ind2_out.values, values) 984 self.assertAllEqual(ind2_out.indices, indices) 985 self.assertAllEqual(ind2_out.dense_shape, dense_shape) 986 987 def testExtendWithStatelessOperations(self): 988 with session.Session() as s: 989 a = constant_op.constant(1.0, shape=[1, 2]) 990 b = constant_op.constant(2.0, shape=[2, 3]) 991 c = math_ops.matmul(a, b) 992 c_val = s.run(c) 993 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 994 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 995 e = math_ops.matmul(c, d) 996 # Extend will happen here. 997 e_val = s.run(e) 998 self.assertAllEqual([[24.0]], e_val) 999 1000 def testExtendWithStatefulOperations(self): 1001 with session.Session() as s: 1002 a = constant_op.constant(1.0, shape=[1, 2]) 1003 b = constant_op.constant(2.0, shape=[2, 3]) 1004 c = math_ops.matmul(a, b) 1005 v = variables.Variable(c, name='testExtendWithStatefulOperations_v') 1006 v.initializer.run() 1007 v_val = v.eval() 1008 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1009 d = constant_op.constant(3.0, shape=[2, 3]) 1010 e = math_ops.matmul(a, d) 1011 assign_e_to_v = state_ops.assign(v, e) 1012 # Extend will happen here. 1013 e_val = e.eval() 1014 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1015 v_val = v.eval() 1016 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1017 s.run(assign_e_to_v) 1018 v_val = v.eval() 1019 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1020 1021 def testExtendWithGroupBy(self): 1022 with session.Session() as s: 1023 a = constant_op.constant(1.0, shape=[1, 2]) 1024 p = variables.Variable(a, name='testExtendWithGroupBy_p') 1025 a_val = a.eval() # Force an Extend after this op. 1026 self.assertAllEqual([[1.0, 1.0]], a_val) 1027 1028 b = constant_op.constant(2.0, shape=[1, 2]) 1029 q = variables.Variable(b, name='testExtendWithGroupBy_q') 1030 # Extend will happen here. 1031 init = control_flow_ops.group(p.initializer, q.initializer) 1032 s.run(init) 1033 p_val, q_val = s.run([p, q]) 1034 1035 self.assertAllEqual([[1.0, 1.0]], p_val) 1036 self.assertAllEqual([[2.0, 2.0]], q_val) 1037 1038 def testTensorGetMethod(self): 1039 with session.Session(): 1040 a = constant_op.constant(1.0, shape=[1, 2]) 1041 b = constant_op.constant(2.0, shape=[2, 3]) 1042 c = math_ops.matmul(a, b) 1043 1044 c_val = c.eval() 1045 self.assertAllEqual([[4.0, 4.0, 4.0]], c_val) 1046 1047 fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]}) 1048 self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val) 1049 1050 @test_util.run_v1_only('b/120545219') 1051 def testOperationRunMethod(self): 1052 with session.Session(): 1053 a = constant_op.constant(1.0, shape=[1, 2]) 1054 b = constant_op.constant(2.0, shape=[1, 2], name='b') 1055 v = variables.VariableV1(a, a.dtype) 1056 assign_a_to_v = state_ops.assign(v, a) 1057 1058 assign_a_to_v.eval() 1059 1060 v_val = v.eval() 1061 self.assertAllEqual([[1.0, 1.0]], v_val) 1062 1063 assign_b_to_v = state_ops.assign(v, b) 1064 1065 assign_b_to_v.eval() 1066 v_val = v.eval() 1067 self.assertAllEqual([[2.0, 2.0]], v_val) 1068 1069 assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]}) 1070 v_val = v.eval() 1071 self.assertAllEqual([[3.0, 3.0]], v_val) 1072 1073 def testDefaultGraph(self): 1074 with session.Session() as s: 1075 self.assertEqual(ops.get_default_graph(), s.graph) 1076 a = constant_op.constant(1.0, shape=[1, 2]) 1077 b = constant_op.constant(2.0, shape=[2, 3]) 1078 self.assertEqual(ops.get_default_graph(), a.graph) 1079 self.assertEqual(ops.get_default_graph(), b.graph) 1080 c = math_ops.matmul(a, b) 1081 v = variables.Variable(c, name='testDefaultGraph_v') 1082 v.initializer.run() 1083 v_val = v.eval() 1084 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1085 d = constant_op.constant(3.0, shape=[2, 3]) 1086 e = math_ops.matmul(a, d) 1087 assign_e_to_v = state_ops.assign(v, e) 1088 e_val = e.eval() 1089 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1090 v_val = v.eval() 1091 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1092 s.run(assign_e_to_v) 1093 v_val = v.eval() 1094 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1095 self.assertEqual(ops.get_default_graph(), s.graph) 1096 1097 def _testDefaultGraphInThread(self, constructed_event, continue_event, i): 1098 with session.Session() as s: 1099 self.assertEqual(ops.get_default_graph(), s.graph) 1100 a = constant_op.constant(1.0, shape=[1, 2]) 1101 b = constant_op.constant(2.0, shape=[2, 3]) 1102 c = math_ops.matmul(a, b) 1103 v = variables.Variable(c, name='var_%d' % i) 1104 1105 # Block here until all threads have constructed their graph. 1106 constructed_event.set() 1107 continue_event.wait() 1108 1109 assign_c_to_v = state_ops.assign(v, c) 1110 v.initializer.run() 1111 assign_c_to_v.eval() 1112 v_val = v.eval() 1113 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1114 d = constant_op.constant(3.0, shape=[2, 3]) 1115 e = math_ops.matmul(a, d) 1116 assign_e_to_v = state_ops.assign(v, e) 1117 e_val = e.eval() 1118 self.assertAllEqual([[6.0, 6.0, 6.0]], e_val) 1119 v_val = v.eval() 1120 self.assertAllEqual([[4.0, 4.0, 4.0]], v_val) 1121 s.run(assign_e_to_v) 1122 v_val = v.eval() 1123 self.assertAllEqual([[6.0, 6.0, 6.0]], v_val) 1124 self.assertEqual(ops.get_default_graph(), s.graph) 1125 1126 def testDefaultGraphWithThreads(self): 1127 # Fork ten threads that use their thread-local default graph. 1128 threads = [] 1129 constructed_events = [threading.Event() for _ in range(10)] 1130 continue_event = threading.Event() 1131 for i, constructed_event in enumerate(constructed_events): 1132 t = self.checkedThread( 1133 target=self._testDefaultGraphInThread, 1134 args=(constructed_event, continue_event, i)) 1135 threads.append(t) 1136 for t in threads: 1137 t.start() 1138 for constructed_event in constructed_events: 1139 constructed_event.wait() 1140 continue_event.set() 1141 for t in threads: 1142 t.join() 1143 1144 def testParallelRun(self): 1145 with session.Session() as sess: 1146 c = constant_op.constant(5.0) 1147 ev = threading.Event() 1148 1149 def run_step(): 1150 ev.wait() 1151 val = c.eval(session=sess) 1152 self.assertEqual(val, 5.0) 1153 1154 threads = [self.checkedThread(target=run_step) for _ in range(100)] 1155 for t in threads: 1156 t.start() 1157 ev.set() 1158 for t in threads: 1159 t.join() 1160 1161 @staticmethod 1162 def _build_graph(): 1163 time.sleep(random.random() * 0.1) 1164 # Do some graph construction. Try to exercise non-trivial paths. 1165 graph = ops.get_default_graph() 1166 gdef = None 1167 for _ in range(10): 1168 x = array_ops.placeholder(dtype=dtypes.float32) 1169 with ops.colocate_with(x): 1170 y = array_ops.placeholder(dtype=dtypes.float32) 1171 with ops.device('/cpu:0'): 1172 z = control_flow_ops.while_loop( 1173 lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y]) 1174 with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}): 1175 gradients_impl.gradients(z, [x, y]) 1176 if gdef is None: 1177 gdef = graph.as_graph_def() 1178 else: 1179 importer.import_graph_def(gdef, name='import') 1180 1181 @test_util.run_v1_only('b/120545219') 1182 def testParallelRunAndSingleBuild(self): 1183 with session.Session() as sess: 1184 c = constant_op.constant(5.0) 1185 stop = threading.Event() 1186 1187 def run_loop(): 1188 while not stop.is_set(): 1189 time.sleep(random.random() * 0.1) 1190 self.assertEqual(sess.run(c), 5.0) 1191 1192 threads = [self.checkedThread(target=run_loop) for _ in range(10)] 1193 for t in threads: 1194 t.start() 1195 1196 SessionTest._build_graph() 1197 1198 stop.set() 1199 for t in threads: 1200 t.join() 1201 1202 @test_util.run_v1_only('b/120545219') 1203 def testParallelRunAndParallelBuild(self): 1204 with session.Session() as sess: 1205 c = constant_op.constant(5.0) 1206 stop = threading.Event() 1207 1208 def run_loop(): 1209 while not stop.is_set(): 1210 time.sleep(random.random() * 0.1) 1211 self.assertEqual(sess.run(c), 5.0) 1212 1213 run_threads = [self.checkedThread(target=run_loop) for _ in range(10)] 1214 for t in run_threads: 1215 t.start() 1216 1217 build_threads = [self.checkedThread(target=SessionTest._build_graph) 1218 for _ in range(10)] 1219 for t in build_threads: 1220 t.start() 1221 for t in build_threads: 1222 t.join() 1223 1224 # Let the run_threads run until the build threads are finished. 1225 stop.set() 1226 for t in run_threads: 1227 t.join() 1228 1229 def testRunFeedDict(self): 1230 with session.Session() as s: 1231 x = array_ops.zeros([2]) 1232 1233 y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)}) 1234 self.assertAllEqual(y, 2 * np.ones(2)) 1235 1236 y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)}) 1237 self.assertAllEqual(y, 2 * np.ones(2)) 1238 1239 y = s.run(2 * x, feed_dict={x: [1, 1]}) 1240 assert (y == 2 * np.ones(2)).all() 1241 1242 # Test nested tuple keys 1243 z = (((array_ops.zeros([2]),),), array_ops.zeros([2]), 1244 (array_ops.zeros([2]),)) 1245 result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2] 1246 values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),)) 1247 result_value = s.run(result, feed_dict={z: values}) 1248 self.assertAllEqual(result_value[0], 2 * np.ones(2)) 1249 self.assertAllEqual(result_value[1], 2 * np.array([2, 2])) 1250 self.assertAllEqual(result_value[2], 2 * np.array([3, 3])) 1251 1252 def testGraphDef(self): 1253 with session.Session() as sess: 1254 self.assertProtoEquals('versions { producer: %d min_consumer: %d }' % 1255 (versions.GRAPH_DEF_VERSION, 1256 versions.GRAPH_DEF_VERSION_MIN_CONSUMER), 1257 sess.graph_def) 1258 c = constant_op.constant(5.0, name='c') 1259 self.assertEqual(len(sess.graph_def.node), 1) 1260 d = constant_op.constant(6.0, name='d') 1261 self.assertEqual(len(sess.graph_def.node), 2) 1262 self.assertAllEqual(c, 5.0) 1263 self.assertAllEqual(d, 6.0) 1264 e = constant_op.constant(7.0, name='e') 1265 self.assertEqual(len(sess.graph_def.node), 3) 1266 self.assertAllEqual(e, 7.0) 1267 1268 def testUseAfterClose(self): 1269 with session.Session() as sess: 1270 c = constant_op.constant(5.0) 1271 self.assertAllEqual(sess.run(c), 5.0) 1272 with self.assertRaisesWithPredicateMatch( 1273 RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)): 1274 sess.run(c) 1275 1276 def testUseAfterCloseConcurrent(self): 1277 with session.Session() as sess: 1278 c = constant_op.constant(5.0) 1279 self.assertAllEqual(sess.run(c), 5.0) 1280 1281 def update_thread(): 1282 with self.assertRaisesWithPredicateMatch( 1283 RuntimeError, 1284 lambda e: 'Attempted to use a closed Session.' in str(e)): 1285 while True: 1286 sess.run(c) 1287 1288 t = threading.Thread(target=update_thread) 1289 t.start() 1290 time.sleep(0.1) 1291 sess.close() 1292 t.join() 1293 1294 def testUseEmptyGraph(self): 1295 with session.Session() as sess: 1296 with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'): 1297 sess.run([]) 1298 with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'): 1299 sess.run(()) 1300 with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'): 1301 sess.run({}) 1302 1303 @test_util.run_v1_only('b/120545219') 1304 def testNotEntered(self): 1305 # pylint: disable=protected-access 1306 self.assertIsNone(ops._default_session_stack.get_default()) 1307 # pylint: enable=protected-access 1308 with ops.device('/cpu:0'): 1309 sess = session.Session() 1310 c_1 = constant_op.constant(5.0) 1311 with sess.graph.as_default(): 1312 c_2 = constant_op.constant(5.0) 1313 self.assertEqual(c_1.graph, c_2.graph) 1314 self.assertEqual(sess.run(c_2), 5.0) 1315 with self.assertRaisesWithPredicateMatch( 1316 ValueError, lambda e: 'No default session is registered.' in str(e)): 1317 c_2.eval() 1318 1319 @test_util.run_v1_only('b/120545219') 1320 def testInteractive(self): 1321 with ops.device('/cpu:0'): 1322 sess = session.InteractiveSession() 1323 a = constant_op.constant(1.0, shape=[1, 2]) 1324 b = constant_op.constant(2.0, shape=[2, 3]) 1325 c = math_ops.matmul(a, b) 1326 self.assertAllEqual([[4.0, 4.0, 4.0]], c) 1327 d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1]) 1328 e = math_ops.matmul(c, d) 1329 self.assertAllEqual([[24.0]], e) 1330 sess.close() 1331 1332 @test_util.run_v1_only('b/120545219') 1333 def testMultipleInteractiveSessionsWarning(self): 1334 # Reinitialize the global state to ensure that the expected warnings will 1335 # be emitted. 1336 session.InteractiveSession._active_session_count = 0 # pylint: disable=protected-access 1337 1338 sess = session.InteractiveSession() 1339 sess.run(constant_op.constant(4.0)) # Run so that the session is "opened". 1340 sess.close() 1341 # Opening and closing interactive sessions serially should not warn. 1342 with warnings.catch_warnings(record=True) as w: 1343 sess = session.InteractiveSession() 1344 sess.close() 1345 self.assertEqual(0, len(w)) 1346 1347 with warnings.catch_warnings(record=True) as w: 1348 sess = session.InteractiveSession() 1349 self.assertEqual(0, len(w)) 1350 with warnings.catch_warnings(record=True) as w: 1351 sess2 = session.InteractiveSession() 1352 self.assertEqual(1, len(w)) 1353 self.assertIn('An interactive session is already active. This can cause ' 1354 'out-of-memory errors in some cases. You must explicitly ' 1355 'call `InteractiveSession.close()` to release resources ' 1356 'held by the other session(s).', str(w[0].message)) 1357 sess2.close() 1358 sess.close() 1359 1360 @test_util.run_v1_only('b/120545219') 1361 def testInteractivePlacePrunedGraph(self): 1362 sess = session.InteractiveSession() 1363 1364 # Build a graph that has a bad op in it (no kernel). 1365 # 1366 # This test currently does not link in any GPU kernels, 1367 # which is why placing this is invalid. If at some point 1368 # GPU kernels are added to this test, some other different 1369 # op / device combo should be chosen. 1370 with ops.device('/device:GPU:0'): 1371 a = constant_op.constant(1.0, shape=[1, 2]) 1372 1373 b = constant_op.constant(1.0, shape=[1, 2]) 1374 1375 # Only run the valid op, this should work. 1376 b.eval() 1377 1378 with self.assertRaises(errors.InvalidArgumentError): 1379 a.eval() 1380 sess.close() 1381 1382 @test_util.run_v1_only('b/120545219') 1383 def testDefaultSessionPlacePrunedGraph(self): 1384 sess = session.Session() 1385 1386 # Build a graph that has a bad op in it (no kernel). 1387 # 1388 # This test currently does not link in any GPU kernels, 1389 # which is why placing this is invalid. If at some point 1390 # GPU kernels are added to this test, some other different 1391 # op / device combo should be chosen. 1392 with ops.device('/device:GPU:0'): 1393 _ = constant_op.constant(1.0, shape=[1, 2]) 1394 1395 b = constant_op.constant(1.0, shape=[1, 2]) 1396 1397 with self.assertRaises(errors.InvalidArgumentError): 1398 # Even though we don't run the bad op, we place the entire 1399 # graph, which should fail with a non-interactive session. 1400 sess.run(b) 1401 1402 sess.close() 1403 1404 def testSharedGraph(self): 1405 with ops.Graph().as_default() as g, ops.device('/cpu:0'): 1406 a = constant_op.constant(1.0, shape=[1, 2]) 1407 b = constant_op.constant(2.0, shape=[2, 3]) 1408 c = math_ops.matmul(a, b) 1409 1410 with session.Session(graph=g) as sess1: 1411 with session.Session(graph=g) as sess2: 1412 self.assertAllEqual(sess1.run(c), sess2.run(c)) 1413 1414 def testDuplicatedInputs(self): 1415 with session.Session() as sess: 1416 a = constant_op.constant(1.0, shape=[1, 2]) 1417 b = constant_op.constant(2.0, shape=[1, 3]) 1418 a_val, b_val, a2_val = sess.run([a, b, a]) 1419 self.assertAllEqual(a_val, [[1.0, 1.0]]) 1420 self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]]) 1421 self.assertAllEqual(a2_val, [[1.0, 1.0]]) 1422 1423 def testFeedAndFetch(self): 1424 with session.Session() as sess: 1425 for dtype in [ 1426 dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32, 1427 dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool, 1428 dtypes.complex64, dtypes.complex128 1429 ]: 1430 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1431 np_dtype = dtype.as_numpy_dtype 1432 1433 feed_t = array_ops.placeholder(dtype=dtype, shape=shape) 1434 out_t = array_ops.identity(feed_t) 1435 1436 np_array = np.random.randint(-10, 10, shape) 1437 1438 if dtype == dtypes.bool: 1439 np_array = np_array > 0 1440 elif dtype == dtypes.complex64: 1441 np_array = np.sqrt(np_array.astype(np_dtype)) 1442 elif dtype == dtypes.complex64: 1443 np_array = np.sqrt(np_array.astype(np_dtype)) 1444 else: 1445 np_array = np_array.astype(np_dtype) 1446 1447 self.assertAllEqual(np_array, 1448 sess.run(out_t, feed_dict={ 1449 feed_t: np_array 1450 })) 1451 # Check that we can also get the feed back. 1452 self.assertAllEqual(np_array, 1453 sess.run(feed_t, feed_dict={ 1454 feed_t: np_array 1455 })) 1456 # Also check that we can get both back. 1457 out_v, feed_v = sess.run( 1458 [out_t, feed_t], feed_dict={ 1459 feed_t: np_array 1460 }) 1461 self.assertAllEqual(np_array, out_v) 1462 self.assertAllEqual(np_array, feed_v) 1463 1464 feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t]) 1465 out_v, feed_v = feed_fetch_runner(np_array) 1466 self.assertAllEqual(np_array, out_v) 1467 self.assertAllEqual(np_array, feed_v) 1468 1469 def testMakeCallableOnTensorWithRunOptions(self): 1470 with session.Session() as sess: 1471 a = constant_op.constant(42.0) 1472 tensor_runner = sess.make_callable(a, accept_options=True) 1473 run_options = config_pb2.RunOptions( 1474 trace_level=config_pb2.RunOptions.FULL_TRACE) 1475 run_metadata = config_pb2.RunMetadata() 1476 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1477 res = tensor_runner(options=run_options, run_metadata=run_metadata) 1478 self.assertEqual(42.0, res) 1479 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1480 1481 def testMakeCallableOnOperationWithRunOptions(self): 1482 with session.Session() as sess: 1483 a = variables.Variable(42.0) 1484 b = state_ops.assign_add(a, 1.0) 1485 sess.run(a.initializer) 1486 tensor_runner = sess.make_callable(b.op, accept_options=True) 1487 run_options = config_pb2.RunOptions( 1488 trace_level=config_pb2.RunOptions.FULL_TRACE) 1489 run_metadata = config_pb2.RunMetadata() 1490 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1491 tensor_runner(options=run_options, run_metadata=run_metadata) 1492 self.assertEqual(43.0, sess.run(a)) 1493 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1494 1495 def testMakeCallableWithFeedListAndRunOptions(self): 1496 with session.Session() as sess: 1497 ph = array_ops.placeholder(dtypes.float32) 1498 a = math_ops.add(ph, 1.0) 1499 tensor_runner = sess.make_callable( 1500 a, feed_list=[ph.name], accept_options=True) 1501 run_options = config_pb2.RunOptions( 1502 trace_level=config_pb2.RunOptions.FULL_TRACE) 1503 run_metadata = config_pb2.RunMetadata() 1504 self.assertEqual(0, len(run_metadata.step_stats.dev_stats)) 1505 self.assertAllClose(42.0, 1506 tensor_runner( 1507 41.0, 1508 options=run_options, 1509 run_metadata=run_metadata)) 1510 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1511 1512 def testOptimizedMakeCallable(self): 1513 with session.Session() as sess: 1514 ph = array_ops.placeholder(dtypes.float32) 1515 a = math_ops.add(ph, 1.0) 1516 callable_opts = config_pb2.CallableOptions() 1517 callable_opts.feed.append(ph.name) 1518 callable_opts.fetch.append(a.name) 1519 for _ in range(3): 1520 callable_fn = sess._make_callable_from_options(callable_opts) 1521 for _ in range(5): 1522 self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32))) 1523 1524 def testOptimizedMakeCallableWithRunMetadata(self): 1525 with session.Session() as sess: 1526 ph = array_ops.placeholder(dtypes.float32) 1527 a = math_ops.add(ph, 1.0) 1528 callable_opts = config_pb2.CallableOptions() 1529 callable_opts.feed.append(ph.name) 1530 callable_opts.fetch.append(a.name) 1531 callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE 1532 callable_fn = sess._make_callable_from_options(callable_opts) 1533 run_metadata = config_pb2.RunMetadata() 1534 self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32), 1535 run_metadata=run_metadata)) 1536 self.assertGreater(len(run_metadata.step_stats.dev_stats), 0) 1537 1538 def testFeedError(self): 1539 with session.Session() as sess: 1540 feed_t = array_ops.placeholder(dtype=dtypes.float32) 1541 out_t = array_ops.identity(feed_t) 1542 feed_val = constant_op.constant(5.0) 1543 with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'): 1544 sess.run(out_t, feed_dict={feed_t: feed_val}) 1545 with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'): 1546 out_t.eval(feed_dict={feed_t: feed_val}) 1547 with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'): 1548 out_t.op.run(feed_dict={feed_t: feed_val}) 1549 1550 def testFeedPrecisionLossError(self): 1551 with session.Session() as sess: 1552 largest_int64 = np.iinfo(np.int64).max 1553 1554 feed_int_implicit_int32 = constant_op.constant(1) 1555 feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32) 1556 1557 out_t = constant_op.constant(1.0) 1558 1559 with self.assertRaisesRegex(TypeError, 1560 'is not compatible with Tensor type'): 1561 sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64}) 1562 with self.assertRaisesRegex(TypeError, 1563 'is not compatible with Tensor type'): 1564 sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64}) 1565 1566 def testStringFetch(self): 1567 with session.Session(): 1568 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1569 size = 1 1570 for s in shape: 1571 size *= s 1572 c_list = np.array( 1573 [compat.as_bytes(str(i)) for i in xrange(size)], 1574 dtype=np.object).reshape(shape) if size > 0 else [] 1575 c = constant_op.constant(c_list) 1576 self.assertAllEqual(c, c_list) 1577 1578 def testStringFeed(self): 1579 with session.Session() as sess: 1580 for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]: 1581 size = 1 1582 for s in shape: 1583 size *= s 1584 c_list = np.array( 1585 [compat.as_bytes(str(i)) for i in xrange(size)], 1586 dtype=np.object).reshape(shape) 1587 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape) 1588 c = array_ops.identity(feed_t) 1589 self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list) 1590 self.assertAllEqual( 1591 sess.run(feed_t, feed_dict={ 1592 feed_t: c_list 1593 }), c_list) 1594 c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list}) 1595 self.assertAllEqual(c_v, c_list) 1596 self.assertAllEqual(feed_v, c_list) 1597 1598 def testStringFeedWithNullCharacters(self): 1599 with session.Session(): 1600 c_list = [b'\n\x01\x00', b'\n\x00\x01'] 1601 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2]) 1602 c = array_ops.identity(feed_t) 1603 out = c.eval(feed_dict={feed_t: c_list}) 1604 self.assertEqual(c_list[0], out[0]) 1605 self.assertEqual(c_list[1], out[1]) 1606 1607 def testStringFeedWithUnicode(self): 1608 with session.Session(): 1609 c_list = [ 1610 u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode', 1611 u'\U0001f60e deal with it' 1612 ] 1613 feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)]) 1614 c = array_ops.identity(feed_t) 1615 1616 out = c.eval(feed_dict={feed_t: c_list}) 1617 for i in range(len(c_list)): 1618 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1619 1620 out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)}) 1621 for i in range(len(c_list)): 1622 self.assertEqual(c_list[i], out[i].decode('utf-8')) 1623 1624 def testInvalidTargetFails(self): 1625 with self.assertRaisesRegex( 1626 errors.NotFoundError, 1627 'No session factory registered for the given session options'): 1628 session.Session('INVALID_TARGET') 1629 1630 def testFetchByNameDifferentStringTypes(self): 1631 with session.Session() as sess: 1632 c = constant_op.constant(42.0, name='c') 1633 d = constant_op.constant(43.0, name=u'd') 1634 e = constant_op.constant(44.0, name=b'e') 1635 f = constant_op.constant(45.0, name=r'f') 1636 1637 self.assertIsInstance(c.name, six.text_type) 1638 self.assertIsInstance(d.name, six.text_type) 1639 self.assertIsInstance(e.name, six.text_type) 1640 self.assertIsInstance(f.name, six.text_type) 1641 1642 self.assertEqual(42.0, sess.run('c:0')) 1643 self.assertEqual(42.0, sess.run(u'c:0')) 1644 self.assertEqual(42.0, sess.run(b'c:0')) 1645 self.assertEqual(42.0, sess.run(r'c:0')) 1646 1647 self.assertEqual(43.0, sess.run('d:0')) 1648 self.assertEqual(43.0, sess.run(u'd:0')) 1649 self.assertEqual(43.0, sess.run(b'd:0')) 1650 self.assertEqual(43.0, sess.run(r'd:0')) 1651 1652 self.assertEqual(44.0, sess.run('e:0')) 1653 self.assertEqual(44.0, sess.run(u'e:0')) 1654 self.assertEqual(44.0, sess.run(b'e:0')) 1655 self.assertEqual(44.0, sess.run(r'e:0')) 1656 1657 self.assertEqual(45.0, sess.run('f:0')) 1658 self.assertEqual(45.0, sess.run(u'f:0')) 1659 self.assertEqual(45.0, sess.run(b'f:0')) 1660 self.assertEqual(45.0, sess.run(r'f:0')) 1661 1662 def testIncorrectGraph(self): 1663 with ops.Graph().as_default() as g_1: 1664 c_1 = constant_op.constant(1.0, name='c') 1665 1666 with ops.Graph().as_default() as g_2: 1667 c_2 = constant_op.constant(2.0, name='c') 1668 1669 self.assertEqual('c', c_1.op.name) 1670 self.assertEqual('c', c_2.op.name) 1671 1672 with session.Session(graph=g_1) as sess_1: 1673 self.assertEqual(1.0, sess_1.run(c_1)) 1674 with self.assertRaises(ValueError): 1675 sess_1.run(c_2) 1676 with self.assertRaises(ValueError): 1677 sess_1.run(c_2.op) 1678 1679 with session.Session(graph=g_2) as sess_2: 1680 with self.assertRaises(ValueError): 1681 sess_2.run(c_1) 1682 with self.assertRaises(ValueError): 1683 sess_2.run(c_1.op) 1684 self.assertEqual(2.0, sess_2.run(c_2)) 1685 1686 def testFeedDictKeyException(self): 1687 with session.Session() as sess: 1688 a = constant_op.constant(1.0, dtypes.float32, name='a') 1689 with self.assertRaisesRegex(TypeError, 'Cannot interpret feed_dict'): 1690 sess.run(a, feed_dict={'a': [2.0]}) 1691 1692 def testPerStepTrace(self): 1693 run_options = config_pb2.RunOptions( 1694 trace_level=config_pb2.RunOptions.SOFTWARE_TRACE) 1695 run_metadata = config_pb2.RunMetadata() 1696 1697 with ops.device('/cpu:0'): 1698 with session.Session() as sess: 1699 sess.run(constant_op.constant(1.0)) 1700 self.assertFalse(run_metadata.HasField('step_stats')) 1701 1702 sess.run(constant_op.constant(1.0), run_metadata=run_metadata) 1703 self.assertFalse(run_metadata.HasField('step_stats')) 1704 1705 sess.run( 1706 constant_op.constant(1.0), 1707 options=run_options, 1708 run_metadata=run_metadata) 1709 1710 self.assertTrue(run_metadata.HasField('step_stats')) 1711 self.assertEqual(len(run_metadata.step_stats.dev_stats), 1) 1712 1713 def testRunOptionsRunMetadata(self): 1714 run_options = config_pb2.RunOptions( 1715 trace_level=config_pb2.RunOptions.SOFTWARE_TRACE) 1716 run_metadata = config_pb2.RunMetadata() 1717 1718 with ops.device('/cpu:0'): 1719 with session.Session() as sess: 1720 # all combinations are valid 1721 sess.run(constant_op.constant(1.0), options=None, run_metadata=None) 1722 sess.run( 1723 constant_op.constant(1.0), options=None, run_metadata=run_metadata) 1724 self.assertFalse(run_metadata.HasField('step_stats')) 1725 1726 sess.run( 1727 constant_op.constant(1.0), options=run_options, run_metadata=None) 1728 self.assertFalse(run_metadata.HasField('step_stats')) 1729 1730 sess.run( 1731 constant_op.constant(1.0), 1732 options=run_options, 1733 run_metadata=run_metadata) 1734 1735 self.assertTrue(run_metadata.HasField('step_stats')) 1736 self.assertEqual(len(run_metadata.step_stats.dev_stats), 1) 1737 1738 def testFeedShapeCompatibility(self): 1739 with session.Session() as sess: 1740 some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0]) 1741 new_shape = constant_op.constant([2, 2]) 1742 reshaped_tensor = array_ops.reshape(some_tensor, new_shape) 1743 1744 with self.assertRaisesRegex(ValueError, 'Cannot feed value of shape'): 1745 sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]}) 1746 1747 with self.assertRaisesRegex( 1748 errors.InvalidArgumentError, 1749 'Input to reshape is a tensor with 4 values, ' 1750 'but the requested shape has 21'): 1751 sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]}) 1752 1753 def testInferShapesFalse(self): 1754 with ops.Graph().as_default(), ops.device('/cpu:0'): 1755 a = constant_op.constant([[1, 2]]) 1756 sess = session.Session() 1757 self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr) 1758 # Avoid lint error regarding 'unused' var a. 1759 self.assertEqual(a, a) 1760 1761 def testInferShapesTrue(self): 1762 config_pb = config_pb2.ConfigProto( 1763 graph_options=config_pb2.GraphOptions(infer_shapes=True)) 1764 with ops.Graph().as_default(), ops.device('/cpu:0'): 1765 a = constant_op.constant([[1, 2]]) 1766 sess = session.Session(config=config_pb) 1767 self.assertIn('_output_shapes', sess.graph_def.node[0].attr) 1768 # Avoid lint error regarding 'unused' var a. 1769 self.assertEqual(a, a) 1770 1771 def testBuildCostModel(self): 1772 run_options = config_pb2.RunOptions() 1773 config_pb = config_pb2.ConfigProto( 1774 allow_soft_placement=True, 1775 graph_options=config_pb2.GraphOptions(build_cost_model=100)) 1776 with session.Session(config=config_pb) as sess: 1777 with ops.device('/device:GPU:0'): 1778 a = array_ops.placeholder(dtypes.float32, shape=[]) 1779 b = math_ops.add(a, a) 1780 c = array_ops.identity(b) 1781 d = math_ops.multiply(c, c) 1782 for step in xrange(120): 1783 run_metadata = config_pb2.RunMetadata() 1784 sess.run( 1785 d, 1786 feed_dict={a: 1.0}, 1787 options=run_options, 1788 run_metadata=run_metadata) 1789 if step == 99: 1790 self.assertTrue(run_metadata.HasField('cost_graph')) 1791 else: 1792 self.assertFalse(run_metadata.HasField('cost_graph')) 1793 1794 def runTestOutputPartitionGraphs(self, sess): 1795 run_options = config_pb2.RunOptions(output_partition_graphs=True) 1796 a = constant_op.constant(1) 1797 run_metadata = config_pb2.RunMetadata() 1798 sess.run(a, options=run_options, run_metadata=run_metadata) 1799 self.assertGreater(len(run_metadata.partition_graphs), 0) 1800 sess.run(a, run_metadata=run_metadata) 1801 self.assertEqual(len(run_metadata.partition_graphs), 0) 1802 1803 @test_util.run_v1_only('b/120545219') 1804 def testOutputPartitionGraphsDirect(self): 1805 self.runTestOutputPartitionGraphs(session.Session()) 1806 1807 @test_util.run_v1_only('b/120545219') 1808 def testOutputPartitionGraphsDistributed(self): 1809 server = server_lib.Server.create_local_server() 1810 self.runTestOutputPartitionGraphs(session.Session(server.target)) 1811 1812 def testNonInteractiveSessionNesting(self): 1813 sess1 = session.Session() 1814 sess1_controller = sess1.as_default() 1815 sess1_controller.__enter__() 1816 1817 sess2 = session.Session() 1818 sess2_controller = sess2.as_default() 1819 sess2_controller.__enter__() 1820 1821 with self.assertRaisesRegex(AssertionError, 'Nesting violated'): 1822 sess1_controller.__exit__(None, None, None) 1823 1824 ops._default_session_stack.reset() 1825 1826 def testInteractiveSessionNesting(self): 1827 sess1 = session.InteractiveSession() 1828 sess2 = session.InteractiveSession() 1829 del sess1 1830 del sess2 1831 1832 @test_util.run_v1_only('b/120545219') 1833 def testAsDefault(self): 1834 c = constant_op.constant(37) 1835 sess = session.Session() 1836 with sess.as_default(): 1837 self.assertEqual(37, c.eval()) 1838 1839 # Ensure that the session remains valid even when it is not captured. 1840 with session.Session().as_default(): 1841 self.assertEqual(37, c.eval()) 1842 1843 def testReentry(self): 1844 sess = session.Session() 1845 with self.assertRaisesRegex(RuntimeError, 'not re-entrant'): 1846 with sess: 1847 with sess: 1848 pass 1849 1850 def testInvalidArgument(self): 1851 with self.assertRaisesRegex(TypeError, 'target must be a string'): 1852 session.Session(37) 1853 with self.assertRaisesRegex(TypeError, 'config must be a tf.ConfigProto'): 1854 session.Session(config=37) 1855 with self.assertRaisesRegex(TypeError, 'graph must be a tf.Graph'): 1856 session.Session(graph=37) 1857 1858 @test_util.run_v1_only('b/120545219') 1859 def testTimeoutWithShortOperations(self): 1860 num_epochs = 5 1861 q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()]) 1862 enqueue_op = q.enqueue_many(constant_op.constant([1, 2])) 1863 1864 # Use a 10-second timeout, which should be longer than any 1865 # non-blocking enqueue_many op. 1866 config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=10000) 1867 with session.Session(config=config_pb) as sess: 1868 for _ in range(num_epochs): 1869 sess.run(enqueue_op) 1870 self.assertEqual(sess.run(q.size()), num_epochs * 2) 1871 1872 @test_util.run_v1_only('b/120545219') 1873 def testRegisterFetchAndFeedConversionFunctions(self): 1874 1875 class SquaredTensor(object): 1876 1877 def __init__(self, tensor): 1878 self.sq = math_ops.square(tensor) 1879 1880 fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0]) 1881 feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)] 1882 feed_fn2 = lambda feed: [feed.sq] 1883 1884 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1885 feed_fn1, feed_fn2) 1886 with self.assertRaises(ValueError): 1887 session.register_session_run_conversion_functions(SquaredTensor, fetch_fn, 1888 feed_fn1, feed_fn2) 1889 with self.cached_session() as sess: 1890 np1 = np.array([1.0, 1.5, 2.0, 2.5]) 1891 np2 = np.array([3.0, 3.5, 4.0, 4.5]) 1892 squared_tensor = SquaredTensor(np2) 1893 squared_eval = sess.run(squared_tensor) 1894 self.assertAllClose(np2 * np2, squared_eval) 1895 squared_eval = sess.run( 1896 squared_tensor, feed_dict={ 1897 squared_tensor: np1 * np1 1898 }) 1899 self.assertAllClose(np1 * np1, squared_eval) 1900 partial_run = sess.partial_run_setup([squared_tensor], []) 1901 squared_eval = sess.partial_run(partial_run, squared_tensor) 1902 self.assertAllClose(np2 * np2, squared_eval) 1903 1904 def testDefaultLogDevicePlacement(self): 1905 1906 class CaptureStderr(str): 1907 """Class to capture stderr from C++ shared library.""" 1908 1909 def __enter__(self): 1910 self._esc = compat.as_str('\b') 1911 self._output = compat.as_str('') 1912 self._stderr = sys.stderr 1913 self._fd = self._stderr.fileno() 1914 self._out_pipe, in_pipe = os.pipe() 1915 # Save the original io stream. 1916 self._dup_fd = os.dup(self._fd) 1917 # Replace the original io stream with in pipe. 1918 os.dup2(in_pipe, self._fd) 1919 return self 1920 1921 def __exit__(self, *args): 1922 self._stderr.write(self._esc) 1923 self._stderr.flush() 1924 self.read() 1925 os.close(self._out_pipe) 1926 # Restore the original io stream. 1927 os.dup2(self._dup_fd, self._fd) 1928 1929 def read(self): 1930 while True: 1931 data = os.read(self._out_pipe, 1) 1932 if not data or compat.as_str(data) == self._esc: 1933 break 1934 self._output += compat.as_str(data) 1935 1936 def __str__(self): 1937 return self._output 1938 1939 context.set_log_device_placement(True) 1940 if context.executing_eagerly(): 1941 with CaptureStderr() as log: 1942 a = constant_op.constant(1) 1943 b = constant_op.constant(2) 1944 c = a + b 1945 # Ensure if the same kernel with the same arguments is executed then its 1946 # execution is logged. 1947 d = a + b 1948 else: 1949 # Passing the config to the server, but not the session should still 1950 # result in logging device placement. 1951 config_pb = config_pb2.ConfigProto(log_device_placement=True) 1952 server = server_lib.Server.create_local_server(config=config_pb) 1953 a = constant_op.constant(1) 1954 b = constant_op.constant(2) 1955 c = a + b 1956 d = a + b 1957 with session.Session(server.target) as sess: 1958 with CaptureStderr() as log: 1959 c, d = sess.run([c, d]) 1960 1961 self.assertEqual(c, 3) 1962 self.assertEqual(d, 3) 1963 # Ensure that we did log device placement. 1964 add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] 1965 self.assertEqual(len(add_executions), 2) 1966 1967 @def_function.function 1968 def fn(a, b): 1969 c = a + b 1970 # These two AddV2 cannot use the same argument in tf.function since an 1971 # optimization pass will remove duplicate ops and only run it once. 1972 d = a + c 1973 return c, d 1974 1975 with CaptureStderr() as log: 1976 c, d = self.evaluate(fn(constant_op.constant(1), constant_op.constant(2))) 1977 self.assertEqual(c, 3) 1978 self.assertEqual(d, 4) 1979 # Ensure that we did log device placement. 1980 add_executions = [l for l in str(log).splitlines() if 'AddV2' in l] 1981 self.assertEqual(len(add_executions), 2) 1982 1983 @test_util.run_v1_only('b/120545219') 1984 def testLocalMasterSessionTimeout(self): 1985 # Test that the timeout passed in a config to the session works correctly. 1986 config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 1987 server = server_lib.Server.create_local_server() 1988 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 1989 dequeued_t = q.dequeue() 1990 1991 with session.Session(server.target, config=config_pb) as sess: 1992 # Intentionally do not run any enqueue_ops so that dequeue will block 1993 # until operation_timeout_in_ms. 1994 with self.assertRaises(errors.DeadlineExceededError): 1995 sess.run(dequeued_t) 1996 1997 @test_util.run_v1_only('b/120545219') 1998 def testDefaultServerTimeout(self): 1999 # Test that the default server config timeout gets used when no Session 2000 # config is provided. 2001 config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000) 2002 server = server_lib.Server.create_local_server(config=config_pb) 2003 q = data_flow_ops.FIFOQueue(1, dtypes.float32) 2004 dequeued_t = q.dequeue() 2005 2006 with session.Session(server.target) as sess: 2007 # Intentionally do not run any enqueue_ops so that dequeue will block 2008 # until operation_timeout_in_ms. 2009 with self.assertRaises(errors.DeadlineExceededError): 2010 sess.run(dequeued_t) 2011 2012 def runTestBuildGraphError(self, sess): 2013 # Ensure that errors from building the graph get propagated. 2014 data = array_ops.placeholder(dtypes.float32, shape=[]) 2015 # pylint: disable=protected-access 2016 enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False) 2017 enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False) 2018 # pylint: enable=protected-access 2019 res = math_ops.add(enter_1, enter_2) 2020 with self.assertRaisesOpError('has inputs from different frames'): 2021 sess.run(res, feed_dict={data: 1.0}) 2022 2023 @test_util.run_v1_only('b/120545219') 2024 def testBuildGraphErrorDirect(self): 2025 self.runTestBuildGraphError(session.Session()) 2026 2027 @test_util.run_v1_only('b/120545219') 2028 def testBuildGraphErrorDist(self): 2029 server = server_lib.Server.create_local_server() 2030 self.runTestBuildGraphError(session.Session(server.target)) 2031 2032 def testDeviceAttributes(self): 2033 attrs = session._DeviceAttributes( 2034 '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000) 2035 self.assertEqual(1337, attrs.memory_limit_bytes) 2036 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name) 2037 self.assertEqual('TYPE', attrs.device_type) 2038 self.assertEqual(1000000, attrs.incarnation) 2039 str_repr = '%s' % attrs 2040 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 2041 2042 def testDeviceAttributesCanonicalization(self): 2043 attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1', 2044 'TYPE', 1337, 1000000) 2045 self.assertEqual(1337, attrs.memory_limit_bytes) 2046 self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name) 2047 self.assertEqual('TYPE', attrs.device_type) 2048 self.assertEqual(1000000, attrs.incarnation) 2049 str_repr = '%s' % attrs 2050 self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr) 2051 2052 def runTestAddFunctionToSession(self, target=''): 2053 """Add a function to a session after the graph has already been run.""" 2054 2055 @function.Defun(dtypes.float32) 2056 def foo(x): 2057 return x + 1 2058 2059 x = constant_op.constant(1.0) 2060 with session.Session(target=target) as sess: 2061 sess.run(x) 2062 f = foo(x) 2063 result = sess.run(f) 2064 self.assertEqual(result, 2.0) 2065 2066 @test_util.run_v1_only('b/120545219') 2067 def testAddFunctionToSession(self): 2068 self.runTestAddFunctionToSession() 2069 2070 @test_util.run_v1_only('b/120545219') 2071 def testAddFunctionToGrpcSession(self): 2072 server = server_lib.Server.create_local_server() 2073 self.runTestAddFunctionToSession(server.target) 2074 2075 def testOpenAndCloseGrpcSession(self): 2076 server = server_lib.Server.create_local_server() 2077 with session.Session(server.target): 2078 pass 2079 2080 def testOpenAndCloseSession(self): 2081 with session.Session(): 2082 pass 2083 2084 @test_util.run_v1_only('b/120545219') 2085 def testAutoConvertAndCheckData(self): 2086 with self.cached_session() as sess: 2087 a = array_ops.placeholder(dtype=dtypes.string) 2088 with self.assertRaisesRegex( 2089 TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'): 2090 sess.run(a, feed_dict={a: 1}) 2091 2092 @test_util.run_v1_only('b/120545219') 2093 def testOptimizerOptions(self): 2094 config.set_optimizer_experimental_options({'min_graph_nodes': -1}) 2095 2096 with ops.Graph().as_default(): 2097 sess = session.Session() 2098 self.assertEqual( 2099 sess._config.graph_options.rewrite_options.min_graph_nodes, -1) 2100 2101 2102if __name__ == '__main__': 2103 googletest.main() 2104