1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=unidiomatic-typecheck 16"""Prototype decorator for defining graph functions with eager semantics.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import functools 23import weakref 24 25from tensorflow.python.eager import context 26from tensorflow.python.eager import function as function_lib 27from tensorflow.python.eager import lift_to_graph 28from tensorflow.python.framework import func_graph as func_graph_module 29from tensorflow.python.framework import ops 30from tensorflow.python.ops import control_flow_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import resource_variable_ops 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.training.tracking import base as trackable 35from tensorflow.python.util import nest 36from tensorflow.python.util import tf_decorator 37from tensorflow.python.util.tf_export import tf_export 38 39 40class UnliftedInitializerVariable(resource_variable_ops.ResourceVariable): 41 """Variable which does not lift its initializer out of function context. 42 43 Instances of this variable, when created, build a graph which runs their 44 initializer inside a tf.cond(is_initialized) block. 45 46 This can only be created inside a defun called from (eventually) eager 47 mode. That is, non-function-building graphs are not supported. 48 """ 49 50 def __init__(self, # pylint: disable=super-init-not-called 51 initial_value=None, 52 trainable=None, 53 caching_device=None, 54 name=None, 55 dtype=None, 56 constraint=None, 57 add_initializers_to=None, 58 lifted_initializer_graph=None, 59 **unused_kwargs): 60 """Creates a variable. 61 62 Args: 63 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 64 which is the initial value for the Variable. The initial value must have 65 a shape specified unless `validate_shape` is set to False. Can also be a 66 callable with no argument that returns the initial value when called. 67 (Note that initializer functions from init_ops.py must first be bound 68 to a shape before being used here.) 69 trainable: If `True`, GradientTapes automatically watch uses of this 70 Variable. 71 caching_device: Optional device string or function describing where the 72 Variable should be cached for reading. Defaults to the Variable's 73 device. If not `None`, caches on another device. Typical use is to 74 cache on the device where the Ops using the Variable reside, to 75 deduplicate copying through `Switch` and other conditional statements. 76 name: Optional name for the variable. Defaults to `'Variable'` and gets 77 uniquified automatically. 78 dtype: If set, initial_value will be converted to the given type. 79 If None, either the datatype will be kept (if initial_value is 80 a Tensor) or float32 will be used (if it is a Python object convertible 81 to a Tensor). 82 constraint: An optional projection function to be applied to the variable 83 after being updated by an `Optimizer` (e.g. used to implement norm 84 constraints or value constraints for layer weights). The function must 85 take as input the unprojected Tensor representing the value of the 86 variable and return the Tensor for the projected value 87 (which must have the same shape). Constraints are not safe to 88 use when doing asynchronous distributed training. 89 add_initializers_to: if not None and not in legacy graph mode, the 90 initializer tensor will be added to this map in addition to adding the 91 assignment to the function. 92 lifted_initializer_graph: FuncGraph to try to lift initializers to. 93 94 Raises: 95 ValueError: If the initial value is not specified, or does not have a 96 shape and `validate_shape` is `True`. 97 RuntimeError: If called outside of a function definition. 98 """ 99 if not ops.inside_function(): 100 # If we've been init_scope()d out of the function definition nothing to do 101 # here; we can't really do the capturing or conditional logic. 102 resource_variable_ops.ResourceVariable.__init__( 103 self, initial_value=initial_value, trainable=trainable, 104 caching_device=caching_device, name=name, dtype=dtype, 105 constraint=constraint) 106 return 107 with ops.init_scope(): 108 self._in_graph_mode = not context.executing_eagerly() 109 if initial_value is None: 110 raise ValueError("initial_value must be specified.") 111 init_from_fn = callable(initial_value) 112 113 if constraint is not None and not callable(constraint): 114 raise ValueError("The `constraint` argument must be a callable.") 115 116 if isinstance(initial_value, trackable.CheckpointInitialValue): 117 self._maybe_initialize_trackable() 118 self._update_uid = initial_value.checkpoint_position.restore_uid 119 initial_value = initial_value.wrapped_value 120 121 if trainable is None: 122 trainable = True 123 self._trainable = trainable 124 self._save_slice_info = None 125 self._initial_value = None 126 self._initializer_op = None 127 self._is_initialized_op = None 128 self._graph_element = None 129 self._cached_value = None 130 # Store the graph key so optimizers know how to only retrieve variables from 131 # this graph. Guaranteed to be the same as the eager graph_key. 132 self._graph_key = ops.get_default_graph()._graph_key # pylint: disable=protected-access 133 with ops.name_scope(name, "Variable", [] 134 if init_from_fn else [initial_value]) as name: 135 # pylint: disable=protected-access 136 with ops.init_scope(): 137 handle_name = ops._name_from_scope_name(name) 138 unique_id = "%s_%d" % (handle_name, ops.uid()) 139 shared_name = context.shared_name(unique_id) 140 with ops.name_scope("Initializer"), ops.device(None): 141 initial_value = ops.convert_to_tensor( 142 initial_value() if init_from_fn else initial_value, 143 name="initial_value", dtype=dtype) 144 with ops.init_scope(): 145 self._handle = resource_variable_ops.eager_safe_variable_handle( 146 initial_value=initial_value, 147 shared_name=shared_name, 148 name=name, 149 graph_mode=self._in_graph_mode) 150 self._shape = initial_value.shape 151 self._unique_id = unique_id 152 self._handle_name = handle_name + ":0" 153 self._dtype = initial_value.dtype.base_dtype 154 self._constraint = constraint 155 assert initial_value is not None 156 if self._in_graph_mode: 157 with ops.init_scope(): 158 outer_graph = ops.get_default_graph() 159 func_graph = ops.get_default_graph() 160 function_placeholders = ( 161 func_graph.inputs + func_graph.internal_captures) 162 placeholder_ops = set( 163 [tensor.op for tensor in function_placeholders]) 164 lifted_initializer = lift_to_graph.lift_to_graph( 165 [initial_value], outer_graph, 166 disallowed_placeholders=placeholder_ops)[initial_value] 167 with ops.init_scope(): 168 self._initial_value = lifted_initializer 169 with ops.name_scope("IsInitialized"): 170 self._is_initialized_op = ( 171 resource_variable_ops.var_is_initialized_op(self._handle)) 172 if initial_value is not None: 173 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 174 self._initializer_op = resource_variable_ops.assign_variable_op( 175 self._handle, lifted_initializer, name=n) 176 with ops.name_scope("Read"), ops.colocate_with(self._handle): 177 # Manually assign reads to the handle's device to avoid log 178 # messages. 179 with ops.device(self._handle.device): 180 value = self._read_variable_op() 181 self._graph_element = value 182 ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES, self) 183 else: 184 if add_initializers_to is not None: 185 add_initializers_to[self] = initial_value 186 def assign_fn(): 187 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 188 resource_variable_ops.assign_variable_op( 189 self._handle, 190 initial_value, 191 name=n) 192 # Returning values to keep tf.cond happy. 193 return ops.convert_to_tensor(1) 194 def not_assign_fn(): 195 return ops.convert_to_tensor(0) 196 # Note: this cond is always guaranteed to run because we're inside a 197 # defun which will insert automatic control dependencies. 198 control_flow_ops.cond( 199 resource_variable_ops.var_is_initialized_op(self._handle), 200 not_assign_fn, assign_fn) 201 202 # After the handle has been created, set up a way to clean it up when 203 # executing eagerly. We'll hold the only reference to the deleter, so that 204 # when this object is garbage collected the deleter will be too. This 205 # means ResourceVariables can be part of reference cycles without those 206 # cycles being uncollectable. 207 if not self._in_graph_mode: 208 self._handle_deleter = resource_variable_ops.EagerResourceDeleter( 209 handle=self._handle, handle_device=self._handle.device) 210 self._cached_shape_as_list = None 211 212 213RUN_FUNCTIONS_EAGERLY = False 214 215 216@tf_export("config.experimental_run_functions_eagerly") 217def run_functions_eagerly(run_eagerly): 218 """Enables / disables eager execution of `tf.function`s. 219 220 After calling `tf.config.experimental_run_functions_eagerly(True)` all 221 invocations of tf.function will run eagerly instead of running through a graph 222 function. 223 224 This can be useful for debugging or profiling. 225 226 Similarly, calling `tf.config.experimental_run_functions_eagerly(False)` will 227 revert the behavior of all functions to graph functions. 228 229 Args: 230 run_eagerly: Boolean. Whether to run functions eagerly. 231 """ 232 global RUN_FUNCTIONS_EAGERLY 233 RUN_FUNCTIONS_EAGERLY = bool(run_eagerly) 234 235 236class FunctionDeleter(object): 237 238 def __init__(self, func_graph): 239 self.func_graph = func_graph 240 241 def __del__(self): 242 try: 243 func_graph_module.dismantle_func_graph(self.func_graph) 244 except: # pylint: disable=bare-except 245 # Note: bare except here because this can be noisy at shutdown time. 246 pass 247 248 249class Function(object): 250 """Wrapper class for the graph functions defined for a Python function. 251 252 See the documentation for `tf.function` for more information on the semantics 253 of defined functions. 254 255 `Function` is thread-compatible. 256 """ 257 258 def __init__(self, 259 python_function, 260 name, 261 input_signature=None, 262 autograph=True, 263 experimental_autograph_options=None): 264 """Initializes a `Function`. 265 266 Args: 267 python_function: the function to be wrapped. 268 name: the name given to it. 269 input_signature: a possibly nested sequence of `TensorSpec` objects 270 specifying the input signature of this function. If `None`, a separate 271 function is instantiated for each inferred input signature. 272 autograph: whether `python_function` should be converted to graph mode. 273 See https://www.tensorflow.org/guide/autograph for more information. 274 experimental_autograph_options: optional tuple of 275 tensorflow.autograph.Feature values. Allows enabling additional 276 conversion options when autograph is set to True. 277 278 Raises: 279 ValueError: if `input_signature` is not None and the `python_function`'s 280 argspec has keyword arguments. 281 """ 282 self._python_function = python_function 283 # TODO(vbardiovsky): Both _stateful_fn and _stateless_fn are populating the 284 # same FunctionSpec. Consider removing it from both and passing in instead. 285 self._function_spec = function_lib.FunctionSpec.from_function_and_signature( 286 python_function, input_signature) 287 self._autograph = autograph 288 self._experimental_autograph_options = experimental_autograph_options 289 self._created_variables = None 290 self._stateful_fn = None 291 self._stateless_fn = None 292 self._descriptor_cache = weakref.WeakKeyDictionary() 293 self._name = name 294 295 def _defun_with_scope(self, scope): 296 """Creates a defun wrapped inside a variable creator scope.""" 297 298 weak_wrapped_fn = None 299 def wrapped_fn(*args, **kwds): 300 """Wraps `self._python_function` in a variable creator scope.""" 301 # We register a variable creator with reduced priority. If an outer 302 # variable creator is just modifying keyword arguments to the variable 303 # constructor, this will work harmoniously. Since the `scope` registered 304 # here actually creates the variable, it taking priority would otherwise 305 # ignore the outer creator. 306 # 307 # If an outer variable creator calls the variable constructor manually, 308 # for example creating a MirroredVariable, then they won't call our 309 # creator. This means we won't be able to trace the initialization graph, 310 # and so variable initializers can't depend on function arguments. This is 311 # better than the alternative, tracing the initialization graph but giving 312 # the user a variable type they didn't want. 313 with ops.get_default_graph()._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access 314 # __wrapped__ allows AutoGraph to swap in a converted function. We give 315 # the function a weak reference to itself to avoid a reference cycle. 316 return weak_wrapped_fn().__wrapped__(*args, **kwds) 317 weak_wrapped_fn = weakref.ref(wrapped_fn) 318 319 # TODO(mdan): Pipe self._experimental_autograph_options through. 320 return function_lib.defun( 321 tf_decorator.make_decorator( 322 self._python_function, 323 wrapped_fn, 324 decorator_argspec=self._function_spec.fullargspec), 325 input_signature=self.input_signature, 326 autograph=self._autograph, 327 experimental_autograph_options=self._experimental_autograph_options) 328 329 def _initialize(self, args, kwds, add_initializers_to=None): 330 """Initializes, on the first call. 331 332 Creates two `Function`s, one that will allow creation of variables 333 and one that won't. 334 335 Additionally runs a trace for the `Function` that allows creation 336 of variables. 337 338 Args: 339 args: Arguments to the underlying python callable. 340 kwds: Keyword arguments to the python callable. 341 add_initializers_to: Where to collect variable initializers, if not None. 342 """ 343 344 created_variables = [] 345 lifted_initializer_graph = func_graph_module.FuncGraph("initializer") 346 347 def variable_capturing_scope(unused_next_creator, **kwds): 348 """Creates UnliftedInitializerVariables and saves references to them.""" 349 v = UnliftedInitializerVariable( 350 add_initializers_to=add_initializers_to, 351 lifted_initializer_graph=lifted_initializer_graph, **kwds) 352 created_variables.append(weakref.ref(v)) 353 return v 354 355 self._created_variables = created_variables 356 self._stateful_fn = self._defun_with_scope(variable_capturing_scope) 357 self._stateful_fn._name = self._name # pylint: disable=protected-access 358 # Force the definition of the function for these arguments 359 self._lifted_initializer_graph = lifted_initializer_graph 360 self._graph_deleter = FunctionDeleter(self._lifted_initializer_graph) 361 self._concrete_stateful_fn = ( 362 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access 363 *args, **kwds)) 364 365 def invalid_creator_scope(*unused_args, **unused_kwds): 366 """Disables variable creation.""" 367 raise ValueError( 368 "tf.function-decorated function tried to create " 369 "variables on non-first call.") 370 371 self._stateless_fn = self._defun_with_scope(invalid_creator_scope) 372 self._stateless_fn._name = self._name # pylint: disable=protected-access 373 374 def _decorate(self, decorator): 375 """Allows the captured Python function to be decorated in place. 376 377 This method is only safe to call when the Function has not been called by a 378 user. It makes sense to use this method to push a decorator into the 379 function rather than wrapping the function in the decorator. 380 381 We use this in tf.Module to allow user annotated `tf.functions` to remain as 382 `Function` objects but still automatically enter the Module name_scope 383 when they are evaluated like all other methods. 384 385 Args: 386 decorator: A callable accepting a single argument which is the function 387 to decorate and returning a callable result. 388 389 Raises: 390 ValueError: If the function has been called a ValueError is raised. 391 """ 392 if self._stateful_fn is not None or self._stateless_fn is not None: 393 raise ValueError( 394 "Functions cannot be decorated after they have been traced.") 395 396 self._python_function = decorator(self._python_function) 397 self._function_spec = function_lib.FunctionSpec.from_function_and_signature( 398 self._python_function, self.input_signature) 399 400 def __call__(self, *args, **kwds): 401 """Calls the graph function.""" 402 if RUN_FUNCTIONS_EAGERLY: 403 return self._python_function(*args, **kwds) 404 if self._created_variables: 405 # In this case we have created variables on the first call, so we run the 406 # defunned version which is guaranteed to never create variables. 407 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable 408 elif self._stateful_fn is not None: 409 # In this case we have not created variables on the first call. So we can 410 # run the first trace but we should fail if variables are created. 411 results = self._stateful_fn(*args, **kwds) 412 if self._created_variables: 413 raise ValueError("Creating variables on a non-first call to a function" 414 " decorated with tf.function.") 415 return results 416 417 # This is the first call of __call__, so we have to initialize. 418 initializer_map = {} 419 self._initialize(args, kwds, add_initializers_to=initializer_map) 420 if self._created_variables: 421 try: 422 # Attempt to initialize variables eagerly and without conds by lifting 423 # out initialization graphs. This is the only initialization strategy 424 # compatible with XLA at the moment. 425 self._initialize_uninitialized_variables(initializer_map) 426 except lift_to_graph.UnliftableError: 427 pass # Fall through to cond-based initialization. 428 else: 429 # Lifting succeeded, so variables are initialized and we can run the 430 # stateless function. 431 return self._stateless_fn(*args, **kwds) 432 else: 433 canon_args, canon_kwds = \ 434 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 435 *args, **kwds) 436 # If we did not create any variables the trace we have is good enough. 437 return self._concrete_stateful_fn._filtered_call(canon_args, canon_kwds) # pylint: disable=protected-access 438 439 def fn_with_cond(*inner_args, **inner_kwds): 440 """Conditionally runs initialization if it's needed.""" 441 condition = True 442 for wr in self._created_variables: 443 variable = wr() 444 if variable is None: 445 raise ValueError( 446 "A tf.Variable created inside your tf.function has been" 447 " garbage-collected. Your code needs to keep Python references" 448 " to variables created inside `tf.function`s.\n" 449 "\n" 450 "A common way to raise this error is to create and return a" 451 " variable only referenced inside your function:\n" 452 "\n" 453 "@tf.function\n" 454 "def f():\n" 455 " v = tf.Variable(1.0)\n" 456 " return v\n" 457 "\n" 458 "v = f() # Crashes with this error message!\n" 459 "\n" 460 "The reason this crashes is that @tf.function annotated" 461 " function returns a **`tf.Tensor`** with the **value** of the" 462 " variable when the function is called rather than the" 463 " variable instance itself. As such there is no code holding a" 464 " reference to the `v` created inside the function and Python" 465 " garbage collects it.\n" 466 "\n" 467 "The simplest way to fix this issue is to create variables" 468 " outside the function and capture them:\n" 469 "\n" 470 "v = tf.Variable(1.0)\n" 471 "\n" 472 "@tf.function\n" 473 "def f():\n" 474 " return v\n" 475 "\n" 476 "f() # <tf.Tensor: ... numpy=1.>\n" 477 "v.assign_add(1.)\n" 478 "f() # <tf.Tensor: ... numpy=2.>") 479 condition = math_ops.logical_and( 480 condition, resource_variable_ops.var_is_initialized_op( 481 variable.handle)) 482 # We want to call stateless_fn if possible because it avoids recomputing 483 # potentially expensive initializers. 484 return control_flow_ops.cond( 485 condition, 486 lambda: self._stateless_fn(*inner_args, **inner_kwds), 487 functools.partial(self._concrete_stateful_fn._filtered_call, # pylint: disable=protected-access 488 inner_args, inner_kwds)) 489 490 # We've created variables and are unable to lift the initialization graphs, 491 # so we fall back to initializing with conds while running the function. 492 canon_args, canon_kwds = \ 493 self._stateful_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 494 *args, **kwds) 495 return function_lib.defun(fn_with_cond)(*canon_args, **canon_kwds) 496 497 @property 498 def python_function(self): 499 """The python function wrapped in this tf.function.""" 500 return self._python_function 501 502 @property 503 def input_signature(self): 504 return self._function_spec.input_signature 505 506 @property 507 def function_spec(self): 508 return self._function_spec 509 510 def _initialize_uninitialized_variables(self, initializer_map): 511 """Make and call a `ConcreteFunction` which initializes variables.""" 512 513 # Note: using defun here avoids an infinite recursion. 514 # Note: there is no reason not to autograph once the overhead is negligible. 515 @function_lib.defun(autograph=False) # tf.function internal, pure graph 516 def initialize_variables(): 517 for v, init in initializer_map.items(): 518 with ops.init_scope(): 519 if resource_variable_ops.var_is_initialized_op(v.handle): 520 # Ignore variables which are already initialized at trace time. 521 continue 522 v.assign(lift_to_graph.lift_to_graph( 523 [init], ops.get_default_graph())[init]) 524 525 with ops.init_scope(): 526 return initialize_variables.get_concrete_function()() 527 528 def get_initialization_function(self, *args, **kwargs): 529 """Returns a `ConcreteFunction` which initializes this function's variables. 530 531 Requires that this function hasn't been accessed yet through either calling 532 it or calling get_concrete_function. Fails if we cannot build an initializer 533 function which does not depend on the concrete values of the inputs to this 534 function. 535 536 Note that running this function will overwrite any values currently assigned 537 to variables, for example restores from a checkpoint. 538 539 Args: 540 *args: arguments to the underlying python callable. 541 **kwargs: keyword arguments to the python callable. 542 543 Returns: 544 A `ConcreteFunction` object which initializes the variables of this 545 function. 546 547 Raises: 548 RuntimeError: if called after the variables have been initialized. 549 """ 550 if self._stateful_fn is not None: 551 raise RuntimeError( 552 "get_initialization_function cannot be called after the function " 553 "has been used") 554 # Here we trace the function, collect the initializers, and attempt to 555 # extract them and run them eagerly. Fail only if we cannot do so. 556 initializer_map = {} 557 self._initialize(args, kwargs, add_initializers_to=initializer_map) 558 559 # Note: using defun here avoids an infinite recursion. 560 @function_lib.defun 561 def initialize_variables(): 562 for v, init in initializer_map.items(): 563 v.assign(lift_to_graph.lift_to_graph( 564 [init], ops.get_default_graph())[init]) 565 566 return initialize_variables.get_concrete_function() 567 568 def _list_all_concrete_functions_for_serialization(self): 569 """Returns all concrete functions for serialization. 570 571 Returns: 572 A list of instances of `Function`. 573 """ 574 if self.input_signature is not None: 575 self.get_concrete_function() 576 concrete_functions = [] 577 # pylint: disable=protected-access 578 if self._stateful_fn: 579 concrete_functions.extend( 580 self._stateful_fn._function_cache.all_values()) 581 if self._stateless_fn: 582 concrete_functions.extend( 583 self._stateless_fn._function_cache.all_values()) 584 # pylint: enable=protected-access 585 deduplicated_concrete_functions = list() 586 seen_signatures = list() 587 # We are using a list so that: 588 # - the returned collection is deterministic, and 589 # - we can use a custom equality operator (is_same_structure). 590 # This is run only at serialization time on likely very small inputs so we 591 # are not concerned about O(n^2) runtime. 592 for concrete_function in concrete_functions: 593 signature, _ = concrete_function.structured_input_signature 594 flattened = nest.flatten(signature) 595 if any( 596 isinstance(arg, func_graph_module.UnknownArgument) 597 for arg in flattened): 598 logging.info("Unsupported signature for serialization: %s.", signature) 599 continue 600 equal_to_signature = functools.partial( 601 function_lib.is_same_structure, signature, check_values=True) 602 if not any(equal_to_signature(s) for s in seen_signatures): 603 deduplicated_concrete_functions.append(concrete_function) 604 seen_signatures.append(signature) 605 return deduplicated_concrete_functions 606 607 def get_concrete_function(self, *args, **kwargs): 608 """Returns a `ConcreteFunction` specialized to inputs and execution context. 609 610 If this `Function` was created with an `input_signature`, `args` and 611 `kwargs` may be omitted. With an input signature there is only one 612 concrete function associated with this `Function`. 613 614 If there is no fixed `input_signature` associated with this 615 `Function`, positional and keyword arguments to `get_concrete_function` 616 follow the same rules as input signature specification, with `tf.TensorSpec` 617 objects describing `tf.Tensor`s which will be passed to the concrete 618 function. 619 620 Each `tf.Tensor` argument to the concrete function must have a unique name, 621 either because it is the only one associated with a named argument of the 622 Python function or because an explicit `name=` was passed to its 623 `tf.TensorSpec` object. These names become the argument names for the 624 concrete function. 625 626 Arguments to the concrete function may always be specified as keyword 627 arguments, naming the Tensor input. Positional arguments may be used instead 628 when each preceding argument to the Python function is a Tensor. 629 630 ```python 631 @tf.function 632 def f(x): 633 return x 634 635 f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64)) 636 f_concrete(tf.constant(1.)) 637 f_concrete(x=tf.constant(1.)) 638 ``` 639 640 Nested structures containing Tensors may be specified when retrieving 641 concrete functions. Structures with multiple Tensors are expanded into 642 multiple arguments of the concrete function. Since multiple concrete 643 function arguments are associated with one argument to the original 644 function, these Tensors must be named explicitly. Tensors in nested 645 structures may not be passed using positional arguments when calling the 646 concrete function. 647 648 ```python 649 f_concrete2 = f.get_concrete_function( 650 (tf.TensorSpec(None, tf.float64, name="first"), 651 tf.TensorSpec([], tf.float32, name="second"))) 652 # Keyword arguments are required when identifying Tensors in nested 653 # structures. 654 f_concrete2(first=tf.constant([1.]), second=tf.constant(0.)) 655 ``` 656 657 Functions with fixed input signatures have only one concrete function 658 associated with them, which can be retrieved without specifying any 659 arguments. As before Tensors must have unique names, either inferred from 660 the argument names in the original Python function or specified 661 explicitly. 662 663 ```python 664 @tf.function(input_signature=(tf.TensorSpec(None, tf.float32))) 665 def f_sig(y): 666 return y 667 668 f_sig_concrete = f.get_concrete_function() 669 f_sig_concrete(tf.constant(1.)) 670 f_sig_concrete(y=tf.constant(1.)) 671 ``` 672 673 Args: 674 *args: inputs to specialize on. 675 **kwargs: inputs to specialize on. 676 677 Returns: 678 A TensorFlow function which takes exactly one `tf.Tensor` per argument. 679 680 Raises: 681 ValueError: if this object has not yet been called on concrete values. 682 """ 683 if self._stateful_fn is None: 684 initializer_map = {} 685 self._initialize(args, kwargs, add_initializers_to=initializer_map) 686 self._initialize_uninitialized_variables(initializer_map) 687 688 if self._created_variables: 689 # In this case we have created variables on the first call, so we run the 690 # defunned version which is guaranteed to never create variables. 691 return self._stateless_fn.get_concrete_function(*args, **kwargs) 692 elif self._stateful_fn is not None: 693 # In this case we have not created variables on the first call. So we can 694 # run the first trace but we should fail if variables are created. 695 concrete = self._stateful_fn.get_concrete_function(*args, **kwargs) 696 if self._created_variables: 697 raise ValueError("Creating variables on a non-first call to a function" 698 " decorated with tf.function.") 699 return concrete 700 701 def __get__(self, instance, owner): 702 """Makes it possible to defun instance methods.""" 703 del owner 704 # `instance` here is the instance that this `Function` was accessed through 705 # e.g., for 706 # 707 # class Foo(object): 708 # 709 # @function.defun 710 # def bar(self): 711 # ... 712 # 713 # foo = Foo() 714 # foo.bar() # `foo.bar` is a `Function` instance 715 # 716 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a 717 # new instance of `Function` here to allow different instances each 718 # to create variables once, thereby allowing methods to be decorated with 719 # tf.function. Keeps a cache to avoid retracing the function every time the 720 # descriptor is accessed. 721 if instance not in self._descriptor_cache: 722 if instance is None: 723 return self 724 self._descriptor_cache[instance] = ( 725 function_lib.class_method_to_instance_method(self, instance)) 726 return self._descriptor_cache[instance] 727 728 729@tf_export("function") 730def function(func=None, 731 input_signature=None, 732 autograph=True, 733 experimental_autograph_options=None): 734 """Creates a callable TensorFlow graph from a Python function. 735 736 `function` constructs a callable that executes a TensorFlow graph 737 (`tf.Graph`) created by tracing the TensorFlow operations in `func`. 738 This allows the TensorFlow runtime to apply optimizations and exploit 739 parallelism in the computation defined by `func`. 740 741 _Example Usage_ 742 743 ```python 744 def f(x, y): 745 return tf.reduce_mean(tf.multiply(x ** 2, 3) + y) 746 747 g = tf.function(f) 748 749 x = tf.constant([[2.0, 3.0]]) 750 y = tf.constant([[3.0, -2.0]]) 751 752 # `f` and `g` will return the same value, but `g` will be executed as a 753 # TensorFlow graph. 754 assert f(x, y).numpy() == g(x, y).numpy() 755 756 # Tensors and tf.Variables used by the Python function are captured in the 757 # graph. 758 @tf.function 759 def h(): 760 return f(x, y) 761 762 assert (h().numpy() == f(x, y).numpy()).all() 763 764 # Data-dependent control flow is also captured in the graph. Supported 765 # control flow statements include `if`, `for`, `break`, `continue`, `return`. 766 @tf.function 767 def g(x): 768 if tf.reduce_sum(x) > 0: 769 return x * x 770 else: 771 return -x // 2 772 773 # print and TensorFlow side effects are supported, but exercise caution when 774 # using Python side effects like mutating objects, saving to files, etc. 775 l = [] 776 777 @tf.function 778 def g(x): 779 for i in x: 780 print(i) # Works 781 tf.assign(v, i) # Works 782 tf.py_func(lambda i: l.append(i))(i) # Works 783 l.append(i) # Caution! Doesn't work. 784 ``` 785 786 Note that unlike other TensorFlow operations, we don't convert python 787 numerical inputs to tensors. 788 789 _Referencing `tf.Variable`s_ 790 791 The Python function `func` may reference stateful objects (such as 792 `tf.Variable`). 793 These are captured as implicit inputs to the callable returned by `function`. 794 For example: 795 796 ```python 797 c = tf.Variable(0) 798 799 @tf.function 800 def f(x): 801 c.assign_add(1) 802 return x + tf.to_float(c) 803 804 assert int(c) == 0 805 assert f(1.0) == 2.0 806 assert int(c) == 1 807 assert f(1.0) == 3.0 808 assert int(c) == 2 809 ``` 810 811 `function` can be applied to methods of an object. For example: 812 813 ```python 814 class Dense(object): 815 def __init__(self): 816 self.W = tf.Variable(tf.glorot_uniform_initializer()((10, 10))) 817 self.b = tf.Variable(tf.zeros(10)) 818 819 @tf.function 820 def compute(self, x): 821 return tf.matmul(x, self.W) + self.b 822 823 d1 = Dense() 824 d2 = Dense() 825 x = tf.random_uniform((10, 10)) 826 # d1 and d2 are using distinct variables 827 assert not (d1.compute(x).numpy() == d2.compute(x).numpy()).all() 828 ``` 829 830 _Usage with `tf.keras`_ 831 832 The `call` methods of a `tf.keras.Model` subclass can be decorated with 833 `function` in order to apply graph execution optimizations on it. 834 For example: 835 836 ```python 837 class MyModel(tf.keras.Model): 838 def __init__(self, keep_probability=0.2): 839 super(MyModel, self).__init__() 840 self.dense1 = tf.keras.layers.Dense(4) 841 self.dense2 = tf.keras.layers.Dense(5) 842 self.keep_probability = keep_probability 843 844 @tf.function 845 def call(self, inputs, training=True): 846 y = self.dense2(self.dense1(inputs)) 847 if training: 848 return tf.nn.dropout(y, self.keep_probability) 849 else: 850 return y 851 852 model = MyModel() 853 model(x, training=True) # executes a graph, with dropout 854 model(x, training=False) # executes a graph, without dropout 855 ``` 856 857 _Input Signatures_ 858 859 `function` instantiates a separate graph for every unique set of input 860 shapes and datatypes. For example, the following code snippet will result 861 in three distinct graphs being traced, as each input has a different 862 shape. 863 864 ```python 865 @tf.function 866 def f(x): return tf.add(x, 1.) 867 868 scalar = tf.constant(1.0) 869 vector = tf.constant([1.0, 1.0]) 870 matrix = tf.constant([[3.0]]) 871 872 f(scalar) 873 f(vector) 874 f(matrix) 875 ``` 876 877 An "input signature" can be optionally provided to `function` to control 878 the graphs traced. The input signature specifies the shape and type of each 879 `Tensor` argument to the function using a `tf.TensorSpec` object. For example, 880 the following code snippet ensures that a single graph is created where the 881 input `Tensor` is required to be a floating point tensor with no restrictions 882 on shape. 883 884 ```python 885 @tf.function(input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 886 def f(x): return tf.add(x, 1.) 887 ``` 888 889 When an `input_signature` is specified, the callable will convert the inputs 890 to the specified TensorSpecs. 891 892 _Tracing and staging_ 893 894 When `autograph` is `True`, all Python code that depends on `Tensor` values is 895 staged into a TensorFlow graph. When `autograph` is `False`, the function is 896 traced and control flow is not allowed to depend on data. 897 898 Note that `function` only stages TensorFlow operations, all Python code that 899 `func` executes and does not depend on data will shape the _construction_ of 900 the graph. 901 For example, consider the following: 902 903 ```python 904 import numpy as np 905 906 def add_noise(): 907 return tf.eye(5) + np.random.randn(5, 5) 908 909 traced = tf.function(add_noise) 910 ``` 911 912 `add_noise()` will return a different output every time it is invoked. 913 However, `traced()` will return the same value every time it is called, 914 since a particular random value generated by the `np.random.randn` call will 915 be inserted in the traced/staged TensorFlow graph as a constant. In this 916 particular example, replacing `np.random.randn(5, 5)` with 917 `tf.random_normal((5, 5))` will result in the same behavior for `add_noise()` 918 and `traced()`. 919 920 _Python Side-Effects_ 921 922 A corollary of the previous discussion on tracing is the following: If a 923 Python function `func` has Python side-effects, then executing `func` multiple 924 times may not be semantically equivalent to executing `F = tf.function(func)` 925 multiple times; this difference is due to the fact that `function` only 926 captures the subgraph of TensorFlow operations that is constructed when `func` 927 is invoked to trace a graph. 928 929 The same is true if code with Python side effects is used inside control flow, 930 such as a loop. If your code uses side effects that are not intended to 931 control graph construction, wrap them inside `tf.py_func`. 932 933 Args: 934 func: function to be compiled. If `func` is None, returns a decorator that 935 can be invoked with a single argument - `func`. The end result is 936 equivalent to providing all the arguments up front. In other words, 937 `tf.function(input_signature=...)(func)` is equivalent to 938 `tf.function(func, input_signature=...)`. The former can be used to 939 decorate Python functions, for example: 940 @tf.function(input_signature=...) 941 def foo(...): ... 942 input_signature: A possibly nested sequence of `tf.TensorSpec` objects 943 specifying the shapes and dtypes of the Tensors that will be supplied to 944 this function. If `None`, a separate function is instantiated for each 945 inferred input signature. If input_signature is specified, every input to 946 `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. 947 autograph: Whether autograph should be applied on `func` before tracing a 948 graph. This allows for dynamic control flow (Python if's, loops etc.) 949 in the traced graph. See https://www.tensorflow.org/guide/autograph for 950 more information. 951 experimental_autograph_options: Experimental knobs (in the form of a tuple 952 of tensorflow.autograph.Feature values) to control behavior when 953 autograph=True. 954 955 Returns: 956 If `func` is not None, returns a callable that will execute the compiled 957 function (and return zero or more `tf.Tensor` objects). 958 If `func` is None, returns a decorator that, when invoked with a single 959 `func` argument, returns a callable equivalent to the case above. 960 961 Raises: 962 TypeError: If `input_signature` is neither `None` nor a sequence of 963 `TensorSpec` objects. 964 """ 965 if input_signature is not None: 966 function_lib.validate_signature(input_signature) 967 968 def decorated(inner_function): 969 try: 970 name = inner_function.__name__ 971 except AttributeError: 972 name = "function" 973 return tf_decorator.make_decorator( 974 inner_function, 975 Function( 976 inner_function, 977 name, 978 input_signature=input_signature, 979 autograph=autograph, 980 experimental_autograph_options=experimental_autograph_options)) 981 982 # This code path is for the `foo = tf.function(foo, ...)` use case 983 if func is not None: 984 return decorated(func) 985 986 # This code path is for the 987 # 988 # @tf.function(...) 989 # def foo(...): 990 # ... 991 # 992 # use case, which is equivalent to `foo = tf.function(...)(foo)` 993 return decorated 994