1# -*- coding: utf-8 -*- 2# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Tests for py_func op.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import gc 23import re 24 25import numpy as np 26from six.moves import queue 27from six.moves import xrange # pylint: disable=redefined-builtin 28 29from tensorflow.python.client import session as session_lib 30from tensorflow.python.eager import backprop 31from tensorflow.python.eager import context 32from tensorflow.python.eager import def_function 33from tensorflow.python.eager import function 34from tensorflow.python.framework import constant_op 35from tensorflow.python.framework import dtypes 36from tensorflow.python.framework import errors 37from tensorflow.python.framework import ops 38from tensorflow.python.framework import test_util 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import batch_ops 41from tensorflow.python.ops import gradients_impl 42from tensorflow.python.ops import math_ops 43from tensorflow.python.ops import resource_variable_ops 44from tensorflow.python.ops import script_ops 45from tensorflow.python.platform import test 46 47 48def np_func(x, y): 49 return np.sinh(x) + np.cosh(y) 50 51 52def matmul(x, y): 53 return math_ops.matmul(x, y) 54 55 56class PyFuncTestBase(test.TestCase): 57 58 def verifyExceptionHandling(self, py_exp, tf_exp, eager=False): 59 60 def inner_exception(): 61 raise py_exp("blah") # pylint: disable=not-callable 62 63 def raise_exception(): 64 inner_exception() 65 66 expected_regexp = r": blah.*" # Error at the top 67 expected_regexp += r"in raise_exception.*" # Stacktrace outer 68 expected_regexp += r"in inner_exception.*" # Stacktrace inner 69 expected_regexp += r": blah" # Stacktrace of raise 70 def expected_error_check(exception): 71 return re.search(expected_regexp, str(exception), re.DOTALL) 72 73 if eager: 74 if context.executing_eagerly(): 75 with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): 76 f = script_ops.eager_py_func(raise_exception, [], []) 77 return 78 else: 79 f = script_ops.eager_py_func(raise_exception, [], []) 80 else: 81 f = script_ops.py_func(raise_exception, [], []) 82 83 with self.assertRaisesWithPredicateMatch(tf_exp, expected_error_check): 84 self.evaluate(f) 85 86 87class PyFuncTest(PyFuncTestBase): 88 """Encapsulates tests for py_func only.""" 89 90 def testRealDataTypes(self): 91 def sum_func(x, y): 92 return x + y 93 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64, 94 dtypes.uint8, dtypes.int8, dtypes.uint16, dtypes.int16, 95 dtypes.int32, dtypes.int64]: 96 with self.cached_session(): 97 x = constant_op.constant(1, dtype=dtype) 98 y = constant_op.constant(2, dtype=dtype) 99 z = self.evaluate(script_ops.py_func(sum_func, [x, y], dtype)) 100 self.assertEqual(z, 3) 101 102 def testComplexDataTypes(self): 103 def sub_func(x, y): 104 return x - y 105 for dtype in [dtypes.complex64, dtypes.complex128]: 106 with self.cached_session(): 107 x = constant_op.constant(1 + 1j, dtype=dtype) 108 y = constant_op.constant(2 - 2j, dtype=dtype) 109 z = self.evaluate(script_ops.py_func(sub_func, [x, y], dtype)) 110 self.assertEqual(z, -1 + 3j) 111 112 def testBoolDataTypes(self): 113 def and_func(x, y): 114 return x and y 115 dtype = dtypes.bool 116 with self.cached_session(): 117 x = constant_op.constant(True, dtype=dtype) 118 y = constant_op.constant(False, dtype=dtype) 119 z = self.evaluate(script_ops.py_func(and_func, [x, y], dtype)) 120 self.assertEqual(z, False) 121 122 def testSingleType(self): 123 with self.cached_session(): 124 x = constant_op.constant(1.0, dtypes.float32) 125 y = constant_op.constant(2.0, dtypes.float32) 126 z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.float32)) 127 self.assertEqual(z, np_func(1.0, 2.0).astype(np.float32)) 128 129 def testScalar(self): 130 with self.cached_session(): 131 x = constant_op.constant(1.0, dtypes.float32) 132 y = constant_op.constant(2.0, dtypes.float32) 133 z = self.evaluate( 134 script_ops.eager_py_func(np_func, [x, y], [dtypes.float32])) 135 self.assertEqual(z[0], np_func(1.0, 2.0).astype(np.float32)) 136 137 @test_util.run_v1_only("b/120545219") 138 def testArray(self): 139 with self.cached_session(): 140 x = constant_op.constant([1.0, 2.0], dtypes.float64) 141 y = constant_op.constant([2.0, 3.0], dtypes.float64) 142 z = self.evaluate(script_ops.py_func(np_func, [x, y], [dtypes.float64])) 143 self.assertAllEqual(z[0], 144 np_func([1.0, 2.0], [2.0, 3.0]).astype(np.float64)) 145 146 def testComplexType(self): 147 with self.cached_session(): 148 x = constant_op.constant(1 + 2j, dtypes.complex64) 149 y = constant_op.constant(3 + 4j, dtypes.complex64) 150 z = self.evaluate(script_ops.py_func(np_func, [x, y], dtypes.complex64)) 151 self.assertAllClose(z, np_func(1 + 2j, 3 + 4j)) 152 153 def testRFFT(self): 154 with self.cached_session(): 155 x = constant_op.constant([1., 2., 3., 4.], dtypes.float32) 156 157 def rfft(x): 158 return np.fft.rfft(x).astype(np.complex64) 159 160 y = self.evaluate(script_ops.py_func(rfft, [x], dtypes.complex64)) 161 self.assertAllClose(y, np.fft.rfft([1., 2., 3., 4.])) 162 163 def testPythonLiteral(self): 164 with self.cached_session(): 165 166 def literal(x): 167 return 1.0 if float(x) == 0.0 else 0.0 168 169 x = constant_op.constant(0.0, dtypes.float64) 170 y = self.evaluate(script_ops.py_func(literal, [x], dtypes.float64)) 171 self.assertAllClose(y, 1.0) 172 173 def testList(self): 174 with self.cached_session(): 175 176 def list_func(x): 177 return [x, x + 1] 178 179 x = constant_op.constant(0.0, dtypes.float64) 180 y = self.evaluate( 181 script_ops.py_func(list_func, [x], [dtypes.float64] * 2)) 182 self.assertAllClose(y, [0.0, 1.0]) 183 184 def testTuple(self): 185 # returns a tuple 186 with self.cached_session(): 187 188 def tuple_func(x): 189 return x, x + 1 190 191 x = constant_op.constant(0.0, dtypes.float64) 192 y = self.evaluate( 193 script_ops.py_func(tuple_func, [x], [dtypes.float64] * 2)) 194 self.assertAllClose(y, [0.0, 1.0]) 195 196 # returns a tuple, Tout and inp a tuple 197 with self.cached_session(): 198 x = constant_op.constant(0.0, dtypes.float64) 199 y = self.evaluate( 200 script_ops.py_func(tuple_func, (x,), 201 (dtypes.float64, dtypes.float64))) 202 self.assertAllClose(y, [0.0, 1.0]) 203 204 @test_util.run_v1_only("b/120545219") 205 def testStrings(self): 206 207 def read_fixed_length_numpy_strings(): 208 return np.array([b" there"]) 209 210 def read_and_return_strings(x, y): 211 return x + y 212 213 with self.cached_session(): 214 x = constant_op.constant([b"hello", b"hi"], dtypes.string) 215 y = self.evaluate( 216 script_ops.py_func(read_fixed_length_numpy_strings, [], 217 dtypes.string)) 218 z = self.evaluate( 219 script_ops.py_func(read_and_return_strings, [x, y], dtypes.string)) 220 self.assertAllEqual(z, [b"hello there", b"hi there"]) 221 222 @test_util.run_v1_only("b/120545219") 223 def testStringsAreConvertedToBytes(self): 224 225 def read_fixed_length_numpy_strings(): 226 return np.array([" there"]) 227 228 def read_and_return_strings(x, y): 229 return x + y 230 231 with self.cached_session(): 232 x = constant_op.constant(["hello", "hi"], dtypes.string) 233 y = self.evaluate( 234 script_ops.py_func(read_fixed_length_numpy_strings, [], 235 dtypes.string)) 236 z = self.evaluate( 237 script_ops.py_func(read_and_return_strings, [x, y], dtypes.string)) 238 self.assertAllEqual(z, [b"hello there", b"hi there"]) 239 240 @test_util.run_v1_only("b/120545219") 241 def testObjectArraysAreConvertedToBytes(self): 242 243 def read_object_array(): 244 return np.array([b" there", u" ya"], dtype=np.object) 245 246 def read_and_return_strings(x, y): 247 return x + y 248 249 with self.cached_session(): 250 x = constant_op.constant(["hello", "hi"], dtypes.string) 251 y, = script_ops.py_func(read_object_array, [], 252 [dtypes.string]) 253 z, = script_ops.py_func(read_and_return_strings, [x, y], [dtypes.string]) 254 self.assertListEqual(list(self.evaluate(z)), [b"hello there", b"hi ya"]) 255 256 @test_util.run_v1_only("b/120545219") 257 def testStringPadding(self): 258 correct = [b"this", b"is", b"a", b"test"] 259 with self.cached_session(): 260 s, = script_ops.py_func(lambda: [correct], [], [dtypes.string]) 261 self.assertAllEqual(s, correct) 262 263 @test_util.run_v1_only("b/120545219") 264 def testStringPaddingAreConvertedToBytes(self): 265 inp = ["this", "is", "a", "test"] 266 correct = [b"this", b"is", b"a", b"test"] 267 with self.cached_session(): 268 s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) 269 self.assertAllEqual(s, correct) 270 271 @test_util.run_v1_only("b/120545219") 272 def testNulTerminatedStrings(self): 273 inp = np.array(["this\0", "is\0\0", "a\0", "test\0\0"], dtype=np.str_) 274 correct = [b"this", b"is", b"a", b"test"] 275 with self.cached_session(): 276 s, = script_ops.py_func(lambda: [inp], [], [dtypes.string]) 277 self.assertAllEqual(s, correct) 278 279 @test_util.run_v1_only("b/120545219") 280 def testLarge(self): 281 with self.cached_session() as sess: 282 x = array_ops.zeros([1000000], dtype=np.float32) 283 y = script_ops.py_func(lambda x: x + 1, [x], [dtypes.float32]) 284 z = script_ops.py_func(lambda x: x * 2, [x], [dtypes.float32]) 285 for _ in xrange(100): 286 sess.run([y[0].op, z[0].op]) 287 288 def testNoInput(self): 289 with self.cached_session(): 290 x = self.evaluate(script_ops.py_func(lambda: 42.0, [], dtypes.float64)) 291 self.assertAllClose(x, 42.0) 292 293 @test_util.run_v1_only("b/120545219") 294 def testAlias(self): 295 with self.cached_session(): 296 np_array = np.array([1.0, 2.0], dtype=np.float32) 297 tf_array = script_ops.py_func(lambda: np_array, [], [dtypes.float32]) 298 value = tf_array + constant_op.constant([2.0, 3.0], dtype=dtypes.float32) 299 value.op.run() 300 self.assertAllEqual(np_array, [1.0, 2.0]) 301 302 @test_util.run_v1_only("b/120545219") 303 def testReturnUnicodeString(self): 304 with self.cached_session(): 305 correct = u"你好 世界" 306 307 def unicode_string(): 308 return correct 309 310 z, = script_ops.py_func(unicode_string, [], [dtypes.string]) 311 self.assertEqual(self.evaluate(z), correct.encode("utf8")) 312 313 @test_util.run_v1_only("b/120545219") 314 def testBadNumpyReturnType(self): 315 with self.cached_session(): 316 317 def bad(): 318 # Structured numpy arrays aren't supported. 319 return np.array([], dtype=[("foo", np.float32)]) 320 321 y, = script_ops.py_func(bad, [], [dtypes.float32]) 322 323 with self.assertRaisesRegex(errors.InternalError, 324 "Unsupported numpy data type"): 325 self.evaluate(y) 326 327 @test_util.run_v1_only("b/120545219") 328 def testBadReturnType(self): 329 with self.cached_session(): 330 331 def bad(): 332 # Non-string python objects aren't supported. 333 return {"foo": dtypes.float32} 334 335 z, = script_ops.py_func(bad, [], [dtypes.int64]) 336 337 with self.assertRaisesRegex(errors.InternalError, 338 "Unsupported object type"): 339 self.evaluate(z) 340 341 @test_util.run_v1_only("b/120545219") 342 def testReturnInput(self): 343 with self.cached_session(): 344 345 def ident(x): 346 return x[0] 347 348 p = array_ops.placeholder(dtypes.float32) 349 350 # Create a numpy array aliasing a tensor and a tensor aliasing this array 351 z, = script_ops.py_func(ident, [p], [dtypes.float32]) 352 z += 0.0 # Makes sure we release the tensor aliasing the numpy array x[0] 353 # above instead of using its memory as the return value of 354 # session.run 355 self.assertEqual(0.0, z.eval(feed_dict={p: [0.0]})) 356 357 def testStateful(self): 358 # Not using self.cached_session(), which disables optimization. 359 with session_lib.Session(): 360 producer = iter(range(3)) 361 x, = script_ops.py_func(lambda: next(producer), [], [dtypes.int64]) 362 self.assertEqual(self.evaluate(x), 0) 363 self.assertEqual(self.evaluate(x), 1) 364 self.assertEqual(self.evaluate(x), 2) 365 366 @test_util.enable_tf_xla_constant_folding("b/134376434") 367 def testStateless(self): 368 # Not using self.cached_session(), which disables optimization. 369 with session_lib.Session(): 370 producer = iter(range(3)) 371 x, = script_ops.py_func( 372 lambda: next(producer), [], [dtypes.int64], stateful=False) 373 self.assertEqual(self.evaluate(x), 0) 374 self.assertEqual(self.evaluate(x), 0) 375 self.assertEqual(self.evaluate(x), 0) 376 377 @test_util.run_v1_only("b/120545219") 378 def testGradientFunction(self): 379 # Input to tf.compat.v1.py_func is necessary, 380 # otherwise get_gradient_function() returns None per default. 381 a = constant_op.constant(0) 382 x, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64]) 383 y, = script_ops.py_func(lambda a: 0, [a], [dtypes.int64], stateful=False) 384 self.assertEqual(None, ops.get_gradient_function(x.op)) 385 self.assertEqual(None, ops.get_gradient_function(y.op)) 386 387 @test_util.run_v1_only("b/120545219") 388 def testCOrder(self): 389 with self.cached_session(): 390 val = [[1, 2], [3, 4]] 391 x, = script_ops.py_func(lambda: np.array(val, order="F"), [], 392 [dtypes.int64]) 393 self.assertAllEqual(val, self.evaluate(x)) 394 395 @test_util.run_v1_only("b/120545219") 396 def testParallel(self): 397 # Tests that tf.compat.v1.py_func's can run in parallel if they release 398 # the GIL. 399 with self.cached_session() as session: 400 q = queue.Queue(1) 401 402 def blocking_put(): 403 q.put(42) 404 q.join() # Wait for task_done(). 405 return 42 406 407 def blocking_get(): 408 v = q.get(block=True) # Wait for put(). 409 q.task_done() 410 return v 411 412 x, = script_ops.py_func(blocking_put, [], [dtypes.int64]) 413 y, = script_ops.py_func(blocking_get, [], [dtypes.int64]) 414 415 # This will result in a deadlock if the py_func's don't run in parallel. 416 session.run([x, y]) 417 418 def testNoReturnValueStateful(self): 419 420 class State(object): 421 422 def __init__(self): 423 self._value = np.array([1], np.int64) 424 425 def _increment(self, diff): 426 self._value += diff 427 428 def increment(self, diff): 429 return script_ops.py_func(self._increment, [diff], [], stateful=True) 430 431 @property 432 def value(self): 433 return self._value 434 435 with self.cached_session(): 436 s = State() 437 op = s.increment(constant_op.constant(2, dtypes.int64)) 438 ret = self.evaluate(op) 439 self.assertIsNone(ret) 440 self.assertAllEqual([3], s.value) 441 442 @test_util.run_v1_only("b/120545219") 443 def testNoReturnValueStateless(self): 444 445 def do_nothing(unused_x): 446 pass 447 448 f = script_ops.py_func( 449 do_nothing, [constant_op.constant(3, dtypes.int64)], [], stateful=False) 450 with self.cached_session(): 451 self.assertEqual(self.evaluate(f), []) 452 453 @test_util.run_v1_only("b/120545219") 454 def testExceptionHandling(self): 455 with self.cached_session(): 456 self.verifyExceptionHandling(ValueError, errors.InvalidArgumentError) 457 self.verifyExceptionHandling(TypeError, errors.InvalidArgumentError) 458 self.verifyExceptionHandling(StopIteration, errors.OutOfRangeError) 459 self.verifyExceptionHandling(MemoryError, errors.ResourceExhaustedError) 460 self.verifyExceptionHandling(NotImplementedError, 461 errors.UnimplementedError) 462 463 class WeirdError(Exception): 464 pass 465 466 self.verifyExceptionHandling(WeirdError, errors.UnknownError) 467 468 def testFunctionReferencesAreKept(self): 469 g = ops.Graph() 470 with g.as_default(): 471 c = constant_op.constant([1.], dtypes.float32) 472 @batch_ops.batch_function(1, 10, 100000) 473 def fn(x): 474 # Upon exiting this function, the py_func holds the sole reference 475 # to this lambda, without which it would be garbage collected. 476 return script_ops.py_func(lambda x: x, [x], [dtypes.float32]) 477 result = fn(c) 478 gc.collect() 479 self.evaluate(result) 480 481 482class PyFuncAndEagerPyFuncTest(PyFuncTestBase): 483 """Encapsulates tests shared between py_func and eager_py_func.""" 484 485 def verifyPyFuncsNoIncrease(self, make_graph): 486 ops.reset_default_graph() 487 gc.collect() 488 initial_size = script_ops._py_funcs.size() 489 490 for _ in xrange(1000): 491 make_graph() 492 493 ops.reset_default_graph() 494 gc.collect() 495 self.assertEqual(initial_size, script_ops._py_funcs.size()) 496 497 def testCleanup(self): 498 499 def make_graph(): 500 g = ops.Graph() 501 with g.as_default(): 502 c = constant_op.constant([1.], dtypes.float32) 503 _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) 504 _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) 505 # These ops have a reference to 'c' which has a reference to the 506 # graph. 507 # Checks if the functions are being deleted though the graph is 508 # referenced from them (see #18292). 509 script_ops.py_func( 510 lambda x: x + c.shape[0], [c], [dtypes.float32]) 511 script_ops.eager_py_func( 512 lambda x: x + c.shape[0], [c], [dtypes.float32]) 513 514 self.verifyPyFuncsNoIncrease(make_graph) 515 516 def testCleanupInTfFunction(self): 517 518 self.skipTest("b/144098211") 519 520 def make_graph(): 521 g = ops.Graph() 522 with g.as_default(): 523 @def_function.function 524 def fn(): 525 c = constant_op.constant([1.], dtypes.float32) 526 _ = script_ops.py_func(lambda x: x + 1, [c], [dtypes.float32]) 527 _ = script_ops.eager_py_func(lambda x: x + 1, [c], [dtypes.float32]) 528 # These ops have a reference to 'c' which has a reference to the 529 # graph. 530 # Checks if the functions are being deleted though the graph is 531 # referenced from them (see #18292). 532 script_ops.py_func( 533 lambda x: x + c.shape[0], [c], [dtypes.float32]) 534 script_ops.eager_py_func( 535 lambda x: x + c.shape[0], [c], [dtypes.float32]) 536 fn() 537 538 self.verifyPyFuncsNoIncrease(make_graph) 539 540 541class EagerPyFuncTest(PyFuncTestBase): 542 """Encapsulates tests for eager_py_func only.""" 543 544 @test_util.run_in_graph_and_eager_modes 545 def testEagerSingleOutputInt32(self): 546 a = array_ops.ones((3, 3), dtype=dtypes.int32) 547 x = array_ops.ones((3, 1), dtype=dtypes.int32) 548 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.int32) 549 ret = self.evaluate(output) 550 self.assertAllEqual(ret, [[3], [3], [3]]) 551 552 @test_util.run_in_graph_and_eager_modes 553 def testRenamedDeviceInTestClusterCorrectlyIdentifiedAsLocalhost(self): 554 if context.executing_eagerly(): 555 self.skipTest("b/126565353: We don't test eager's remote execution.") 556 557 workers, _ = test_util.create_local_cluster(num_workers=1, num_ps=0) 558 worker = workers[0] 559 session = session_lib.Session(worker.target) 560 with ops.device("/job:worker/task:0/cpu:0"): 561 a = array_ops.ones((3, 3), dtype=dtypes.float32) 562 x = array_ops.ones((3, 1), dtype=dtypes.float32) 563 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 564 ret = session.run(output) 565 self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) 566 567 @test_util.run_in_graph_and_eager_modes 568 def testEagerSingleOutputFloat32(self): 569 with test_util.device(use_gpu=True): 570 a = array_ops.ones((3, 3), dtype=dtypes.float32) 571 x = array_ops.ones((3, 1), dtype=dtypes.float32) 572 output = script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 573 ret = self.evaluate(output) 574 self.assertAllClose(ret, [[3.0], [3.0], [3.0]]) 575 576 @test_util.run_in_graph_and_eager_modes 577 def testEagerArrayOutput(self): 578 with test_util.device(use_gpu=True): 579 a = array_ops.ones((3, 3), dtype=dtypes.float32) 580 x = array_ops.ones((3, 1), dtype=dtypes.float32) 581 output = script_ops.eager_py_func( 582 lambda a, x: [matmul(a, x)], inp=[a, x], Tout=[dtypes.float32]) 583 ret = self.evaluate(output) 584 self.assertAllEqual(ret, [[[3.0], [3.0], [3.0]]]) 585 586 @test_util.run_in_graph_and_eager_modes 587 def testEagerReturnNone(self): 588 with test_util.device(use_gpu=True): 589 def no_return_value(): 590 return 591 592 output = script_ops.eager_py_func(no_return_value, inp=[], Tout=[]) 593 ret = self.evaluate(output) 594 if context.executing_eagerly(): 595 self.assertEqual(len(ret), 0) 596 else: 597 self.assertIsNone(ret) 598 599 @test_util.run_in_graph_and_eager_modes 600 @test_util.disable_tfrt("b/180469928") 601 def testEagerPyFuncInDefun(self): 602 with test_util.device(use_gpu=True): 603 def wrapper(): 604 a = array_ops.ones((3, 3), dtype=dtypes.float32) 605 x = array_ops.ones((3, 1), dtype=dtypes.float32) 606 return script_ops.eager_py_func(matmul, inp=[a, x], Tout=dtypes.float32) 607 608 wrapped = function.defun(wrapper) 609 ret = self.evaluate(wrapped()) 610 self.assertAllEqual(ret, [[3.0], [3.0], [3.0]]) 611 612 @test_util.run_in_graph_and_eager_modes 613 @test_util.run_v1_only("b/120545219") 614 def testEagerExceptionHandling(self): 615 with test_util.device(use_gpu=True): 616 self.verifyExceptionHandling( 617 ValueError, errors.InvalidArgumentError, eager=True) 618 self.verifyExceptionHandling( 619 TypeError, errors.InvalidArgumentError, eager=True) 620 self.verifyExceptionHandling( 621 StopIteration, errors.OutOfRangeError, eager=True) 622 self.verifyExceptionHandling( 623 MemoryError, errors.ResourceExhaustedError, eager=True) 624 self.verifyExceptionHandling( 625 NotImplementedError, errors.UnimplementedError, eager=True) 626 627 class WeirdError(Exception): 628 pass 629 630 self.verifyExceptionHandling(WeirdError, errors.UnknownError, eager=True) 631 632 @test_util.run_in_graph_and_eager_modes 633 @test_util.run_v1_only("b/120545219") 634 def testEagerReturningVariableRaisesError(self): 635 def return_variable(): 636 return resource_variable_ops.ResourceVariable(0.0) 637 638 with self.assertRaisesRegex(errors.UnknownError, 639 "Attempting to return a variable"): 640 output = script_ops.eager_py_func( 641 return_variable, inp=[], Tout=dtypes.float32) 642 self.evaluate(output) 643 644 @test_util.run_in_graph_and_eager_modes 645 def testEagerGradientTape(self): 646 647 def f(x): 648 return x**2 649 650 x = constant_op.constant(3.0) 651 with backprop.GradientTape() as tape: 652 tape.watch(x) 653 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 654 dy_dx = tape.gradient(y, x) 655 self.assertAllClose(self.evaluate(dy_dx), 6.0) 656 657 # Test complex values 658 x = constant_op.constant(3.0 + 3.0j) 659 with backprop.GradientTape() as tape: 660 tape.watch(x) 661 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.complex128) 662 dy_dx = tape.gradient(y, x) 663 # Gradient of complex will be the conj 664 self.assertAllClose(self.evaluate(dy_dx), 6.0 - 6.0j) 665 666 @test_util.run_v1_only("b/120545219") 667 def testEagerGradientGraph(self): 668 669 def f(x): 670 return x**2 671 672 x = constant_op.constant(3.0) 673 y = script_ops.eager_py_func(f, inp=[x], Tout=dtypes.float32) 674 dy_dx = gradients_impl.gradients(y, x)[0] 675 self.assertEqual(self.evaluate(dy_dx), 6.0) 676 677 @test_util.run_v1_only("b/120545219") 678 def testEagerGradientGraphTwoOutputs(self): 679 680 def f(x, y): 681 return x * y, x / y 682 683 x = constant_op.constant(3.0) 684 y = constant_op.constant(2.0) 685 fa, fb = script_ops.eager_py_func(f, inp=[x, y], 686 Tout=[dtypes.float32, dtypes.float32]) 687 dy_dx = gradients_impl.gradients(fa + fb, x)[0] 688 self.assertEqual(self.evaluate(dy_dx), 2.5) 689 690 @test_util.run_in_graph_and_eager_modes 691 def testEagerGradientTapeMultipleArgs(self): 692 693 def f(x, y): 694 return x**2 + y**2 695 696 x = constant_op.constant(3.0) 697 y = constant_op.constant(4.0) 698 with backprop.GradientTape() as tape: 699 tape.watch(x) 700 tape.watch(y) 701 z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) 702 703 dz_dx, dz_dy = tape.gradient(z, [x, y]) 704 self.assertEqual(self.evaluate(dz_dx), 6.0) 705 self.assertEqual(self.evaluate(dz_dy), 8.0) 706 707 @test_util.run_v1_only("b/120545219") 708 def testEagerGradientGraphMultipleArgs(self): 709 710 def f(x, y): 711 return x**2 + y**2 712 713 x = constant_op.constant(3.0) 714 y = constant_op.constant(4.0) 715 z = script_ops.eager_py_func(f, inp=[x, y], Tout=dtypes.float32) 716 717 dz_dx, dz_dy = gradients_impl.gradients(z, [x, y]) 718 self.assertEqual(self.evaluate(dz_dx), 6.0) 719 self.assertEqual(self.evaluate(dz_dy), 8.0) 720 721 @test_util.run_v1_only("b/120545219") 722 def testEagerGradientGraphLogHuber(self): 723 724 def log_huber(x, m): 725 if math_ops.abs(x) <= m: 726 return x**2 727 else: 728 return m**2 * (1 - 2 * math_ops.log(m) + math_ops.log(x**2)) 729 730 x = array_ops.placeholder(dtypes.float32) 731 m = array_ops.placeholder(dtypes.float32) 732 733 y = script_ops.eager_py_func( 734 func=log_huber, inp=[x, m], Tout=dtypes.float32) 735 dy_dx = gradients_impl.gradients(y, x)[0] 736 737 with self.cached_session() as sess: 738 # Takes the first branch of log_huber. 739 y, dy_dx = sess.run([y, dy_dx], feed_dict={x: 1.0, m: 2.0}) 740 self.assertEqual(y, 1.0) 741 self.assertEqual(dy_dx, 2.0) 742 743 @test_util.run_v1_only("b/120545219") 744 def testEagerRespectsDevicePlacementOfOp(self): 745 746 def f(x): 747 return math_ops.square(x) 748 749 def g(x): 750 return math_ops.add(x, x) 751 752 with ops.device("/CPU:0"): 753 # Explicitly ask for the py_funcs to execute on CPU, even if 754 # a GPU is available. 755 x = array_ops.placeholder(dtypes.float32) 756 y = script_ops.eager_py_func(func=f, inp=[x], Tout=dtypes.float32) 757 z = script_ops.eager_py_func(func=g, inp=[y], Tout=dtypes.float32) 758 759 with self.session() as sess: 760 output = sess.run(z, feed_dict={x: 3.0}) 761 self.assertEqual(output, 18.0) 762 763 @test_util.run_in_graph_and_eager_modes 764 def testEagerPyFuncOnGPUWithStrings(self): 765 766 def fn(a): 767 return str(a.dtype) 768 769 x = constant_op.constant("x", dtype=dtypes.string) 770 output = script_ops.eager_py_func(fn, inp=[x], Tout=dtypes.string) 771 self.assertEqual(self.evaluate(output), "<dtype: 'string'>".encode("utf8")) 772 773 @test_util.run_in_graph_and_eager_modes 774 def testEagerPyFuncNotACallable(self): 775 x = constant_op.constant("x", dtype=dtypes.string) 776 777 with self.assertRaisesRegex(ValueError, "callable"): 778 _ = script_ops.eager_py_func(x, inp=[x], Tout=dtypes.string) 779 780 781if __name__ == "__main__": 782 test.main() 783