1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for py_func op."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import gc
23import re
24
25import numpy as np
26from six.moves import queue
27from six.moves import xrange  # pylint: disable=redefined-builtin
28
29from tensorflow.python.client import session as session_lib
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import context
32from tensorflow.python.eager import def_function
33from tensorflow.python.eager import function
34from tensorflow.python.framework import constant_op
35from tensorflow.python.framework import dtypes
36from tensorflow.python.framework import errors
37from tensorflow.python.framework import ops
38from tensorflow.python.framework import test_util
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import batch_ops
41from tensorflow.python.ops import gradients_impl
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import resource_variable_ops
44from tensorflow.python.ops import script_ops
45from tensorflow.python.platform import test
46
47
48def np_func(x, y):
49  return np.sinh(x) + np.cosh(y)
50
51
52def matmul(x, y):
53  return math_ops.matmul(x, y)
54
55
56class PyFuncTestBase(test.TestCase):
57
58  def verifyExceptionHandling(self, py_exp, tf_exp, eager=False):
59
60    def inner_exception():
61      raise py_exp("blah")  # pylint: disable=not-callable
62
63    def raise_exception():
64      inner_exception()
65
66    expected_regexp = r": blah.*"               # Error at the top
67    expected_regexp += r"in raise_exception.*"  # Stacktrace outer
68    expected_regexp += r"in inner_exception.*"  # Stacktrace inner
69    expected_regexp += r": blah"                # Stacktrace of raise
70    def expected_error_check(exception):
71      return re.search(expected_regexp, str(exception), re.DOTALL)
72
73    if eager:
74      if context.executing_eagerly():
75        with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
76          f = script_ops.eager_py_func(raise_exception, [], [])
77        return
78      else:
79        f = script_ops.eager_py_func(raise_exception, [], [])
80    else:
81      f = script_ops.py_func(raise_exception, [], [])
82
83    with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
84      self.evaluate(f)
85
86
87class PyFuncTest(PyFuncTestBase):
88  """Encapsulates tests for py_func only."""
89
90  def testRealDataTypes(self):
91    def sum_func(x, y):
92      return x + y
93    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
94                  dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
95                  dtypes.int32, dtypes.int64]:
96      with self.cached_session():
97        x = constant_op.constant(1, dtype=dtype)
98        y = constant_op.constant(2, dtype=dtype)
99        z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
100        self.assertEqual(z, 3)
101
102  def testComplexDataTypes(self):
103    def sub_func(x, y):
104      return x - y
105    for dtype in [dtypes.complex64, dtypes.complex128]:
106      with self.cached_session():
107        x = constant_op.constant(1 + 1j, dtype=dtype)
108        y = constant_op.constant(2 - 2j, dtype=dtype)
109        z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
110        self.assertEqual(z, -1 + 3j)
111
112  def testBoolDataTypes(self):
113    def and_func(x, y):
114      return x and y
115    dtype = dtypes.bool
116    with self.cached_session():
117      x = constant_op.constant(True, dtype=dtype)
118      y = constant_op.constant(False, dtype=dtype)
119      z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
120      self.assertEqual(z, False)
121
122  def testSingleType(self):
123    with self.cached_session():
124      x = constant_op.constant(1.0, dtypes.float32)
125      y = constant_op.constant(2.0, dtypes.float32)
126      z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
127      self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
128
129  def testScalar(self):
130    with self.cached_session():
131      x = constant_op.constant(1.0, dtypes.float32)
132      y = constant_op.constant(2.0, dtypes.float32)
133      z = self.evaluate(
134          script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
135      self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
136
137  @test_util.run_v1_only("b/120545219")
138  def testArray(self):
139    with self.cached_session():
140      x = constant_op.constant([1.0, 2.0], dtypes.float64)
141      y = constant_op.constant([2.0, 3.0], dtypes.float64)
142      z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
143      self.assertAllEqual(z[0],
144                          np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
145
146  def testComplexType(self):
147    with self.cached_session():
148      x = constant_op.constant(1 + 2j, dtypes.complex64)
149      y = constant_op.constant(3 + 4j, dtypes.complex64)
150      z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
151      self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
152
153  def testRFFT(self):
154    with self.cached_session():
155      x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
156
157      def rfft(x):
158        return np.fft.rfft(x).astype(np.complex64)
159
160      y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
161      self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
162
163  def testPythonLiteral(self):
164    with self.cached_session():
165
166      def literal(x):
167        return 1.0 if float(x) == 0.0 else 0.0
168
169      x = constant_op.constant(0.0, dtypes.float64)
170      y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
171      self.assertAllClose(y, 1.0)
172
173  def testList(self):
174    with self.cached_session():
175
176      def list_func(x):
177        return [x, x + 1]
178
179      x = constant_op.constant(0.0, dtypes.float64)
180      y = self.evaluate(
181          script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
182      self.assertAllClose(y, [0.0, 1.0])
183
184  def testTuple(self):
185    # returns a tuple
186    with self.cached_session():
187
188      def tuple_func(x):
189        return x, x + 1
190
191      x = constant_op.constant(0.0, dtypes.float64)
192      y = self.evaluate(
193          script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
194      self.assertAllClose(y, [0.0, 1.0])
195
196    # returns a tuple, Tout and inp a tuple
197    with self.cached_session():
198      x = constant_op.constant(0.0, dtypes.float64)
199      y = self.evaluate(
200          script_ops.py_func(tuple_func, (x,),
201                             (dtypes.float64, dtypes.float64)))
202      self.assertAllClose(y, [0.0, 1.0])
203
204  @test_util.run_v1_only("b/120545219")
205  def testStrings(self):
206
207    def read_fixed_length_numpy_strings():
208      return np.array([b" there"])
209
210    def read_and_return_strings(x, y):
211      return x + y
212
213    with self.cached_session():
214      x = constant_op.constant([b"hello", b"hi"], dtypes.string)
215      y = self.evaluate(
216          script_ops.py_func(read_fixed_length_numpy_strings, [],
217                             dtypes.string))
218      z = self.evaluate(
219          script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
220      self.assertAllEqual(z, [b"hello there", b"hi there"])
221
222  @test_util.run_v1_only("b/120545219")
223  def testStringsAreConvertedToBytes(self):
224
225    def read_fixed_length_numpy_strings():
226      return np.array([" there"])
227
228    def read_and_return_strings(x, y):
229      return x + y
230
231    with self.cached_session():
232      x = constant_op.constant(["hello", "hi"], dtypes.string)
233      y = self.evaluate(
234          script_ops.py_func(read_fixed_length_numpy_strings, [],
235                             dtypes.string))
236      z = self.evaluate(
237          script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
238      self.assertAllEqual(z, [b"hello there", b"hi there"])
239
240  @test_util.run_v1_only("b/120545219")
241  def testObjectArraysAreConvertedToBytes(self):
242
243    def read_object_array():
244      return np.array([b" there", u" ya"], dtype=np.object)
245
246    def read_and_return_strings(x, y):
247      return x + y
248
249    with self.cached_session():
250      x = constant_op.constant(["hello", "hi"], dtypes.string)
251      y, = script_ops.py_func(read_object_array, [],
252                              [dtypes.string])
253      z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
254      self.assertListEqual(list(self.evaluate(z)), [b"hello there", b"hi ya"])
255
256  @test_util.run_v1_only("b/120545219")
257  def testStringPadding(self):
258    correct = [b"this", b"is", b"a", b"test"]
259    with self.cached_session():
260      s, = script_ops.py_func(lambda: [correct], [], [dtypes.string])
261      self.assertAllEqual(s, correct)
262
263  @test_util.run_v1_only("b/120545219")
264  def testStringPaddingAreConvertedToBytes(self):
265    inp = ["this", "is", "a", "test"]
266    correct = [b"this", b"is", b"a", b"test"]
267    with self.cached_session():
268      s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
269      self.assertAllEqual(s, correct)
270
271  @test_util.run_v1_only("b/120545219")
272  def testNulTerminatedStrings(self):
273    inp = np.array(["this\0", "is\0\0", "a\0", "test\0\0"], dtype=np.str_)
274    correct = [b"this", b"is", b"a", b"test"]
275    with self.cached_session():
276      s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
277      self.assertAllEqual(s, correct)
278
279  @test_util.run_v1_only("b/120545219")
280  def testLarge(self):
281    with self.cached_session() as sess:
282      x = array_ops.zeros([1000000], dtype=np.float32)
283      y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32])
284      z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32])
285      for _ in xrange(100):
286        sess.run([y[0].op, z[0].op])
287
288  def testNoInput(self):
289    with self.cached_session():
290      x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
291      self.assertAllClose(x, 42.0)
292
293  @test_util.run_v1_only("b/120545219")
294  def testAlias(self):
295    with self.cached_session():
296      np_array = np.array([1.0, 2.0], dtype=np.float32)
297      tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
298      value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
299      value.op.run()
300      self.assertAllEqual(np_array, [1.0, 2.0])
301
302  @test_util.run_v1_only("b/120545219")
303  def testReturnUnicodeString(self):
304    with self.cached_session():
305      correct = u"你好 世界"
306
307      def unicode_string():
308        return correct
309
310      z, = script_ops.py_func(unicode_string, [], [dtypes.string])
311      self.assertEqual(self.evaluate(z), correct.encode("utf8"))
312
313  @test_util.run_v1_only("b/120545219")
314  def testBadNumpyReturnType(self):
315    with self.cached_session():
316
317      def bad():
318        # Structured numpy arrays aren't supported.
319        return np.array([], dtype=[("foo", np.float32)])
320
321      y, = script_ops.py_func(bad, [], [dtypes.float32])
322
323      with self.assertRaisesRegex(errors.InternalError,
324                                  "Unsupported numpy data type"):
325        self.evaluate(y)
326
327  @test_util.run_v1_only("b/120545219")
328  def testBadReturnType(self):
329    with self.cached_session():
330
331      def bad():
332        # Non-string python objects aren't supported.
333        return {"foo": dtypes.float32}
334
335      z, = script_ops.py_func(bad, [], [dtypes.int64])
336
337      with self.assertRaisesRegex(errors.InternalError,
338                                  "Unsupported object type"):
339        self.evaluate(z)
340
341  @test_util.run_v1_only("b/120545219")
342  def testReturnInput(self):
343    with self.cached_session():
344
345      def ident(x):
346        return x[0]
347
348      p = array_ops.placeholder(dtypes.float32)
349
350      # Create a numpy array aliasing a tensor and a tensor aliasing this array
351      z, = script_ops.py_func(ident, [p], [dtypes.float32])
352      z += 0.0  # Makes sure we release the tensor aliasing the numpy array x[0]
353      # above instead of using its memory as the return value of
354      # session.run
355      self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
356
357  def testStateful(self):
358    # Not using self.cached_session(), which disables optimization.
359    with session_lib.Session():
360      producer = iter(range(3))
361      x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64])
362      self.assertEqual(self.evaluate(x), 0)
363      self.assertEqual(self.evaluate(x), 1)
364      self.assertEqual(self.evaluate(x), 2)
365
366  @test_util.enable_tf_xla_constant_folding("b/134376434")
367  def testStateless(self):
368    # Not using self.cached_session(), which disables optimization.
369    with session_lib.Session():
370      producer = iter(range(3))
371      x, = script_ops.py_func(
372          lambda: next(producer), [], [dtypes.int64], stateful=False)
373      self.assertEqual(self.evaluate(x), 0)
374      self.assertEqual(self.evaluate(x), 0)
375      self.assertEqual(self.evaluate(x), 0)
376
377  @test_util.run_v1_only("b/120545219")
378  def testGradientFunction(self):
379    # Input to tf.compat.v1.py_func is necessary,
380    # otherwise get_gradient_function() returns None per default.
381    a = constant_op.constant(0)
382    x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64])
383    y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False)
384    self.assertEqual(None, ops.get_gradient_function(x.op))
385    self.assertEqual(None, ops.get_gradient_function(y.op))
386
387  @test_util.run_v1_only("b/120545219")
388  def testCOrder(self):
389    with self.cached_session():
390      val = [[1, 2], [3, 4]]
391      x, = script_ops.py_func(lambda: np.array(val, order="F"), [],
392                              [dtypes.int64])
393      self.assertAllEqual(val, self.evaluate(x))
394
395  @test_util.run_v1_only("b/120545219")
396  def testParallel(self):
397    # Tests that tf.compat.v1.py_func's can run in parallel if they release
398    # the GIL.
399    with self.cached_session() as session:
400      q = queue.Queue(1)
401
402      def blocking_put():
403        q.put(42)
404        q.join()  # Wait for task_done().
405        return 42
406
407      def blocking_get():
408        v = q.get(block=True)  # Wait for put().
409        q.task_done()
410        return v
411
412      x, = script_ops.py_func(blocking_put, [], [dtypes.int64])
413      y, = script_ops.py_func(blocking_get, [], [dtypes.int64])
414
415      # This will result in a deadlock if the py_func's don't run in parallel.
416      session.run([x, y])
417
418  def testNoReturnValueStateful(self):
419
420    class State(object):
421
422      def __init__(self):
423        self._value = np.array([1], np.int64)
424
425      def _increment(self, diff):
426        self._value += diff
427
428      def increment(self, diff):
429        return script_ops.py_func(self._increment, [diff], [], stateful=True)
430
431      @property
432      def value(self):
433        return self._value
434
435    with self.cached_session():
436      s = State()
437      op = s.increment(constant_op.constant(2, dtypes.int64))
438      ret = self.evaluate(op)
439      self.assertIsNone(ret)
440      self.assertAllEqual([3], s.value)
441
442  @test_util.run_v1_only("b/120545219")
443  def testNoReturnValueStateless(self):
444
445    def do_nothing(unused_x):
446      pass
447
448    f = script_ops.py_func(
449        do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False)
450    with self.cached_session():
451      self.assertEqual(self.evaluate(f), [])
452
453  @test_util.run_v1_only("b/120545219")
454  def testExceptionHandling(self):
455    with self.cached_session():
456      self.verifyExceptionHandling(ValueError, errors.InvalidArgumentError)
457      self.verifyExceptionHandling(TypeError, errors.InvalidArgumentError)
458      self.verifyExceptionHandling(StopIteration, errors.OutOfRangeError)
459      self.verifyExceptionHandling(MemoryError, errors.ResourceExhaustedError)
460      self.verifyExceptionHandling(NotImplementedError,
461                                   errors.UnimplementedError)
462
463      class WeirdError(Exception):
464        pass
465
466      self.verifyExceptionHandling(WeirdError, errors.UnknownError)
467
468  def testFunctionReferencesAreKept(self):
469    g = ops.Graph()
470    with g.as_default():
471      c = constant_op.constant([1.], dtypes.float32)
472      @batch_ops.batch_function(1, 10, 100000)
473      def fn(x):
474        # Upon exiting this function, the py_func holds the sole reference
475        # to this lambda, without which it would be garbage collected.
476        return script_ops.py_func(lambda x: x, [x], [dtypes.float32])
477      result = fn(c)
478      gc.collect()
479      self.evaluate(result)
480
481
482class PyFuncAndEagerPyFuncTest(PyFuncTestBase):
483  """Encapsulates tests shared between py_func and eager_py_func."""
484
485  def verifyPyFuncsNoIncrease(self, make_graph):
486    ops.reset_default_graph()
487    gc.collect()
488    initial_size = script_ops._py_funcs.size()
489
490    for _ in xrange(1000):
491      make_graph()
492
493    ops.reset_default_graph()
494    gc.collect()
495    self.assertEqual(initial_size, script_ops._py_funcs.size())
496
497  def testCleanup(self):
498
499    def make_graph():
500      g = ops.Graph()
501      with g.as_default():
502        c = constant_op.constant([1.], dtypes.float32)
503        _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
504        _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
505        # These ops have a reference to 'c' which has a reference to the
506        # graph.
507        # Checks if the functions are being deleted though the graph is
508        # referenced from them (see #18292).
509        script_ops.py_func(
510            lambda x: x + c.shape[0], [c], [dtypes.float32])
511        script_ops.eager_py_func(
512            lambda x: x + c.shape[0], [c], [dtypes.float32])
513
514    self.verifyPyFuncsNoIncrease(make_graph)
515
516  def testCleanupInTfFunction(self):
517
518    self.skipTest("b/144098211")
519
520    def make_graph():
521      g = ops.Graph()
522      with g.as_default():
523        @def_function.function
524        def fn():
525          c = constant_op.constant([1.], dtypes.float32)
526          _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
527          _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
528          # These ops have a reference to 'c' which has a reference to the
529          # graph.
530          # Checks if the functions are being deleted though the graph is
531          # referenced from them (see #18292).
532          script_ops.py_func(
533              lambda x: x + c.shape[0], [c], [dtypes.float32])
534          script_ops.eager_py_func(
535              lambda x: x + c.shape[0], [c], [dtypes.float32])
536        fn()
537
538    self.verifyPyFuncsNoIncrease(make_graph)
539
540
541class EagerPyFuncTest(PyFuncTestBase):
542  """Encapsulates tests for eager_py_func only."""
543
544  @test_util.run_in_graph_and_eager_modes
545  def testEagerSingleOutputInt32(self):
546    a = array_ops.ones((3, 3), dtype=dtypes.int32)
547    x = array_ops.ones((3, 1), dtype=dtypes.int32)
548    output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
549    ret = self.evaluate(output)
550    self.assertAllEqual(ret, [[3], [3], [3]])
551
552  @test_util.run_in_graph_and_eager_modes
553  def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self):
554    if context.executing_eagerly():
555      self.skipTest("b/126565353: We don't test eager's remote execution.")
556
557    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
558    worker = workers[0]
559    session = session_lib.Session(worker.target)
560    with ops.device("/job:worker/task:0/cpu:0"):
561      a = array_ops.ones((3, 3), dtype=dtypes.float32)
562      x = array_ops.ones((3, 1), dtype=dtypes.float32)
563      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
564    ret = session.run(output)
565    self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
566
567  @test_util.run_in_graph_and_eager_modes
568  def testEagerSingleOutputFloat32(self):
569    with test_util.device(use_gpu=True):
570      a = array_ops.ones((3, 3), dtype=dtypes.float32)
571      x = array_ops.ones((3, 1), dtype=dtypes.float32)
572      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
573      ret = self.evaluate(output)
574      self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
575
576  @test_util.run_in_graph_and_eager_modes
577  def testEagerArrayOutput(self):
578    with test_util.device(use_gpu=True):
579      a = array_ops.ones((3, 3), dtype=dtypes.float32)
580      x = array_ops.ones((3, 1), dtype=dtypes.float32)
581      output = script_ops.eager_py_func(
582          lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32])
583      ret = self.evaluate(output)
584      self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]])
585
586  @test_util.run_in_graph_and_eager_modes
587  def testEagerReturnNone(self):
588    with test_util.device(use_gpu=True):
589      def no_return_value():
590        return
591
592      output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
593      ret = self.evaluate(output)
594      if context.executing_eagerly():
595        self.assertEqual(len(ret), 0)
596      else:
597        self.assertIsNone(ret)
598
599  @test_util.run_in_graph_and_eager_modes
600  @test_util.disable_tfrt("b/180469928")
601  def testEagerPyFuncInDefun(self):
602    with test_util.device(use_gpu=True):
603      def wrapper():
604        a = array_ops.ones((3, 3), dtype=dtypes.float32)
605        x = array_ops.ones((3, 1), dtype=dtypes.float32)
606        return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
607
608      wrapped = function.defun(wrapper)
609      ret = self.evaluate(wrapped())
610      self.assertAllEqual(ret, [[3.0], [3.0], [3.0]])
611
612  @test_util.run_in_graph_and_eager_modes
613  @test_util.run_v1_only("b/120545219")
614  def testEagerExceptionHandling(self):
615    with test_util.device(use_gpu=True):
616      self.verifyExceptionHandling(
617          ValueError, errors.InvalidArgumentError, eager=True)
618      self.verifyExceptionHandling(
619          TypeError, errors.InvalidArgumentError, eager=True)
620      self.verifyExceptionHandling(
621          StopIteration, errors.OutOfRangeError, eager=True)
622      self.verifyExceptionHandling(
623          MemoryError, errors.ResourceExhaustedError, eager=True)
624      self.verifyExceptionHandling(
625          NotImplementedError, errors.UnimplementedError, eager=True)
626
627      class WeirdError(Exception):
628        pass
629
630      self.verifyExceptionHandling(WeirdError, errors.UnknownError, eager=True)
631
632  @test_util.run_in_graph_and_eager_modes
633  @test_util.run_v1_only("b/120545219")
634  def testEagerReturningVariableRaisesError(self):
635    def return_variable():
636      return resource_variable_ops.ResourceVariable(0.0)
637
638    with self.assertRaisesRegex(errors.UnknownError,
639                                "Attempting to return a variable"):
640      output = script_ops.eager_py_func(
641          return_variable, inp=[], Tout=dtypes.float32)
642      self.evaluate(output)
643
644  @test_util.run_in_graph_and_eager_modes
645  def testEagerGradientTape(self):
646
647    def f(x):
648      return x**2
649
650    x = constant_op.constant(3.0)
651    with backprop.GradientTape() as tape:
652      tape.watch(x)
653      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
654    dy_dx = tape.gradient(y, x)
655    self.assertAllClose(self.evaluate(dy_dx), 6.0)
656
657    # Test complex values
658    x = constant_op.constant(3.0 + 3.0j)
659    with backprop.GradientTape() as tape:
660      tape.watch(x)
661      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.complex128)
662    dy_dx = tape.gradient(y, x)
663    # Gradient of complex will be the conj
664    self.assertAllClose(self.evaluate(dy_dx), 6.0 - 6.0j)
665
666  @test_util.run_v1_only("b/120545219")
667  def testEagerGradientGraph(self):
668
669    def f(x):
670      return x**2
671
672    x = constant_op.constant(3.0)
673    y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
674    dy_dx = gradients_impl.gradients(y, x)[0]
675    self.assertEqual(self.evaluate(dy_dx), 6.0)
676
677  @test_util.run_v1_only("b/120545219")
678  def testEagerGradientGraphTwoOutputs(self):
679
680    def f(x, y):
681      return x * y, x / y
682
683    x = constant_op.constant(3.0)
684    y = constant_op.constant(2.0)
685    fa, fb = script_ops.eager_py_func(f, inp=[x, y],
686                                      Tout=[dtypes.float32, dtypes.float32])
687    dy_dx = gradients_impl.gradients(fa + fb, x)[0]
688    self.assertEqual(self.evaluate(dy_dx), 2.5)
689
690  @test_util.run_in_graph_and_eager_modes
691  def testEagerGradientTapeMultipleArgs(self):
692
693    def f(x, y):
694      return x**2 + y**2
695
696    x = constant_op.constant(3.0)
697    y = constant_op.constant(4.0)
698    with backprop.GradientTape() as tape:
699      tape.watch(x)
700      tape.watch(y)
701      z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
702
703    dz_dx, dz_dy = tape.gradient(z, [x, y])
704    self.assertEqual(self.evaluate(dz_dx), 6.0)
705    self.assertEqual(self.evaluate(dz_dy), 8.0)
706
707  @test_util.run_v1_only("b/120545219")
708  def testEagerGradientGraphMultipleArgs(self):
709
710    def f(x, y):
711      return x**2 + y**2
712
713    x = constant_op.constant(3.0)
714    y = constant_op.constant(4.0)
715    z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
716
717    dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
718    self.assertEqual(self.evaluate(dz_dx), 6.0)
719    self.assertEqual(self.evaluate(dz_dy), 8.0)
720
721  @test_util.run_v1_only("b/120545219")
722  def testEagerGradientGraphLogHuber(self):
723
724    def log_huber(x, m):
725      if math_ops.abs(x) <= m:
726        return x**2
727      else:
728        return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))
729
730    x = array_ops.placeholder(dtypes.float32)
731    m = array_ops.placeholder(dtypes.float32)
732
733    y = script_ops.eager_py_func(
734        func=log_huber, inp=[x, m], Tout=dtypes.float32)
735    dy_dx = gradients_impl.gradients(y, x)[0]
736
737    with self.cached_session() as sess:
738      # Takes the first branch of log_huber.
739      y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
740      self.assertEqual(y, 1.0)
741      self.assertEqual(dy_dx, 2.0)
742
743  @test_util.run_v1_only("b/120545219")
744  def testEagerRespectsDevicePlacementOfOp(self):
745
746    def f(x):
747      return math_ops.square(x)
748
749    def g(x):
750      return math_ops.add(x, x)
751
752    with ops.device("/CPU:0"):
753      # Explicitly ask for the py_funcs to execute on CPU, even if
754      # a GPU is available.
755      x = array_ops.placeholder(dtypes.float32)
756      y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32)
757      z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32)
758
759    with self.session() as sess:
760      output = sess.run(z, feed_dict={x: 3.0})
761      self.assertEqual(output, 18.0)
762
763  @test_util.run_in_graph_and_eager_modes
764  def testEagerPyFuncOnGPUWithStrings(self):
765
766    def fn(a):
767      return str(a.dtype)
768
769    x = constant_op.constant("x", dtype=dtypes.string)
770    output = script_ops.eager_py_func(fn, inp=[x], Tout=dtypes.string)
771    self.assertEqual(self.evaluate(output), "<dtype: 'string'>".encode("utf8"))
772
773  @test_util.run_in_graph_and_eager_modes
774  def testEagerPyFuncNotACallable(self):
775    x = constant_op.constant("x", dtype=dtypes.string)
776
777    with self.assertRaisesRegex(ValueError, "callable"):
778      _ = script_ops.eager_py_func(x, inp=[x], Tout=dtypes.string)
779
780
781if __name__ == "__main__":
782  test.main()
783