1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow."""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import collections
23from collections import OrderedDict
24import contextlib
25import gc
26import itertools
27import math
28import os
29import random
30import re
31import tempfile
32import threading
33import unittest
34
35import numpy as np
36import six
37
38_portpicker_import_error = None
39try:
40  import portpicker  # pylint: disable=g-import-not-at-top
41except ImportError as _error:
42  _portpicker_import_error = _error
43  portpicker = None
44
45# pylint: disable=g-import-not-at-top
46from google.protobuf import descriptor_pool
47from google.protobuf import text_format
48
49from tensorflow.core.framework import graph_pb2
50from tensorflow.core.protobuf import config_pb2
51from tensorflow.core.protobuf import rewriter_config_pb2
52from tensorflow.python import pywrap_tensorflow
53from tensorflow.python import tf2
54from tensorflow.python.client import device_lib
55from tensorflow.python.client import session
56from tensorflow.python.eager import context
57from tensorflow.python.eager import def_function
58from tensorflow.python.eager import tape
59from tensorflow.python.framework import device as pydev
60from tensorflow.python.framework import dtypes
61from tensorflow.python.framework import errors
62from tensorflow.python.framework import errors_impl
63from tensorflow.python.framework import importer
64from tensorflow.python.framework import ops
65from tensorflow.python.framework import random_seed
66from tensorflow.python.framework import sparse_tensor
67from tensorflow.python.framework import tensor_shape
68from tensorflow.python.framework import versions
69from tensorflow.python.ops import array_ops
70from tensorflow.python.ops import control_flow_util
71from tensorflow.python.ops import script_ops
72from tensorflow.python.ops import variables
73from tensorflow.python.platform import googletest
74from tensorflow.python.platform import tf_logging as logging
75from tensorflow.python.training import server_lib
76from tensorflow.python.util import compat
77from tensorflow.python.util import deprecation
78from tensorflow.python.util import nest
79from tensorflow.python.util import tf_decorator
80from tensorflow.python.util import tf_inspect
81from tensorflow.python.util.protobuf import compare
82from tensorflow.python.util.tf_export import tf_export
83
84
85# If the above import is made available through the BUILD rule, then this
86# function is overridden and will instead return True and cause Tensorflow
87# graphs to be compiled with XLA.
88def is_xla_enabled():
89  return False
90
91
92try:
93  from tensorflow.python.framework.is_xla_test_true import is_xla_enabled  # pylint: disable=g-import-not-at-top
94except:
95  pass
96
97
98@tf_export("test.gpu_device_name")
99def gpu_device_name():
100  """Returns the name of a GPU device if available or the empty string."""
101  for x in device_lib.list_local_devices():
102    if x.device_type == "GPU" or x.device_type == "SYCL":
103      return compat.as_str(x.name)
104  return ""
105
106
107def assert_ops_in_graph(expected_ops, graph):
108  """Assert all expected operations are found.
109
110  Args:
111    expected_ops: `dict<string, string>` of op name to op type.
112    graph: Graph to check.
113
114  Returns:
115    `dict<string, node>` of node name to node.
116
117  Raises:
118    ValueError: If the expected ops are not present in the graph.
119  """
120  actual_ops = {}
121  gd = graph.as_graph_def()
122  for node in gd.node:
123    if node.name in expected_ops:
124      if expected_ops[node.name] != node.op:
125        raise ValueError("Expected op for node %s is different. %s vs %s" %
126                         (node.name, expected_ops[node.name], node.op))
127      actual_ops[node.name] = node
128  if set(expected_ops.keys()) != set(actual_ops.keys()):
129    raise ValueError("Not all expected ops are present. Expected %s, found %s" %
130                     (expected_ops.keys(), actual_ops.keys()))
131  return actual_ops
132
133
134@tf_export("test.assert_equal_graph_def", v1=[])
135def assert_equal_graph_def_v2(expected, actual):
136  """Asserts that two `GraphDef`s are (mostly) the same.
137
138  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
139  nodes, attrs, and control inputs.  Node names are used to match up nodes
140  between the graphs, so the naming of nodes must be consistent. This function
141  ignores randomized attribute values that may appear in V2 checkpoints.
142
143  Args:
144    expected: The `GraphDef` we expected.
145    actual: The `GraphDef` we have.
146
147  Raises:
148    AssertionError: If the `GraphDef`s do not match.
149    TypeError: If either argument is not a `GraphDef`.
150  """
151  assert_equal_graph_def(actual, expected, checkpoint_v2=True)
152
153
154@tf_export(v1=["test.assert_equal_graph_def"])
155def assert_equal_graph_def_v1(actual, expected, checkpoint_v2=False):
156  """Asserts that two `GraphDef`s are (mostly) the same.
157
158  Compares two `GraphDef` protos for equality, ignoring versions and ordering of
159  nodes, attrs, and control inputs.  Node names are used to match up nodes
160  between the graphs, so the naming of nodes must be consistent.
161
162  Args:
163    actual: The `GraphDef` we have.
164    expected: The `GraphDef` we expected.
165    checkpoint_v2: boolean determining whether to ignore randomized attribute
166      values that appear in V2 checkpoints.
167
168  Raises:
169    AssertionError: If the `GraphDef`s do not match.
170    TypeError: If either argument is not a `GraphDef`.
171  """
172  assert_equal_graph_def(actual, expected, checkpoint_v2)
173
174
175def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
176  if not isinstance(actual, graph_pb2.GraphDef):
177    raise TypeError(
178        "Expected tf.GraphDef for actual, got %s" % type(actual).__name__)
179  if not isinstance(expected, graph_pb2.GraphDef):
180    raise TypeError(
181        "Expected tf.GraphDef for expected, got %s" % type(expected).__name__)
182
183  if checkpoint_v2:
184    _strip_checkpoint_v2_randomized(actual)
185    _strip_checkpoint_v2_randomized(expected)
186
187  diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
188                                                expected.SerializeToString())
189  if diff:
190    raise AssertionError(compat.as_str(diff))
191
192
193def assert_meta_graph_protos_equal(tester, a, b):
194  """Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
195  # Carefully check the collection_defs
196  tester.assertEqual(set(a.collection_def), set(b.collection_def))
197  collection_keys = a.collection_def.keys()
198  for k in collection_keys:
199    a_value = a.collection_def[k]
200    b_value = b.collection_def[k]
201    proto_type = ops.get_collection_proto_type(k)
202    if proto_type:
203      a_proto = proto_type()
204      b_proto = proto_type()
205      # Number of entries in the collections is the same
206      tester.assertEqual(
207          len(a_value.bytes_list.value), len(b_value.bytes_list.value))
208      for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
209                                              b_value.bytes_list.value):
210        a_proto.ParseFromString(a_value_item)
211        b_proto.ParseFromString(b_value_item)
212        tester.assertProtoEquals(a_proto, b_proto)
213    else:
214      tester.assertEquals(a_value, b_value)
215  # Compared the fields directly, remove their raw values from the
216  # proto comparison below.
217  a.ClearField("collection_def")
218  b.ClearField("collection_def")
219
220  # Check the graph_defs.
221  assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
222  # Check graph_def versions (ignored by assert_equal_graph_def).
223  tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
224  # Compared the fields directly, remove their raw values from the
225  # proto comparison below.
226  a.ClearField("graph_def")
227  b.ClearField("graph_def")
228
229  tester.assertProtoEquals(a, b)
230
231
232# Matches attributes named via _SHARDED_SUFFIX in
233# tensorflow/python/training/saver.py
234_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
235
236
237def _strip_checkpoint_v2_randomized(graph_def):
238  for node in graph_def.node:
239    delete_keys = []
240    for attr_key in node.attr:
241      attr_tensor_value = node.attr[attr_key].tensor
242      if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
243        attr_tensor_string_value = attr_tensor_value.string_val[0]
244        if (attr_tensor_string_value and
245            re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
246          delete_keys.append(attr_key)
247    for attr_key in delete_keys:
248      del node.attr[attr_key]
249
250
251def IsGoogleCudaEnabled():
252  return pywrap_tensorflow.IsGoogleCudaEnabled()
253
254
255def CudaSupportsHalfMatMulAndConv():
256  return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
257
258
259def IsMklEnabled():
260  return pywrap_tensorflow.IsMklEnabled()
261
262
263def InstallStackTraceHandler():
264  pywrap_tensorflow.InstallStacktraceHandler()
265
266
267def NHWCToNCHW(input_tensor):
268  """Converts the input from the NHWC format to NCHW.
269
270  Args:
271    input_tensor: a 4- or 5-D tensor, or an array representing shape
272
273  Returns:
274    converted tensor or shape array
275  """
276  # tensor dim -> new axis order
277  new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
278  if isinstance(input_tensor, ops.Tensor):
279    ndims = input_tensor.shape.ndims
280    return array_ops.transpose(input_tensor, new_axes[ndims])
281  else:
282    ndims = len(input_tensor)
283    return [input_tensor[a] for a in new_axes[ndims]]
284
285
286def NHWCToNCHW_VECT_C(input_shape_or_tensor):
287  """Transforms the input from the NHWC layout to NCHW_VECT_C layout.
288
289  Note: Does not include quantization or type conversion steps, which should
290  be applied afterwards.
291
292  Args:
293    input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
294
295  Returns:
296    tensor or shape array transformed into NCHW_VECT_C
297
298  Raises:
299    ValueError: if last dimension of `input_shape_or_tensor` is not evenly
300        divisible by 4.
301  """
302  permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
303  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
304  temp_shape = (
305      input_shape_or_tensor.shape.as_list()
306      if is_tensor else input_shape_or_tensor)
307  if temp_shape[-1] % 4 != 0:
308    raise ValueError(
309        "Last dimension of input must be evenly divisible by 4 to convert to "
310        "NCHW_VECT_C.")
311  temp_shape[-1] //= 4
312  temp_shape.append(4)
313  permutation = permutations[len(temp_shape)]
314  if is_tensor:
315    t = array_ops.reshape(input_shape_or_tensor, temp_shape)
316    return array_ops.transpose(t, permutation)
317  else:
318    return [temp_shape[a] for a in permutation]
319
320
321def NCHW_VECT_CToNHWC(input_shape_or_tensor):
322  """Transforms the input from the NCHW_VECT_C layout to NHWC layout.
323
324  Note: Does not include de-quantization or type conversion steps, which should
325  be applied beforehand.
326
327  Args:
328    input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
329
330  Returns:
331    tensor or shape array transformed into NHWC
332
333  Raises:
334    ValueError: if last dimension of `input_shape_or_tensor` is not 4.
335  """
336  permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
337  is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
338  input_shape = (
339      input_shape_or_tensor.shape.as_list()
340      if is_tensor else input_shape_or_tensor)
341  if input_shape[-1] != 4:
342    raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
343  permutation = permutations[len(input_shape)]
344  nhwc_shape = [input_shape[a] for a in permutation[:-1]]
345  nhwc_shape[-1] *= input_shape[-1]
346  if is_tensor:
347    t = array_ops.transpose(input_shape_or_tensor, permutation)
348    return array_ops.reshape(t, nhwc_shape)
349  else:
350    return nhwc_shape
351
352
353def NCHWToNHWC(input_tensor):
354  """Converts the input from the NCHW format to NHWC.
355
356  Args:
357    input_tensor: a 4- or 5-D tensor, or an array representing shape
358
359  Returns:
360    converted tensor or shape array
361  """
362  # tensor dim -> new axis order
363  new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
364  if isinstance(input_tensor, ops.Tensor):
365    ndims = input_tensor.shape.ndims
366    return array_ops.transpose(input_tensor, new_axes[ndims])
367  else:
368    ndims = len(input_tensor)
369    return [input_tensor[a] for a in new_axes[ndims]]
370
371
372def skip_if(condition):
373  """Skips the decorated function if condition is or evaluates to True.
374
375  Args:
376    condition: Either an expression that can be used in "if not condition"
377      statement, or a callable whose result should be a boolean.
378
379  Returns:
380    The wrapped function
381  """
382
383  def real_skip_if(fn):
384
385    def wrapper(*args, **kwargs):
386      if callable(condition):
387        skip = condition()
388      else:
389        skip = condition
390      if not skip:
391        return fn(*args, **kwargs)
392
393    return wrapper
394
395  return real_skip_if
396
397
398def enable_c_shapes(fn):
399  """No-op. TODO(b/74620627): Remove this."""
400  return fn
401
402
403def with_c_shapes(cls):
404  """No-op. TODO(b/74620627): Remove this."""
405  return cls
406
407
408def enable_control_flow_v2(fn):
409  """Decorator for enabling CondV2 and WhileV2 on a test.
410
411  Note this enables using CondV2 and WhileV2 after running the test class's
412  setup/teardown methods.
413
414  In addition to this, callers must import the while_v2 module in order to set
415  the _while_v2 module in control_flow_ops.
416
417  Args:
418    fn: the function to be wrapped
419
420  Returns:
421    The wrapped function
422  """
423
424  def wrapper(*args, **kwargs):
425    enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
426    control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
427    try:
428      return fn(*args, **kwargs)
429    finally:
430      control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
431
432  return wrapper
433
434
435def with_control_flow_v2(cls):
436  """Adds methods that call original methods with WhileV2 and CondV2 enabled.
437
438  Note this enables CondV2 and WhileV2 in new methods after running the test
439  class's setup method.
440
441  In addition to this, callers must import the while_v2 module in order to set
442  the _while_v2 module in control_flow_ops.
443
444  If a test function has _disable_control_flow_v2 attr set to True (using the
445  @disable_control_flow_v2 decorator), the v2 function is not generated for it.
446
447  Example:
448
449  @test_util.with_control_flow_v2
450  class ControlFlowTest(test.TestCase):
451
452    def testEnabledForV2(self):
453      ...
454
455    @test_util.disable_control_flow_v2("b/xyzabc")
456    def testDisabledForV2(self):
457      ...
458
459  Generated class:
460  class ControlFlowTest(test.TestCase):
461
462    def testEnabledForV2(self):
463      ...
464
465    def testEnabledForV2WithControlFlowV2(self):
466      // Enable V2 flags.
467      testEnabledForV2(self)
468      // Restore V2 flags.
469
470    def testDisabledForV2(self):
471      ...
472
473  Args:
474    cls: class to decorate
475
476  Returns:
477    cls with new test methods added
478  """
479  if control_flow_util.ENABLE_CONTROL_FLOW_V2:
480    return cls
481
482  for name, value in cls.__dict__.copy().items():
483    if (callable(value) and
484        name.startswith(unittest.TestLoader.testMethodPrefix) and
485        not getattr(value, "_disable_control_flow_v2", False)):
486      setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
487  return cls
488
489
490def disable_control_flow_v2(unused_msg):
491  """Decorator for a function in a with_control_flow_v2 enabled test class.
492
493  Blocks the function from being run with v2 control flow ops.
494
495  Args:
496    unused_msg: Reason for disabling.
497
498  Returns:
499    The wrapped function with _disable_control_flow_v2 attr set to True.
500  """
501
502  def wrapper(func):
503    func._disable_control_flow_v2 = True
504    return func
505
506  return wrapper
507
508
509def assert_no_new_pyobjects_executing_eagerly(f):
510  """Decorator for asserting that no new Python objects persist after a test.
511
512  Runs the test multiple times executing eagerly, first as a warmup and then to
513  let objects accumulate. The warmup helps ignore caches which do not grow as
514  the test is run repeatedly.
515
516  Useful for checking that there are no missing Py_DECREFs in the C exercised by
517  a bit of Python.
518  """
519
520  def decorator(self, **kwargs):
521    """Warms up, gets an object count, runs the test, checks for new objects."""
522    with context.eager_mode():
523      gc.disable()
524      # Run the test 2 times as warmup, in an attempt to fill up caches, which
525      # should not grow as the test is run repeatedly below.
526      #
527      # TODO(b/117156879): Running warmup twice is black magic; we have seen
528      # tests that fail with 1 warmup run, and pass with 2, on various versions
529      # of python2.7.x.
530      for _ in range(2):
531        f(self, **kwargs)
532      gc.collect()
533      previous_count = len(gc.get_objects())
534      if ops.has_default_graph():
535        collection_sizes_before = {
536            collection: len(ops.get_collection(collection))
537            for collection in ops.get_default_graph().collections
538        }
539      for _ in range(3):
540        f(self, **kwargs)
541      # Note that gc.get_objects misses anything that isn't subject to garbage
542      # collection (C types). Collections are a common source of leaks, so we
543      # test for collection sizes explicitly.
544      if ops.has_default_graph():
545        for collection_key in ops.get_default_graph().collections:
546          collection = ops.get_collection(collection_key)
547          size_before = collection_sizes_before.get(collection_key, 0)
548          if len(collection) > size_before:
549            raise AssertionError(
550                ("Collection %s increased in size from "
551                 "%d to %d (current items %s).") %
552                (collection_key, size_before, len(collection), collection))
553          # Make sure our collection checks don't show up as leaked memory by
554          # removing references to temporary variables.
555          del collection
556          del collection_key
557          del size_before
558        del collection_sizes_before
559      gc.collect()
560      # There should be no new Python objects hanging around.
561      new_count = len(gc.get_objects())
562      # In some cases (specifacally on MacOS), new_count is somehow
563      # smaller than previous_count.
564      # Using plain assert because not all classes using this decorator
565      # have assertLessEqual
566      assert new_count <= previous_count, (
567          "new_count(%d) is not less than or equal to previous_count(%d)" %
568          (new_count, previous_count))
569      gc.enable()
570
571  return decorator
572
573
574def assert_no_new_tensors(f):
575  """Decorator for asserting that no new Tensors persist after a test.
576
577  Mainly useful for checking that code using the Python C API has correctly
578  manipulated reference counts.
579
580  Clears the caches that it knows about, runs the garbage collector, then checks
581  that there are no Tensor or Tensor-like objects still around. This includes
582  Tensors to which something still has a reference (e.g. from missing
583  Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
584  of the objects has __del__ defined).
585
586  Args:
587    f: The test case to run.
588
589  Returns:
590    The decorated test case.
591  """
592
593  def decorator(self, **kwargs):
594    """Finds existing Tensors, runs the test, checks for new Tensors."""
595
596    def _is_tensorflow_object(obj):
597      try:
598        return isinstance(obj,
599                          (ops.Tensor, variables.Variable,
600                           tensor_shape.Dimension, tensor_shape.TensorShape))
601      except ReferenceError:
602        # If the object no longer exists, we don't care about it.
603        return False
604
605    tensors_before = set(
606        id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
607    outside_executed_eagerly = context.executing_eagerly()
608    # Run the test in a new graph so that collections get cleared when it's
609    # done, but inherit the graph key so optimizers behave.
610    outside_graph_key = ops.get_default_graph()._graph_key
611    with ops.Graph().as_default():
612      ops.get_default_graph()._graph_key = outside_graph_key
613      if outside_executed_eagerly:
614        with context.eager_mode():
615          result = f(self, **kwargs)
616      else:
617        result = f(self, **kwargs)
618    # Make an effort to clear caches, which would otherwise look like leaked
619    # Tensors.
620    context.context()._clear_caches()  # pylint: disable=protected-access
621    gc.collect()
622    tensors_after = [
623        obj for obj in gc.get_objects()
624        if _is_tensorflow_object(obj) and id(obj) not in tensors_before
625    ]
626    if tensors_after:
627      raise AssertionError(("%d Tensors not deallocated after test: %s" % (
628          len(tensors_after),
629          str(tensors_after),
630      )))
631    return result
632
633  return decorator
634
635
636def _find_reference_cycle(objects, idx):
637
638  def get_ignore_reason(obj, blacklist):
639    """Tests whether an object should be omitted from the dependency graph."""
640    if len(blacklist) > 100:
641      return "<depth limit>"
642    if tf_inspect.isframe(obj):
643      if "test_util.py" in tf_inspect.getframeinfo(obj)[0]:
644        return "<test code>"
645    for b in blacklist:
646      if b is obj:
647        return "<test code>"
648    if obj is blacklist:
649      return "<test code>"
650    return None
651
652  # Note: this function is meant to help with diagnostics. Its output is purely
653  # a human-readable representation, so you may freely modify it to suit your
654  # needs.
655  def describe(obj, blacklist, leaves_only=False):
656    """Returns a custom human-readable summary of obj.
657
658    Args:
659      obj: the value to describe.
660      blacklist: same as blacklist in get_ignore_reason.
661      leaves_only: boolean flag used when calling describe recursively. Useful
662        for summarizing collections.
663    """
664    if get_ignore_reason(obj, blacklist):
665      return "{}{}".format(get_ignore_reason(obj, blacklist), type(obj))
666    if tf_inspect.isframe(obj):
667      return "frame: {}".format(tf_inspect.getframeinfo(obj))
668    elif tf_inspect.ismodule(obj):
669      return "module: {}".format(obj.__name__)
670    else:
671      if leaves_only:
672        return "{}, {}".format(type(obj), id(obj))
673      elif isinstance(obj, list):
674        return "list({}): {}".format(
675            id(obj), [describe(e, blacklist, leaves_only=True) for e in obj])
676      elif isinstance(obj, tuple):
677        return "tuple({}): {}".format(
678            id(obj), [describe(e, blacklist, leaves_only=True) for e in obj])
679      elif isinstance(obj, dict):
680        return "dict({}): {} keys".format(id(obj), len(obj.keys()))
681      elif tf_inspect.isfunction(obj):
682        return "function({}) {}; globals ID: {}".format(
683            id(obj), obj.__name__, id(obj.__globals__))
684      else:
685        return "{}, {}".format(type(obj), id(obj))
686
687  def build_ref_graph(obj, graph, reprs, blacklist):
688    """Builds a reference graph as <referrer> -> <list of refferents>.
689
690    Args:
691      obj: The object to start from. The graph will be built by recursively
692        adding its referrers.
693      graph: Dict holding the graph to be built. To avoid creating extra
694        references, the graph holds object IDs rather than actual objects.
695      reprs: Auxiliary structure that maps object IDs to their human-readable
696        description.
697      blacklist: List of objects to ignore.
698    """
699    referrers = gc.get_referrers(obj)
700    blacklist = blacklist + (referrers,)
701
702    obj_id = id(obj)
703    for r in referrers:
704      if get_ignore_reason(r, blacklist) is None:
705        r_id = id(r)
706        if r_id not in graph:
707          graph[r_id] = []
708        if obj_id not in graph[r_id]:
709          graph[r_id].append(obj_id)
710          build_ref_graph(r, graph, reprs, blacklist)
711          reprs[r_id] = describe(r, blacklist)
712
713  def find_cycle(el, graph, reprs, path):
714    """Finds and prints a single cycle in the dependency graph."""
715    if el not in graph:
716      return
717    for r in graph[el]:
718      if r in path:
719        logging.error("Reference cycle sample:")
720        for p in path + (r,):
721          logging.error(reprs.get(p, "unknown object " + str(p)))
722        return True
723      else:
724        if find_cycle(r, graph, reprs, path + (r,)):
725          return True
726    return False
727
728  obj = objects[idx]
729  graph = {}  # referrer ID -> object ID
730  reprs = {}  # object ID -> description
731  build_ref_graph(obj, graph, reprs, (objects, graph, reprs, get_ignore_reason,
732                                      describe, build_ref_graph, find_cycle))
733  for k in graph:
734    if find_cycle(k, graph, reprs, ()):
735      return True
736  return False
737
738
739def assert_no_garbage_created(f):
740  """Test method decorator to assert that no garbage has been created.
741
742  Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
743  cannot be un-set (i.e. will disable garbage collection for any other unit
744  tests in the same file/shard).
745
746  Args:
747    f: The function to decorate.
748
749  Returns:
750    The decorated function.
751  """
752
753  def decorator(self, **kwargs):
754    """Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
755    # Force-load `distribution_strategy_context` to prevent GC at
756    # test time when using eager. Remove once b/117329403 is resolved.
757    tape.distribution_strategy_context.get_strategy()
758
759    gc.disable()
760    previous_debug_flags = gc.get_debug()
761    gc.set_debug(gc.DEBUG_SAVEALL)
762    gc.collect()
763    previous_garbage = len(gc.garbage)
764    result = f(self, **kwargs)
765    gc.collect()
766    new_garbage = len(gc.garbage)
767    if new_garbage > previous_garbage:
768      logging.error(
769          "The decorated test created work for Python's garbage collector, "
770          "likely due to a reference cycle. New objects in cycle(s):")
771      for i, obj in enumerate(gc.garbage[previous_garbage:]):
772        try:
773          logging.error("Object %d of %d", i,
774                        len(gc.garbage) - previous_garbage)
775
776          def _safe_object_str(obj):
777            return "<%s %d>" % (obj.__class__.__name__, id(obj))
778
779          logging.error("  Object type: %s", _safe_object_str(obj))
780          logging.error(
781              "  Referrer types: %s", ", ".join(
782                  [_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
783          logging.error(
784              "  Referent types: %s", ", ".join(
785                  [_safe_object_str(ref) for ref in gc.get_referents(obj)]))
786          logging.error("  Object attribute names: %s", dir(obj))
787          logging.error("  Object __str__:")
788          logging.error(obj)
789          logging.error("  Object __repr__:")
790          logging.error(repr(obj))
791        except Exception:  # pylint: disable=broad-except
792          logging.error("(Exception while printing object)")
793
794    # When garbage is created, this call can help identify reference cycles,
795    # which are typically the cause of such garbage.
796    if new_garbage > previous_garbage:
797      for i in range(previous_garbage, new_garbage):
798        if _find_reference_cycle(gc.garbage, i):
799          break
800
801    # This will fail if any garbage has been created, typically because of a
802    # reference cycle.
803    self.assertEqual(previous_garbage, new_garbage)
804    # TODO(allenl): Figure out why this debug flag reset doesn't work. It would
805    # be nice to be able to decorate arbitrary tests in a large test suite and
806    # not hold on to every object in other tests.
807    gc.set_debug(previous_debug_flags)
808    gc.enable()
809    return result
810
811  return decorator
812
813
814def _combine_named_parameters(**kwargs):
815  """Generate combinations based on its keyword arguments.
816
817  Two sets of returned combinations can be concatenated using +.  Their product
818  can be computed using `times()`.
819
820  Args:
821    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
822      `option=the_only_possibility`.
823
824  Returns:
825    a list of dictionaries for each combination. Keys in the dictionaries are
826    the keyword argument names.  Each key has one value - one of the
827    corresponding keyword argument values.
828  """
829  if not kwargs:
830    return [OrderedDict()]
831
832  sort_by_key = lambda k: k[0][0]
833  kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
834  first = list(kwargs.items())[0]
835
836  rest = dict(list(kwargs.items())[1:])
837  rest_combined = _combine_named_parameters(**rest)
838
839  key = first[0]
840  values = first[1]
841  if not isinstance(values, list):
842    values = [values]
843
844  combinations = [
845      OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
846      for v in values
847      for combined in rest_combined
848  ]
849  return combinations
850
851
852def generate_combinations_with_testcase_name(**kwargs):
853  """Generate combinations based on its keyword arguments using combine().
854
855  This function calls combine() and appends a testcase name to the list of
856  dictionaries returned. The 'testcase_name' key is a required for named
857  parameterized tests.
858
859  Args:
860    **kwargs: keyword arguments of form `option=[possibilities, ...]` or
861      `option=the_only_possibility`.
862
863  Returns:
864    a list of dictionaries for each combination. Keys in the dictionaries are
865    the keyword argument names.  Each key has one value - one of the
866    corresponding keyword argument values.
867  """
868  combinations = _combine_named_parameters(**kwargs)
869  named_combinations = []
870  for combination in combinations:
871    assert isinstance(combination, OrderedDict)
872    name = "".join([
873        "_{}_{}".format("".join(filter(str.isalnum, key)), "".join(
874            filter(str.isalnum, str(value))))
875        for key, value in combination.items()
876    ])
877    named_combinations.append(
878        OrderedDict(
879            list(combination.items()) + [("testcase_name",
880                                          "_test{}".format(name))]))
881
882  return named_combinations
883
884
885def run_all_in_graph_and_eager_modes(cls):
886  """Execute all test methods in the given class with and without eager."""
887  base_decorator = run_in_graph_and_eager_modes
888  for name, value in cls.__dict__.copy().items():
889    if callable(value) and name.startswith(
890        unittest.TestLoader.testMethodPrefix) and not (
891            name.startswith("testSkipEager") or
892            name.startswith("test_skip_eager") or name == "test_session"):
893      setattr(cls, name, base_decorator(value))
894  return cls
895
896
897def run_in_graph_and_eager_modes(func=None,
898                                 config=None,
899                                 use_gpu=True,
900                                 reset_test=True,
901                                 assert_no_eager_garbage=False):
902  """Execute the decorated test with and without enabling eager execution.
903
904  This function returns a decorator intended to be applied to test methods in
905  a `tf.test.TestCase` class. Doing so will cause the contents of the test
906  method to be executed twice - once normally, and once with eager execution
907  enabled. This allows unittests to confirm the equivalence between eager
908  and graph execution (see `tf.enable_eager_execution`).
909
910  For example, consider the following unittest:
911
912  ```python
913  class MyTests(tf.test.TestCase):
914
915    @run_in_graph_and_eager_modes
916    def test_foo(self):
917      x = tf.constant([1, 2])
918      y = tf.constant([3, 4])
919      z = tf.add(x, y)
920      self.assertAllEqual([4, 6], self.evaluate(z))
921
922  if __name__ == "__main__":
923    tf.test.main()
924  ```
925
926  This test validates that `tf.add()` has the same behavior when computed with
927  eager execution enabled as it does when constructing a TensorFlow graph and
928  executing the `z` tensor in a session.
929
930  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
931  `run_in_graph_and_eager_modes` are available decorators for different
932  v1/v2/eager/graph combinations.
933
934
935  Args:
936    func: function to be annotated. If `func` is None, this method returns a
937      decorator the can be applied to a function. If `func` is not None this
938      returns the decorator applied to `func`.
939    config: An optional config_pb2.ConfigProto to use to configure the session
940      when executing graphs.
941    use_gpu: If True, attempt to run as many operations as possible on GPU.
942    reset_test: If True, tearDown and SetUp the test case between the two
943      executions of the test (once with and once without eager execution).
944    assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
945      collector and asserts that no extra garbage has been created when running
946      the test with eager execution enabled. This will fail if there are
947      reference cycles (e.g. a = []; a.append(a)). Off by default because some
948      tests may create garbage for legitimate reasons (e.g. they define a class
949      which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
950      Python interpreters (meaning that tests which rely on objects being
951      collected elsewhere in the unit test file will not work). Additionally,
952      checks that nothing still has a reference to Tensors that the test
953      allocated.
954
955  Returns:
956    Returns a decorator that will run the decorated test method twice:
957    once by constructing and executing a graph in a session and once with
958    eager execution enabled.
959  """
960
961  def decorator(f):
962    if tf_inspect.isclass(f):
963      raise ValueError(
964          "`run_in_graph_and_eager_modes` only supports test methods. "
965          "Did you mean to use `run_all_in_graph_and_eager_modes`?")
966
967    def decorated(self, *args, **kwargs):
968      try:
969        with context.graph_mode():
970          with self.test_session(use_gpu=use_gpu, config=config):
971            f(self, *args, **kwargs)
972      except unittest.case.SkipTest:
973        pass
974
975      def run_eagerly(self, **kwargs):
976        if not use_gpu:
977          with ops.device("/device:CPU:0"):
978            f(self, *args, **kwargs)
979        else:
980          f(self, *args, **kwargs)
981
982      if assert_no_eager_garbage:
983        ops.reset_default_graph()
984        run_eagerly = assert_no_new_tensors(
985            assert_no_garbage_created(run_eagerly))
986
987      if reset_test:
988        # This decorator runs the wrapped test twice.
989        # Reset the test environment between runs.
990        self.tearDown()
991        self._tempdir = None
992      # Create a new graph for the eagerly executed version of this test for
993      # better isolation.
994      graph_for_eager_test = ops.Graph()
995      with graph_for_eager_test.as_default(), context.eager_mode():
996        if reset_test:
997          self.setUp()
998        run_eagerly(self, **kwargs)
999      ops.dismantle_graph(graph_for_eager_test)
1000
1001    return decorated
1002
1003  if func is not None:
1004    return decorator(func)
1005
1006  return decorator
1007
1008
1009def py_func_if_in_function(f):
1010
1011  def decorated(*args, **kwds):
1012    if not ops.get_default_graph()._building_function:
1013      return f(*args, **kwds)
1014
1015    tensor_args = []
1016    tensor_indices = []
1017    for i, arg in enumerate(args):
1018      if isinstance(arg, (ops.Tensor, variables.Variable)):
1019        tensor_args.append(arg)
1020        tensor_indices.append(i)
1021
1022    def inner_f(*inner_tensor_args):
1023      my_args = list(args)
1024      for i, n in zip(tensor_indices, inner_tensor_args):
1025        my_args[i] = n
1026      return f(*my_args, **kwds)
1027
1028    return script_ops.py_func(inner_f, tensor_args, [])
1029
1030  return tf_decorator.make_decorator(f, decorated)
1031
1032
1033def also_run_as_tf_function(f):
1034  """Runs the decorated test twice--once as is, once inside a tf.function.
1035
1036  This allows you to run a test both in eager execution and inside a
1037  tf.function, exercising the two execution modes supported in tf 2.0. The test
1038  assertions are automatically done inside tf.py_funcs, and tf.function ensures
1039  that they run in the proper order and with the proper side effects.
1040
1041  Currently variable creation is not supported in tests annotated with this
1042  decorator since it's tricky to ensure the variable doesn't get repeatedly
1043  created when retracing the tf.function.
1044
1045  Args:
1046    f: the test method to be decorated
1047
1048  Returns:
1049    The decorated test method, which will run both in eager and inside a
1050    tf.function.
1051  """
1052
1053  def decorated(*args, **kwds):
1054    def bound_f():
1055      f(*args, **kwds)
1056    with context.eager_mode():
1057      # Running in eager mode
1058      bound_f()
1059      # Running as TF function
1060      # TODO(b/121143941): Remove the autograph override.
1061      def_function.function(bound_f, autograph=False)()
1062
1063  return decorated
1064
1065
1066def deprecated_graph_mode_only(func=None):
1067  """Execute the decorated test in graph mode.
1068
1069  This function returns a decorator intended to be applied to tests that are not
1070  compatible with eager mode. When this decorator is applied, the test body will
1071  be run in an environment where API calls construct graphs instead of executing
1072  eagerly.
1073
1074  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1075  `run_in_graph_and_eager_modes` are available decorators for different
1076  v1/v2/eager/graph combinations.
1077
1078  Args:
1079    func: function to be annotated. If `func` is None, this method returns a
1080      decorator the can be applied to a function. If `func` is not None this
1081      returns the decorator applied to `func`.
1082
1083  Returns:
1084    Returns a decorator that will run the decorated test method in graph mode.
1085  """
1086
1087  def decorator(f):
1088    if tf_inspect.isclass(f):
1089      setup = f.__dict__.get("setUp")
1090      if setup is not None:
1091        setattr(f, "setUp", decorator(setup))
1092
1093      for name, value in f.__dict__.copy().items():
1094        if (callable(value) and
1095            name.startswith(unittest.TestLoader.testMethodPrefix)):
1096          setattr(f, name, decorator(value))
1097
1098      return f
1099
1100    def decorated(self, *args, **kwargs):
1101      if tf2.enabled():
1102        with context.graph_mode():
1103          return f(self, *args, **kwargs)
1104      else:
1105        return f(self, *args, **kwargs)
1106
1107    return decorated
1108
1109  if func is not None:
1110    return decorator(func)
1111
1112  return decorator
1113
1114
1115run_deprecated_v1 = deprecated_graph_mode_only
1116
1117
1118def run_v1_only(reason, func=None):
1119  """Execute the decorated test only if running in v1 mode.
1120
1121  This function is intended to be applied to tests that exercise v1 only
1122  functionality. If the test is run in v2 mode it will simply be skipped.
1123
1124  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1125  `run_in_graph_and_eager_modes` are available decorators for different
1126  v1/v2/eager/graph combinations.
1127
1128  Args:
1129    reason: string giving a reason for limiting the test to v1 only.
1130    func: function to be annotated. If `func` is None, this method returns a
1131      decorator the can be applied to a function. If `func` is not None this
1132      returns the decorator applied to `func`.
1133
1134  Returns:
1135    Returns a decorator that will conditionally skip the decorated test method.
1136  """
1137
1138  def decorator(f):
1139    if tf_inspect.isclass(f):
1140      setup = f.__dict__.get("setUp")
1141      if setup is not None:
1142        setattr(f, "setUp", decorator(setup))
1143
1144      for name, value in f.__dict__.copy().items():
1145        if (callable(value) and
1146            name.startswith(unittest.TestLoader.testMethodPrefix)):
1147          setattr(f, name, decorator(value))
1148
1149      return f
1150
1151    def decorated(self, *args, **kwargs):
1152      if tf2.enabled():
1153        self.skipTest(reason)
1154
1155      return f(self, *args, **kwargs)
1156
1157    return decorated
1158
1159  if func is not None:
1160    return decorator(func)
1161
1162  return decorator
1163
1164
1165def run_v2_only(func=None):
1166  """Execute the decorated test only if running in v2 mode.
1167
1168  This function is intended to be applied to tests that exercise v2 only
1169  functionality. If the test is run in v1 mode it will simply be skipped.
1170
1171  `deprecated_graph_mode_only`, `run_v1_only`, `run_v2_only`, and
1172  `run_in_graph_and_eager_modes` are available decorators for different
1173  v1/v2/eager/graph combinations.
1174
1175  Args:
1176    func: function to be annotated. If `func` is None, this method returns a
1177      decorator the can be applied to a function. If `func` is not None this
1178      returns the decorator applied to `func`.
1179
1180  Returns:
1181    Returns a decorator that will conditionally skip the decorated test method.
1182  """
1183
1184  def decorator(f):
1185    if tf_inspect.isclass(f):
1186      raise ValueError("`run_v2_only` only supports test methods.")
1187
1188    def decorated(self, *args, **kwargs):
1189      if not tf2.enabled():
1190        self.skipTest("Test is only comptaible in v2")
1191
1192      return f(self, *args, **kwargs)
1193
1194    return decorated
1195
1196  if func is not None:
1197    return decorator(func)
1198
1199  return decorator
1200
1201
1202def run_gpu_only(func=None):
1203  """Execute the decorated test only if a GPU is available.
1204
1205  This function is intended to be applied to tests that require the presence
1206  of a GPU. If a GPU is absent, it will simply be skipped.
1207
1208  Args:
1209    func: function to be annotated. If `func` is None, this method returns a
1210      decorator the can be applied to a function. If `func` is not None this
1211      returns the decorator applied to `func`.
1212
1213  Returns:
1214    Returns a decorator that will conditionally skip the decorated test method.
1215  """
1216
1217  def decorator(f):
1218    if tf_inspect.isclass(f):
1219      raise ValueError("`run_gpu_only` only supports test methods.")
1220
1221    def decorated(self, *args, **kwargs):
1222      if not is_gpu_available():
1223        self.skipTest("Test requires GPU")
1224
1225      return f(self, *args, **kwargs)
1226
1227    return decorated
1228
1229  if func is not None:
1230    return decorator(func)
1231
1232  return decorator
1233
1234
1235def run_cuda_only(func=None):
1236  """Execute the decorated test only if a GPU is available.
1237
1238  This function is intended to be applied to tests that require the precense
1239  of a CUDA GPU. If a CUDA GPU is absent, it will simply be skipped.
1240
1241  Args:
1242    func: function to be annotated. If `func` is None, this method returns a
1243      decorator the can be applied to a function. If `func` is not None this
1244      returns the decorator applied to `func`.
1245
1246  Returns:
1247    Returns a decorator that will conditionally skip the decorated test method.
1248  """
1249
1250  def decorator(f):
1251    if tf_inspect.isclass(f):
1252      raise ValueError("`run_cuda_only` only supports test methods.")
1253
1254    def decorated(self, *args, **kwargs):
1255      if not is_gpu_available(cuda_only=True):
1256        self.skipTest("Test requires CUDA GPU")
1257
1258      return f(self, *args, **kwargs)
1259
1260    return decorated
1261
1262  if func is not None:
1263    return decorator(func)
1264
1265  return decorator
1266
1267
1268@tf_export("test.is_gpu_available")
1269def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
1270  """Returns whether TensorFlow can access a GPU.
1271
1272  Args:
1273    cuda_only: limit the search to CUDA gpus.
1274    min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
1275      CUDA compute capability required, or None if no requirement.
1276
1277  Returns:
1278    True if a gpu device of the requested kind is available.
1279  """
1280
1281  def compute_capability_from_device_desc(device_desc):
1282    # TODO(jingyue): The device description generator has to be in sync with
1283    # this file. Another option is to put compute capability in
1284    # DeviceAttributes, but I avoided that to keep DeviceAttributes
1285    # target-independent. Reconsider this option when we have more things like
1286    # this to keep in sync.
1287    # LINT.IfChange
1288    match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc)
1289    # LINT.ThenChange(//tensorflow/core/\
1290    #                 common_runtime/gpu/gpu_device.cc)
1291    if not match:
1292      return 0, 0
1293    return int(match.group(1)), int(match.group(2))
1294
1295  try:
1296    for local_device in device_lib.list_local_devices():
1297      if local_device.device_type == "GPU":
1298        if (min_cuda_compute_capability is None or
1299            compute_capability_from_device_desc(
1300                local_device.physical_device_desc) >=
1301            min_cuda_compute_capability):
1302          return True
1303      if local_device.device_type == "SYCL" and not cuda_only:
1304        return True
1305    return False
1306  except errors_impl.NotFoundError as e:
1307    if not all(x in str(e) for x in ["CUDA", "not find"]):
1308      raise e
1309    else:
1310      logging.error(str(e))
1311      return False
1312
1313
1314@contextlib.contextmanager
1315def device(use_gpu):
1316  """Uses gpu when requested and available."""
1317  if use_gpu and is_gpu_available():
1318    dev = "/device:GPU:0"
1319  else:
1320    dev = "/device:CPU:0"
1321  with ops.device(dev):
1322    yield
1323
1324
1325@contextlib.contextmanager
1326def use_gpu():
1327  """Uses gpu when requested and available."""
1328  with device(use_gpu=True):
1329    yield
1330
1331
1332@contextlib.contextmanager
1333def force_gpu():
1334  """Force the gpu to be used."""
1335  with ops.device("/device:GPU:0"):
1336    yield
1337
1338
1339@contextlib.contextmanager
1340def force_cpu():
1341  """Force the cpu to be used."""
1342  with ops.device("/device:CPU:0"):
1343    yield
1344
1345
1346class CapturedWrites(object):
1347  """A utility class to load the captured writes made to a stream."""
1348
1349  def __init__(self, capture_location):
1350    self.capture_location = capture_location
1351
1352  def contents(self):
1353    """Get the captured writes as a single string."""
1354    with open(self.capture_location) as tmp_file:
1355      output_data = "".join(tmp_file.readlines())
1356    return output_data
1357
1358
1359class FakeEagerSession(object):
1360  """Fake session so tests that conditionally use placeholders can use eager.
1361
1362  There are a number of tests that conditionally use placeholders for shape
1363  inference. The pattern is demonstrated here:
1364
1365  ```python
1366  with self.cached_session() as sess:
1367    if static_shape:
1368      y = math_ops.matmul(x, ...)
1369      feed_dict = {}
1370    else:
1371      x_ph = array_ops.placeholder(...)
1372      y = math_ops.matmul(x_ph, ...)
1373      feed_dict = {x_ph: x}
1374    val = sess.run(y, feed_dict=feed_dict)
1375  ```
1376
1377  Since the feed_dict is empty when not using placeholders we should be able to
1378  call self.evaluate(), however this requires rewriting the test case.
1379  This class should be considered a stop-gap solution to get tests running with
1380  eager with minimal changes to the actual test.
1381  """
1382
1383  def __init__(self, test_case):
1384    self._test_case = test_case
1385
1386  def run(self, fetches, *args, **kwargs):
1387    """Evalaute `fetches`.
1388
1389    Fail if additional args are specified.
1390
1391    Args:
1392      fetches: A Tensor or a nested list/tuple of Tensors.
1393      *args: Positional arguments
1394      **kwargs: Keyword arguments
1395
1396    Raises:
1397      RuntimeError: If args or kwargs are specified.
1398
1399    Returns:
1400      Tensors as numpy values.
1401    """
1402    feed_dict = kwargs.pop("feed_dict", {})
1403    if feed_dict:
1404      raise RuntimeError(
1405          "feed_dict is not supported when eager execution is enabled "
1406          "(in this case, sess.run(t) is shorthand for t.numpy()")
1407
1408    if args or kwargs:
1409      raise RuntimeError(
1410          "Optional args are not supported when eager execution is enabled "
1411          "(in this case, sess.run(t) is shorthand for t.numpy()")
1412
1413    return self._test_case.evaluate(fetches)
1414
1415
1416class ErrorLoggingSession(session.Session):
1417  """Wrapper around a Session that logs errors in run()."""
1418
1419  def run(self, *args, **kwargs):
1420    try:
1421      return super(ErrorLoggingSession, self).run(*args, **kwargs)
1422    except Exception as e:  # pylint: disable=broad-except
1423      # Note: disable the logging for OutOfRangeError, which makes the output
1424      # of tf.data tests hard to read, because OutOfRangeError is used as the
1425      # signal completion
1426      if not isinstance(e, errors.OutOfRangeError):
1427        logging.error(str(e))
1428      raise
1429
1430
1431def use_deterministic_cudnn(func):
1432  """Disable autotuning during the call to this function.
1433
1434  Some tests want to base assertions on a graph being isomorphic with a copy.
1435  To ensure this, this decorator disables autotuning.
1436
1437  Args:
1438    func: Function to run with CUDNN autotuning turned off.
1439
1440  Returns:
1441    Decorated function.
1442  """
1443
1444  def decorator(f):
1445
1446    def decorated(self, *args, **kwargs):
1447      original_var = os.environ.get("TF_CUDNN_DETERMINISTIC", "")
1448      os.environ["TF_CUDNN_DETERMINISTIC"] = "true"
1449      result = f(self, *args, **kwargs)
1450      os.environ["TF_CUDNN_DETERMINISTIC"] = original_var
1451      return result
1452
1453    return decorated
1454
1455  if func is not None:
1456    return decorator(func)
1457
1458  return decorator
1459
1460
1461# The description is just for documentation purposes.
1462def disable_xla(description):
1463
1464  def disable_xla_impl(func):
1465    """Execute the test method only if xla is not enabled."""
1466
1467    def decorator(func):
1468
1469      def decorated(self, *args, **kwargs):
1470        if is_xla_enabled():
1471          return
1472        else:
1473          return func(self, *args, **kwargs)
1474
1475      return decorated
1476
1477    if func is not None:
1478      return decorator(func)
1479
1480    return decorator
1481
1482  return disable_xla_impl
1483
1484
1485# The description is just for documentation purposes.
1486def disable_all_xla(description):
1487
1488  def disable_all_impl(cls):
1489    """Execute all test methods in this class only if xla is not enabled."""
1490    base_decorator = disable_xla
1491    for name in dir(cls):
1492      value = getattr(cls, name)
1493      if callable(value) and name.startswith(
1494          "test") and not name == "test_session":
1495        setattr(cls, name, base_decorator(description)(value))
1496    return cls
1497
1498  return disable_all_impl
1499
1500
1501class EagerSessionWarner(object):
1502
1503  def __getattr__(self, attr):
1504    raise AttributeError(
1505        "Trying to access properties or call methods on the result of "
1506        "self.session(), self.cached_session(), etc while eager execution "
1507        "is enabled. If you're porting this test case to TF 2.0, either "
1508        "adapt the test to work with eager execution or insert a call to "
1509        "tf.disable_eager_execution() in the main() function of this test "
1510        "file.")
1511
1512
1513@tf_export("test.TestCase")
1514class TensorFlowTestCase(googletest.TestCase):
1515  """Base class for tests that need to test TensorFlow."""
1516
1517  def __init__(self, methodName="runTest"):  # pylint: disable=invalid-name
1518    super(TensorFlowTestCase, self).__init__(methodName)
1519    if is_xla_enabled():
1520      os.putenv(
1521          "TF_XLA_FLAGS", "--tf_xla_auto_jit=2 --tf_xla_min_cluster_size=1 "
1522          "--tf_xla_enable_lazy_compilation=false " +
1523          os.getenv("TF_XLA_FLAGS", ""))
1524    self._threads = []
1525    self._tempdir = None
1526    self._cached_session = None
1527
1528  def setUp(self):
1529    self._ClearCachedSession()
1530    random.seed(random_seed.DEFAULT_GRAPH_SEED)
1531    np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
1532    # Note: The following line is necessary because some test methods may error
1533    # out from within nested graph contexts (e.g., via assertRaises and
1534    # assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
1535    # under certain versions of Python. That would cause
1536    # ops.reset_default_graph() to throw an exception if the stack were not
1537    # cleared first.
1538    ops._default_graph_stack.reset()  # pylint: disable=protected-access
1539    ops.reset_default_graph()
1540    random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
1541
1542    # Avoiding calling setUp() for the poorly named test_session method.
1543    if self.id().endswith(".test_session"):
1544      self.skipTest("Not a test.")
1545
1546  def tearDown(self):
1547    for thread in self._threads:
1548      thread.check_termination()
1549
1550    self._ClearCachedSession()
1551
1552  def _ClearCachedSession(self):
1553    if self._cached_session is not None:
1554      self._cached_session.close()
1555      self._cached_session = None
1556
1557  def get_temp_dir(self):
1558    """Returns a unique temporary directory for the test to use.
1559
1560    If you call this method multiple times during in a test, it will return the
1561    same folder. However, across different runs the directories will be
1562    different. This will ensure that across different runs tests will not be
1563    able to pollute each others environment.
1564    If you need multiple unique directories within a single test, you should
1565    use tempfile.mkdtemp as follows:
1566      tempfile.mkdtemp(dir=self.get_temp_dir()):
1567
1568    Returns:
1569      string, the path to the unique temporary directory created for this test.
1570    """
1571    if not self._tempdir:
1572      self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
1573    return self._tempdir
1574
1575  @contextlib.contextmanager
1576  def captureWritesToStream(self, stream):
1577    """A context manager that captures the writes to a given stream.
1578
1579    This context manager captures all writes to a given stream inside of a
1580    `CapturedWrites` object. When this context manager is created, it yields
1581    the `CapturedWrites` object. The captured contents can be accessed  by
1582    calling `.contents()` on the `CapturedWrites`.
1583
1584    For this function to work, the stream must have a file descriptor that
1585    can be modified using `os.dup` and `os.dup2`, and the stream must support
1586    a `.flush()` method. The default python sys.stdout and sys.stderr are
1587    examples of this. Note that this does not work in Colab or Jupyter
1588    notebooks, because those use alternate stdout streams.
1589
1590    Example:
1591    ```python
1592    class MyOperatorTest(test_util.TensorFlowTestCase):
1593      def testMyOperator(self):
1594        input = [1.0, 2.0, 3.0, 4.0, 5.0]
1595        with self.captureWritesToStream(sys.stdout) as captured:
1596          result = MyOperator(input).eval()
1597        self.assertStartsWith(captured.contents(), "This was printed.")
1598    ```
1599
1600    Args:
1601      stream: The stream whose writes should be captured. This stream must have
1602        a file descriptor, support writing via using that file descriptor, and
1603        must have a `.flush()` method.
1604
1605    Yields:
1606      A `CapturedWrites` object that contains all writes to the specified stream
1607      made during this context.
1608    """
1609    stream.flush()
1610    fd = stream.fileno()
1611    tmp_file_path = tempfile.mktemp(dir=self.get_temp_dir())
1612    tmp_file = open(tmp_file_path, "w")
1613    orig_fd = os.dup(fd)
1614    os.dup2(tmp_file.fileno(), fd)
1615    try:
1616      yield CapturedWrites(tmp_file_path)
1617    finally:
1618      tmp_file.close()
1619      os.dup2(orig_fd, fd)
1620
1621  def _AssertProtoEquals(self, a, b, msg=None):
1622    """Asserts that a and b are the same proto.
1623
1624    Uses ProtoEq() first, as it returns correct results
1625    for floating point attributes, and then use assertProtoEqual()
1626    in case of failure as it provides good error messages.
1627
1628    Args:
1629      a: a proto.
1630      b: another proto.
1631      msg: Optional message to report on failure.
1632    """
1633    if not compare.ProtoEq(a, b):
1634      compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
1635
1636  def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
1637    """Asserts that message is same as parsed expected_message_ascii.
1638
1639    Creates another prototype of message, reads the ascii message into it and
1640    then compares them using self._AssertProtoEqual().
1641
1642    Args:
1643      expected_message_maybe_ascii: proto message in original or ascii form.
1644      message: the message to validate.
1645      msg: Optional message to report on failure.
1646    """
1647    msg = msg if msg else ""
1648    if isinstance(expected_message_maybe_ascii, type(message)):
1649      expected_message = expected_message_maybe_ascii
1650      self._AssertProtoEquals(expected_message, message)
1651    elif isinstance(expected_message_maybe_ascii, str):
1652      expected_message = type(message)()
1653      text_format.Merge(
1654          expected_message_maybe_ascii,
1655          expected_message,
1656          descriptor_pool=descriptor_pool.Default())
1657      self._AssertProtoEquals(expected_message, message, msg=msg)
1658    else:
1659      assert False, ("Can't compare protos of type %s and %s. %s" %
1660                     (type(expected_message_maybe_ascii), type(message), msg))
1661
1662  def assertProtoEqualsVersion(
1663      self,
1664      expected,
1665      actual,
1666      producer=versions.GRAPH_DEF_VERSION,
1667      min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
1668      msg=None):
1669    expected = "versions { producer: %d min_consumer: %d };\n%s" % (
1670        producer, min_consumer, expected)
1671    self.assertProtoEquals(expected, actual, msg=msg)
1672
1673  def assertStartsWith(self, actual, expected_start, msg=None):
1674    """Assert that actual.startswith(expected_start) is True.
1675
1676    Args:
1677      actual: str
1678      expected_start: str
1679      msg: Optional message to report on failure.
1680    """
1681    if not actual.startswith(expected_start):
1682      fail_msg = "%r does not start with %r" % (actual, expected_start)
1683      fail_msg += " : %r" % (msg) if msg else ""
1684      self.fail(fail_msg)
1685
1686  def _eval_tensor(self, tensor):
1687    if tensor is None:
1688      return None
1689    elif callable(tensor):
1690      return self._eval_helper(tensor())
1691    else:
1692      try:
1693        if sparse_tensor.is_sparse(tensor):
1694          return sparse_tensor.SparseTensorValue(tensor.indices.numpy(),
1695                                                 tensor.values.numpy(),
1696                                                 tensor.dense_shape.numpy())
1697        elif isinstance(tensor, ops.IndexedSlices):
1698          return ops.IndexedSlicesValue(values=tensor.values.numpy(),
1699                                        indices=tensor.indices.numpy(),
1700                                        dense_shape=tensor.dense_shape.numpy())
1701        return tensor.numpy()
1702      except AttributeError as e:
1703        six.raise_from(ValueError("Unsupported type %s." % type(tensor)), e)
1704
1705  def _eval_helper(self, tensors):
1706    if tensors is None:
1707      return None
1708    return nest.map_structure(self._eval_tensor, tensors)
1709
1710  def evaluate(self, tensors):
1711    """Evaluates tensors and returns numpy values.
1712
1713    Args:
1714      tensors: A Tensor or a nested list/tuple of Tensors.
1715
1716    Returns:
1717      tensors numpy values.
1718    """
1719    if context.executing_eagerly():
1720      return self._eval_helper(tensors)
1721    else:
1722      sess = ops.get_default_session()
1723      if sess is None:
1724        with self.test_session() as sess:
1725          return sess.run(tensors)
1726      else:
1727        return sess.run(tensors)
1728
1729  # pylint: disable=g-doc-return-or-yield
1730  @contextlib.contextmanager
1731  def session(self, graph=None, config=None, use_gpu=False, force_gpu=False):
1732    """Returns a TensorFlow Session for use in executing tests.
1733
1734    Note that this will set this session and the graph as global defaults.
1735
1736    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
1737    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
1738    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
1739    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
1740    the CPU.
1741
1742    Example:
1743    ```python
1744    class MyOperatorTest(test_util.TensorFlowTestCase):
1745      def testMyOperator(self):
1746        with self.session(use_gpu=True):
1747          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
1748          result = MyOperator(valid_input).eval()
1749          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
1750          invalid_input = [-1.0, 2.0, 7.0]
1751          with self.assertRaisesOpError("negative input not supported"):
1752            MyOperator(invalid_input).eval()
1753    ```
1754
1755    Args:
1756      graph: Optional graph to use during the returned session.
1757      config: An optional config_pb2.ConfigProto to use to configure the
1758        session.
1759      use_gpu: If True, attempt to run as many ops as possible on GPU.
1760      force_gpu: If True, pin all ops to `/device:GPU:0`.
1761
1762    Yields:
1763      A Session object that should be used as a context manager to surround
1764      the graph building and execution code in a test case.
1765    """
1766    if context.executing_eagerly():
1767      yield EagerSessionWarner()
1768    else:
1769      with self._create_session(graph, config, force_gpu) as sess:
1770        with self._constrain_devices_and_set_default(sess, use_gpu, force_gpu):
1771          yield sess
1772
1773  @contextlib.contextmanager
1774  def cached_session(self,
1775                     graph=None,
1776                     config=None,
1777                     use_gpu=False,
1778                     force_gpu=False):
1779    """Returns a TensorFlow Session for use in executing tests.
1780
1781    This method behaves differently than self.session(): for performance reasons
1782    `cached_session` will by default reuse the same session within the same
1783    test. The session returned by this function will only be closed at the end
1784    of the test (in the TearDown function).
1785
1786    Use the `use_gpu` and `force_gpu` options to control where ops are run. If
1787    `force_gpu` is True, all ops are pinned to `/device:GPU:0`. Otherwise, if
1788    `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
1789    possible. If both `force_gpu and `use_gpu` are False, all ops are pinned to
1790    the CPU.
1791
1792    Example:
1793    ```python
1794    class MyOperatorTest(test_util.TensorFlowTestCase):
1795      def testMyOperator(self):
1796        with self.cached_session(use_gpu=True) as sess:
1797          valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
1798          result = MyOperator(valid_input).eval()
1799          self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
1800          invalid_input = [-1.0, 2.0, 7.0]
1801          with self.assertRaisesOpError("negative input not supported"):
1802            MyOperator(invalid_input).eval()
1803    ```
1804
1805    Args:
1806      graph: Optional graph to use during the returned session.
1807      config: An optional config_pb2.ConfigProto to use to configure the
1808        session.
1809      use_gpu: If True, attempt to run as many ops as possible on GPU.
1810      force_gpu: If True, pin all ops to `/device:GPU:0`.
1811
1812    Yields:
1813      A Session object that should be used as a context manager to surround
1814      the graph building and execution code in a test case.
1815    """
1816    if context.executing_eagerly():
1817      yield FakeEagerSession(self)
1818    else:
1819      sess = self._get_cached_session(
1820          graph, config, force_gpu, crash_if_inconsistent_args=True)
1821      with self._constrain_devices_and_set_default(sess, use_gpu,
1822                                                   force_gpu) as cached:
1823        yield cached
1824
1825  @contextlib.contextmanager
1826  @deprecation.deprecated(None, "Use `self.session()` or "
1827                          "`self.cached_session()` instead.")
1828  def test_session(self,
1829                   graph=None,
1830                   config=None,
1831                   use_gpu=False,
1832                   force_gpu=False):
1833    """Use cached_session instead."""
1834    if self.id().endswith(".test_session"):
1835      self.skipTest("Not a test.")
1836    if context.executing_eagerly():
1837      yield None
1838    else:
1839      if graph is None:
1840        sess = self._get_cached_session(
1841            graph, config, force_gpu, crash_if_inconsistent_args=False)
1842        with self._constrain_devices_and_set_default(sess, use_gpu,
1843                                                     force_gpu) as cached:
1844          yield cached
1845      else:
1846        with self.session(graph, config, use_gpu, force_gpu) as sess:
1847          yield sess
1848
1849  # pylint: enable=g-doc-return-or-yield
1850
1851  class _CheckedThread(object):
1852    """A wrapper class for Thread that asserts successful completion.
1853
1854    This class should be created using the TensorFlowTestCase.checkedThread()
1855    method.
1856    """
1857
1858    def __init__(self, testcase, target, args=None, kwargs=None):
1859      """Constructs a new instance of _CheckedThread.
1860
1861      Args:
1862        testcase: The TensorFlowTestCase for which this thread is being created.
1863        target: A callable object representing the code to be executed in the
1864          thread.
1865        args: A tuple of positional arguments that will be passed to target.
1866        kwargs: A dictionary of keyword arguments that will be passed to target.
1867      """
1868      self._testcase = testcase
1869      self._target = target
1870      self._args = () if args is None else args
1871      self._kwargs = {} if kwargs is None else kwargs
1872      self._thread = threading.Thread(target=self._protected_run)
1873      self._exception = None
1874
1875      self._is_thread_joined = False
1876
1877    def _protected_run(self):
1878      """Target for the wrapper thread. Sets self._exception on failure."""
1879      try:
1880        self._target(*self._args, **self._kwargs)
1881      except Exception as e:  # pylint: disable=broad-except
1882        self._exception = e
1883
1884    def start(self):
1885      """Starts the thread's activity.
1886
1887      This must be called at most once per _CheckedThread object. It arranges
1888      for the object's target to be invoked in a separate thread of control.
1889      """
1890      self._thread.start()
1891
1892    def join(self):
1893      """Blocks until the thread terminates.
1894
1895      Raises:
1896        self._testcase.failureException: If the thread terminates with due to
1897          an exception.
1898      """
1899      self._is_thread_joined = True
1900      self._thread.join()
1901      if self._exception is not None:
1902        self._testcase.fail("Error in checkedThread: %s" % str(self._exception))
1903
1904    def is_alive(self):
1905      """Returns whether the thread is alive.
1906
1907      This method returns True just before the run() method starts
1908      until just after the run() method terminates.
1909
1910      Returns:
1911        True if the thread is alive, otherwise False.
1912      """
1913      return self._thread.is_alive()
1914
1915    def check_termination(self):
1916      """Returns whether the checked thread was properly used and did terminate.
1917
1918      Every checked thread should be "join"ed after starting, and before the
1919      test tears down. If it is not joined, it is possible the thread will hang
1920      and cause flaky failures in tests.
1921
1922      Raises:
1923        self._testcase.failureException: If check_termination was called before
1924        thread was joined.
1925
1926        RuntimeError: If the thread is not terminated. This means thread was not
1927        joined with the main thread.
1928      """
1929      if self._is_thread_joined:
1930        if self.is_alive():
1931          raise RuntimeError(
1932              "Thread was not joined with main thread, and is still running "
1933              "when the test finished.")
1934      else:
1935        self._testcase.fail("A checked thread was not joined.")
1936
1937  def checkedThread(self, target, args=None, kwargs=None):
1938    """Returns a Thread wrapper that asserts 'target' completes successfully.
1939
1940    This method should be used to create all threads in test cases, as
1941    otherwise there is a risk that a thread will silently fail, and/or
1942    assertions made in the thread will not be respected.
1943
1944    Args:
1945      target: A callable object to be executed in the thread.
1946      args: The argument tuple for the target invocation. Defaults to ().
1947      kwargs: A dictionary of keyword arguments for the target invocation.
1948        Defaults to {}.
1949
1950    Returns:
1951      A wrapper for threading.Thread that supports start() and join() methods.
1952    """
1953    ret = TensorFlowTestCase._CheckedThread(self, target, args, kwargs)
1954    self._threads.append(ret)
1955    return ret
1956
1957  # pylint: enable=invalid-name
1958  @py_func_if_in_function
1959  def assertNear(self, f1, f2, err, msg=None):
1960    """Asserts that two floats are near each other.
1961
1962    Checks that |f1 - f2| < err and asserts a test failure
1963    if not.
1964
1965    Args:
1966      f1: A float value.
1967      f2: A float value.
1968      err: A float value.
1969      msg: An optional string message to append to the failure message.
1970    """
1971    # f1 == f2 is needed here as we might have: f1, f2 = inf, inf
1972    self.assertTrue(
1973        f1 == f2 or math.fabs(f1 - f2) <= err, "%f != %f +/- %f%s" %
1974        (f1, f2, err, " (%s)" % msg if msg is not None else ""))
1975
1976  @py_func_if_in_function
1977  def assertArrayNear(self, farray1, farray2, err, msg=None):
1978    """Asserts that two float arrays are near each other.
1979
1980    Checks that for all elements of farray1 and farray2
1981    |f1 - f2| < err.  Asserts a test failure if not.
1982
1983    Args:
1984      farray1: a list of float values.
1985      farray2: a list of float values.
1986      err: a float value.
1987      msg: Optional message to report on failure.
1988    """
1989    self.assertEqual(len(farray1), len(farray2), msg=msg)
1990    for f1, f2 in zip(farray1, farray2):
1991      self.assertNear(float(f1), float(f2), err, msg=msg)
1992
1993  def _NDArrayNear(self, ndarray1, ndarray2, err):
1994    return np.linalg.norm(ndarray1 - ndarray2) < err
1995
1996  @py_func_if_in_function
1997  def assertNDArrayNear(self, ndarray1, ndarray2, err, msg=None):
1998    """Asserts that two numpy arrays have near values.
1999
2000    Args:
2001      ndarray1: a numpy ndarray.
2002      ndarray2: a numpy ndarray.
2003      err: a float. The maximum absolute difference allowed.
2004      msg: Optional message to report on failure.
2005    """
2006    self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
2007
2008  def _GetNdArray(self, a):
2009    # If a is a tensor then convert it to ndarray
2010    if isinstance(a, ops.Tensor):
2011      if isinstance(a, ops._EagerTensorBase):
2012        a = a.numpy()
2013      else:
2014        a = self.evaluate(a)
2015    if not isinstance(a, np.ndarray):
2016      return np.array(a)
2017    return a
2018
2019  def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2020    a = self._GetNdArray(a)
2021    b = self._GetNdArray(b)
2022    # When the array rank is small, print its contents. Numpy array printing is
2023    # implemented using inefficient recursion so prints can cause tests to
2024    # time out.
2025    if a.shape != b.shape and (b.ndim <= 3 or b.size < 500):
2026      shape_mismatch_msg = ("Shape mismatch: expected %s, got %s with contents "
2027                            "%s.") % (a.shape, b.shape, b)
2028    else:
2029      shape_mismatch_msg = "Shape mismatch: expected %s, got %s." % (a.shape,
2030                                                                     b.shape)
2031    self.assertEqual(a.shape, b.shape, shape_mismatch_msg)
2032
2033    msgs = [msg]
2034    if not np.allclose(a, b, rtol=rtol, atol=atol):
2035      # Adds more details to np.testing.assert_allclose.
2036      #
2037      # NOTE: numpy.allclose (and numpy.testing.assert_allclose)
2038      # checks whether two arrays are element-wise equal within a
2039      # tolerance. The relative difference (rtol * abs(b)) and the
2040      # absolute difference atol are added together to compare against
2041      # the absolute difference between a and b.  Here, we want to
2042      # tell user which elements violate such conditions.
2043      cond = np.logical_or(
2044          np.abs(a - b) > atol + rtol * np.abs(b),
2045          np.isnan(a) != np.isnan(b))
2046      if a.ndim:
2047        x = a[np.where(cond)]
2048        y = b[np.where(cond)]
2049        msgs.append("not close where = {}".format(np.where(cond)))
2050      else:
2051        # np.where is broken for scalars
2052        x, y = a, b
2053      msgs.append("not close lhs = {}".format(x))
2054      msgs.append("not close rhs = {}".format(y))
2055      msgs.append("not close dif = {}".format(np.abs(x - y)))
2056      msgs.append("not close tol = {}".format(atol + rtol * np.abs(y)))
2057      msgs.append("dtype = {}, shape = {}".format(a.dtype, a.shape))
2058      # TODO(xpan): There seems to be a bug:
2059      # tensorflow/compiler/tests:binary_ops_test pass with float32
2060      # nan even though the equal_nan is False by default internally.
2061      np.testing.assert_allclose(
2062          a, b, rtol=rtol, atol=atol, err_msg="\n".join(msgs), equal_nan=True)
2063
2064  def _assertAllCloseRecursive(self,
2065                               a,
2066                               b,
2067                               rtol=1e-6,
2068                               atol=1e-6,
2069                               path=None,
2070                               msg=None):
2071    path = path or []
2072    path_str = (("[" + "][".join([str(p) for p in path]) + "]") if path else "")
2073    msg = msg if msg else ""
2074
2075    # Check if a and/or b are namedtuples.
2076    if hasattr(a, "_asdict"):
2077      a = a._asdict()
2078    if hasattr(b, "_asdict"):
2079      b = b._asdict()
2080    a_is_dict = isinstance(a, collections.Mapping)
2081    if a_is_dict != isinstance(b, collections.Mapping):
2082      raise ValueError("Can't compare dict to non-dict, a%s vs b%s. %s" %
2083                       (path_str, path_str, msg))
2084    if a_is_dict:
2085      self.assertItemsEqual(
2086          a.keys(),
2087          b.keys(),
2088          msg="mismatched keys: a%s has keys %s, but b%s has keys %s. %s" %
2089          (path_str, a.keys(), path_str, b.keys(), msg))
2090      for k in a:
2091        path.append(k)
2092        self._assertAllCloseRecursive(
2093            a[k], b[k], rtol=rtol, atol=atol, path=path, msg=msg)
2094        del path[-1]
2095    elif isinstance(a, (list, tuple)):
2096      # Try to directly compare a, b as ndarrays; if not work, then traverse
2097      # through the sequence, which is more expensive.
2098      try:
2099        a_as_ndarray = self._GetNdArray(a)
2100        b_as_ndarray = self._GetNdArray(b)
2101        self._assertArrayLikeAllClose(
2102            a_as_ndarray,
2103            b_as_ndarray,
2104            rtol=rtol,
2105            atol=atol,
2106            msg="Mismatched value: a%s is different from b%s. %s" %
2107            (path_str, path_str, msg))
2108      except (ValueError, TypeError) as e:
2109        if len(a) != len(b):
2110          raise ValueError(
2111              "Mismatched length: a%s has %d items, but b%s has %d items. %s" %
2112              (path_str, len(a), path_str, len(b), msg))
2113        for idx, (a_ele, b_ele) in enumerate(zip(a, b)):
2114          path.append(str(idx))
2115          self._assertAllCloseRecursive(
2116              a_ele, b_ele, rtol=rtol, atol=atol, path=path, msg=msg)
2117          del path[-1]
2118    # a and b are ndarray like objects
2119    else:
2120      try:
2121        self._assertArrayLikeAllClose(
2122            a,
2123            b,
2124            rtol=rtol,
2125            atol=atol,
2126            msg=("Mismatched value: a%s is different from b%s. %s" %
2127                 (path_str, path_str, msg)))
2128      except TypeError as e:
2129        msg = ("Error: a%s has %s, but b%s has %s. %s" %
2130               (path_str, type(a), path_str, type(b), msg))
2131        e.args = ((e.args[0] + " : " + msg,) + e.args[1:])
2132        raise
2133
2134  @py_func_if_in_function
2135  def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
2136    """Asserts that two structures of numpy arrays or Tensors, have near values.
2137
2138    `a` and `b` can be arbitrarily nested structures. A layer of a nested
2139    structure can be a `dict`, `namedtuple`, `tuple` or `list`.
2140
2141    Args:
2142      a: The expected numpy `ndarray`, or anything that can be converted into a
2143        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2144        structure of these.
2145      b: The actual numpy `ndarray`, or anything that can be converted into a
2146        numpy `ndarray` (including Tensor), or any arbitrarily nested of
2147        structure of these.
2148      rtol: relative tolerance.
2149      atol: absolute tolerance.
2150      msg: Optional message to report on failure.
2151
2152    Raises:
2153      ValueError: if only one of `a[p]` and `b[p]` is a dict or
2154          `a[p]` and `b[p]` have different length, where `[p]` denotes a path
2155          to the nested structure, e.g. given `a = [(1, 1), {'d': (6, 7)}]` and
2156          `[p] = [1]['d']`, then `a[p] = (6, 7)`.
2157    """
2158    self._assertAllCloseRecursive(a, b, rtol=rtol, atol=atol, msg=msg)
2159
2160  @py_func_if_in_function
2161  def assertAllCloseAccordingToType(self,
2162                                    a,
2163                                    b,
2164                                    rtol=1e-6,
2165                                    atol=1e-6,
2166                                    float_rtol=1e-6,
2167                                    float_atol=1e-6,
2168                                    half_rtol=1e-3,
2169                                    half_atol=1e-3,
2170                                    bfloat16_rtol=1e-2,
2171                                    bfloat16_atol=1e-2,
2172                                    msg=None):
2173    """Like assertAllClose, but also suitable for comparing fp16 arrays.
2174
2175    In particular, the tolerance is reduced to 1e-3 if at least
2176    one of the arguments is of type float16.
2177
2178    Args:
2179      a: the expected numpy ndarray or anything can be converted to one.
2180      b: the actual numpy ndarray or anything can be converted to one.
2181      rtol: relative tolerance.
2182      atol: absolute tolerance.
2183      float_rtol: relative tolerance for float32.
2184      float_atol: absolute tolerance for float32.
2185      half_rtol: relative tolerance for float16.
2186      half_atol: absolute tolerance for float16.
2187      bfloat16_rtol: relative tolerance for bfloat16.
2188      bfloat16_atol: absolute tolerance for bfloat16.
2189      msg: Optional message to report on failure.
2190    """
2191    a = self._GetNdArray(a)
2192    b = self._GetNdArray(b)
2193    # types with lower tol are put later to overwrite previous ones.
2194    if (a.dtype == np.float32 or b.dtype == np.float32 or
2195        a.dtype == np.complex64 or b.dtype == np.complex64):
2196      rtol = max(rtol, float_rtol)
2197      atol = max(atol, float_atol)
2198    if a.dtype == np.float16 or b.dtype == np.float16:
2199      rtol = max(rtol, half_rtol)
2200      atol = max(atol, half_atol)
2201    if (a.dtype == dtypes.bfloat16.as_numpy_dtype or
2202        b.dtype == dtypes.bfloat16.as_numpy_dtype):
2203      rtol = max(rtol, bfloat16_rtol)
2204      atol = max(atol, bfloat16_atol)
2205
2206    self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
2207
2208  @py_func_if_in_function
2209  def assertNotAllClose(self, a, b, **kwargs):
2210    """Assert that two numpy arrays, or Tensors, do not have near values.
2211
2212    Args:
2213      a: the first value to compare.
2214      b: the second value to compare.
2215      **kwargs: additional keyword arguments to be passed to the underlying
2216        `assertAllClose` call.
2217
2218    Raises:
2219      AssertionError: If `a` and `b` are unexpectedly close at all elements.
2220    """
2221    try:
2222      self.assertAllClose(a, b, **kwargs)
2223    except AssertionError:
2224      return
2225    raise AssertionError("The two values are close at all elements")
2226
2227  @py_func_if_in_function
2228  def assertAllEqual(self, a, b, msg=None):
2229    """Asserts that two numpy arrays or Tensors have the same values.
2230
2231    Args:
2232      a: the expected numpy ndarray or anything can be converted to one.
2233      b: the actual numpy ndarray or anything can be converted to one.
2234      msg: Optional message to report on failure.
2235    """
2236    msg = msg if msg else ""
2237    a = self._GetNdArray(a)
2238    b = self._GetNdArray(b)
2239    # Arbitrary bounds so that we don't print giant tensors.
2240    if (b.ndim <= 3 or b.size < 500):
2241      self.assertEqual(
2242          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2243          " Contents: %s. \n%s." % (a.shape, b.shape, b, msg))
2244    else:
2245      self.assertEqual(
2246          a.shape, b.shape, "Shape mismatch: expected %s, got %s."
2247          " %s" % (a.shape, b.shape, msg))
2248
2249    same = (a == b)
2250
2251    if (a.dtype in [
2252        np.float16, np.float32, np.float64, dtypes.bfloat16.as_numpy_dtype
2253    ]):
2254      same = np.logical_or(same, np.logical_and(np.isnan(a), np.isnan(b)))
2255    msgs = [msg]
2256    if not np.all(same):
2257      # Adds more details to np.testing.assert_array_equal.
2258      diff = np.logical_not(same)
2259      if a.ndim:
2260        x = a[np.where(diff)]
2261        y = b[np.where(diff)]
2262        msgs.append("not equal where = {}".format(np.where(diff)))
2263      else:
2264        # np.where is broken for scalars
2265        x, y = a, b
2266      msgs.append("not equal lhs = {}".format(x))
2267      msgs.append("not equal rhs = {}".format(y))
2268      np.testing.assert_array_equal(a, b, err_msg="\n".join(msgs))
2269
2270  @py_func_if_in_function
2271  def assertAllGreater(self, a, comparison_target):
2272    """Assert element values are all greater than a target value.
2273
2274    Args:
2275      a: The numpy `ndarray`, or anything that can be converted into a numpy
2276        `ndarray` (including Tensor).
2277      comparison_target: The target value of comparison.
2278    """
2279    a = self._GetNdArray(a)
2280    self.assertGreater(np.min(a), comparison_target)
2281
2282  @py_func_if_in_function
2283  def assertAllLess(self, a, comparison_target):
2284    """Assert element values are all less than a target value.
2285
2286    Args:
2287      a: The numpy `ndarray`, or anything that can be converted into a numpy
2288        `ndarray` (including Tensor).
2289      comparison_target: The target value of comparison.
2290    """
2291    a = self._GetNdArray(a)
2292    self.assertLess(np.max(a), comparison_target)
2293
2294  @py_func_if_in_function
2295  def assertAllGreaterEqual(self, a, comparison_target):
2296    """Assert element values are all greater than or equal to a target value.
2297
2298    Args:
2299      a: The numpy `ndarray`, or anything that can be converted into a numpy
2300        `ndarray` (including Tensor).
2301      comparison_target: The target value of comparison.
2302    """
2303    a = self._GetNdArray(a)
2304    self.assertGreaterEqual(np.min(a), comparison_target)
2305
2306  @py_func_if_in_function
2307  def assertAllLessEqual(self, a, comparison_target):
2308    """Assert element values are all less than or equal to a target value.
2309
2310    Args:
2311      a: The numpy `ndarray`, or anything that can be converted into a numpy
2312        `ndarray` (including Tensor).
2313      comparison_target: The target value of comparison.
2314    """
2315    a = self._GetNdArray(a)
2316    self.assertLessEqual(np.max(a), comparison_target)
2317
2318  def _format_subscripts(self, subscripts, value, limit=10, indent=2):
2319    """Generate a summary of ndarray subscripts as a list of str.
2320
2321    If limit == N, this method will print up to the first N subscripts on
2322    separate
2323    lines. A line of ellipses (...) will be appended at the end if the number of
2324    subscripts exceeds N.
2325
2326    Args:
2327      subscripts: The tensor (np.ndarray) subscripts, of the same format as
2328        np.where()'s return value, i.e., a tuple of arrays with each array
2329        corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
2330      value: (np.ndarray) value of the tensor.
2331      limit: (int) The maximum number of indices to print.
2332      indent: (int) Number of characters to indent at the beginning of each
2333        line.
2334
2335    Returns:
2336      (list of str) the multi-line representation of the subscripts and values,
2337        potentially with omission at the end.
2338    """
2339    lines = []
2340    subscripts = np.transpose(subscripts)
2341    prefix = " " * indent
2342    for subscript in itertools.islice(subscripts, limit):
2343      lines.append(prefix + str(subscript) + " : " +
2344                   str(value[tuple(subscript)]))
2345    if len(subscripts) > limit:
2346      lines.append(prefix + "...")
2347    return lines
2348
2349  @py_func_if_in_function
2350  def assertAllInRange(self,
2351                       target,
2352                       lower_bound,
2353                       upper_bound,
2354                       open_lower_bound=False,
2355                       open_upper_bound=False):
2356    """Assert that elements in a Tensor are all in a given range.
2357
2358    Args:
2359      target: The numpy `ndarray`, or anything that can be converted into a
2360        numpy `ndarray` (including Tensor).
2361      lower_bound: lower bound of the range
2362      upper_bound: upper bound of the range
2363      open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
2364        than the default >=)
2365      open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
2366        than the default <=)
2367
2368    Raises:
2369      AssertionError:
2370        if the value tensor does not have an ordered numeric type (float* or
2371          int*), or
2372        if there are nan values, or
2373        if any of the elements do not fall in the specified range.
2374    """
2375    target = self._GetNdArray(target)
2376    if not (np.issubdtype(target.dtype, np.floating) or
2377            np.issubdtype(target.dtype, np.integer)):
2378      raise AssertionError(
2379          "The value of %s does not have an ordered numeric type, instead it "
2380          "has type: %s" % (target, target.dtype))
2381
2382    nan_subscripts = np.where(np.isnan(target))
2383    if np.size(nan_subscripts):
2384      raise AssertionError(
2385          "%d of the %d element(s) are NaN. "
2386          "Subscripts(s) and value(s) of the NaN element(s):\n" %
2387          (len(nan_subscripts[0]), np.size(target)) +
2388          "\n".join(self._format_subscripts(nan_subscripts, target)))
2389
2390    range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
2391                 str(upper_bound) + (")" if open_upper_bound else "]"))
2392
2393    violations = (
2394        np.less_equal(target, lower_bound) if open_lower_bound else np.less(
2395            target, lower_bound))
2396    violations = np.logical_or(
2397        violations,
2398        np.greater_equal(target, upper_bound)
2399        if open_upper_bound else np.greater(target, upper_bound))
2400    violation_subscripts = np.where(violations)
2401    if np.size(violation_subscripts):
2402      raise AssertionError(
2403          "%d of the %d element(s) are outside the range %s. " %
2404          (len(violation_subscripts[0]), np.size(target), range_str) +
2405          "Subscript(s) and value(s) of the offending elements:\n" +
2406          "\n".join(self._format_subscripts(violation_subscripts, target)))
2407
2408  @py_func_if_in_function
2409  def assertAllInSet(self, target, expected_set):
2410    """Assert that elements of a Tensor are all in a given closed set.
2411
2412    Args:
2413      target: The numpy `ndarray`, or anything that can be converted into a
2414        numpy `ndarray` (including Tensor).
2415      expected_set: (`list`, `tuple` or `set`) The closed set that the elements
2416        of the value of `target` are expected to fall into.
2417
2418    Raises:
2419      AssertionError:
2420        if any of the elements do not fall into `expected_set`.
2421    """
2422    target = self._GetNdArray(target)
2423
2424    # Elements in target that are not in expected_set.
2425    diff = np.setdiff1d(target.flatten(), list(expected_set))
2426    if np.size(diff):
2427      raise AssertionError("%d unique element(s) are not in the set %s: %s" %
2428                           (np.size(diff), expected_set, diff))
2429
2430  @py_func_if_in_function
2431  def assertDTypeEqual(self, target, expected_dtype):
2432    """Assert ndarray data type is equal to expected.
2433
2434    Args:
2435      target: The numpy `ndarray`, or anything that can be converted into a
2436        numpy `ndarray` (including Tensor).
2437      expected_dtype: Expected data type.
2438    """
2439    target = self._GetNdArray(target)
2440    if not isinstance(target, list):
2441      arrays = [target]
2442    for arr in arrays:
2443      self.assertEqual(arr.dtype, expected_dtype)
2444
2445  # pylint: disable=g-doc-return-or-yield
2446  @contextlib.contextmanager
2447  def assertRaisesWithPredicateMatch(self, exception_type,
2448                                     expected_err_re_or_predicate):
2449    """Returns a context manager to enclose code expected to raise an exception.
2450
2451    If the exception is an OpError, the op stack is also included in the message
2452    predicate search.
2453
2454    Args:
2455      exception_type: The expected type of exception that should be raised.
2456      expected_err_re_or_predicate: If this is callable, it should be a function
2457        of one argument that inspects the passed-in exception and returns True
2458        (success) or False (please fail the test). Otherwise, the error message
2459        is expected to match this regular expression partially.
2460
2461    Returns:
2462      A context manager to surround code that is expected to raise an
2463      exception.
2464    """
2465    if callable(expected_err_re_or_predicate):
2466      predicate = expected_err_re_or_predicate
2467    else:
2468
2469      def predicate(e):
2470        err_str = e.message if isinstance(e, errors.OpError) else str(e)
2471        op = e.op if isinstance(e, errors.OpError) else None
2472        while op is not None:
2473          err_str += "\nCaused by: " + op.name
2474          op = op._original_op  # pylint: disable=protected-access
2475        logging.info("Searching within error strings: '%s' within '%s'",
2476                     expected_err_re_or_predicate, err_str)
2477        return re.search(expected_err_re_or_predicate, err_str)
2478
2479    try:
2480      yield
2481      self.fail(exception_type.__name__ + " not raised")
2482    except Exception as e:  # pylint: disable=broad-except
2483      if not isinstance(e, exception_type) or not predicate(e):
2484        raise AssertionError(
2485            "Exception of type %s: %s" % (str(type(e)), str(e)))
2486
2487  # pylint: enable=g-doc-return-or-yield
2488
2489  def assertRaisesOpError(self, expected_err_re_or_predicate):
2490    return self.assertRaisesWithPredicateMatch(errors.OpError,
2491                                               expected_err_re_or_predicate)
2492
2493  def assertShapeEqual(self, np_array, tf_tensor, msg=None):
2494    """Asserts that a Numpy ndarray and a TensorFlow tensor have the same shape.
2495
2496    Args:
2497      np_array: A Numpy ndarray or Numpy scalar.
2498      tf_tensor: A Tensor.
2499      msg: Optional message to report on failure.
2500
2501    Raises:
2502      TypeError: If the arguments have the wrong type.
2503    """
2504    if not isinstance(np_array, (np.ndarray, np.generic)):
2505      raise TypeError("np_array must be a Numpy ndarray or Numpy scalar")
2506    if not isinstance(tf_tensor, ops.Tensor):
2507      raise TypeError("tf_tensor must be a Tensor")
2508    self.assertAllEqual(
2509        np_array.shape, tf_tensor.get_shape().as_list(), msg=msg)
2510
2511  def assertDeviceEqual(self, device1, device2, msg=None):
2512    """Asserts that the two given devices are the same.
2513
2514    Args:
2515      device1: A string device name or TensorFlow `DeviceSpec` object.
2516      device2: A string device name or TensorFlow `DeviceSpec` object.
2517      msg: Optional message to report on failure.
2518    """
2519    device1 = pydev.canonical_name(device1)
2520    device2 = pydev.canonical_name(device2)
2521    self.assertEqual(
2522        device1, device2,
2523        "Devices %s and %s are not equal. %s" % (device1, device2, msg))
2524
2525  # Fix Python 3 compatibility issues
2526  if six.PY3:
2527    # pylint: disable=invalid-name
2528
2529    # Silence a deprecation warning
2530    assertRaisesRegexp = googletest.TestCase.assertRaisesRegex
2531
2532    # assertItemsEqual is assertCountEqual as of 3.2.
2533    assertItemsEqual = googletest.TestCase.assertCountEqual
2534
2535    # pylint: enable=invalid-name
2536
2537  @contextlib.contextmanager
2538  def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
2539    """Set the session and its graph to global default and constrain devices."""
2540    if context.executing_eagerly():
2541      yield None
2542    else:
2543      with sess.graph.as_default(), sess.as_default():
2544        if force_gpu:
2545          # Use the name of an actual device if one is detected, or
2546          # '/device:GPU:0' otherwise
2547          gpu_name = gpu_device_name()
2548          if not gpu_name:
2549            gpu_name = "/device:GPU:0"
2550          with sess.graph.device(gpu_name):
2551            yield sess
2552        elif use_gpu:
2553          yield sess
2554        else:
2555          with sess.graph.device("/device:CPU:0"):
2556            yield sess
2557
2558  def _create_session(self, graph, config, force_gpu):
2559    """See session() for details."""
2560
2561    def prepare_config(config):
2562      """Returns a config for sessions.
2563
2564      Args:
2565        config: An optional config_pb2.ConfigProto to use to configure the
2566          session.
2567
2568      Returns:
2569        A config_pb2.ConfigProto object.
2570      """
2571      # TODO(b/114333779): Enforce allow_soft_placement=False when
2572      # use_gpu=False. Currently many tests rely on the fact that any device
2573      # will be used even when a specific device is supposed to be used.
2574      allow_soft_placement = not force_gpu
2575      if config is None:
2576        config = config_pb2.ConfigProto()
2577        config.allow_soft_placement = allow_soft_placement
2578        config.gpu_options.per_process_gpu_memory_fraction = 0.3
2579      elif not allow_soft_placement and config.allow_soft_placement:
2580        config_copy = config_pb2.ConfigProto()
2581        config_copy.CopyFrom(config)
2582        config = config_copy
2583        config.allow_soft_placement = False
2584      # Don't perform optimizations for tests so we don't inadvertently run
2585      # gpu ops on cpu
2586      config.graph_options.optimizer_options.opt_level = -1
2587      # Disable Grappler constant folding since some tests & benchmarks
2588      # use constant input and become meaningless after constant folding.
2589      # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
2590      # GRAPPLER TEAM.
2591      config.graph_options.rewrite_options.constant_folding = (
2592          rewriter_config_pb2.RewriterConfig.OFF)
2593      config.graph_options.rewrite_options.pin_to_host_optimization = (
2594          rewriter_config_pb2.RewriterConfig.OFF)
2595      return config
2596
2597    return ErrorLoggingSession(graph=graph, config=prepare_config(config))
2598
2599  def _get_cached_session(self,
2600                          graph=None,
2601                          config=None,
2602                          force_gpu=False,
2603                          crash_if_inconsistent_args=True):
2604    """See cached_session() for documentation."""
2605    if self._cached_session is None:
2606      sess = self._create_session(
2607          graph=graph, config=config, force_gpu=force_gpu)
2608      self._cached_session = sess
2609      self._cached_graph = graph
2610      self._cached_config = config
2611      self._cached_force_gpu = force_gpu
2612      return sess
2613    else:
2614      if crash_if_inconsistent_args and self._cached_graph is not graph:
2615        raise ValueError("The graph used to get the cached session is "
2616                         "different than the one that was used to create the "
2617                         "session. Maybe create a new session with "
2618                         "self.session()")
2619      if crash_if_inconsistent_args and self._cached_config is not config:
2620        raise ValueError("The config used to get the cached session is "
2621                         "different than the one that was used to create the "
2622                         "session. Maybe create a new session with "
2623                         "self.session()")
2624      if crash_if_inconsistent_args and (self._cached_force_gpu is
2625                                         not force_gpu):
2626        raise ValueError(
2627            "The force_gpu value used to get the cached session is "
2628            "different than the one that was used to create the "
2629            "session. Maybe create a new session with "
2630            "self.session()")
2631      return self._cached_session
2632
2633
2634@tf_export("test.create_local_cluster")
2635def create_local_cluster(num_workers,
2636                         num_ps,
2637                         protocol="grpc",
2638                         worker_config=None,
2639                         ps_config=None):
2640  """Create and start local servers and return the associated `Server` objects.
2641
2642  Example:
2643  ```python
2644  workers, _ = tf.test.create_local_cluster(num_workers=2, num_ps=2)
2645
2646  worker_sessions = [tf.Session(w.target) for w in workers]
2647
2648  with tf.device("/job:ps/task:0"):
2649    ...
2650  with tf.device("/job:ps/task:1"):
2651    ...
2652  with tf.device("/job:worker/task:0"):
2653    ...
2654  with tf.device("/job:worker/task:1"):
2655    ...
2656
2657  worker_sessions[0].run(...)
2658  ```
2659
2660  Args:
2661    num_workers: Number of worker servers to start.
2662    num_ps: Number of PS servers to start.
2663    protocol: Communication protocol.  Allowed values are documented in the
2664      documentation of `tf.train.Server`.
2665    worker_config: (optional) ConfigProto to initialize workers. Can be used to
2666      instantiate multiple devices etc.
2667    ps_config: (optional) ConfigProto to initialize PS servers.
2668
2669  Returns:
2670    A tuple `(worker_servers, ps_servers)`.  `worker_servers` is a list
2671    of `num_workers` objects of type `tf.train.Server` (all running locally);
2672    and `ps_servers` is a list of `num_ps` objects of similar type.
2673
2674  Raises:
2675    ImportError: if portpicker module was not found at load time
2676  """
2677  if _portpicker_import_error:
2678    raise _portpicker_import_error  # pylint: disable=raising-bad-type
2679  worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
2680  ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
2681  cluster_dict = {
2682      "worker": ["localhost:%s" % port for port in worker_ports],
2683      "ps": ["localhost:%s" % port for port in ps_ports]
2684  }
2685  cs = server_lib.ClusterSpec(cluster_dict)
2686
2687  workers = [
2688      server_lib.Server(
2689          cs,
2690          job_name="worker",
2691          protocol=protocol,
2692          task_index=ix,
2693          config=worker_config,
2694          start=True) for ix in range(num_workers)
2695  ]
2696  ps_servers = [
2697      server_lib.Server(
2698          cs,
2699          job_name="ps",
2700          protocol=protocol,
2701          task_index=ix,
2702          config=ps_config,
2703          start=True) for ix in range(num_ps)
2704  ]
2705
2706  return workers, ps_servers
2707
2708
2709def get_node_def_from_graph(node_name, graph_def):
2710  """Returns the `NodeDef` instance for given node name in the graph def.
2711
2712  This method explores only the NodeDefs in `graph_def.node`.
2713
2714  Args:
2715    node_name: Name of the NodeDef to search for.
2716    graph_def: An instance of `GraphDef` proto.
2717
2718  Returns:
2719    the `NodeDef` instance whose name field matches the given node_name or None.
2720  """
2721  for node_def in graph_def.node:
2722    if node_def.name == node_name:
2723      return node_def
2724  return None
2725
2726
2727def set_producer_version(graph, producer_version):
2728  """Sets graph.graph_def_versions.producer to `producer_version`."""
2729  # The C API doesn't expose altering GraphDefVersions. We can indirectly set
2730  # it via import_graph_def though.
2731  graph_def = graph_pb2.GraphDef()
2732  graph_def.versions.producer = producer_version
2733  with graph.as_default():
2734    importer.import_graph_def(graph_def)
2735  assert graph.graph_def_versions.producer, producer_version
2736