1# -*- coding: utf-8 -*-
2# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Tests for py_func op."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import gc
23import re
24
25import numpy as np
26from six.moves import queue
27from six.moves import xrange  # pylint: disable=redefined-builtin
28
29from tensorflow.python.client import session as session_lib
30from tensorflow.python.eager import backprop
31from tensorflow.python.eager import context
32from tensorflow.python.eager import function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import errors
36from tensorflow.python.framework import ops
37from tensorflow.python.framework import test_util
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import gradients_impl
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import resource_variable_ops
42from tensorflow.python.ops import script_ops
43from tensorflow.python.platform import test
44
45
46def np_func(x, y):
47  return np.sinh(x) + np.cosh(y)
48
49
50def matmul(x, y):
51  return math_ops.matmul(x, y)
52
53
54class PyFuncTest(test.TestCase):
55  """Encapsulates tests for py_func and eager_py_func."""
56
57  # ----- Tests for py_func -----
58  def testRealDataTypes(self):
59    def sum_func(x, y):
60      return x + y
61    for dtype in [dtypes.float16, dtypes.float32, dtypes.float64,
62                  dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16,
63                  dtypes.int32, dtypes.int64]:
64      with self.cached_session():
65        x = constant_op.constant(1, dtype=dtype)
66        y = constant_op.constant(2, dtype=dtype)
67        z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype))
68        self.assertEqual(z, 3)
69
70  def testComplexDataTypes(self):
71    def sub_func(x, y):
72      return x - y
73    for dtype in [dtypes.complex64, dtypes.complex128]:
74      with self.cached_session():
75        x = constant_op.constant(1 + 1j, dtype=dtype)
76        y = constant_op.constant(2 - 2j, dtype=dtype)
77        z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype))
78        self.assertEqual(z, -1 + 3j)
79
80  def testBoolDataTypes(self):
81    def and_func(x, y):
82      return x and y
83    dtype = dtypes.bool
84    with self.cached_session():
85      x = constant_op.constant(True, dtype=dtype)
86      y = constant_op.constant(False, dtype=dtype)
87      z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype))
88      self.assertEqual(z, False)
89
90  def testSingleType(self):
91    with self.cached_session():
92      x = constant_op.constant(1.0, dtypes.float32)
93      y = constant_op.constant(2.0, dtypes.float32)
94      z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32))
95      self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32))
96
97  def testScalar(self):
98    with self.cached_session():
99      x = constant_op.constant(1.0, dtypes.float32)
100      y = constant_op.constant(2.0, dtypes.float32)
101      z = self.evaluate(
102          script_ops.eager_py_func(np_func, [x, y], [dtypes.float32]))
103      self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32))
104
105  @test_util.run_v1_only("b/120545219")
106  def testArray(self):
107    with self.cached_session():
108      x = constant_op.constant([1.0, 2.0], dtypes.float64)
109      y = constant_op.constant([2.0, 3.0], dtypes.float64)
110      z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64]))
111      self.assertAllEqual(z[0],
112                          np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64))
113
114  def testComplexType(self):
115    with self.cached_session():
116      x = constant_op.constant(1 + 2j, dtypes.complex64)
117      y = constant_op.constant(3 + 4j, dtypes.complex64)
118      z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64))
119      self.assertAllClose(z, np_func(1 + 2j, 3 + 4j))
120
121  def testRFFT(self):
122    with self.cached_session():
123      x = constant_op.constant([1., 2., 3., 4.], dtypes.float32)
124
125      def rfft(x):
126        return np.fft.rfft(x).astype(np.complex64)
127
128      y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64))
129      self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.]))
130
131  def testPythonLiteral(self):
132    with self.cached_session():
133
134      def literal(x):
135        return 1.0 if float(x) == 0.0 else 0.0
136
137      x = constant_op.constant(0.0, dtypes.float64)
138      y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64))
139      self.assertAllClose(y, 1.0)
140
141  def testList(self):
142    with self.cached_session():
143
144      def list_func(x):
145        return [x, x + 1]
146
147      x = constant_op.constant(0.0, dtypes.float64)
148      y = self.evaluate(
149          script_ops.py_func(list_func, [x], [dtypes.float64] * 2))
150      self.assertAllClose(y, [0.0, 1.0])
151
152  def testTuple(self):
153    # returns a tuple
154    with self.cached_session():
155
156      def tuple_func(x):
157        return x, x + 1
158
159      x = constant_op.constant(0.0, dtypes.float64)
160      y = self.evaluate(
161          script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2))
162      self.assertAllClose(y, [0.0, 1.0])
163
164    # returns a tuple, Tout and inp a tuple
165    with self.cached_session():
166      x = constant_op.constant(0.0, dtypes.float64)
167      y = self.evaluate(
168          script_ops.py_func(tuple_func, (x,),
169                             (dtypes.float64, dtypes.float64)))
170      self.assertAllClose(y, [0.0, 1.0])
171
172  @test_util.run_v1_only("b/120545219")
173  def testStrings(self):
174
175    def read_fixed_length_numpy_strings():
176      return np.array([b" there"])
177
178    def read_and_return_strings(x, y):
179      return x + y
180
181    with self.cached_session():
182      x = constant_op.constant([b"hello", b"hi"], dtypes.string)
183      y = self.evaluate(
184          script_ops.py_func(read_fixed_length_numpy_strings, [],
185                             dtypes.string))
186      z = self.evaluate(
187          script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
188      self.assertAllEqual(z, [b"hello there", b"hi there"])
189
190  @test_util.run_v1_only("b/120545219")
191  def testStringsAreConvertedToBytes(self):
192
193    def read_fixed_length_numpy_strings():
194      return np.array([" there"])
195
196    def read_and_return_strings(x, y):
197      return x + y
198
199    with self.cached_session():
200      x = constant_op.constant(["hello", "hi"], dtypes.string)
201      y = self.evaluate(
202          script_ops.py_func(read_fixed_length_numpy_strings, [],
203                             dtypes.string))
204      z = self.evaluate(
205          script_ops.py_func(read_and_return_strings, [x, y], dtypes.string))
206      self.assertAllEqual(z, [b"hello there", b"hi there"])
207
208  @test_util.run_v1_only("b/120545219")
209  def testObjectArraysAreConvertedToBytes(self):
210
211    def read_object_array():
212      return np.array([b" there", u" ya"], dtype=np.object)
213
214    def read_and_return_strings(x, y):
215      return x + y
216
217    with self.cached_session():
218      x = constant_op.constant(["hello", "hi"], dtypes.string)
219      y, = script_ops.py_func(read_object_array, [],
220                              [dtypes.string])
221      z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string])
222      self.assertListEqual(list(z.eval()), [b"hello there", b"hi ya"])
223
224  @test_util.run_v1_only("b/120545219")
225  def testStringPadding(self):
226    correct = [b"this", b"is", b"a", b"test"]
227    with self.cached_session():
228      s, = script_ops.py_func(lambda: [correct], [], [dtypes.string])
229      self.assertAllEqual(s.eval(), correct)
230
231  @test_util.run_v1_only("b/120545219")
232  def testStringPaddingAreConvertedToBytes(self):
233    inp = ["this", "is", "a", "test"]
234    correct = [b"this", b"is", b"a", b"test"]
235    with self.cached_session():
236      s, = script_ops.py_func(lambda: [inp], [], [dtypes.string])
237      self.assertAllEqual(s.eval(), correct)
238
239  @test_util.run_v1_only("b/120545219")
240  def testLarge(self):
241    with self.cached_session() as sess:
242      x = array_ops.zeros([1000000], dtype=np.float32)
243      y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32])
244      z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32])
245      for _ in xrange(100):
246        sess.run([y[0].op, z[0].op])
247
248  def testNoInput(self):
249    with self.cached_session():
250      x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64))
251      self.assertAllClose(x, 42.0)
252
253  @test_util.run_v1_only("b/120545219")
254  def testAlias(self):
255    with self.cached_session():
256      np_array = np.array([1.0, 2.0], dtype=np.float32)
257      tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32])
258      value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32)
259      value.op.run()
260      self.assertAllEqual(np_array, [1.0, 2.0])
261
262  @test_util.run_v1_only("b/120545219")
263  def testReturnUnicodeString(self):
264    with self.cached_session():
265      correct = u"你好 世界"
266
267      def unicode_string():
268        return correct
269
270      z, = script_ops.py_func(unicode_string, [], [dtypes.string])
271      self.assertEqual(z.eval(), correct.encode("utf8"))
272
273  @test_util.run_v1_only("b/120545219")
274  def testBadNumpyReturnType(self):
275    with self.cached_session():
276
277      def bad():
278        # Structured numpy arrays aren't supported.
279        return np.array([], dtype=[("foo", np.float32)])
280
281      y, = script_ops.py_func(bad, [], [dtypes.float32])
282
283      with self.assertRaisesRegexp(errors.UnimplementedError,
284                                   "Unsupported numpy type"):
285        self.evaluate(y)
286
287  @test_util.run_v1_only("b/120545219")
288  def testBadReturnType(self):
289    with self.cached_session():
290
291      def bad():
292        # Non-string python objects aren't supported.
293        return {"foo": dtypes.float32}
294
295      z, = script_ops.py_func(bad, [], [dtypes.int64])
296
297      with self.assertRaisesRegexp(errors.UnimplementedError,
298                                   "Unsupported object type"):
299        self.evaluate(z)
300
301  @test_util.run_v1_only("b/120545219")
302  def testReturnInput(self):
303    with self.cached_session():
304
305      def ident(x):
306        return x[0]
307
308      p = array_ops.placeholder(dtypes.float32)
309
310      # Create a numpy array aliasing a tensor and a tensor aliasing this array
311      z, = script_ops.py_func(ident, [p], [dtypes.float32])
312      z += 0.0  # Makes sure we release the tensor aliasing the numpy array x[0]
313      # above instead of using its memory as the return value of
314      # session.run
315      self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]}))
316
317  def testStateful(self):
318    # Not using self.cached_session(), which disables optimization.
319    with session_lib.Session() as sess:
320      producer = iter(range(3))
321      x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64])
322      self.assertEqual(self.evaluate(x), 0)
323      self.assertEqual(self.evaluate(x), 1)
324      self.assertEqual(self.evaluate(x), 2)
325
326  def testStateless(self):
327    # Not using self.cached_session(), which disables optimization.
328    with session_lib.Session() as sess:
329      producer = iter(range(3))
330      x, = script_ops.py_func(
331          lambda: next(producer), [], [dtypes.int64], stateful=False)
332      self.assertEqual(self.evaluate(x), 0)
333      self.assertEqual(self.evaluate(x), 0)
334      self.assertEqual(self.evaluate(x), 0)
335
336  @test_util.run_v1_only("b/120545219")
337  def testGradientFunction(self):
338    # Input to tf.py_func is necessary, otherwise get_gradient_function()
339    # returns None per default.
340    a = constant_op.constant(0)
341    x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64])
342    y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False)
343    self.assertEqual(None, ops.get_gradient_function(x.op))
344    self.assertEqual(None, ops.get_gradient_function(y.op))
345
346  @test_util.run_v1_only("b/120545219")
347  def testCOrder(self):
348    with self.cached_session():
349      val = [[1, 2], [3, 4]]
350      x, = script_ops.py_func(lambda: np.array(val, order="F"), [],
351                              [dtypes.int64])
352      self.assertAllEqual(val, self.evaluate(x))
353
354  @test_util.run_v1_only("b/120545219")
355  def testParallel(self):
356    # Tests that tf.py_func's can run in parallel if they release the GIL.
357    with self.cached_session() as session:
358      q = queue.Queue(1)
359
360      def blocking_put():
361        q.put(42)
362        q.join()  # Wait for task_done().
363        return 42
364
365      def blocking_get():
366        v = q.get(block=True)  # Wait for put().
367        q.task_done()
368        return v
369
370      x, = script_ops.py_func(blocking_put, [], [dtypes.int64])
371      y, = script_ops.py_func(blocking_get, [], [dtypes.int64])
372
373      # This will result in a deadlock if the py_func's don't run in parallel.
374      session.run([x, y])
375
376  def testNoReturnValueStateful(self):
377
378    class State(object):
379
380      def __init__(self):
381        self._value = np.array([1], np.int64)
382
383      def _increment(self, diff):
384        self._value += diff
385
386      def increment(self, diff):
387        return script_ops.py_func(self._increment, [diff], [], stateful=True)
388
389      @property
390      def value(self):
391        return self._value
392
393    with self.cached_session():
394      s = State()
395      op = s.increment(constant_op.constant(2, dtypes.int64))
396      ret = self.evaluate(op)
397      self.assertIsNone(ret)
398      self.assertAllEqual([3], s.value)
399
400  @test_util.run_v1_only("b/120545219")
401  def testNoReturnValueStateless(self):
402
403    def do_nothing(unused_x):
404      pass
405
406    f = script_ops.py_func(
407        do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False)
408    with self.cached_session() as sess:
409      self.assertEqual(self.evaluate(f), [])
410
411  def _testExceptionHandling(self, py_exp, tf_exp, eager=False):
412
413    def inner_exception():
414      raise py_exp("blah")  # pylint: disable=not-callable
415
416    def raise_exception():
417      inner_exception()
418
419    expected_regexp = r": blah.*"               # Error at the top
420    expected_regexp += r"in raise_exception.*"  # Stacktrace outer
421    expected_regexp += r"in inner_exception.*"  # Stacktrace inner
422    expected_regexp += r": blah"                # Stacktrace of raise
423    def expected_error_check(exception):
424      return re.search(expected_regexp, str(exception), re.DOTALL)
425
426    if eager:
427      if context.executing_eagerly():
428        with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
429          f = script_ops.eager_py_func(raise_exception, [], [])
430        return
431      else:
432        f = script_ops.eager_py_func(raise_exception, [], [])
433    else:
434      f = script_ops.py_func(raise_exception, [], [])
435
436    with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check):
437      self.evaluate(f)
438
439  @test_util.run_v1_only("b/120545219")
440  def testExceptionHandling(self):
441    with self.cached_session():
442      self._testExceptionHandling(ValueError, errors.InvalidArgumentError)
443      self._testExceptionHandling(TypeError, errors.InvalidArgumentError)
444      self._testExceptionHandling(StopIteration, errors.OutOfRangeError)
445      self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError)
446      self._testExceptionHandling(NotImplementedError,
447                                  errors.UnimplementedError)
448
449      class WeirdError(Exception):
450        pass
451
452      self._testExceptionHandling(WeirdError, errors.UnknownError)
453
454  # ----- Tests shared by py_func and eager_py_func -----
455  def testCleanup(self):
456    # Delete everything created by previous tests to avoid side effects.
457    ops.reset_default_graph()
458    gc.collect()
459    initial_size = script_ops._py_funcs.size()
460    # Encapsulate the graph generation, so locals can be deleted.
461    def make_graphs():
462      for _ in xrange(1000):
463        g = ops.Graph()
464        with g.as_default():
465          c = constant_op.constant([1.], dtypes.float32)
466          _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32])
467          _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32])
468          # These ops have a reference to 'c' which has a reference to the graph.
469          # Checks if the functions are being deleted though the graph is referenced from them.
470          # (see #18292)
471          _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
472          _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32])
473
474    # Call garbage collector to enforce deletion.
475    make_graphs()
476    ops.reset_default_graph()
477    gc.collect()
478    self.assertEqual(initial_size, script_ops._py_funcs.size())
479
480  # ----- Tests for eager_py_func -----
481  @test_util.run_in_graph_and_eager_modes
482  def testEagerSingleOutputInt32(self):
483    a = array_ops.ones((3, 3), dtype=dtypes.int32)
484    x = array_ops.ones((3, 1), dtype=dtypes.int32)
485    output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32)
486    ret = self.evaluate(output)
487    self.assertAllEqual(ret, [[3], [3], [3]])
488
489  @test_util.run_in_graph_and_eager_modes
490  def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self):
491    if context.executing_eagerly():
492      self.skipTest("b/126565353: We don't test eager's remote execution.")
493
494    workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0)
495    worker = workers[0]
496    session = session_lib.Session(worker.target)
497    with ops.device("/job:worker/task:0/cpu:0"):
498      a = array_ops.ones((3, 3), dtype=dtypes.float32)
499      x = array_ops.ones((3, 1), dtype=dtypes.float32)
500      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
501    ret = session.run(output)
502    self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
503
504  @test_util.run_in_graph_and_eager_modes
505  def testEagerSingleOutputFloat32(self):
506    with test_util.device(use_gpu=True):
507      a = array_ops.ones((3, 3), dtype=dtypes.float32)
508      x = array_ops.ones((3, 1), dtype=dtypes.float32)
509      output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
510      ret = self.evaluate(output)
511      self.assertAllClose(ret, [[3.0], [3.0], [3.0]])
512
513  @test_util.run_in_graph_and_eager_modes
514  def testEagerArrayOutput(self):
515    with test_util.device(use_gpu=True):
516      a = array_ops.ones((3, 3), dtype=dtypes.float32)
517      x = array_ops.ones((3, 1), dtype=dtypes.float32)
518      output = script_ops.eager_py_func(
519          lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32])
520      ret = self.evaluate(output)
521      self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]])
522
523  @test_util.run_in_graph_and_eager_modes
524  def testEagerReturnNone(self):
525    with test_util.device(use_gpu=True):
526      def no_return_value():
527        return
528
529      output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[])
530      ret = self.evaluate(output)
531      if context.executing_eagerly():
532        self.assertEquals(len(ret), 0)
533      else:
534        self.assertIsNone(ret)
535
536  @test_util.run_in_graph_and_eager_modes
537  @test_util.disable_xla("XLA cannot compile functions containing py_func")
538  def testEagerPyFuncInDefun(self):
539    with test_util.device(use_gpu=True):
540      def wrapper():
541        a = array_ops.ones((3, 3), dtype=dtypes.float32)
542        x = array_ops.ones((3, 1), dtype=dtypes.float32)
543        return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32)
544
545      wrapped = function.defun(wrapper)
546      ret = self.evaluate(wrapped())
547      self.assertAllEqual(ret, [[3.0], [3.0], [3.0]])
548
549  @test_util.run_in_graph_and_eager_modes
550  @test_util.run_v1_only("b/120545219")
551  def testEagerExceptionHandling(self):
552    with test_util.device(use_gpu=True):
553      self._testExceptionHandling(
554          ValueError, errors.InvalidArgumentError, eager=True)
555      self._testExceptionHandling(
556          TypeError, errors.InvalidArgumentError, eager=True)
557      self._testExceptionHandling(
558          StopIteration, errors.OutOfRangeError, eager=True)
559      self._testExceptionHandling(
560          MemoryError, errors.ResourceExhaustedError, eager=True)
561      self._testExceptionHandling(
562          NotImplementedError, errors.UnimplementedError, eager=True)
563
564      class WeirdError(Exception):
565        pass
566
567      self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True)
568
569  @test_util.run_in_graph_and_eager_modes
570  @test_util.run_v1_only("b/120545219")
571  def testEagerReturningVariableRaisesError(self):
572    def return_variable():
573      return resource_variable_ops.ResourceVariable(0.0)
574
575    with self.assertRaisesRegexp(errors.UnknownError,
576                                 "Attempting to return a variable"):
577      output = script_ops.eager_py_func(
578          return_variable, inp=[], Tout=dtypes.float32)
579      self.evaluate(output)
580
581  @test_util.run_in_graph_and_eager_modes
582  def testEagerGradientTape(self):
583
584    def f(x):
585      return x**2
586
587    x = constant_op.constant(3.0)
588    with backprop.GradientTape() as tape:
589      tape.watch(x)
590      y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
591    dy_dx = tape.gradient(y, x)
592    self.assertEqual(self.evaluate(dy_dx), 6.0)
593
594  @test_util.run_v1_only("b/120545219")
595  def testEagerGradientGraph(self):
596
597    def f(x):
598      return x**2
599
600    x = constant_op.constant(3.0)
601    y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32)
602    dy_dx = gradients_impl.gradients(y, x)[0]
603    self.assertEqual(self.evaluate(dy_dx), 6.0)
604
605  @test_util.run_v1_only("b/120545219")
606  def testEagerGradientGraphTwoOutputs(self):
607
608    def f(x, y):
609      return x * y, x / y
610
611    x = constant_op.constant(3.0)
612    y = constant_op.constant(2.0)
613    fa, fb = script_ops.eager_py_func(f, inp=[x, y],
614                                      Tout=[dtypes.float32, dtypes.float32])
615    dy_dx = gradients_impl.gradients(fa + fb, x)[0]
616    self.assertEqual(self.evaluate(dy_dx), 2.5)
617
618  @test_util.run_in_graph_and_eager_modes
619  def testEagerGradientTapeMultipleArgs(self):
620
621    def f(x, y):
622      return x**2 + y**2
623
624    x = constant_op.constant(3.0)
625    y = constant_op.constant(4.0)
626    with backprop.GradientTape() as tape:
627      tape.watch(x)
628      tape.watch(y)
629      z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
630
631    dz_dx, dz_dy = tape.gradient(z, [x, y])
632    self.assertEqual(self.evaluate(dz_dx), 6.0)
633    self.assertEqual(self.evaluate(dz_dy), 8.0)
634
635  @test_util.run_v1_only("b/120545219")
636  def testEagerGradientGraphMultipleArgs(self):
637
638    def f(x, y):
639      return x**2 + y**2
640
641    x = constant_op.constant(3.0)
642    y = constant_op.constant(4.0)
643    z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32)
644
645    dz_dx, dz_dy = gradients_impl.gradients(z, [x, y])
646    self.assertEqual(self.evaluate(dz_dx), 6.0)
647    self.assertEqual(self.evaluate(dz_dy), 8.0)
648
649  @test_util.run_v1_only("b/120545219")
650  def testEagerGradientGraphLogHuber(self):
651
652    def log_huber(x, m):
653      if math_ops.abs(x) <= m:
654        return x**2
655      else:
656        return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2))
657
658    x = array_ops.placeholder(dtypes.float32)
659    m = array_ops.placeholder(dtypes.float32)
660
661    y = script_ops.eager_py_func(
662        func=log_huber, inp=[x, m], Tout=dtypes.float32)
663    dy_dx = gradients_impl.gradients(y, x)[0]
664
665    with self.cached_session() as sess:
666      # Takes the first branch of log_huber.
667      y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0})
668      self.assertEqual(y, 1.0)
669      self.assertEqual(dy_dx, 2.0)
670
671  @test_util.run_v1_only("b/120545219")
672  def testEagerRespectsDevicePlacmentOfOp(self):
673
674    def f(x):
675      return math_ops.square(x)
676
677    def g(x):
678      return math_ops.add(x, x)
679
680    with ops.device("/CPU:0"):
681      # Explicitly ask for the py_funcs to execute on CPU, even if
682      # a GPU is available.
683      x = array_ops.placeholder(dtypes.float32)
684      y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32)
685      z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32)
686
687    with self.session(use_gpu=True) as sess:
688      output = sess.run(z, feed_dict={x: 3.0})
689      self.assertEqual(output, 18.0)
690
691
692if __name__ == "__main__":
693  test.main()
694