1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for py_func op.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import gc 23import re 24 25import numpy as np 26from six.moves import queue 27from six.moves import xrange # pylint: disable=redefined-builtin 28 29from tensorflow.python.client import session as session_lib 30from tensorflow.python.eager import backprop 31from tensorflow.python.eager import context 32from tensorflow.python.eager import function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import errors 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import test_util 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import gradients_impl 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import resource_variable_ops 42from tensorflow.python.ops import script_ops 43from tensorflow.python.platform import test 44 45 46def np_func(x, y): 47 return np.sinh(x) + np.cosh(y) 48 49 50def matmul(x, y): 51 return math_ops.matmul(x, y) 52 53 54class PyFuncTest(test.TestCase): 55 """Encapsulates tests for py_func and eager_py_func.""" 56 57 # ----- Tests for py_func ----- 58 def testRealDataTypes(self): 59 def sum_func(x, y): 60 return x + y 61 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64, 62 dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16, 63 dtypes.int32, dtypes.int64]: 64 with self.cached_session(): 65 x = constant_op.constant(1, dtype=dtype) 66 y = constant_op.constant(2, dtype=dtype) 67 z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype)) 68 self.assertEqual(z, 3) 69 70 def testComplexDataTypes(self): 71 def sub_func(x, y): 72 return x - y 73 for dtype in [dtypes.complex64, dtypes.complex128]: 74 with self.cached_session(): 75 x = constant_op.constant(1 + 1j, dtype=dtype) 76 y = constant_op.constant(2 - 2j, dtype=dtype) 77 z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype)) 78 self.assertEqual(z, -1 + 3j) 79 80 def testBoolDataTypes(self): 81 def and_func(x, y): 82 return x and y 83 dtype = dtypes.bool 84 with self.cached_session(): 85 x = constant_op.constant(True, dtype=dtype) 86 y = constant_op.constant(False, dtype=dtype) 87 z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype)) 88 self.assertEqual(z, False) 89 90 def testSingleType(self): 91 with self.cached_session(): 92 x = constant_op.constant(1.0, dtypes.float32) 93 y = constant_op.constant(2.0, dtypes.float32) 94 z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32)) 95 self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32)) 96 97 def testScalar(self): 98 with self.cached_session(): 99 x = constant_op.constant(1.0, dtypes.float32) 100 y = constant_op.constant(2.0, dtypes.float32) 101 z = self.evaluate( 102 script_ops.eager_py_func(np_func, [x, y], [dtypes.float32])) 103 self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32)) 104 105 @test_util.run_v1_only("b/120545219") 106 def testArray(self): 107 with self.cached_session(): 108 x = constant_op.constant([1.0, 2.0], dtypes.float64) 109 y = constant_op.constant([2.0, 3.0], dtypes.float64) 110 z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64])) 111 self.assertAllEqual(z[0], 112 np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) 113 114 def testComplexType(self): 115 with self.cached_session(): 116 x = constant_op.constant(1 + 2j, dtypes.complex64) 117 y = constant_op.constant(3 + 4j, dtypes.complex64) 118 z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64)) 119 self.assertAllClose(z, np_func(1 + 2j, 3 + 4j)) 120 121 def testRFFT(self): 122 with self.cached_session(): 123 x = constant_op.constant([1., 2., 3., 4.], dtypes.float32) 124 125 def rfft(x): 126 return np.fft.rfft(x).astype(np.complex64) 127 128 y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64)) 129 self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.])) 130 131 def testPythonLiteral(self): 132 with self.cached_session(): 133 134 def literal(x): 135 return 1.0 if float(x) == 0.0 else 0.0 136 137 x = constant_op.constant(0.0, dtypes.float64) 138 y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64)) 139 self.assertAllClose(y, 1.0) 140 141 def testList(self): 142 with self.cached_session(): 143 144 def list_func(x): 145 return [x, x + 1] 146 147 x = constant_op.constant(0.0, dtypes.float64) 148 y = self.evaluate( 149 script_ops.py_func(list_func, [x], [dtypes.float64] * 2)) 150 self.assertAllClose(y, [0.0, 1.0]) 151 152 def testTuple(self): 153 # returns a tuple 154 with self.cached_session(): 155 156 def tuple_func(x): 157 return x, x + 1 158 159 x = constant_op.constant(0.0, dtypes.float64) 160 y = self.evaluate( 161 script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)) 162 self.assertAllClose(y, [0.0, 1.0]) 163 164 # returns a tuple, Tout and inp a tuple 165 with self.cached_session(): 166 x = constant_op.constant(0.0, dtypes.float64) 167 y = self.evaluate( 168 script_ops.py_func(tuple_func, (x,), 169 (dtypes.float64, dtypes.float64))) 170 self.assertAllClose(y, [0.0, 1.0]) 171 172 @test_util.run_v1_only("b/120545219") 173 def testStrings(self): 174 175 def read_fixed_length_numpy_strings(): 176 return np.array([b" there"]) 177 178 def read_and_return_strings(x, y): 179 return x + y 180 181 with self.cached_session(): 182 x = constant_op.constant([b"hello", b"hi"], dtypes.string) 183 y = self.evaluate( 184 script_ops.py_func(read_fixed_length_numpy_strings, [], 185 dtypes.string)) 186 z = self.evaluate( 187 script_ops.py_func(read_and_return_strings, [x, y], dtypes.string)) 188 self.assertAllEqual(z, [b"hello there", b"hi there"]) 189 190 @test_util.run_v1_only("b/120545219") 191 def testStringsAreConvertedToBytes(self): 192 193 def read_fixed_length_numpy_strings(): 194 return np.array([" there"]) 195 196 def read_and_return_strings(x, y): 197 return x + y 198 199 with self.cached_session(): 200 x = constant_op.constant(["hello", "hi"], dtypes.string) 201 y = self.evaluate( 202 script_ops.py_func(read_fixed_length_numpy_strings, [], 203 dtypes.string)) 204 z = self.evaluate( 205 script_ops.py_func(read_and_return_strings, [x, y], dtypes.string)) 206 self.assertAllEqual(z, [b"hello there", b"hi there"]) 207 208 @test_util.run_v1_only("b/120545219") 209 def testObjectArraysAreConvertedToBytes(self): 210 211 def read_object_array(): 212 return np.array([b" there", u" ya"], dtype=np.object) 213 214 def read_and_return_strings(x, y): 215 return x + y 216 217 with self.cached_session(): 218 x = constant_op.constant(["hello", "hi"], dtypes.string) 219 y, = script_ops.py_func(read_object_array, [], 220 [dtypes.string]) 221 z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string]) 222 self.assertListEqual(list(z.eval()), [b"hello there", b"hi ya"]) 223 224 @test_util.run_v1_only("b/120545219") 225 def testStringPadding(self): 226 correct = [b"this", b"is", b"a", b"test"] 227 with self.cached_session(): 228 s, = script_ops.py_func(lambda: [correct], [], [dtypes.string]) 229 self.assertAllEqual(s.eval(), correct) 230 231 @test_util.run_v1_only("b/120545219") 232 def testStringPaddingAreConvertedToBytes(self): 233 inp = ["this", "is", "a", "test"] 234 correct = [b"this", b"is", b"a", b"test"] 235 with self.cached_session(): 236 s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) 237 self.assertAllEqual(s.eval(), correct) 238 239 @test_util.run_v1_only("b/120545219") 240 def testLarge(self): 241 with self.cached_session() as sess: 242 x = array_ops.zeros([1000000], dtype=np.float32) 243 y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32]) 244 z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32]) 245 for _ in xrange(100): 246 sess.run([y[0].op, z[0].op]) 247 248 def testNoInput(self): 249 with self.cached_session(): 250 x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64)) 251 self.assertAllClose(x, 42.0) 252 253 @test_util.run_v1_only("b/120545219") 254 def testAlias(self): 255 with self.cached_session(): 256 np_array = np.array([1.0, 2.0], dtype=np.float32) 257 tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32]) 258 value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32) 259 value.op.run() 260 self.assertAllEqual(np_array, [1.0, 2.0]) 261 262 @test_util.run_v1_only("b/120545219") 263 def testReturnUnicodeString(self): 264 with self.cached_session(): 265 correct = u"你好 世界" 266 267 def unicode_string(): 268 return correct 269 270 z, = script_ops.py_func(unicode_string, [], [dtypes.string]) 271 self.assertEqual(z.eval(), correct.encode("utf8")) 272 273 @test_util.run_v1_only("b/120545219") 274 def testBadNumpyReturnType(self): 275 with self.cached_session(): 276 277 def bad(): 278 # Structured numpy arrays aren't supported. 279 return np.array([], dtype=[("foo", np.float32)]) 280 281 y, = script_ops.py_func(bad, [], [dtypes.float32]) 282 283 with self.assertRaisesRegexp(errors.UnimplementedError, 284 "Unsupported numpy type"): 285 self.evaluate(y) 286 287 @test_util.run_v1_only("b/120545219") 288 def testBadReturnType(self): 289 with self.cached_session(): 290 291 def bad(): 292 # Non-string python objects aren't supported. 293 return {"foo": dtypes.float32} 294 295 z, = script_ops.py_func(bad, [], [dtypes.int64]) 296 297 with self.assertRaisesRegexp(errors.UnimplementedError, 298 "Unsupported object type"): 299 self.evaluate(z) 300 301 @test_util.run_v1_only("b/120545219") 302 def testReturnInput(self): 303 with self.cached_session(): 304 305 def ident(x): 306 return x[0] 307 308 p = array_ops.placeholder(dtypes.float32) 309 310 # Create a numpy array aliasing a tensor and a tensor aliasing this array 311 z, = script_ops.py_func(ident, [p], [dtypes.float32]) 312 z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0] 313 # above instead of using its memory as the return value of 314 # session.run 315 self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]})) 316 317 def testStateful(self): 318 # Not using self.cached_session(), which disables optimization. 319 with session_lib.Session() as sess: 320 producer = iter(range(3)) 321 x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64]) 322 self.assertEqual(self.evaluate(x), 0) 323 self.assertEqual(self.evaluate(x), 1) 324 self.assertEqual(self.evaluate(x), 2) 325 326 def testStateless(self): 327 # Not using self.cached_session(), which disables optimization. 328 with session_lib.Session() as sess: 329 producer = iter(range(3)) 330 x, = script_ops.py_func( 331 lambda: next(producer), [], [dtypes.int64], stateful=False) 332 self.assertEqual(self.evaluate(x), 0) 333 self.assertEqual(self.evaluate(x), 0) 334 self.assertEqual(self.evaluate(x), 0) 335 336 @test_util.run_v1_only("b/120545219") 337 def testGradientFunction(self): 338 # Input to tf.py_func is necessary, otherwise get_gradient_function() 339 # returns None per default. 340 a = constant_op.constant(0) 341 x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64]) 342 y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False) 343 self.assertEqual(None, ops.get_gradient_function(x.op)) 344 self.assertEqual(None, ops.get_gradient_function(y.op)) 345 346 @test_util.run_v1_only("b/120545219") 347 def testCOrder(self): 348 with self.cached_session(): 349 val = [[1, 2], [3, 4]] 350 x, = script_ops.py_func(lambda: np.array(val, order="F"), [], 351 [dtypes.int64]) 352 self.assertAllEqual(val, self.evaluate(x)) 353 354 @test_util.run_v1_only("b/120545219") 355 def testParallel(self): 356 # Tests that tf.py_func's can run in parallel if they release the GIL. 357 with self.cached_session() as session: 358 q = queue.Queue(1) 359 360 def blocking_put(): 361 q.put(42) 362 q.join() # Wait for task_done(). 363 return 42 364 365 def blocking_get(): 366 v = q.get(block=True) # Wait for put(). 367 q.task_done() 368 return v 369 370 x, = script_ops.py_func(blocking_put, [], [dtypes.int64]) 371 y, = script_ops.py_func(blocking_get, [], [dtypes.int64]) 372 373 # This will result in a deadlock if the py_func's don't run in parallel. 374 session.run([x, y]) 375 376 def testNoReturnValueStateful(self): 377 378 class State(object): 379 380 def __init__(self): 381 self._value = np.array([1], np.int64) 382 383 def _increment(self, diff): 384 self._value += diff 385 386 def increment(self, diff): 387 return script_ops.py_func(self._increment, [diff], [], stateful=True) 388 389 @property 390 def value(self): 391 return self._value 392 393 with self.cached_session(): 394 s = State() 395 op = s.increment(constant_op.constant(2, dtypes.int64)) 396 ret = self.evaluate(op) 397 self.assertIsNone(ret) 398 self.assertAllEqual([3], s.value) 399 400 @test_util.run_v1_only("b/120545219") 401 def testNoReturnValueStateless(self): 402 403 def do_nothing(unused_x): 404 pass 405 406 f = script_ops.py_func( 407 do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False) 408 with self.cached_session() as sess: 409 self.assertEqual(self.evaluate(f), []) 410 411 def _testExceptionHandling(self, py_exp, tf_exp, eager=False): 412 413 def inner_exception(): 414 raise py_exp("blah") # pylint: disable=not-callable 415 416 def raise_exception(): 417 inner_exception() 418 419 expected_regexp = r": blah.*" # Error at the top 420 expected_regexp += r"in raise_exception.*" # Stacktrace outer 421 expected_regexp += r"in inner_exception.*" # Stacktrace inner 422 expected_regexp += r": blah" # Stacktrace of raise 423 def expected_error_check(exception): 424 return re.search(expected_regexp, str(exception), re.DOTALL) 425 426 if eager: 427 if context.executing_eagerly(): 428 with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): 429 f = script_ops.eager_py_func(raise_exception, [], []) 430 return 431 else: 432 f = script_ops.eager_py_func(raise_exception, [], []) 433 else: 434 f = script_ops.py_func(raise_exception, [], []) 435 436 with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): 437 self.evaluate(f) 438 439 @test_util.run_v1_only("b/120545219") 440 def testExceptionHandling(self): 441 with self.cached_session(): 442 self._testExceptionHandling(ValueError, errors.InvalidArgumentError) 443 self._testExceptionHandling(TypeError, errors.InvalidArgumentError) 444 self._testExceptionHandling(StopIteration, errors.OutOfRangeError) 445 self._testExceptionHandling(MemoryError, errors.ResourceExhaustedError) 446 self._testExceptionHandling(NotImplementedError, 447 errors.UnimplementedError) 448 449 class WeirdError(Exception): 450 pass 451 452 self._testExceptionHandling(WeirdError, errors.UnknownError) 453 454 # ----- Tests shared by py_func and eager_py_func ----- 455 def testCleanup(self): 456 # Delete everything created by previous tests to avoid side effects. 457 ops.reset_default_graph() 458 gc.collect() 459 initial_size = script_ops._py_funcs.size() 460 # Encapsulate the graph generation, so locals can be deleted. 461 def make_graphs(): 462 for _ in xrange(1000): 463 g = ops.Graph() 464 with g.as_default(): 465 c = constant_op.constant([1.], dtypes.float32) 466 _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) 467 _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) 468 # These ops have a reference to 'c' which has a reference to the graph. 469 # Checks if the functions are being deleted though the graph is referenced from them. 470 # (see #18292) 471 _ = script_ops.py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) 472 _ = script_ops.eager_py_func(lambda x: x + c.shape[0], [c], [dtypes.float32]) 473 474 # Call garbage collector to enforce deletion. 475 make_graphs() 476 ops.reset_default_graph() 477 gc.collect() 478 self.assertEqual(initial_size, script_ops._py_funcs.size()) 479 480 # ----- Tests for eager_py_func ----- 481 @test_util.run_in_graph_and_eager_modes 482 def testEagerSingleOutputInt32(self): 483 a = array_ops.ones((3, 3), dtype=dtypes.int32) 484 x = array_ops.ones((3, 1), dtype=dtypes.int32) 485 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) 486 ret = self.evaluate(output) 487 self.assertAllEqual(ret, [[3], [3], [3]]) 488 489 @test_util.run_in_graph_and_eager_modes 490 def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self): 491 if context.executing_eagerly(): 492 self.skipTest("b/126565353: We don't test eager's remote execution.") 493 494 workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) 495 worker = workers[0] 496 session = session_lib.Session(worker.target) 497 with ops.device("/job:worker/task:0/cpu:0"): 498 a = array_ops.ones((3, 3), dtype=dtypes.float32) 499 x = array_ops.ones((3, 1), dtype=dtypes.float32) 500 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 501 ret = session.run(output) 502 self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) 503 504 @test_util.run_in_graph_and_eager_modes 505 def testEagerSingleOutputFloat32(self): 506 with test_util.device(use_gpu=True): 507 a = array_ops.ones((3, 3), dtype=dtypes.float32) 508 x = array_ops.ones((3, 1), dtype=dtypes.float32) 509 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 510 ret = self.evaluate(output) 511 self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) 512 513 @test_util.run_in_graph_and_eager_modes 514 def testEagerArrayOutput(self): 515 with test_util.device(use_gpu=True): 516 a = array_ops.ones((3, 3), dtype=dtypes.float32) 517 x = array_ops.ones((3, 1), dtype=dtypes.float32) 518 output = script_ops.eager_py_func( 519 lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32]) 520 ret = self.evaluate(output) 521 self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) 522 523 @test_util.run_in_graph_and_eager_modes 524 def testEagerReturnNone(self): 525 with test_util.device(use_gpu=True): 526 def no_return_value(): 527 return 528 529 output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) 530 ret = self.evaluate(output) 531 if context.executing_eagerly(): 532 self.assertEquals(len(ret), 0) 533 else: 534 self.assertIsNone(ret) 535 536 @test_util.run_in_graph_and_eager_modes 537 @test_util.disable_xla("XLA cannot compile functions containing py_func") 538 def testEagerPyFuncInDefun(self): 539 with test_util.device(use_gpu=True): 540 def wrapper(): 541 a = array_ops.ones((3, 3), dtype=dtypes.float32) 542 x = array_ops.ones((3, 1), dtype=dtypes.float32) 543 return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 544 545 wrapped = function.defun(wrapper) 546 ret = self.evaluate(wrapped()) 547 self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) 548 549 @test_util.run_in_graph_and_eager_modes 550 @test_util.run_v1_only("b/120545219") 551 def testEagerExceptionHandling(self): 552 with test_util.device(use_gpu=True): 553 self._testExceptionHandling( 554 ValueError, errors.InvalidArgumentError, eager=True) 555 self._testExceptionHandling( 556 TypeError, errors.InvalidArgumentError, eager=True) 557 self._testExceptionHandling( 558 StopIteration, errors.OutOfRangeError, eager=True) 559 self._testExceptionHandling( 560 MemoryError, errors.ResourceExhaustedError, eager=True) 561 self._testExceptionHandling( 562 NotImplementedError, errors.UnimplementedError, eager=True) 563 564 class WeirdError(Exception): 565 pass 566 567 self._testExceptionHandling(WeirdError, errors.UnknownError, eager=True) 568 569 @test_util.run_in_graph_and_eager_modes 570 @test_util.run_v1_only("b/120545219") 571 def testEagerReturningVariableRaisesError(self): 572 def return_variable(): 573 return resource_variable_ops.ResourceVariable(0.0) 574 575 with self.assertRaisesRegexp(errors.UnknownError, 576 "Attempting to return a variable"): 577 output = script_ops.eager_py_func( 578 return_variable, inp=[], Tout=dtypes.float32) 579 self.evaluate(output) 580 581 @test_util.run_in_graph_and_eager_modes 582 def testEagerGradientTape(self): 583 584 def f(x): 585 return x**2 586 587 x = constant_op.constant(3.0) 588 with backprop.GradientTape() as tape: 589 tape.watch(x) 590 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 591 dy_dx = tape.gradient(y, x) 592 self.assertEqual(self.evaluate(dy_dx), 6.0) 593 594 @test_util.run_v1_only("b/120545219") 595 def testEagerGradientGraph(self): 596 597 def f(x): 598 return x**2 599 600 x = constant_op.constant(3.0) 601 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 602 dy_dx = gradients_impl.gradients(y, x)[0] 603 self.assertEqual(self.evaluate(dy_dx), 6.0) 604 605 @test_util.run_v1_only("b/120545219") 606 def testEagerGradientGraphTwoOutputs(self): 607 608 def f(x, y): 609 return x * y, x / y 610 611 x = constant_op.constant(3.0) 612 y = constant_op.constant(2.0) 613 fa, fb = script_ops.eager_py_func(f, inp=[x, y], 614 Tout=[dtypes.float32, dtypes.float32]) 615 dy_dx = gradients_impl.gradients(fa + fb, x)[0] 616 self.assertEqual(self.evaluate(dy_dx), 2.5) 617 618 @test_util.run_in_graph_and_eager_modes 619 def testEagerGradientTapeMultipleArgs(self): 620 621 def f(x, y): 622 return x**2 + y**2 623 624 x = constant_op.constant(3.0) 625 y = constant_op.constant(4.0) 626 with backprop.GradientTape() as tape: 627 tape.watch(x) 628 tape.watch(y) 629 z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) 630 631 dz_dx, dz_dy = tape.gradient(z, [x, y]) 632 self.assertEqual(self.evaluate(dz_dx), 6.0) 633 self.assertEqual(self.evaluate(dz_dy), 8.0) 634 635 @test_util.run_v1_only("b/120545219") 636 def testEagerGradientGraphMultipleArgs(self): 637 638 def f(x, y): 639 return x**2 + y**2 640 641 x = constant_op.constant(3.0) 642 y = constant_op.constant(4.0) 643 z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) 644 645 dz_dx, dz_dy = gradients_impl.gradients(z, [x, y]) 646 self.assertEqual(self.evaluate(dz_dx), 6.0) 647 self.assertEqual(self.evaluate(dz_dy), 8.0) 648 649 @test_util.run_v1_only("b/120545219") 650 def testEagerGradientGraphLogHuber(self): 651 652 def log_huber(x, m): 653 if math_ops.abs(x) <= m: 654 return x**2 655 else: 656 return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2)) 657 658 x = array_ops.placeholder(dtypes.float32) 659 m = array_ops.placeholder(dtypes.float32) 660 661 y = script_ops.eager_py_func( 662 func=log_huber, inp=[x, m], Tout=dtypes.float32) 663 dy_dx = gradients_impl.gradients(y, x)[0] 664 665 with self.cached_session() as sess: 666 # Takes the first branch of log_huber. 667 y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) 668 self.assertEqual(y, 1.0) 669 self.assertEqual(dy_dx, 2.0) 670 671 @test_util.run_v1_only("b/120545219") 672 def testEagerRespectsDevicePlacmentOfOp(self): 673 674 def f(x): 675 return math_ops.square(x) 676 677 def g(x): 678 return math_ops.add(x, x) 679 680 with ops.device("/CPU:0"): 681 # Explicitly ask for the py_funcs to execute on CPU, even if 682 # a GPU is available. 683 x = array_ops.placeholder(dtypes.float32) 684 y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32) 685 z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32) 686 687 with self.session(use_gpu=True) as sess: 688 output = sess.run(z, feed_dict={x: 3.0}) 689 self.assertEqual(output, 18.0) 690 691 692if __name__ == "__main__": 693 test.main() 694