1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.python.client.session.Session."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import collections
21import os
22import random
23import sys
24import threading
25import time
26import warnings
27
28import numpy as np
29import six
30from six.moves import xrange  # pylint: disable=redefined-builtin
31
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.core.lib.core import error_codes_pb2
34from tensorflow.core.protobuf import config_pb2
35from tensorflow.python.client import session
36from tensorflow.python.eager import context
37from tensorflow.python.eager import def_function
38from tensorflow.python.framework import config
39from tensorflow.python.framework import constant_op
40from tensorflow.python.framework import device as framework_device_lib
41from tensorflow.python.framework import dtypes
42from tensorflow.python.framework import errors
43from tensorflow.python.framework import function
44from tensorflow.python.framework import importer
45from tensorflow.python.framework import ops
46from tensorflow.python.framework import sparse_tensor
47from tensorflow.python.framework import tensor_util
48from tensorflow.python.framework import test_util
49from tensorflow.python.framework import versions
50from tensorflow.python.ops import array_ops
51from tensorflow.python.ops import control_flow_ops
52from tensorflow.python.ops import data_flow_ops
53from tensorflow.python.ops import gen_control_flow_ops
54# Import gradients to resolve circular imports
55from tensorflow.python.ops import gradients  # pylint: disable=unused-import
56from tensorflow.python.ops import gradients_impl
57from tensorflow.python.ops import math_ops
58# Import resource_variable_ops for the variables-to-tensor implicit conversion.
59from tensorflow.python.ops import resource_variable_ops  # pylint: disable=unused-import
60from tensorflow.python.ops import state_ops
61from tensorflow.python.ops import variables
62from tensorflow.python.platform import googletest
63from tensorflow.python.training import server_lib
64from tensorflow.python.util import compat
65
66try:
67  import attr  # pylint:disable=g-import-not-at-top
68except ImportError:
69  attr = None
70
71try:
72  from frozendict import frozendict  # pylint:disable=g-import-not-at-top
73except ImportError:
74  frozendict = dict  # pylint:disable=invalid-name
75
76defaultdict = collections.defaultdict  # pylint:disable=invalid-name
77
78
79class SessionTest(test_util.TensorFlowTestCase):
80
81  def setUp(self):
82    super(SessionTest, self).setUp()
83    warnings.simplefilter('always')
84
85  def testUseExistingGraph(self):
86    with ops.Graph().as_default() as g, ops.device('/cpu:0'):
87      a = constant_op.constant(6.0, shape=[1, 1])
88      b = constant_op.constant(7.0, shape=[1, 1])
89      c = math_ops.matmul(a, b, name='matmul')
90    with session.Session(graph=g):
91      result = c.eval()
92      self.assertAllEqual(result, [[42.0]])
93
94  def testUseDefaultGraph(self):
95    with ops.Graph().as_default(), ops.device('/cpu:0'):
96      a = constant_op.constant(6.0, shape=[1, 1])
97      b = constant_op.constant(7.0, shape=[1, 1])
98      c = math_ops.matmul(a, b, name='matmul')
99      with session.Session():
100        result = c.eval()
101        self.assertAllEqual(result, [[42.0]])
102
103  def testCreate(self):
104    with session.Session():
105      inp = constant_op.constant(10.0, shape=[2, 3], name='W1')
106      copy = array_ops.identity(inp)
107      # Test with feed.
108      # TODO(mrry): Investigate why order='F' didn't work.
109      arr = np.asarray([[0, 1, 2], [3, 4, 5]], dtype=np.float32, order='C')
110      copy_val = copy.eval({'W1:0': arr})
111      self.assertAllEqual(arr, copy_val)
112      # Test without feed.
113      copy_val = copy.eval()
114      self.assertAllEqual(
115          np.asarray(
116              [[10.0, 10.0, 10.0], [10.0, 10.0, 10.0]], dtype=np.float32),
117          copy_val)
118
119  def testManyCPUs(self):
120    with session.Session(
121        config=config_pb2.ConfigProto(device_count={
122            'CPU': 2, 'GPU': 0
123        })) as sess:
124      inp = constant_op.constant(10.0, name='W1')
125      self.assertAllEqual(inp, 10.0)
126
127      num_cpu_devices = 0
128      num_gpu_devices = 0
129      for device in sess.list_devices():
130        device_type = framework_device_lib.DeviceSpec.from_string(
131            device.name).device_type
132        if device_type == 'CPU':
133          num_cpu_devices += 1
134        elif device_type == 'GPU':
135          num_gpu_devices += 1
136      self.assertEqual(2, num_cpu_devices)
137      self.assertEqual(0, num_gpu_devices)
138
139  def testPerSessionThreads(self):
140    with session.Session(
141        config=config_pb2.ConfigProto(use_per_session_threads=True)):
142      inp = constant_op.constant(10.0, name='W1')
143      self.assertAllEqual(inp, 10.0)
144
145  def testSessionInterOpThreadPool(self):
146    config_pb = config_pb2.ConfigProto()
147    pool = config_pb.session_inter_op_thread_pool.add()
148    with session.Session(config=config_pb) as s:
149      inp = constant_op.constant(10.0, name='W1')
150      results = s.run([inp])
151      self.assertAllEqual([10.0], results)
152
153    pool = config_pb.session_inter_op_thread_pool.add()
154    pool.num_threads = 1
155    with session.Session(config=config_pb) as s:
156      inp = constant_op.constant(20.0, name='W2')
157      results = s.run([inp])
158      self.assertAllEqual([20.0], results)
159
160    pool = config_pb.session_inter_op_thread_pool.add()
161    pool.num_threads = 1
162    pool.global_name = 't1'
163    run_options = config_pb2.RunOptions()
164    run_options.inter_op_thread_pool = (
165        len(config_pb.session_inter_op_thread_pool) - 1)
166    with session.Session(config=config_pb) as s:
167      inp = constant_op.constant(30.0, name='W2')
168      results = s.run([inp], options=run_options)
169      self.assertAllEqual([30.0], results)
170
171  def testErrorsReported(self):
172    with session.Session() as s:
173      constant_op.constant(10.0, name='W1')
174      with self.assertRaises(ValueError):
175        s.run('foo:0')
176
177  def testErrorPayload(self):
178    with session.Session():
179      a = array_ops.placeholder(dtypes.float32)
180      with self.assertRaisesOpError(lambda e: e.op == a.op):
181        a.eval()
182
183  def testErrorCodeWithNoNodeDef(self):
184    with session.Session() as s:
185      a = array_ops.placeholder(dtypes.float32, shape=[])
186      b = array_ops.placeholder(dtypes.float32, shape=[])
187      r1 = math_ops.add(a, b)
188
189      def exc_predicate(e):
190        return (e.op is None and e.node_def is None and
191                e.error_code == error_codes_pb2.INVALID_ARGUMENT)
192
193      with self.assertRaisesOpError(exc_predicate):
194        # Run with a bogus handle.
195        s.partial_run('foo', r1, feed_dict={a: 1, b: 2})
196
197  def testErrorBasedOn(self):
198    with session.Session() as sess:
199      a = constant_op.constant(0.0, shape=[2, 3])
200      # NOTE(mrry): The original_op is nonsense, but used here to test that the
201      #   errors are reported correctly.
202      with sess.graph._original_op(a.op):
203        b = array_ops.identity(a, name='id')
204      with sess.graph._original_op(b.op):
205        c = array_ops.placeholder(dtypes.float32)
206
207      def exc_predicate(e):
208        return (e.op == c.op and e.op._original_op == b.op and
209                e.op._original_op._original_op == a.op)
210
211      with self.assertRaisesOpError(exc_predicate):
212        c.eval()
213
214  def testFetchNone(self):
215    with session.Session() as s:
216      a = constant_op.constant(1.0)
217      with self.assertRaises(TypeError):
218        s.run(None)
219      with self.assertRaises(TypeError):
220        s.run([None])
221      with self.assertRaises(TypeError):
222        s.run({'b': None})
223      with self.assertRaises(TypeError):
224        s.run({'a': a, 'b': None})
225
226  def testFetchSingleton(self):
227    with session.Session() as sess:
228      a = constant_op.constant(42.0)
229      res = sess.run(a)
230      self.assertEqual(42.0, res)
231      res = sess.run(a.op)  # An op, not a tensor.
232      self.assertIsNone(res)
233      tensor_runner = sess.make_callable(a)
234      res = tensor_runner()
235      self.assertEqual(42.0, res)
236      op_runner = sess.make_callable(a.op)
237      res = op_runner()
238      self.assertIsNone(res)
239
240  def testFetchSingletonByName(self):
241    with session.Session() as sess:
242      a = constant_op.constant(42.0)
243      res = sess.run(a.name)
244      self.assertEqual(42.0, res)
245      res = sess.run(a.op)  # An op, not a tensor.
246      self.assertIsNone(res)
247
248  def testFetchList(self):
249    with session.Session() as sess:
250      a = constant_op.constant(42.0)
251      b = control_flow_ops.no_op()  # An op, not a tensor.
252      c = constant_op.constant(44.0)
253      v = variables.Variable([54.0])
254      assign = v.assign([63.0])
255      res = sess.run([a, b, c, a.name, assign.op])
256      self.assertIsInstance(res, list)
257      self.assertEqual([42.0, None, 44.0, 42.0, None], res)
258      list_runner = sess.make_callable([a, b, c, a.name, assign.op])
259      res = list_runner()
260      self.assertIsInstance(res, list)
261      self.assertEqual([42.0, None, 44.0, 42.0, None], res)
262
263  def testFetchTuple(self):
264    with session.Session() as sess:
265      a = constant_op.constant(42.0)
266      b = control_flow_ops.no_op()  # An op, not a tensor.
267      c = constant_op.constant(44.0)
268      res = sess.run((a, b, c, a.name))
269      self.assertIsInstance(res, tuple)
270      self.assertEqual((42.0, None, 44.0, 42.0), res)
271      tuple_runner = sess.make_callable((a, b, c, a.name))
272      res = tuple_runner()
273      self.assertIsInstance(res, tuple)
274      self.assertEqual((42.0, None, 44.0, 42.0), res)
275
276  def testFetchNamedTuple(self):
277    # pylint: disable=invalid-name
278    ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
279    # pylint: enable=invalid-name
280    with session.Session() as sess:
281      a = constant_op.constant(42.0)
282      b = control_flow_ops.no_op()  # An op, not a tensor.
283      c = constant_op.constant(44.0)
284      res = sess.run(ABC(a, b, c))
285      self.assertIsInstance(res, ABC)
286      self.assertEqual(42.0, res.a)
287      self.assertIsNone(res.b)
288      self.assertEqual(44.0, res.c)
289      namedtuple_runner = sess.make_callable(ABC(a, b, c))
290      res = namedtuple_runner()
291      self.assertIsInstance(res, ABC)
292      self.assertEqual(42.0, res.a)
293      self.assertIsNone(res.b)
294      self.assertEqual(44.0, res.c)
295
296  def testFetchDict(self):
297    with session.Session() as sess:
298      a = constant_op.constant(42.0)
299      b = control_flow_ops.no_op()  # An op, not a tensor.
300      c = constant_op.constant(44.0)
301      res = sess.run({'a': a, 'b': b, 'c': c})
302      self.assertIsInstance(res, dict)
303      self.assertEqual(42.0, res['a'])
304      self.assertIsNone(res['b'])
305      self.assertEqual(44.0, res['c'])
306
307  def testFetchOrderedDict(self):
308    with session.Session() as sess:
309      a = constant_op.constant(42.0)
310      b = control_flow_ops.no_op()  # An op, not a tensor.
311      c = constant_op.constant(44.0)
312      res = sess.run(collections.OrderedDict([(3, a), (2, b), (1, c)]))
313      self.assertIsInstance(res, collections.OrderedDict)
314      self.assertEqual([3, 2, 1], list(res.keys()))
315      self.assertEqual(42.0, res[3])
316      self.assertIsNone(res[2])
317      self.assertEqual(44.0, res[1])
318
319  @test_util.run_v1_only('b/120545219')
320  def testFetchAttrs(self):
321    if attr is None:
322      self.skipTest('attr module is unavailable.')
323
324    @attr.s
325    class SampleAttr(object):
326      field1 = attr.ib()
327      field2 = attr.ib()
328
329    val1 = np.array([1.2, 3.4, 5.6])
330    val2 = np.array([[1, 2], [4, 3]])
331    val3 = np.array([10, 20, 30])
332
333    t1 = constant_op.constant(val1)
334    t2 = constant_op.constant(val2)
335
336    sample = SampleAttr(t1, t2)
337    with session.Session() as sess:
338      result = sess.run(sample)
339      self.assertIsInstance(result, SampleAttr)
340      self.assertAllEqual(val1, result.field1)
341      self.assertAllEqual(val2, result.field2)
342
343      result = sess.run(sample, feed_dict={sample.field1: val3})
344      self.assertIsInstance(result, SampleAttr)
345      self.assertAllEqual(val3, result.field1)
346      self.assertAllEqual(val2, result.field2)
347
348  @test_util.run_v1_only('b/120545219')
349  def testFetchNestedAttrs(self):
350    if attr is None:
351      self.skipTest('attr module is unavailable.')
352
353    @attr.s
354    class SampleAttr(object):
355      field0 = attr.ib()
356      field1 = attr.ib()
357
358    v1 = 10
359    v2 = 20
360    v3 = np.float32(1.2)
361    v4 = np.float32(3.4)
362    v5 = np.float64(100.001)
363    v6 = np.float64(-23.451)
364    arr1 = np.array([1.2, 6.7, 3.4])
365    arr2 = np.array([7, 11, 3])
366    sample = SampleAttr(
367        SampleAttr(
368            SampleAttr(constant_op.constant(v1), constant_op.constant(v2)),
369            SampleAttr(constant_op.constant(arr1), constant_op.constant(arr2))),
370        {'A': SampleAttr(constant_op.constant(v3), constant_op.constant(v4)),
371         'B': [SampleAttr(constant_op.constant(v5), constant_op.constant(v6))]})
372
373    with session.Session() as sess:
374      result = sess.run(sample)
375      self.assertIsInstance(result, SampleAttr)
376      self.assertIsInstance(result.field0, SampleAttr)
377      self.assertIsInstance(result.field0.field0, SampleAttr)
378      self.assertIsInstance(result.field0.field1, SampleAttr)
379      self.assertIsInstance(result.field0.field1.field0, np.ndarray)
380      self.assertAllEqual(arr1, result.field0.field1.field0)
381      self.assertIsInstance(result.field0.field1.field1, np.ndarray)
382      self.assertAllEqual(arr2, result.field0.field1.field1)
383      self.assertIsInstance(result.field1, dict)
384      self.assertIn('A', result.field1)
385      self.assertIn('B', result.field1)
386      self.assertIsInstance(result.field1['A'], SampleAttr)
387      self.assertAllEqual(
388          [v3, v4],
389          [result.field1['A'].field0, result.field1['A'].field1])
390      self.assertIsInstance(result.field1['B'], list)
391      self.assertEqual(1, len(result.field1['B']))
392      self.assertIsInstance(result.field1['B'][0], SampleAttr)
393      self.assertAllEqual(
394          [v5, v6],
395          [result.field1['B'][0].field0, result.field1['B'][0].field1])
396
397  def testFetchNestingEmptyOneLevel(self):
398    with session.Session() as sess:
399      a_val = 11.0
400      a = constant_op.constant(a_val)
401
402      res = sess.run([[], tuple(), {}])
403      self.assertIsInstance(res, list)
404      self.assertEqual(3, len(res))
405      self.assertIsInstance(res[0], list)
406      self.assertEqual(0, len(res[0]))
407      self.assertIsInstance(res[1], tuple)
408      self.assertEqual(0, len(res[1]))
409      self.assertIsInstance(res[2], dict)
410      self.assertEqual(0, len(res[2]))
411
412      res = sess.run([[], tuple(), {}, a])
413      self.assertIsInstance(res, list)
414      self.assertEqual(4, len(res))
415      self.assertIsInstance(res[0], list)
416      self.assertEqual(0, len(res[0]))
417      self.assertIsInstance(res[1], tuple)
418      self.assertEqual(0, len(res[1]))
419      self.assertIsInstance(res[2], dict)
420      self.assertEqual(0, len(res[2]))
421      self.assertEqual(a_val, res[3])
422
423  def testFetchNestingOneLevel(self):
424    with session.Session() as sess:
425      # pylint: disable=invalid-name
426      ABC = collections.namedtuple('ABC', ['a', 'b', 'c'])
427      DEFGHI = collections.namedtuple('DEFGHI', ['d', 'e', 'f', 'g', 'h', 'i'])
428      # pylint: enable=invalid-name
429      a_val = 42.0
430      b_val = None
431      c_val = 44.0
432      a = constant_op.constant(a_val)
433      b = control_flow_ops.no_op()  # An op, not a tensor.
434      c = constant_op.constant(c_val)
435      test_dct = {'a': a.name, 'c': c, 'b': b}
436      test_dct_types = [dict, frozendict, defaultdict]
437      # List of lists, tuples, namedtuple, dict, frozendict, and defaultdict
438      res = sess.run([
439          [a, b, c],
440          (a, b, c),
441          ABC(a=a, b=b, c=c),
442          dict(test_dct),
443          frozendict(test_dct),
444          defaultdict(str, test_dct),
445      ])
446      self.assertIsInstance(res, list)
447      self.assertEqual(6, len(res))
448      self.assertIsInstance(res[0], list)
449      self.assertEqual(3, len(res[0]))
450      self.assertEqual(a_val, res[0][0])
451      self.assertEqual(b_val, res[0][1])
452      self.assertEqual(c_val, res[0][2])
453      self.assertIsInstance(res[1], tuple)
454      self.assertEqual(3, len(res[1]))
455      self.assertEqual(a_val, res[1][0])
456      self.assertEqual(b_val, res[1][1])
457      self.assertEqual(c_val, res[1][2])
458      self.assertIsInstance(res[2], ABC)
459      self.assertEqual(a_val, res[2].a)
460      self.assertEqual(b_val, res[2].b)
461      self.assertEqual(c_val, res[2].c)
462      for expected_type, r in zip(test_dct_types, res[3:]):
463        self.assertIsInstance(r, expected_type)
464        self.assertEqual(3, len(r))
465        self.assertEqual(a_val, r['a'])
466        self.assertEqual(b_val, r['b'])
467        self.assertEqual(c_val, r['c'])
468      self.assertEqual(res[5].default_factory, str)
469      # Tuple of lists, tuples, namedtuple, dict, frozendict, and defaultdict
470      res = sess.run(([a, b, c], (a.name, b, c), ABC(a=a, b=b,
471                                                     c=c), dict(test_dct),
472                      frozendict(test_dct), defaultdict(str, test_dct)))
473      self.assertIsInstance(res, tuple)
474      self.assertEqual(6, len(res))
475      self.assertIsInstance(res[0], list)
476      self.assertEqual(3, len(res[0]))
477      self.assertEqual(a_val, res[0][0])
478      self.assertEqual(b_val, res[0][1])
479      self.assertEqual(c_val, res[0][2])
480      self.assertIsInstance(res[1], tuple)
481      self.assertEqual(3, len(res[1]))
482      self.assertEqual(a_val, res[1][0])
483      self.assertEqual(b_val, res[1][1])
484      self.assertEqual(c_val, res[1][2])
485      self.assertIsInstance(res[2], ABC)
486      self.assertEqual(a_val, res[2].a)
487      self.assertEqual(b_val, res[2].b)
488      self.assertEqual(c_val, res[2].c)
489      for expected_type, r in zip(test_dct_types, res[3:]):
490        self.assertIsInstance(r, expected_type)
491        self.assertEqual(3, len(r))
492        self.assertEqual(a_val, r['a'])
493        self.assertEqual(b_val, r['b'])
494        self.assertEqual(c_val, r['c'])
495      self.assertEqual(res[5].default_factory, str)
496
497      # Namedtuple of lists, tuples, namedtuples, dict, frozendict, defaultdict
498      res = sess.run(
499          DEFGHI(
500              d=[a, b, c],
501              e=(a, b, c),
502              f=ABC(a=a.name, b=b, c=c),
503              g=dict(test_dct),
504              h=frozendict(test_dct),
505              i=defaultdict(str, test_dct)))
506      self.assertIsInstance(res, DEFGHI)
507      self.assertIsInstance(res.d, list)
508      self.assertEqual(3, len(res.d))
509      self.assertEqual(a_val, res.d[0])
510      self.assertEqual(b_val, res.d[1])
511      self.assertEqual(c_val, res.d[2])
512      self.assertIsInstance(res.e, tuple)
513      self.assertEqual(3, len(res.e))
514      self.assertEqual(a_val, res.e[0])
515      self.assertEqual(b_val, res.e[1])
516      self.assertEqual(c_val, res.e[2])
517      self.assertIsInstance(res.f, ABC)
518      self.assertEqual(a_val, res.f.a)
519      self.assertEqual(b_val, res.f.b)
520      self.assertEqual(c_val, res.f.c)
521      self.assertIsInstance(res.g, dict)
522      self.assertEqual(3, len(res.g))
523      self.assertEqual(a_val, res.g['a'])
524      self.assertEqual(b_val, res.g['b'])
525      self.assertEqual(c_val, res.g['c'])
526      self.assertIsInstance(res.h, frozendict)
527      self.assertEqual(3, len(res.h))
528      self.assertEqual(a_val, res.h['a'])
529      self.assertEqual(b_val, res.h['b'])
530      self.assertEqual(c_val, res.h['c'])
531      self.assertIsInstance(res.i, defaultdict)
532      self.assertEqual(3, len(res.i))
533      self.assertEqual(a_val, res.i['a'])
534      self.assertEqual(b_val, res.i['b'])
535      self.assertEqual(c_val, res.i['c'])
536      self.assertEqual(res.i.default_factory, str)
537      # Dict of lists, tuples, namedtuples, dict, frozendict, defaultdict
538      res = sess.run({
539          'd': [a, b, c],
540          'e': (a, b, c),
541          'f': ABC(a=a, b=b, c=c),
542          'g': dict(test_dct),
543          'h': frozendict(test_dct),
544          'i': defaultdict(str, test_dct),
545      })
546      self.assertIsInstance(res, dict)
547      self.assertEqual(6, len(res))
548      self.assertIsInstance(res['d'], list)
549      self.assertEqual(3, len(res['d']))
550      self.assertEqual(a_val, res['d'][0])
551      self.assertEqual(b_val, res['d'][1])
552      self.assertEqual(c_val, res['d'][2])
553      self.assertIsInstance(res['e'], tuple)
554      self.assertEqual(3, len(res['e']))
555      self.assertEqual(a_val, res['e'][0])
556      self.assertEqual(b_val, res['e'][1])
557      self.assertEqual(c_val, res['e'][2])
558      self.assertIsInstance(res['f'], ABC)
559      self.assertEqual(a_val, res['f'].a)
560      self.assertEqual(b_val, res['f'].b)
561      self.assertEqual(c_val, res['f'].c)
562      for expected_type, r_key in zip(test_dct_types, ('g', 'h', 'i')):
563        r = res[r_key]
564        self.assertIsInstance(r, expected_type)
565        self.assertEqual(3, len(r))
566        self.assertEqual(a_val, r['a'])
567        self.assertEqual(b_val, r['b'])
568        self.assertEqual(c_val, r['c'])
569      self.assertEqual(res['i'].default_factory, str)
570
571  def testFetchTensorObject(self):
572    with session.Session() as s:
573      a = constant_op.constant(1.0, shape=[1, 2])
574      b = constant_op.constant(2.0, shape=[2, 3])
575      c = math_ops.matmul(a, b)
576      results_with_list = s.run([c])
577      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_list[0])
578      results_with_single = s.run(c)
579      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_single)
580      results_with_get = c.eval()
581      self.assertAllEqual([[4.0, 4.0, 4.0]], results_with_get)
582      a_val, b_val = s.run([a, b])  # Test multiple fetches.
583      self.assertAllEqual([[1.0, 1.0]], a_val)
584      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]], b_val)
585      results_with_dict = s.run({'a': [a], 'b': b, 'z': [a, b]})
586      self.assertAllEqual([[1.0, 1.0]], results_with_dict['a'][0])
587      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
588                          results_with_dict['b'])
589      self.assertAllEqual(results_with_dict['a'][0], results_with_dict['z'][0])
590      self.assertAllEqual(results_with_dict['b'], results_with_dict['z'][1])
591
592      # Test nested structures
593      results_with_nested_list = s.run([[[a, b], b], a, [a, b]])
594      self.assertAllEqual([[1.0, 1.0]], results_with_nested_list[0][0][0])
595      self.assertAllEqual([[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]],
596                          results_with_nested_list[0][0][1])
597      self.assertAllEqual(results_with_nested_list[0][0][0],
598                          results_with_nested_list[1])
599      self.assertAllEqual(results_with_nested_list[1],
600                          results_with_nested_list[2][0])
601      self.assertAllEqual(results_with_nested_list[0][0][1],
602                          results_with_nested_list[0][1])
603      self.assertAllEqual(results_with_nested_list[0][1],
604                          results_with_nested_list[2][1])
605
606  def testFetchScalar(self):
607    with session.Session() as s:
608      for scalar in np.int32, np.int64, np.float16, np.float32, np.float64:
609        x = scalar(7)
610        y = scalar(8)
611        tf_x = constant_op.constant(x, shape=[])
612        tf_y = constant_op.constant(y)
613        tf_xy = math_ops.add(tf_x, tf_y)
614        # Single fetch
615        xy = s.run(tf_xy)
616        self.assertEqual(scalar, type(xy))
617        self.assertEqual(x + y, xy)
618        # List fetch
619        xy, = s.run([tf_xy])
620        self.assertEqual(scalar, type(xy))
621        self.assertEqual(x + y, xy)
622        # Dict fetch
623        xy = s.run({'xy': tf_xy})['xy']
624        self.assertEqual(scalar, type(xy))
625        self.assertEqual(x + y, xy)
626        # Nested list fetch
627        xy = s.run([[[tf_xy]], tf_xy, [tf_xy]])
628        self.assertAllEqual(xy, [[[x + y]], x + y, [x + y]])
629        self.assertEqual(scalar, type(xy[0][0][0]))
630        self.assertEqual(scalar, type(xy[1]))
631        self.assertEqual(scalar, type(xy[2][0]))
632
633  def testFetchOperationObject(self):
634    with session.Session() as s:
635      a = constant_op.constant(1.0, shape=[1, 2])
636      v = variables.Variable(a, name='testFetchOperationObject_v')
637      s.run(v.initializer)
638      v_val = s.run(v)
639      self.assertAllEqual([[1.0, 1.0]], v_val)
640
641  def testFetchSparseTensor(self):
642    with session.Session() as s:
643      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
644      values = np.array([1.0, 2.0]).astype(np.float32)
645      shape = np.array([7, 9, 2]).astype(np.int64)
646      sp = sparse_tensor.SparseTensor(
647          constant_op.constant(indices), constant_op.constant(values),
648          constant_op.constant(shape))
649      # Single fetch, use as tuple
650      sp_out = s.run(sp)
651      indices_out, values_out, shape_out = sp_out
652      self.assertAllEqual(indices_out, indices)
653      self.assertAllEqual(values_out, values)
654      self.assertAllEqual(shape_out, shape)
655      # Single fetch, use as SparseTensorValue
656      sp_out = s.run(sp)
657      self.assertAllEqual(sp_out.indices, indices)
658      self.assertAllEqual(sp_out.values, values)
659      self.assertAllEqual(sp_out.dense_shape, shape)
660      # Tuple fetch, use as tuple
661      indices_out, values_out, shape_out = s.run(sp)
662      self.assertAllEqual(indices_out, indices)
663      self.assertAllEqual(values_out, values)
664      self.assertAllEqual(shape_out, shape)
665      # List fetch, use as tuple
666      (indices_out, values_out, shape_out), = s.run([sp])
667      self.assertAllEqual(indices_out, indices)
668      self.assertAllEqual(values_out, values)
669      self.assertAllEqual(shape_out, shape)
670      # List fetch, use as SparseTensorValue
671      sp_out, = s.run([sp])
672      self.assertAllEqual(sp_out.indices, indices)
673      self.assertAllEqual(sp_out.values, values)
674      self.assertAllEqual(sp_out.dense_shape, shape)
675      # Dict fetch (single value), use as tuple
676      indices_out, values_out, shape_out = s.run({'sp': sp})['sp']
677      self.assertAllEqual(indices_out, indices)
678      self.assertAllEqual(values_out, values)
679      self.assertAllEqual(shape_out, shape)
680      # Dict fetch (list value), use as tuple
681      (indices_out, values_out, shape_out), = s.run({'sp': [sp]})['sp']
682      self.assertAllEqual(indices_out, indices)
683      self.assertAllEqual(values_out, values)
684      self.assertAllEqual(shape_out, shape)
685      # Dict fetch, use as SparseTensorValue
686      sp_out = s.run({'sp': sp})['sp']
687      self.assertAllEqual(sp_out.indices, indices)
688      self.assertAllEqual(sp_out.values, values)
689      self.assertAllEqual(sp_out.dense_shape, shape)
690      # Nested list fetch use as tuple
691      sp_out = s.run([[[sp]], sp])
692      indices_out, values_out, shape_out = sp_out[0][0][0]
693      self.assertAllEqual(indices_out, indices)
694      self.assertAllEqual(values_out, values)
695      self.assertAllEqual(shape_out, shape)
696      indices_out, values_out, shape_out = sp_out[1]
697      self.assertAllEqual(indices_out, indices)
698      self.assertAllEqual(values_out, values)
699      self.assertAllEqual(shape_out, shape)
700      # Nested list fetch, use as SparseTensorValue
701      sp_out = s.run([[[sp]], sp])
702      self.assertAllEqual(sp_out[0][0][0].indices, indices)
703      self.assertAllEqual(sp_out[0][0][0].values, values)
704      self.assertAllEqual(sp_out[0][0][0].dense_shape, shape)
705      self.assertAllEqual(sp_out[1].indices, indices)
706      self.assertAllEqual(sp_out[1].values, values)
707      self.assertAllEqual(sp_out[1].dense_shape, shape)
708
709  def testFeedSparseTensor(self):
710    with session.Session() as s:
711      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
712      values = np.array([1.0, 2.0]).astype(np.float32)
713      shape = np.array([7, 9, 2]).astype(np.int64)
714      sp = sparse_tensor.SparseTensor(
715          array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
716          array_ops.placeholder(dtype=np.float32, shape=(2,)),
717          array_ops.placeholder(dtype=np.int64, shape=(3,)),
718      )
719      sp_indices = array_ops.identity(sp.indices)
720      sp_values = array_ops.identity(sp.values)
721      sp_shape = array_ops.identity(sp.dense_shape)
722      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
723      # Feed with tuple
724      indices_out, values_out, shape_out = s.run(
725          [sp_indices, sp_values, sp_shape], {
726              sp: (indices, values, shape)
727          })
728      self.assertAllEqual(indices_out, indices)
729      self.assertAllEqual(values_out, values)
730      self.assertAllEqual(shape_out, shape)
731      # Feed with tuple, fetch sp directly
732      sp_out = s.run(sp, {sp: (indices, values, shape)})
733      self.assertAllEqual(sp_out.indices, indices)
734      self.assertAllEqual(sp_out.values, values)
735      self.assertAllEqual(sp_out.dense_shape, shape)
736      # Feed with SparseTensorValue
737      indices_out, values_out, shape_out = s.run(
738          [sp_indices, sp_values, sp_shape], {
739              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
740          })
741      self.assertAllEqual(indices_out, indices)
742      self.assertAllEqual(values_out, values)
743      self.assertAllEqual(shape_out, shape)
744      # Feed with SparseTensorValue, fetch SparseTensorValue
745      sp2_out = s.run(sp2, {
746          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
747      })
748      self.assertAllEqual(sp2_out.indices, indices)
749      self.assertAllEqual(sp2_out.values, values)
750      self.assertAllEqual(sp2_out.dense_shape, shape)
751      # Feed SparseTensorValue and fetch sp directly.
752      sp_out = s.run(sp, {
753          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
754      })
755      self.assertAllEqual(sp_out.indices, indices)
756      self.assertAllEqual(sp_out.values, values)
757      self.assertAllEqual(sp_out.dense_shape, shape)
758
759  def testFeedSparsePlaceholder(self):
760    with session.Session() as s:
761      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
762      values = np.array([1.0, 2.0]).astype(np.float32)
763      shape = np.array([7, 9, 2]).astype(np.int64)
764      sp = array_ops.sparse_placeholder(dtype=np.float32, name='placeholder1')
765      sp_indices = array_ops.identity(sp.indices)
766      sp_values = array_ops.identity(sp.values)
767      sp_shape = array_ops.identity(sp.dense_shape)
768      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
769      # Feed with tuple
770      indices_out, values_out, shape_out = s.run(
771          [sp_indices, sp_values, sp_shape], {
772              sp: (indices, values, shape)
773          })
774      self.assertAllEqual(indices_out, indices)
775      self.assertAllEqual(values_out, values)
776      self.assertAllEqual(shape_out, shape)
777      # Feed with SparseTensorValue
778      indices_out, values_out, shape_out = s.run(
779          [sp_indices, sp_values, sp_shape], {
780              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
781          })
782      self.assertAllEqual(indices_out, indices)
783      self.assertAllEqual(values_out, values)
784      self.assertAllEqual(shape_out, shape)
785      # Feed with SparseTensorValue, fetch SparseTensorValue
786      sp2_out = s.run(sp2, {
787          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
788      })
789      self.assertAllEqual(sp2_out.indices, indices)
790      self.assertAllEqual(sp2_out.values, values)
791      self.assertAllEqual(sp2_out.dense_shape, shape)
792
793  def testFeedSparsePlaceholderPartialShape(self):
794    with session.Session() as s:
795      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
796      values = np.array([1.0, 2.0]).astype(np.float32)
797      shape = np.array([7, 9, 2]).astype(np.int64)
798      sp = array_ops.sparse_placeholder(
799          shape=[None, 9, 2], dtype=np.float32, name='placeholder1')
800      sp_indices = array_ops.identity(sp.indices)
801      sp_values = array_ops.identity(sp.values)
802      sp_shape = array_ops.identity(sp.dense_shape)
803      sp2 = sparse_tensor.SparseTensor(sp_indices, sp_values, sp_shape)
804      # Feed with tuple
805      indices_out, values_out, shape_out = s.run(
806          [sp_indices, sp_values, sp_shape], {
807              sp: (indices, values, shape)
808          })
809      self.assertAllEqual(indices_out, indices)
810      self.assertAllEqual(values_out, values)
811      self.assertAllEqual(shape_out, shape)
812      # Feed with SparseTensorValue
813      indices_out, values_out, shape_out = s.run(
814          [sp_indices, sp_values, sp_shape], {
815              sp: sparse_tensor.SparseTensorValue(indices, values, shape)
816          })
817      self.assertAllEqual(indices_out, indices)
818      self.assertAllEqual(values_out, values)
819      self.assertAllEqual(shape_out, shape)
820      # Feed with SparseTensorValue, fetch SparseTensorValue
821      sp2_out = s.run(sp2, {
822          sp: sparse_tensor.SparseTensorValue(indices, values, shape)
823      })
824      self.assertAllEqual(sp2_out.indices, indices)
825      self.assertAllEqual(sp2_out.values, values)
826      self.assertAllEqual(sp2_out.dense_shape, shape)
827
828  def testFeedSparsePlaceholderConstantShape(self):
829    with session.Session() as s:
830      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
831      values = np.array([1.0, 2.0]).astype(np.float32)
832      shape = np.array([7, 9, 2]).astype(np.int64)
833      sp = array_ops.sparse_placeholder(
834          dtype=np.float32, shape=shape, name='placeholder1')
835      self.assertAllEqual(sp.dense_shape.eval(session=s), shape)
836      self.assertAllEqual(tensor_util.constant_value(sp.shape), shape)
837      sp_indices = array_ops.identity(sp.indices)
838      sp_values = array_ops.identity(sp.values)
839      sp_shape = array_ops.identity(sp.dense_shape)
840      # Feed with tuple
841      indices_out, values_out, shape_out = s.run(
842          [sp_indices, sp_values, sp_shape], {
843              sp: (indices, values)
844          })
845      self.assertAllEqual(indices_out, indices)
846      self.assertAllEqual(values_out, values)
847      self.assertAllEqual(shape_out, shape)
848
849  def testFetchIndexedSlices(self):
850    with session.Session() as s:
851      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
852      values = np.array([1.0, 2.0]).astype(np.float32)
853      dense_shape = np.array([7, 9, 2]).astype(np.int64)
854      ind = ops.IndexedSlices(
855          constant_op.constant(values), constant_op.constant(indices),
856          constant_op.constant(dense_shape))
857      # Single fetch, use as tuple
858      ind_out = s.run(ind)
859      values_out, indices_out, dense_shape_out = ind_out
860      self.assertAllEqual(values_out, values)
861      self.assertAllEqual(indices_out, indices)
862      self.assertAllEqual(dense_shape_out, dense_shape)
863      # Single fetch, use as IndexedSlicesValue
864      ind_out = s.run(ind)
865      self.assertAllEqual(ind_out.values, values)
866      self.assertAllEqual(ind_out.indices, indices)
867      self.assertAllEqual(ind_out.dense_shape, dense_shape)
868      # Tuple fetch, use as tuple
869      values_out, indices_out, dense_shape_out = s.run(ind)
870      self.assertAllEqual(values_out, values)
871      self.assertAllEqual(indices_out, indices)
872      self.assertAllEqual(dense_shape_out, dense_shape)
873      # List fetch, use as tuple
874      (values_out, indices_out, dense_shape_out), = s.run([ind])
875      self.assertAllEqual(values_out, values)
876      self.assertAllEqual(indices_out, indices)
877      self.assertAllEqual(dense_shape_out, dense_shape)
878      # List fetch, use as IndexedSlicesValue
879      ind_out, = s.run([ind])
880      self.assertAllEqual(ind_out.values, values)
881      self.assertAllEqual(ind_out.indices, indices)
882      self.assertAllEqual(ind_out.dense_shape, dense_shape)
883
884  def testFeedIndexedSlices(self):
885    with session.Session() as s:
886      values = np.array([1.0, 2.0]).astype(np.float32)
887      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
888      dense_shape = np.array([7, 9, 2]).astype(np.int64)
889      ind = ops.IndexedSlices(
890          array_ops.placeholder(dtype=np.float32, shape=(2,)),
891          array_ops.placeholder(dtype=np.int64, shape=(2, 3)),
892          array_ops.placeholder(dtype=np.int64, shape=(3,)),
893      )
894      ind_values = array_ops.identity(ind.values)
895      ind_indices = array_ops.identity(ind.indices)
896      ind_dense_shape = array_ops.identity(ind.dense_shape)
897      ind2 = ops.IndexedSlices(ind_values, ind_indices, ind_dense_shape)
898      # Feed with tuple
899      values_out, indices_out, dense_shape_out = s.run(
900          [ind_values, ind_indices, ind_dense_shape], {
901              ind: (values, indices, dense_shape)
902          })
903      self.assertAllEqual(values_out, values)
904      self.assertAllEqual(indices_out, indices)
905      self.assertAllEqual(dense_shape_out, dense_shape)
906      # Feed with IndexedSlicesValue
907      values_out, indices_out, dense_shape_out = s.run(
908          [ind_values, ind_indices, ind_dense_shape], {
909              ind: ops.IndexedSlicesValue(values, indices, dense_shape)
910          })
911      self.assertAllEqual(values_out, values)
912      self.assertAllEqual(indices_out, indices)
913      self.assertAllEqual(dense_shape_out, dense_shape)
914      # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
915      ind2_out = s.run(ind2, {
916          ind: ops.IndexedSlicesValue(values, indices, dense_shape)
917      })
918      self.assertAllEqual(ind2_out.values, values)
919      self.assertAllEqual(ind2_out.indices, indices)
920      self.assertAllEqual(ind2_out.dense_shape, dense_shape)
921
922  def testFetchIndexedSlicesWithoutDenseShape(self):
923    with session.Session() as s:
924      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
925      values = np.array([1.0, 2.0]).astype(np.float32)
926      dense_shape = None
927      ind = ops.IndexedSlices(
928          constant_op.constant(values), constant_op.constant(indices), None)
929      # Single fetch, use as tuple
930      ind_out = s.run(ind)
931      values_out, indices_out, dense_shape_out = ind_out
932      self.assertAllEqual(values_out, values)
933      self.assertAllEqual(indices_out, indices)
934      self.assertAllEqual(dense_shape_out, dense_shape)
935      # Single fetch, use as IndexedSlicesValue
936      ind_out = s.run(ind)
937      self.assertAllEqual(ind_out.values, values)
938      self.assertAllEqual(ind_out.indices, indices)
939      self.assertAllEqual(ind_out.dense_shape, dense_shape)
940      # Tuple fetch, use as tuple
941      values_out, indices_out, dense_shape_out = s.run(ind)
942      self.assertAllEqual(values_out, values)
943      self.assertAllEqual(indices_out, indices)
944      self.assertAllEqual(dense_shape_out, dense_shape)
945      # List fetch, use as tuple
946      (values_out, indices_out, dense_shape_out), = s.run([ind])
947      self.assertAllEqual(values_out, values)
948      self.assertAllEqual(indices_out, indices)
949      self.assertAllEqual(dense_shape_out, dense_shape)
950      # List fetch, use as IndexedSlicesValue
951      ind_out, = s.run([ind])
952      self.assertAllEqual(ind_out.values, values)
953      self.assertAllEqual(ind_out.indices, indices)
954      self.assertAllEqual(ind_out.dense_shape, dense_shape)
955
956  def testFeedIndexedSlicesWithoutDenseShape(self):
957    with session.Session() as s:
958      values = np.array([1.0, 2.0]).astype(np.float32)
959      indices = np.array([[3, 2, 0], [4, 5, 1]]).astype(np.int64)
960      dense_shape = None
961      ind = ops.IndexedSlices(
962          array_ops.placeholder(dtype=np.float32, shape=(2,)),
963          array_ops.placeholder(dtype=np.int64, shape=(2, 3)), None)
964      ind_values = array_ops.identity(ind.values)
965      ind_indices = array_ops.identity(ind.indices)
966      ind2 = ops.IndexedSlices(ind_values, ind_indices)
967      # Feed with tuple
968      values_out, indices_out = s.run([ind_values, ind_indices], {
969          ind: (values, indices)
970      })
971      self.assertAllEqual(values_out, values)
972      self.assertAllEqual(indices_out, indices)
973      # Feed with IndexedSlicesValue
974      values_out, indices_out = s.run([ind_values, ind_indices], {
975          ind: ops.IndexedSlicesValue(values, indices, dense_shape)
976      })
977      self.assertAllEqual(values_out, values)
978      self.assertAllEqual(indices_out, indices)
979      # Feed with IndexedSlicesValue, fetch IndexedSlicesValue
980      ind2_out = s.run(ind2, {
981          ind: ops.IndexedSlicesValue(values, indices, dense_shape)
982      })
983      self.assertAllEqual(ind2_out.values, values)
984      self.assertAllEqual(ind2_out.indices, indices)
985      self.assertAllEqual(ind2_out.dense_shape, dense_shape)
986
987  def testExtendWithStatelessOperations(self):
988    with session.Session() as s:
989      a = constant_op.constant(1.0, shape=[1, 2])
990      b = constant_op.constant(2.0, shape=[2, 3])
991      c = math_ops.matmul(a, b)
992      c_val = s.run(c)
993      self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
994      d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
995      e = math_ops.matmul(c, d)
996      # Extend will happen here.
997      e_val = s.run(e)
998      self.assertAllEqual([[24.0]], e_val)
999
1000  def testExtendWithStatefulOperations(self):
1001    with session.Session() as s:
1002      a = constant_op.constant(1.0, shape=[1, 2])
1003      b = constant_op.constant(2.0, shape=[2, 3])
1004      c = math_ops.matmul(a, b)
1005      v = variables.Variable(c, name='testExtendWithStatefulOperations_v')
1006      v.initializer.run()
1007      v_val = v.eval()
1008      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1009      d = constant_op.constant(3.0, shape=[2, 3])
1010      e = math_ops.matmul(a, d)
1011      assign_e_to_v = state_ops.assign(v, e)
1012      # Extend will happen here.
1013      e_val = e.eval()
1014      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1015      v_val = v.eval()
1016      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1017      s.run(assign_e_to_v)
1018      v_val = v.eval()
1019      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1020
1021  def testExtendWithGroupBy(self):
1022    with session.Session() as s:
1023      a = constant_op.constant(1.0, shape=[1, 2])
1024      p = variables.Variable(a, name='testExtendWithGroupBy_p')
1025      a_val = a.eval()  # Force an Extend after this op.
1026      self.assertAllEqual([[1.0, 1.0]], a_val)
1027
1028      b = constant_op.constant(2.0, shape=[1, 2])
1029      q = variables.Variable(b, name='testExtendWithGroupBy_q')
1030      # Extend will happen here.
1031      init = control_flow_ops.group(p.initializer, q.initializer)
1032      s.run(init)
1033      p_val, q_val = s.run([p, q])
1034
1035      self.assertAllEqual([[1.0, 1.0]], p_val)
1036      self.assertAllEqual([[2.0, 2.0]], q_val)
1037
1038  def testTensorGetMethod(self):
1039    with session.Session():
1040      a = constant_op.constant(1.0, shape=[1, 2])
1041      b = constant_op.constant(2.0, shape=[2, 3])
1042      c = math_ops.matmul(a, b)
1043
1044      c_val = c.eval()
1045      self.assertAllEqual([[4.0, 4.0, 4.0]], c_val)
1046
1047      fed_c_val = c.eval(feed_dict={a.name: [[4.0, 4.0]]})
1048      self.assertAllEqual([[16.0, 16.0, 16.0]], fed_c_val)
1049
1050  @test_util.run_v1_only('b/120545219')
1051  def testOperationRunMethod(self):
1052    with session.Session():
1053      a = constant_op.constant(1.0, shape=[1, 2])
1054      b = constant_op.constant(2.0, shape=[1, 2], name='b')
1055      v = variables.VariableV1(a, a.dtype)
1056      assign_a_to_v = state_ops.assign(v, a)
1057
1058      assign_a_to_v.eval()
1059
1060      v_val = v.eval()
1061      self.assertAllEqual([[1.0, 1.0]], v_val)
1062
1063      assign_b_to_v = state_ops.assign(v, b)
1064
1065      assign_b_to_v.eval()
1066      v_val = v.eval()
1067      self.assertAllEqual([[2.0, 2.0]], v_val)
1068
1069      assign_b_to_v.eval(feed_dict={'b:0': [[3.0, 3.0]]})
1070      v_val = v.eval()
1071      self.assertAllEqual([[3.0, 3.0]], v_val)
1072
1073  def testDefaultGraph(self):
1074    with session.Session() as s:
1075      self.assertEqual(ops.get_default_graph(), s.graph)
1076      a = constant_op.constant(1.0, shape=[1, 2])
1077      b = constant_op.constant(2.0, shape=[2, 3])
1078      self.assertEqual(ops.get_default_graph(), a.graph)
1079      self.assertEqual(ops.get_default_graph(), b.graph)
1080      c = math_ops.matmul(a, b)
1081      v = variables.Variable(c, name='testDefaultGraph_v')
1082      v.initializer.run()
1083      v_val = v.eval()
1084      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1085      d = constant_op.constant(3.0, shape=[2, 3])
1086      e = math_ops.matmul(a, d)
1087      assign_e_to_v = state_ops.assign(v, e)
1088      e_val = e.eval()
1089      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1090      v_val = v.eval()
1091      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1092      s.run(assign_e_to_v)
1093      v_val = v.eval()
1094      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1095      self.assertEqual(ops.get_default_graph(), s.graph)
1096
1097  def _testDefaultGraphInThread(self, constructed_event, continue_event, i):
1098    with session.Session() as s:
1099      self.assertEqual(ops.get_default_graph(), s.graph)
1100      a = constant_op.constant(1.0, shape=[1, 2])
1101      b = constant_op.constant(2.0, shape=[2, 3])
1102      c = math_ops.matmul(a, b)
1103      v = variables.Variable(c, name='var_%d' % i)
1104
1105      # Block here until all threads have constructed their graph.
1106      constructed_event.set()
1107      continue_event.wait()
1108
1109      assign_c_to_v = state_ops.assign(v, c)
1110      v.initializer.run()
1111      assign_c_to_v.eval()
1112      v_val = v.eval()
1113      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1114      d = constant_op.constant(3.0, shape=[2, 3])
1115      e = math_ops.matmul(a, d)
1116      assign_e_to_v = state_ops.assign(v, e)
1117      e_val = e.eval()
1118      self.assertAllEqual([[6.0, 6.0, 6.0]], e_val)
1119      v_val = v.eval()
1120      self.assertAllEqual([[4.0, 4.0, 4.0]], v_val)
1121      s.run(assign_e_to_v)
1122      v_val = v.eval()
1123      self.assertAllEqual([[6.0, 6.0, 6.0]], v_val)
1124      self.assertEqual(ops.get_default_graph(), s.graph)
1125
1126  def testDefaultGraphWithThreads(self):
1127    # Fork ten threads that use their thread-local default graph.
1128    threads = []
1129    constructed_events = [threading.Event() for _ in range(10)]
1130    continue_event = threading.Event()
1131    for i, constructed_event in enumerate(constructed_events):
1132      t = self.checkedThread(
1133          target=self._testDefaultGraphInThread,
1134          args=(constructed_event, continue_event, i))
1135      threads.append(t)
1136    for t in threads:
1137      t.start()
1138    for constructed_event in constructed_events:
1139      constructed_event.wait()
1140    continue_event.set()
1141    for t in threads:
1142      t.join()
1143
1144  def testParallelRun(self):
1145    with session.Session() as sess:
1146      c = constant_op.constant(5.0)
1147      ev = threading.Event()
1148
1149      def run_step():
1150        ev.wait()
1151        val = c.eval(session=sess)
1152        self.assertEqual(val, 5.0)
1153
1154      threads = [self.checkedThread(target=run_step) for _ in range(100)]
1155      for t in threads:
1156        t.start()
1157      ev.set()
1158      for t in threads:
1159        t.join()
1160
1161  @staticmethod
1162  def _build_graph():
1163    time.sleep(random.random() * 0.1)
1164    # Do some graph construction. Try to exercise non-trivial paths.
1165    graph = ops.get_default_graph()
1166    gdef = None
1167    for _ in range(10):
1168      x = array_ops.placeholder(dtype=dtypes.float32)
1169      with ops.colocate_with(x):
1170        y = array_ops.placeholder(dtype=dtypes.float32)
1171      with ops.device('/cpu:0'):
1172        z = control_flow_ops.while_loop(
1173            lambda x, y: x < 10, lambda x, y: (x + 1, x * y), [x, y])
1174      with graph._attr_scope({'_a': attr_value_pb2.AttrValue(b=False)}):
1175        gradients_impl.gradients(z, [x, y])
1176        if gdef is None:
1177          gdef = graph.as_graph_def()
1178        else:
1179          importer.import_graph_def(gdef, name='import')
1180
1181  @test_util.run_v1_only('b/120545219')
1182  def testParallelRunAndSingleBuild(self):
1183    with session.Session() as sess:
1184      c = constant_op.constant(5.0)
1185      stop = threading.Event()
1186
1187      def run_loop():
1188        while not stop.is_set():
1189          time.sleep(random.random() * 0.1)
1190          self.assertEqual(sess.run(c), 5.0)
1191
1192      threads = [self.checkedThread(target=run_loop) for _ in range(10)]
1193      for t in threads:
1194        t.start()
1195
1196      SessionTest._build_graph()
1197
1198      stop.set()
1199      for t in threads:
1200        t.join()
1201
1202  @test_util.run_v1_only('b/120545219')
1203  def testParallelRunAndParallelBuild(self):
1204    with session.Session() as sess:
1205      c = constant_op.constant(5.0)
1206      stop = threading.Event()
1207
1208      def run_loop():
1209        while not stop.is_set():
1210          time.sleep(random.random() * 0.1)
1211          self.assertEqual(sess.run(c), 5.0)
1212
1213      run_threads = [self.checkedThread(target=run_loop) for _ in range(10)]
1214      for t in run_threads:
1215        t.start()
1216
1217      build_threads = [self.checkedThread(target=SessionTest._build_graph)
1218                       for _ in range(10)]
1219      for t in build_threads:
1220        t.start()
1221      for t in build_threads:
1222        t.join()
1223
1224      # Let the run_threads run until the build threads are finished.
1225      stop.set()
1226      for t in run_threads:
1227        t.join()
1228
1229  def testRunFeedDict(self):
1230    with session.Session() as s:
1231      x = array_ops.zeros([2])
1232
1233      y = s.run(2 * x, feed_dict={x: np.ones(2).astype(np.float32)})
1234      self.assertAllEqual(y, 2 * np.ones(2))
1235
1236      y = s.run(2 * x, feed_dict={x.name: np.ones(2).astype(np.float32)})
1237      self.assertAllEqual(y, 2 * np.ones(2))
1238
1239      y = s.run(2 * x, feed_dict={x: [1, 1]})
1240      assert (y == 2 * np.ones(2)).all()
1241
1242      # Test nested tuple keys
1243      z = (((array_ops.zeros([2]),),), array_ops.zeros([2]),
1244           (array_ops.zeros([2]),))
1245      result = [z[0][0][0] * 2, z[1] * 2, z[2][0] * 2]
1246      values = (((np.array([1, 1]),),), np.array([2, 2]), (np.array([3, 3]),))
1247      result_value = s.run(result, feed_dict={z: values})
1248      self.assertAllEqual(result_value[0], 2 * np.ones(2))
1249      self.assertAllEqual(result_value[1], 2 * np.array([2, 2]))
1250      self.assertAllEqual(result_value[2], 2 * np.array([3, 3]))
1251
1252  def testGraphDef(self):
1253    with session.Session() as sess:
1254      self.assertProtoEquals('versions { producer: %d min_consumer: %d }' %
1255                             (versions.GRAPH_DEF_VERSION,
1256                              versions.GRAPH_DEF_VERSION_MIN_CONSUMER),
1257                             sess.graph_def)
1258      c = constant_op.constant(5.0, name='c')
1259      self.assertEqual(len(sess.graph_def.node), 1)
1260      d = constant_op.constant(6.0, name='d')
1261      self.assertEqual(len(sess.graph_def.node), 2)
1262      self.assertAllEqual(c, 5.0)
1263      self.assertAllEqual(d, 6.0)
1264      e = constant_op.constant(7.0, name='e')
1265      self.assertEqual(len(sess.graph_def.node), 3)
1266      self.assertAllEqual(e, 7.0)
1267
1268  def testUseAfterClose(self):
1269    with session.Session() as sess:
1270      c = constant_op.constant(5.0)
1271      self.assertAllEqual(sess.run(c), 5.0)
1272    with self.assertRaisesWithPredicateMatch(
1273        RuntimeError, lambda e: 'Attempted to use a closed Session.' in str(e)):
1274      sess.run(c)
1275
1276  def testUseAfterCloseConcurrent(self):
1277    with session.Session() as sess:
1278      c = constant_op.constant(5.0)
1279      self.assertAllEqual(sess.run(c), 5.0)
1280
1281      def update_thread():
1282        with self.assertRaisesWithPredicateMatch(
1283            RuntimeError,
1284            lambda e: 'Attempted to use a closed Session.' in str(e)):
1285          while True:
1286            sess.run(c)
1287
1288      t = threading.Thread(target=update_thread)
1289      t.start()
1290      time.sleep(0.1)
1291      sess.close()
1292      t.join()
1293
1294  def testUseEmptyGraph(self):
1295    with session.Session() as sess:
1296      with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
1297        sess.run([])
1298      with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
1299        sess.run(())
1300      with self.assertRaisesRegex(RuntimeError, 'The Session graph is empty.'):
1301        sess.run({})
1302
1303  @test_util.run_v1_only('b/120545219')
1304  def testNotEntered(self):
1305    # pylint: disable=protected-access
1306    self.assertIsNone(ops._default_session_stack.get_default())
1307    # pylint: enable=protected-access
1308    with ops.device('/cpu:0'):
1309      sess = session.Session()
1310      c_1 = constant_op.constant(5.0)
1311      with sess.graph.as_default():
1312        c_2 = constant_op.constant(5.0)
1313      self.assertEqual(c_1.graph, c_2.graph)
1314      self.assertEqual(sess.run(c_2), 5.0)
1315      with self.assertRaisesWithPredicateMatch(
1316          ValueError, lambda e: 'No default session is registered.' in str(e)):
1317        c_2.eval()
1318
1319  @test_util.run_v1_only('b/120545219')
1320  def testInteractive(self):
1321    with ops.device('/cpu:0'):
1322      sess = session.InteractiveSession()
1323      a = constant_op.constant(1.0, shape=[1, 2])
1324      b = constant_op.constant(2.0, shape=[2, 3])
1325      c = math_ops.matmul(a, b)
1326      self.assertAllEqual([[4.0, 4.0, 4.0]], c)
1327      d = constant_op.constant([1.0, 2.0, 3.0], shape=[3, 1])
1328      e = math_ops.matmul(c, d)
1329      self.assertAllEqual([[24.0]], e)
1330      sess.close()
1331
1332  @test_util.run_v1_only('b/120545219')
1333  def testMultipleInteractiveSessionsWarning(self):
1334    # Reinitialize the global state to ensure that the expected warnings will
1335    # be emitted.
1336    session.InteractiveSession._active_session_count = 0  # pylint: disable=protected-access
1337
1338    sess = session.InteractiveSession()
1339    sess.run(constant_op.constant(4.0))  # Run so that the session is "opened".
1340    sess.close()
1341    # Opening and closing interactive sessions serially should not warn.
1342    with warnings.catch_warnings(record=True) as w:
1343      sess = session.InteractiveSession()
1344      sess.close()
1345    self.assertEqual(0, len(w))
1346
1347    with warnings.catch_warnings(record=True) as w:
1348      sess = session.InteractiveSession()
1349    self.assertEqual(0, len(w))
1350    with warnings.catch_warnings(record=True) as w:
1351      sess2 = session.InteractiveSession()
1352    self.assertEqual(1, len(w))
1353    self.assertIn('An interactive session is already active. This can cause '
1354                  'out-of-memory errors in some cases. You must explicitly '
1355                  'call `InteractiveSession.close()` to release resources '
1356                  'held by the other session(s).', str(w[0].message))
1357    sess2.close()
1358    sess.close()
1359
1360  @test_util.run_v1_only('b/120545219')
1361  def testInteractivePlacePrunedGraph(self):
1362    sess = session.InteractiveSession()
1363
1364    # Build a graph that has a bad op in it (no kernel).
1365    #
1366    # This test currently does not link in any GPU kernels,
1367    # which is why placing this is invalid.  If at some point
1368    # GPU kernels are added to this test, some other different
1369    # op / device combo should be chosen.
1370    with ops.device('/device:GPU:0'):
1371      a = constant_op.constant(1.0, shape=[1, 2])
1372
1373    b = constant_op.constant(1.0, shape=[1, 2])
1374
1375    # Only run the valid op, this should work.
1376    b.eval()
1377
1378    with self.assertRaises(errors.InvalidArgumentError):
1379      a.eval()
1380    sess.close()
1381
1382  @test_util.run_v1_only('b/120545219')
1383  def testDefaultSessionPlacePrunedGraph(self):
1384    sess = session.Session()
1385
1386    # Build a graph that has a bad op in it (no kernel).
1387    #
1388    # This test currently does not link in any GPU kernels,
1389    # which is why placing this is invalid.  If at some point
1390    # GPU kernels are added to this test, some other different
1391    # op / device combo should be chosen.
1392    with ops.device('/device:GPU:0'):
1393      _ = constant_op.constant(1.0, shape=[1, 2])
1394
1395    b = constant_op.constant(1.0, shape=[1, 2])
1396
1397    with self.assertRaises(errors.InvalidArgumentError):
1398      # Even though we don't run the bad op, we place the entire
1399      # graph, which should fail with a non-interactive session.
1400      sess.run(b)
1401
1402    sess.close()
1403
1404  def testSharedGraph(self):
1405    with ops.Graph().as_default() as g, ops.device('/cpu:0'):
1406      a = constant_op.constant(1.0, shape=[1, 2])
1407      b = constant_op.constant(2.0, shape=[2, 3])
1408      c = math_ops.matmul(a, b)
1409
1410    with session.Session(graph=g) as sess1:
1411      with session.Session(graph=g) as sess2:
1412        self.assertAllEqual(sess1.run(c), sess2.run(c))
1413
1414  def testDuplicatedInputs(self):
1415    with session.Session() as sess:
1416      a = constant_op.constant(1.0, shape=[1, 2])
1417      b = constant_op.constant(2.0, shape=[1, 3])
1418      a_val, b_val, a2_val = sess.run([a, b, a])
1419      self.assertAllEqual(a_val, [[1.0, 1.0]])
1420      self.assertAllEqual(b_val, [[2.0, 2.0, 2.0]])
1421      self.assertAllEqual(a2_val, [[1.0, 1.0]])
1422
1423  def testFeedAndFetch(self):
1424    with session.Session() as sess:
1425      for dtype in [
1426          dtypes.float16, dtypes.float32, dtypes.float64, dtypes.int32,
1427          dtypes.uint8, dtypes.int16, dtypes.int8, dtypes.int64, dtypes.bool,
1428          dtypes.complex64, dtypes.complex128
1429      ]:
1430        for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1431          np_dtype = dtype.as_numpy_dtype
1432
1433          feed_t = array_ops.placeholder(dtype=dtype, shape=shape)
1434          out_t = array_ops.identity(feed_t)
1435
1436          np_array = np.random.randint(-10, 10, shape)
1437
1438          if dtype == dtypes.bool:
1439            np_array = np_array > 0
1440          elif dtype == dtypes.complex64:
1441            np_array = np.sqrt(np_array.astype(np_dtype))
1442          elif dtype == dtypes.complex64:
1443            np_array = np.sqrt(np_array.astype(np_dtype))
1444          else:
1445            np_array = np_array.astype(np_dtype)
1446
1447          self.assertAllEqual(np_array,
1448                              sess.run(out_t, feed_dict={
1449                                  feed_t: np_array
1450                              }))
1451          # Check that we can also get the feed back.
1452          self.assertAllEqual(np_array,
1453                              sess.run(feed_t, feed_dict={
1454                                  feed_t: np_array
1455                              }))
1456          # Also check that we can get both back.
1457          out_v, feed_v = sess.run(
1458              [out_t, feed_t], feed_dict={
1459                  feed_t: np_array
1460              })
1461          self.assertAllEqual(np_array, out_v)
1462          self.assertAllEqual(np_array, feed_v)
1463
1464          feed_fetch_runner = sess.make_callable([out_t, feed_t], [feed_t])
1465          out_v, feed_v = feed_fetch_runner(np_array)
1466          self.assertAllEqual(np_array, out_v)
1467          self.assertAllEqual(np_array, feed_v)
1468
1469  def testMakeCallableOnTensorWithRunOptions(self):
1470    with session.Session() as sess:
1471      a = constant_op.constant(42.0)
1472      tensor_runner = sess.make_callable(a, accept_options=True)
1473      run_options = config_pb2.RunOptions(
1474          trace_level=config_pb2.RunOptions.FULL_TRACE)
1475      run_metadata = config_pb2.RunMetadata()
1476      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1477      res = tensor_runner(options=run_options, run_metadata=run_metadata)
1478      self.assertEqual(42.0, res)
1479      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1480
1481  def testMakeCallableOnOperationWithRunOptions(self):
1482    with session.Session() as sess:
1483      a = variables.Variable(42.0)
1484      b = state_ops.assign_add(a, 1.0)
1485      sess.run(a.initializer)
1486      tensor_runner = sess.make_callable(b.op, accept_options=True)
1487      run_options = config_pb2.RunOptions(
1488          trace_level=config_pb2.RunOptions.FULL_TRACE)
1489      run_metadata = config_pb2.RunMetadata()
1490      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1491      tensor_runner(options=run_options, run_metadata=run_metadata)
1492      self.assertEqual(43.0, sess.run(a))
1493      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1494
1495  def testMakeCallableWithFeedListAndRunOptions(self):
1496    with session.Session() as sess:
1497      ph = array_ops.placeholder(dtypes.float32)
1498      a = math_ops.add(ph, 1.0)
1499      tensor_runner = sess.make_callable(
1500          a, feed_list=[ph.name], accept_options=True)
1501      run_options = config_pb2.RunOptions(
1502          trace_level=config_pb2.RunOptions.FULL_TRACE)
1503      run_metadata = config_pb2.RunMetadata()
1504      self.assertEqual(0, len(run_metadata.step_stats.dev_stats))
1505      self.assertAllClose(42.0,
1506                          tensor_runner(
1507                              41.0,
1508                              options=run_options,
1509                              run_metadata=run_metadata))
1510      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1511
1512  def testOptimizedMakeCallable(self):
1513    with session.Session() as sess:
1514      ph = array_ops.placeholder(dtypes.float32)
1515      a = math_ops.add(ph, 1.0)
1516      callable_opts = config_pb2.CallableOptions()
1517      callable_opts.feed.append(ph.name)
1518      callable_opts.fetch.append(a.name)
1519      for _ in range(3):
1520        callable_fn = sess._make_callable_from_options(callable_opts)
1521        for _ in range(5):
1522          self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32)))
1523
1524  def testOptimizedMakeCallableWithRunMetadata(self):
1525    with session.Session() as sess:
1526      ph = array_ops.placeholder(dtypes.float32)
1527      a = math_ops.add(ph, 1.0)
1528      callable_opts = config_pb2.CallableOptions()
1529      callable_opts.feed.append(ph.name)
1530      callable_opts.fetch.append(a.name)
1531      callable_opts.run_options.trace_level = config_pb2.RunOptions.FULL_TRACE
1532      callable_fn = sess._make_callable_from_options(callable_opts)
1533      run_metadata = config_pb2.RunMetadata()
1534      self.assertEqual([2.0], callable_fn(np.array(1.0, dtype=np.float32),
1535                                          run_metadata=run_metadata))
1536      self.assertGreater(len(run_metadata.step_stats.dev_stats), 0)
1537
1538  def testFeedError(self):
1539    with session.Session() as sess:
1540      feed_t = array_ops.placeholder(dtype=dtypes.float32)
1541      out_t = array_ops.identity(feed_t)
1542      feed_val = constant_op.constant(5.0)
1543      with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
1544        sess.run(out_t, feed_dict={feed_t: feed_val})
1545      with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
1546        out_t.eval(feed_dict={feed_t: feed_val})
1547      with self.assertRaisesRegex(TypeError, 'cannot be a tf.Tensor object'):
1548        out_t.op.run(feed_dict={feed_t: feed_val})
1549
1550  def testFeedPrecisionLossError(self):
1551    with session.Session() as sess:
1552      largest_int64 = np.iinfo(np.int64).max
1553
1554      feed_int_implicit_int32 = constant_op.constant(1)
1555      feed_int_explicit_int32 = constant_op.constant(1, dtype=dtypes.int32)
1556
1557      out_t = constant_op.constant(1.0)
1558
1559      with self.assertRaisesRegex(TypeError,
1560                                  'is not compatible with Tensor type'):
1561        sess.run(out_t, feed_dict={feed_int_implicit_int32: largest_int64})
1562      with self.assertRaisesRegex(TypeError,
1563                                  'is not compatible with Tensor type'):
1564        sess.run(out_t, feed_dict={feed_int_explicit_int32: largest_int64})
1565
1566  def testStringFetch(self):
1567    with session.Session():
1568      for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1569        size = 1
1570        for s in shape:
1571          size *= s
1572        c_list = np.array(
1573            [compat.as_bytes(str(i)) for i in xrange(size)],
1574            dtype=np.object).reshape(shape) if size > 0 else []
1575        c = constant_op.constant(c_list)
1576        self.assertAllEqual(c, c_list)
1577
1578  def testStringFeed(self):
1579    with session.Session() as sess:
1580      for shape in [(32, 4, 128), (37,), (2, 0, 6), (0, 0, 0)]:
1581        size = 1
1582        for s in shape:
1583          size *= s
1584        c_list = np.array(
1585            [compat.as_bytes(str(i)) for i in xrange(size)],
1586            dtype=np.object).reshape(shape)
1587        feed_t = array_ops.placeholder(dtype=dtypes.string, shape=shape)
1588        c = array_ops.identity(feed_t)
1589        self.assertAllEqual(sess.run(c, feed_dict={feed_t: c_list}), c_list)
1590        self.assertAllEqual(
1591            sess.run(feed_t, feed_dict={
1592                feed_t: c_list
1593            }), c_list)
1594        c_v, feed_v = sess.run([c, feed_t], feed_dict={feed_t: c_list})
1595        self.assertAllEqual(c_v, c_list)
1596        self.assertAllEqual(feed_v, c_list)
1597
1598  def testStringFeedWithNullCharacters(self):
1599    with session.Session():
1600      c_list = [b'\n\x01\x00', b'\n\x00\x01']
1601      feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[2])
1602      c = array_ops.identity(feed_t)
1603      out = c.eval(feed_dict={feed_t: c_list})
1604      self.assertEqual(c_list[0], out[0])
1605      self.assertEqual(c_list[1], out[1])
1606
1607  def testStringFeedWithUnicode(self):
1608    with session.Session():
1609      c_list = [
1610          u'\n\x01\x00', u'\n\x00\x01', u'\u26a3 unicode',
1611          u'\U0001f60e deal with it'
1612      ]
1613      feed_t = array_ops.placeholder(dtype=dtypes.string, shape=[len(c_list)])
1614      c = array_ops.identity(feed_t)
1615
1616      out = c.eval(feed_dict={feed_t: c_list})
1617      for i in range(len(c_list)):
1618        self.assertEqual(c_list[i], out[i].decode('utf-8'))
1619
1620      out = c.eval(feed_dict={feed_t: np.array(c_list, dtype=np.object)})
1621      for i in range(len(c_list)):
1622        self.assertEqual(c_list[i], out[i].decode('utf-8'))
1623
1624  def testInvalidTargetFails(self):
1625    with self.assertRaisesRegex(
1626        errors.NotFoundError,
1627        'No session factory registered for the given session options'):
1628      session.Session('INVALID_TARGET')
1629
1630  def testFetchByNameDifferentStringTypes(self):
1631    with session.Session() as sess:
1632      c = constant_op.constant(42.0, name='c')
1633      d = constant_op.constant(43.0, name=u'd')
1634      e = constant_op.constant(44.0, name=b'e')
1635      f = constant_op.constant(45.0, name=r'f')
1636
1637      self.assertIsInstance(c.name, six.text_type)
1638      self.assertIsInstance(d.name, six.text_type)
1639      self.assertIsInstance(e.name, six.text_type)
1640      self.assertIsInstance(f.name, six.text_type)
1641
1642      self.assertEqual(42.0, sess.run('c:0'))
1643      self.assertEqual(42.0, sess.run(u'c:0'))
1644      self.assertEqual(42.0, sess.run(b'c:0'))
1645      self.assertEqual(42.0, sess.run(r'c:0'))
1646
1647      self.assertEqual(43.0, sess.run('d:0'))
1648      self.assertEqual(43.0, sess.run(u'd:0'))
1649      self.assertEqual(43.0, sess.run(b'd:0'))
1650      self.assertEqual(43.0, sess.run(r'd:0'))
1651
1652      self.assertEqual(44.0, sess.run('e:0'))
1653      self.assertEqual(44.0, sess.run(u'e:0'))
1654      self.assertEqual(44.0, sess.run(b'e:0'))
1655      self.assertEqual(44.0, sess.run(r'e:0'))
1656
1657      self.assertEqual(45.0, sess.run('f:0'))
1658      self.assertEqual(45.0, sess.run(u'f:0'))
1659      self.assertEqual(45.0, sess.run(b'f:0'))
1660      self.assertEqual(45.0, sess.run(r'f:0'))
1661
1662  def testIncorrectGraph(self):
1663    with ops.Graph().as_default() as g_1:
1664      c_1 = constant_op.constant(1.0, name='c')
1665
1666    with ops.Graph().as_default() as g_2:
1667      c_2 = constant_op.constant(2.0, name='c')
1668
1669    self.assertEqual('c', c_1.op.name)
1670    self.assertEqual('c', c_2.op.name)
1671
1672    with session.Session(graph=g_1) as sess_1:
1673      self.assertEqual(1.0, sess_1.run(c_1))
1674      with self.assertRaises(ValueError):
1675        sess_1.run(c_2)
1676      with self.assertRaises(ValueError):
1677        sess_1.run(c_2.op)
1678
1679    with session.Session(graph=g_2) as sess_2:
1680      with self.assertRaises(ValueError):
1681        sess_2.run(c_1)
1682      with self.assertRaises(ValueError):
1683        sess_2.run(c_1.op)
1684      self.assertEqual(2.0, sess_2.run(c_2))
1685
1686  def testFeedDictKeyException(self):
1687    with session.Session() as sess:
1688      a = constant_op.constant(1.0, dtypes.float32, name='a')
1689      with self.assertRaisesRegex(TypeError, 'Cannot interpret feed_dict'):
1690        sess.run(a, feed_dict={'a': [2.0]})
1691
1692  def testPerStepTrace(self):
1693    run_options = config_pb2.RunOptions(
1694        trace_level=config_pb2.RunOptions.SOFTWARE_TRACE)
1695    run_metadata = config_pb2.RunMetadata()
1696
1697    with ops.device('/cpu:0'):
1698      with session.Session() as sess:
1699        sess.run(constant_op.constant(1.0))
1700        self.assertFalse(run_metadata.HasField('step_stats'))
1701
1702        sess.run(constant_op.constant(1.0), run_metadata=run_metadata)
1703        self.assertFalse(run_metadata.HasField('step_stats'))
1704
1705        sess.run(
1706            constant_op.constant(1.0),
1707            options=run_options,
1708            run_metadata=run_metadata)
1709
1710        self.assertTrue(run_metadata.HasField('step_stats'))
1711        self.assertEqual(len(run_metadata.step_stats.dev_stats), 1)
1712
1713  def testRunOptionsRunMetadata(self):
1714    run_options = config_pb2.RunOptions(
1715        trace_level=config_pb2.RunOptions.SOFTWARE_TRACE)
1716    run_metadata = config_pb2.RunMetadata()
1717
1718    with ops.device('/cpu:0'):
1719      with session.Session() as sess:
1720        # all combinations are valid
1721        sess.run(constant_op.constant(1.0), options=None, run_metadata=None)
1722        sess.run(
1723            constant_op.constant(1.0), options=None, run_metadata=run_metadata)
1724        self.assertFalse(run_metadata.HasField('step_stats'))
1725
1726        sess.run(
1727            constant_op.constant(1.0), options=run_options, run_metadata=None)
1728        self.assertFalse(run_metadata.HasField('step_stats'))
1729
1730        sess.run(
1731            constant_op.constant(1.0),
1732            options=run_options,
1733            run_metadata=run_metadata)
1734
1735        self.assertTrue(run_metadata.HasField('step_stats'))
1736        self.assertEqual(len(run_metadata.step_stats.dev_stats), 1)
1737
1738  def testFeedShapeCompatibility(self):
1739    with session.Session() as sess:
1740      some_tensor = constant_op.constant([2.0, 2.0, 2.0, 2.0])
1741      new_shape = constant_op.constant([2, 2])
1742      reshaped_tensor = array_ops.reshape(some_tensor, new_shape)
1743
1744      with self.assertRaisesRegex(ValueError, 'Cannot feed value of shape'):
1745        sess.run(reshaped_tensor, feed_dict={some_tensor: [1.0, 2.0, 3.0]})
1746
1747      with self.assertRaisesRegex(
1748          errors.InvalidArgumentError,
1749          'Input to reshape is a tensor with 4 values, '
1750          'but the requested shape has 21'):
1751        sess.run(reshaped_tensor, feed_dict={new_shape: [3, 7]})
1752
1753  def testInferShapesFalse(self):
1754    with ops.Graph().as_default(), ops.device('/cpu:0'):
1755      a = constant_op.constant([[1, 2]])
1756      sess = session.Session()
1757      self.assertNotIn('_output_shapes', sess.graph_def.node[0].attr)
1758      # Avoid lint error regarding 'unused' var a.
1759      self.assertEqual(a, a)
1760
1761  def testInferShapesTrue(self):
1762    config_pb = config_pb2.ConfigProto(
1763        graph_options=config_pb2.GraphOptions(infer_shapes=True))
1764    with ops.Graph().as_default(), ops.device('/cpu:0'):
1765      a = constant_op.constant([[1, 2]])
1766      sess = session.Session(config=config_pb)
1767      self.assertIn('_output_shapes', sess.graph_def.node[0].attr)
1768      # Avoid lint error regarding 'unused' var a.
1769      self.assertEqual(a, a)
1770
1771  def testBuildCostModel(self):
1772    run_options = config_pb2.RunOptions()
1773    config_pb = config_pb2.ConfigProto(
1774        allow_soft_placement=True,
1775        graph_options=config_pb2.GraphOptions(build_cost_model=100))
1776    with session.Session(config=config_pb) as sess:
1777      with ops.device('/device:GPU:0'):
1778        a = array_ops.placeholder(dtypes.float32, shape=[])
1779        b = math_ops.add(a, a)
1780        c = array_ops.identity(b)
1781        d = math_ops.multiply(c, c)
1782      for step in xrange(120):
1783        run_metadata = config_pb2.RunMetadata()
1784        sess.run(
1785            d,
1786            feed_dict={a: 1.0},
1787            options=run_options,
1788            run_metadata=run_metadata)
1789        if step == 99:
1790          self.assertTrue(run_metadata.HasField('cost_graph'))
1791        else:
1792          self.assertFalse(run_metadata.HasField('cost_graph'))
1793
1794  def runTestOutputPartitionGraphs(self, sess):
1795    run_options = config_pb2.RunOptions(output_partition_graphs=True)
1796    a = constant_op.constant(1)
1797    run_metadata = config_pb2.RunMetadata()
1798    sess.run(a, options=run_options, run_metadata=run_metadata)
1799    self.assertGreater(len(run_metadata.partition_graphs), 0)
1800    sess.run(a, run_metadata=run_metadata)
1801    self.assertEqual(len(run_metadata.partition_graphs), 0)
1802
1803  @test_util.run_v1_only('b/120545219')
1804  def testOutputPartitionGraphsDirect(self):
1805    self.runTestOutputPartitionGraphs(session.Session())
1806
1807  @test_util.run_v1_only('b/120545219')
1808  def testOutputPartitionGraphsDistributed(self):
1809    server = server_lib.Server.create_local_server()
1810    self.runTestOutputPartitionGraphs(session.Session(server.target))
1811
1812  def testNonInteractiveSessionNesting(self):
1813    sess1 = session.Session()
1814    sess1_controller = sess1.as_default()
1815    sess1_controller.__enter__()
1816
1817    sess2 = session.Session()
1818    sess2_controller = sess2.as_default()
1819    sess2_controller.__enter__()
1820
1821    with self.assertRaisesRegex(AssertionError, 'Nesting violated'):
1822      sess1_controller.__exit__(None, None, None)
1823
1824    ops._default_session_stack.reset()
1825
1826  def testInteractiveSessionNesting(self):
1827    sess1 = session.InteractiveSession()
1828    sess2 = session.InteractiveSession()
1829    del sess1
1830    del sess2
1831
1832  @test_util.run_v1_only('b/120545219')
1833  def testAsDefault(self):
1834    c = constant_op.constant(37)
1835    sess = session.Session()
1836    with sess.as_default():
1837      self.assertEqual(37, c.eval())
1838
1839    # Ensure that the session remains valid even when it is not captured.
1840    with session.Session().as_default():
1841      self.assertEqual(37, c.eval())
1842
1843  def testReentry(self):
1844    sess = session.Session()
1845    with self.assertRaisesRegex(RuntimeError, 'not re-entrant'):
1846      with sess:
1847        with sess:
1848          pass
1849
1850  def testInvalidArgument(self):
1851    with self.assertRaisesRegex(TypeError, 'target must be a string'):
1852      session.Session(37)
1853    with self.assertRaisesRegex(TypeError, 'config must be a tf.ConfigProto'):
1854      session.Session(config=37)
1855    with self.assertRaisesRegex(TypeError, 'graph must be a tf.Graph'):
1856      session.Session(graph=37)
1857
1858  @test_util.run_v1_only('b/120545219')
1859  def testTimeoutWithShortOperations(self):
1860    num_epochs = 5
1861    q = data_flow_ops.FIFOQueue(capacity=50, dtypes=[dtypes.int32], shapes=[()])
1862    enqueue_op = q.enqueue_many(constant_op.constant([1, 2]))
1863
1864    # Use a 10-second timeout, which should be longer than any
1865    # non-blocking enqueue_many op.
1866    config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=10000)
1867    with session.Session(config=config_pb) as sess:
1868      for _ in range(num_epochs):
1869        sess.run(enqueue_op)
1870      self.assertEqual(sess.run(q.size()), num_epochs * 2)
1871
1872  @test_util.run_v1_only('b/120545219')
1873  def testRegisterFetchAndFeedConversionFunctions(self):
1874
1875    class SquaredTensor(object):
1876
1877      def __init__(self, tensor):
1878        self.sq = math_ops.square(tensor)
1879
1880    fetch_fn = lambda squared_tensor: ([squared_tensor.sq], lambda val: val[0])
1881    feed_fn1 = lambda feed, feed_val: [(feed.sq, feed_val)]
1882    feed_fn2 = lambda feed: [feed.sq]
1883
1884    session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
1885                                                      feed_fn1, feed_fn2)
1886    with self.assertRaises(ValueError):
1887      session.register_session_run_conversion_functions(SquaredTensor, fetch_fn,
1888                                                        feed_fn1, feed_fn2)
1889    with self.cached_session() as sess:
1890      np1 = np.array([1.0, 1.5, 2.0, 2.5])
1891      np2 = np.array([3.0, 3.5, 4.0, 4.5])
1892      squared_tensor = SquaredTensor(np2)
1893      squared_eval = sess.run(squared_tensor)
1894      self.assertAllClose(np2 * np2, squared_eval)
1895      squared_eval = sess.run(
1896          squared_tensor, feed_dict={
1897              squared_tensor: np1 * np1
1898          })
1899      self.assertAllClose(np1 * np1, squared_eval)
1900      partial_run = sess.partial_run_setup([squared_tensor], [])
1901      squared_eval = sess.partial_run(partial_run, squared_tensor)
1902      self.assertAllClose(np2 * np2, squared_eval)
1903
1904  def testDefaultLogDevicePlacement(self):
1905
1906    class CaptureStderr(str):
1907      """Class to capture stderr from C++ shared library."""
1908
1909      def __enter__(self):
1910        self._esc = compat.as_str('\b')
1911        self._output = compat.as_str('')
1912        self._stderr = sys.stderr
1913        self._fd = self._stderr.fileno()
1914        self._out_pipe, in_pipe = os.pipe()
1915        # Save the original io stream.
1916        self._dup_fd = os.dup(self._fd)
1917        # Replace the original io stream with in pipe.
1918        os.dup2(in_pipe, self._fd)
1919        return self
1920
1921      def __exit__(self, *args):
1922        self._stderr.write(self._esc)
1923        self._stderr.flush()
1924        self.read()
1925        os.close(self._out_pipe)
1926        # Restore the original io stream.
1927        os.dup2(self._dup_fd, self._fd)
1928
1929      def read(self):
1930        while True:
1931          data = os.read(self._out_pipe, 1)
1932          if not data or compat.as_str(data) == self._esc:
1933            break
1934          self._output += compat.as_str(data)
1935
1936      def __str__(self):
1937        return self._output
1938
1939    context.set_log_device_placement(True)
1940    if context.executing_eagerly():
1941      with CaptureStderr() as log:
1942        a = constant_op.constant(1)
1943        b = constant_op.constant(2)
1944        c = a + b
1945        # Ensure if the same kernel with the same arguments is executed then its
1946        # execution is logged.
1947        d = a + b
1948    else:
1949      # Passing the config to the server, but not the session should still
1950      # result in logging device placement.
1951      config_pb = config_pb2.ConfigProto(log_device_placement=True)
1952      server = server_lib.Server.create_local_server(config=config_pb)
1953      a = constant_op.constant(1)
1954      b = constant_op.constant(2)
1955      c = a + b
1956      d = a + b
1957      with session.Session(server.target) as sess:
1958        with CaptureStderr() as log:
1959          c, d = sess.run([c, d])
1960
1961    self.assertEqual(c, 3)
1962    self.assertEqual(d, 3)
1963    # Ensure that we did log device placement.
1964    add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
1965    self.assertEqual(len(add_executions), 2)
1966
1967    @def_function.function
1968    def fn(a, b):
1969      c = a + b
1970      # These two AddV2 cannot use the same argument in tf.function since an
1971      # optimization pass will remove duplicate ops and only run it once.
1972      d = a + c
1973      return c, d
1974
1975    with CaptureStderr() as log:
1976      c, d = self.evaluate(fn(constant_op.constant(1), constant_op.constant(2)))
1977    self.assertEqual(c, 3)
1978    self.assertEqual(d, 4)
1979    # Ensure that we did log device placement.
1980    add_executions = [l for l in str(log).splitlines() if 'AddV2' in l]
1981    self.assertEqual(len(add_executions), 2)
1982
1983  @test_util.run_v1_only('b/120545219')
1984  def testLocalMasterSessionTimeout(self):
1985    # Test that the timeout passed in a config to the session works correctly.
1986    config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
1987    server = server_lib.Server.create_local_server()
1988    q = data_flow_ops.FIFOQueue(1, dtypes.float32)
1989    dequeued_t = q.dequeue()
1990
1991    with session.Session(server.target, config=config_pb) as sess:
1992      # Intentionally do not run any enqueue_ops so that dequeue will block
1993      # until operation_timeout_in_ms.
1994      with self.assertRaises(errors.DeadlineExceededError):
1995        sess.run(dequeued_t)
1996
1997  @test_util.run_v1_only('b/120545219')
1998  def testDefaultServerTimeout(self):
1999    # Test that the default server config timeout gets used when no Session
2000    # config is provided.
2001    config_pb = config_pb2.ConfigProto(operation_timeout_in_ms=1000)
2002    server = server_lib.Server.create_local_server(config=config_pb)
2003    q = data_flow_ops.FIFOQueue(1, dtypes.float32)
2004    dequeued_t = q.dequeue()
2005
2006    with session.Session(server.target) as sess:
2007      # Intentionally do not run any enqueue_ops so that dequeue will block
2008      # until operation_timeout_in_ms.
2009      with self.assertRaises(errors.DeadlineExceededError):
2010        sess.run(dequeued_t)
2011
2012  def runTestBuildGraphError(self, sess):
2013    # Ensure that errors from building the graph get propagated.
2014    data = array_ops.placeholder(dtypes.float32, shape=[])
2015    # pylint: disable=protected-access
2016    enter_1 = gen_control_flow_ops.enter(data, 'foo_1', False)
2017    enter_2 = gen_control_flow_ops.enter(data, 'foo_2', False)
2018    # pylint: enable=protected-access
2019    res = math_ops.add(enter_1, enter_2)
2020    with self.assertRaisesOpError('has inputs from different frames'):
2021      sess.run(res, feed_dict={data: 1.0})
2022
2023  @test_util.run_v1_only('b/120545219')
2024  def testBuildGraphErrorDirect(self):
2025    self.runTestBuildGraphError(session.Session())
2026
2027  @test_util.run_v1_only('b/120545219')
2028  def testBuildGraphErrorDist(self):
2029    server = server_lib.Server.create_local_server()
2030    self.runTestBuildGraphError(session.Session(server.target))
2031
2032  def testDeviceAttributes(self):
2033    attrs = session._DeviceAttributes(
2034        '/job:worker/replica:0/task:3/device:CPU:2', 'TYPE', 1337, 1000000)
2035    self.assertEqual(1337, attrs.memory_limit_bytes)
2036    self.assertEqual('/job:worker/replica:0/task:3/device:CPU:2', attrs.name)
2037    self.assertEqual('TYPE', attrs.device_type)
2038    self.assertEqual(1000000, attrs.incarnation)
2039    str_repr = '%s' % attrs
2040    self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
2041
2042  def testDeviceAttributesCanonicalization(self):
2043    attrs = session._DeviceAttributes('/job:worker/replica:0/task:3/cpu:1',
2044                                      'TYPE', 1337, 1000000)
2045    self.assertEqual(1337, attrs.memory_limit_bytes)
2046    self.assertEqual('/job:worker/replica:0/task:3/device:CPU:1', attrs.name)
2047    self.assertEqual('TYPE', attrs.device_type)
2048    self.assertEqual(1000000, attrs.incarnation)
2049    str_repr = '%s' % attrs
2050    self.assertTrue(str_repr.startswith('_DeviceAttributes'), str_repr)
2051
2052  def runTestAddFunctionToSession(self, target=''):
2053    """Add a function to a session after the graph has already been run."""
2054
2055    @function.Defun(dtypes.float32)
2056    def foo(x):
2057      return x + 1
2058
2059    x = constant_op.constant(1.0)
2060    with session.Session(target=target) as sess:
2061      sess.run(x)
2062      f = foo(x)
2063      result = sess.run(f)
2064      self.assertEqual(result, 2.0)
2065
2066  @test_util.run_v1_only('b/120545219')
2067  def testAddFunctionToSession(self):
2068    self.runTestAddFunctionToSession()
2069
2070  @test_util.run_v1_only('b/120545219')
2071  def testAddFunctionToGrpcSession(self):
2072    server = server_lib.Server.create_local_server()
2073    self.runTestAddFunctionToSession(server.target)
2074
2075  def testOpenAndCloseGrpcSession(self):
2076    server = server_lib.Server.create_local_server()
2077    with session.Session(server.target):
2078      pass
2079
2080  def testOpenAndCloseSession(self):
2081    with session.Session():
2082      pass
2083
2084  @test_util.run_v1_only('b/120545219')
2085  def testAutoConvertAndCheckData(self):
2086    with self.cached_session() as sess:
2087      a = array_ops.placeholder(dtype=dtypes.string)
2088      with self.assertRaisesRegex(
2089          TypeError, r'Type of feed value 1 with type <(\w+) \'int\'> is not'):
2090        sess.run(a, feed_dict={a: 1})
2091
2092  @test_util.run_v1_only('b/120545219')
2093  def testOptimizerOptions(self):
2094    config.set_optimizer_experimental_options({'min_graph_nodes': -1})
2095
2096    with ops.Graph().as_default():
2097      sess = session.Session()
2098      self.assertEqual(
2099          sess._config.graph_options.rewrite_options.min_graph_nodes, -1)
2100
2101
2102if __name__ == '__main__':
2103  googletest.main()
2104