1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16# pylint: disable=redefined-outer-name
17# pylint: disable=redefined-builtin
18"""Keras backend API.
19"""
20from __future__ import absolute_import
21from __future__ import division
22from __future__ import print_function
23
24import collections
25import itertools
26import json
27import os
28import threading
29import weakref
30
31import numpy as np
32
33from tensorflow.core.protobuf import config_pb2
34from tensorflow.python.client import session as session_module
35from tensorflow.python.distribute import distribute_coordinator as dc
36from tensorflow.python.distribute import distribute_coordinator_context as dc_context
37from tensorflow.python.distribute import distribution_strategy_context
38from tensorflow.python.eager import context
39from tensorflow.python.eager import function as eager_function
40from tensorflow.python.eager import lift_to_graph
41from tensorflow.python.framework import constant_op
42from tensorflow.python.framework import dtypes as dtypes_module
43from tensorflow.python.framework import func_graph
44from tensorflow.python.framework import ops
45from tensorflow.python.framework import sparse_tensor
46from tensorflow.python.framework import tensor_util
47from tensorflow.python.keras import backend_config
48from tensorflow.python.ops import array_ops
49from tensorflow.python.ops import clip_ops
50from tensorflow.python.ops import control_flow_ops
51from tensorflow.python.ops import ctc_ops as ctc
52from tensorflow.python.ops import functional_ops
53from tensorflow.python.ops import gradients as gradients_module
54from tensorflow.python.ops import image_ops
55from tensorflow.python.ops import init_ops
56from tensorflow.python.ops import linalg_ops
57from tensorflow.python.ops import logging_ops
58from tensorflow.python.ops import map_fn as map_fn_lib
59from tensorflow.python.ops import math_ops
60from tensorflow.python.ops import nn
61from tensorflow.python.ops import random_ops
62from tensorflow.python.ops import resource_variable_ops
63from tensorflow.python.ops import sparse_ops
64from tensorflow.python.ops import state_ops
65from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
66from tensorflow.python.ops import tensor_array_ops
67from tensorflow.python.ops import variables as variables_module
68from tensorflow.python.training import server_lib
69from tensorflow.python.util import nest
70from tensorflow.python.util import tf_contextlib
71from tensorflow.python.util import tf_inspect
72from tensorflow.python.util.tf_export import keras_export
73
74py_all = all
75py_sum = sum
76
77# INTERNAL UTILS
78
79# The internal graph maintained by Keras and used by the symbolic Keras APIs
80# while executing eagerly (such as the functional API for model-building).
81_GRAPH = None
82
83# A graph which is used for constructing functions in eager mode.
84_CURRENT_SCRATCH_GRAPH = None
85
86# This is a thread local object that will hold the default internal TF session
87# used by Keras. It can be set manually via `set_session(sess)`.
88_SESSION = threading.local()
89
90# This dictionary holds a mapping {graph: learning_phase}.
91# A learning phase is a bool tensor used to run Keras models in
92# either train mode (learning_phase == 1) or test mode (learning_phase == 0).
93_GRAPH_LEARNING_PHASES = weakref.WeakKeyDictionary()
94
95
96# _DUMMY_EAGER_GRAPH is used as a key in _GRAPH_LEARNING_PHASES.
97# We keep a separate reference to it to make sure it does not get removed from
98# _GRAPH_LEARNING_PHASES.
99_DUMMY_EAGER_GRAPH = threading.local()
100
101# This boolean flag can be set to True to leave variable initialization
102# up to the user.
103# Change its value via `manual_variable_initialization(value)`.
104_MANUAL_VAR_INIT = False
105
106# This list holds the available devices.
107# It is populated when `_get_available_gpus()` is called for the first time.
108# We assume our devices don't change henceforth.
109_LOCAL_DEVICES = None
110
111# This dictionary holds a mapping between a graph and variables to initialize
112# in the graph.
113_GRAPH_VARIABLES = weakref.WeakKeyDictionary()
114
115# This dictionary holds a mapping between a graph and TF optimizers created in
116# the graph.
117_GRAPH_TF_OPTIMIZERS = weakref.WeakKeyDictionary()
118
119# The below functions are kept accessible from backend for compatibility.
120epsilon = backend_config.epsilon
121floatx = backend_config.floatx
122image_data_format = backend_config.image_data_format
123set_epsilon = backend_config.set_epsilon
124set_floatx = backend_config.set_floatx
125set_image_data_format = backend_config.set_image_data_format
126
127
128@keras_export('keras.backend.backend')
129def backend():
130  """Publicly accessible method for determining the current backend.
131
132  Only exists for API compatibility with multi-backend Keras.
133
134  Returns:
135      The string "tensorflow".
136  """
137  return 'tensorflow'
138
139
140@keras_export('keras.backend.cast_to_floatx')
141def cast_to_floatx(x):
142  """Cast a Numpy array to the default Keras float type.
143
144  Arguments:
145      x: Numpy array.
146
147  Returns:
148      The same Numpy array, cast to its new type.
149
150  Example:
151  ```python
152      >>> from keras import backend as K
153      >>> K.floatx()
154      'float32'
155      >>> arr = numpy.array([1.0, 2.0], dtype='float64')
156      >>> arr.dtype
157      dtype('float64')
158      >>> new_arr = K.cast_to_floatx(arr)
159      >>> new_arr
160      array([ 1.,  2.], dtype=float32)
161      >>> new_arr.dtype
162      dtype('float32')
163  ```
164  """
165  return np.asarray(x, dtype=floatx())
166
167
168# A global dictionary mapping graph objects to an index of counters used
169# for various layer names in each graph.
170# Allows to give unique autogenerated names to layers, in a graph-specific way.
171PER_GRAPH_LAYER_NAME_UIDS = weakref.WeakKeyDictionary()
172
173
174@keras_export('keras.backend.get_uid')
175def get_uid(prefix=''):
176  """Associates a string prefix with an integer counter in a TensorFlow graph.
177
178  Arguments:
179    prefix: String prefix to index.
180
181  Returns:
182    Unique integer ID.
183
184  Example:
185
186  ```
187    >>> get_uid('dense')
188    1
189    >>> get_uid('dense')
190    2
191  ```
192  """
193  graph = get_graph()
194  if graph not in PER_GRAPH_LAYER_NAME_UIDS:
195    PER_GRAPH_LAYER_NAME_UIDS[graph] = collections.defaultdict(int)
196  layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS[graph]
197  layer_name_uids[prefix] += 1
198  return layer_name_uids[prefix]
199
200
201@keras_export('keras.backend.reset_uids')
202def reset_uids():
203  """Resets graph identifiers.
204  """
205  per_graph_layer_name_uids = PER_GRAPH_LAYER_NAME_UIDS
206  keys = list(per_graph_layer_name_uids.keys())
207  for key in keys:
208    del per_graph_layer_name_uids[key]
209
210
211@keras_export('keras.backend.clear_session')
212def clear_session():
213  """Destroys the current TF graph and creates a new one.
214
215  Useful to avoid clutter from old models / layers.
216  """
217  global _SESSION
218  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
219  global _GRAPH_VARIABLES  # pylint: disable=global-variable-not-assigned
220  global _GRAPH_TF_OPTIMIZERS  # pylint: disable=global-variable-not-assigned
221  ops.reset_default_graph()
222  reset_uids()
223  _SESSION.session = None
224  graph = get_graph()
225  with graph.as_default():
226    with ops.name_scope(''):
227      phase = array_ops.placeholder_with_default(
228          False, shape=(), name='keras_learning_phase')
229    _GRAPH_LEARNING_PHASES = {}
230    _GRAPH_LEARNING_PHASES[graph] = phase
231    _GRAPH_VARIABLES.pop(graph, None)
232    _GRAPH_TF_OPTIMIZERS.pop(graph, None)
233
234
235@keras_export('keras.backend.manual_variable_initialization')
236def manual_variable_initialization(value):
237  """Sets the manual variable initialization flag.
238
239  This boolean flag determines whether
240  variables should be initialized
241  as they are instantiated (default), or if
242  the user should handle the initialization
243  (e.g. via `tf.initialize_all_variables()`).
244
245  Arguments:
246      value: Python boolean.
247  """
248  global _MANUAL_VAR_INIT
249  _MANUAL_VAR_INIT = value
250
251
252@keras_export('keras.backend.learning_phase')
253def learning_phase():
254  """Returns the learning phase flag.
255
256  The learning phase flag is a bool tensor (0 = test, 1 = train)
257  to be passed as input to any Keras function
258  that uses a different behavior at train time and test time.
259
260  Returns:
261      Learning phase (scalar integer tensor or Python integer).
262  """
263  if ops.get_default_graph() is _GRAPH:
264    # Don't enter an init_scope for the learning phase if eager execution
265    # is enabled but we're inside the Keras workspace graph.
266    return symbolic_learning_phase()
267  with ops.init_scope():
268    # We always check & set the learning phase inside the init_scope,
269    # otherwise the wrong default_graph will be used to look up the learning
270    # phase inside of functions & defuns.
271    #
272    # This is because functions & defuns (both in graph & in eager mode)
273    # will always execute non-eagerly using a function-specific default
274    # subgraph.
275    if context.executing_eagerly():
276      if _DUMMY_EAGER_GRAPH not in _GRAPH_LEARNING_PHASES:
277        # Fallback to inference mode as default.
278        return 0
279      return _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
280    return symbolic_learning_phase()
281
282
283def symbolic_learning_phase():
284  graph = get_graph()
285  with graph.as_default():
286    if graph not in _GRAPH_LEARNING_PHASES:
287      with ops.name_scope(''):
288        phase = array_ops.placeholder_with_default(
289            False, shape=(), name='keras_learning_phase')
290      _GRAPH_LEARNING_PHASES[graph] = phase
291    return _GRAPH_LEARNING_PHASES[graph]
292
293
294@keras_export('keras.backend.set_learning_phase')
295def set_learning_phase(value):
296  """Sets the learning phase to a fixed value.
297
298  Arguments:
299      value: Learning phase value, either 0 or 1 (integers).
300
301  Raises:
302      ValueError: if `value` is neither `0` nor `1`.
303  """
304  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
305  if value not in {0, 1}:
306    raise ValueError('Expected learning phase to be 0 or 1.')
307  with ops.init_scope():
308    if context.executing_eagerly():
309      # In an eager context, the learning phase values applies to both the eager
310      # context and the internal Keras graph.
311      _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
312    _GRAPH_LEARNING_PHASES[get_graph()] = value
313
314
315def set_eager_learning_phase(value):
316  """Internal utility that sets the learning phase in eager execution only.
317
318  Arguments:
319      value: Learning phase value, either 0 or 1 (integers).
320  """
321  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
322  assert value in {0, 1}
323  assert context.executing_eagerly()
324  _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
325
326
327@keras_export('keras.backend.learning_phase_scope')
328@tf_contextlib.contextmanager
329def learning_phase_scope(value):
330  """Provides a scope within which the learning phase is equal to `value`.
331
332  The learning phase gets restored to its original value upon exiting the scope.
333
334  Arguments:
335     value: Learning phase value, either 0 or 1 (integers).
336
337  Yields:
338    None.
339
340  Raises:
341     ValueError: if `value` is neither `0` nor `1`.
342  """
343  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
344  if value not in {0, 1}:
345    raise ValueError('Expected learning phase to be 0 or 1.')
346
347  with ops.init_scope():
348    if context.executing_eagerly():
349      previous_eager_value = _GRAPH_LEARNING_PHASES.get(
350          _DUMMY_EAGER_GRAPH, None)
351    previous_graph_value = _GRAPH_LEARNING_PHASES.get(get_graph(), None)
352
353  try:
354    set_learning_phase(value)
355    yield
356  finally:
357    # Restore learning phase to initial value.
358    with ops.init_scope():
359      if context.executing_eagerly():
360        if previous_eager_value is not None:
361          _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_eager_value
362        elif _DUMMY_EAGER_GRAPH in _GRAPH_LEARNING_PHASES:
363          del _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH]
364
365      graph = get_graph()
366      if previous_graph_value is not None:
367        _GRAPH_LEARNING_PHASES[graph] = previous_graph_value
368      elif graph in _GRAPH_LEARNING_PHASES:
369        del _GRAPH_LEARNING_PHASES[graph]
370
371@tf_contextlib.contextmanager
372def eager_learning_phase_scope(value):
373  """Internal scope that sets the learning phase in eager execution only.
374
375  Arguments:
376      value: Learning phase value, either 0 or 1 (integers).
377
378  Yields:
379    None.
380
381  Raises:
382     ValueError: if `value` is neither `0` nor `1`.
383  """
384  global _GRAPH_LEARNING_PHASES  # pylint: disable=global-variable-not-assigned
385  assert value in {0, 1}
386  assert context.executing_eagerly()
387  previous_value = learning_phase()
388  try:
389    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = value
390    yield
391  finally:
392    # Restore learning phase to initial value.
393    _GRAPH_LEARNING_PHASES[_DUMMY_EAGER_GRAPH] = previous_value
394
395
396def _current_graph(op_input_list):
397  """Return the graph members of `op_input_list`, or the current graph."""
398  return ops._get_graph_from_inputs(op_input_list)
399
400
401def _get_session(op_input_list=()):
402  """Returns the session object for the current thread."""
403  global _SESSION
404  default_session = ops.get_default_session()
405  if default_session is not None:
406    session = default_session
407  else:
408    if ops.inside_function():
409      raise RuntimeError('Cannot get session inside Tensorflow graph function.')
410    # If we don't have a session, or that session does not match the current
411    # graph, create and cache a new session.
412    if (getattr(_SESSION, 'session', None) is None or
413        _SESSION.session.graph is not _current_graph(op_input_list)):
414      # If we are creating the Session inside a tf.distribute.Strategy scope,
415      # we ask the strategy for the right session options to use.
416      if distribution_strategy_context.has_strategy():
417        configure_and_create_distributed_session(
418            distribution_strategy_context.get_strategy())
419      else:
420        _SESSION.session = session_module.Session(
421            config=get_default_session_config())
422    session = _SESSION.session
423  return session
424
425
426@keras_export(v1=['keras.backend.get_session'])
427def get_session(op_input_list=()):
428  """Returns the TF session to be used by the backend.
429
430  If a default TensorFlow session is available, we will return it.
431
432  Else, we will return the global Keras session assuming it matches
433  the current graph.
434
435  If no global Keras session exists at this point:
436  we will create a new global session.
437
438  Note that you can manually set the global session
439  via `K.set_session(sess)`.
440
441  Arguments:
442      op_input_list: An option sequence of tensors or ops, which will be used
443        to determine the current graph. Otherwise the default graph will be
444        used.
445
446  Returns:
447      A TensorFlow session.
448  """
449  session = _get_session(op_input_list)
450  if not _MANUAL_VAR_INIT:
451    with session.graph.as_default():
452      _initialize_variables(session)
453  return session
454
455
456def get_graph():
457  if context.executing_eagerly():
458    global _GRAPH
459    if _GRAPH is None:
460      _GRAPH = func_graph.FuncGraph('keras_graph')
461    return _GRAPH
462  else:
463    return ops.get_default_graph()
464
465
466@tf_contextlib.contextmanager
467def _scratch_graph(graph=None):
468  """Retrieve a shared and temporary func graph.
469
470  The eager execution path lifts a subgraph from the keras global graph into
471  a scratch graph in order to create a function. DistributionStrategies, in
472  turn, constructs multiple functions as well as a final combined function. In
473  order for that logic to work correctly, all of the functions need to be
474  created on the same scratch FuncGraph.
475
476  Args:
477    graph: A graph to be used as the current scratch graph. If not set then
478      a scratch graph will either be retrieved or created:
479
480  Yields:
481    The current scratch graph.
482  """
483  global _CURRENT_SCRATCH_GRAPH
484  if (_CURRENT_SCRATCH_GRAPH is not None and graph is not None and
485      _CURRENT_SCRATCH_GRAPH is not graph):
486    raise ValueError('Multiple scratch graphs specified.')
487
488  if _CURRENT_SCRATCH_GRAPH:
489    yield _CURRENT_SCRATCH_GRAPH
490    return
491
492  graph = graph or func_graph.FuncGraph('keras_scratch_graph')
493  try:
494    _CURRENT_SCRATCH_GRAPH = graph
495    yield graph
496  finally:
497    _CURRENT_SCRATCH_GRAPH = None
498
499
500@keras_export('keras.backend.set_session')
501def set_session(session):
502  """Sets the global TensorFlow session.
503
504  Arguments:
505      session: A TF Session.
506  """
507  global _SESSION
508  _SESSION.session = session
509
510
511def get_default_session_config():
512  if not os.environ.get('OMP_NUM_THREADS'):
513    config = config_pb2.ConfigProto(allow_soft_placement=True)
514  else:
515    num_thread = int(os.environ.get('OMP_NUM_THREADS'))
516    config = config_pb2.ConfigProto(
517        intra_op_parallelism_threads=num_thread,
518        inter_op_parallelism_threads=num_thread,
519        allow_soft_placement=True)
520  return config
521
522
523# DEVICE MANIPULATION
524
525
526class _TfDeviceCaptureOp(object):
527  """Class for capturing the TF device scope."""
528
529  def __init__(self):
530    self.device = None
531
532  def _set_device(self, device):
533    """This method captures TF's explicit device scope setting."""
534    self.device = device
535
536
537def _get_current_tf_device():
538  """Return explicit device of current context, otherwise returns `None`.
539
540  Returns:
541      If the current device scope is explicitly set, it returns a string with
542      the device (`CPU` or `GPU`). If the scope is not explicitly set, it will
543      return `None`.
544  """
545  graph = get_graph()
546  op = _TfDeviceCaptureOp()
547  graph._apply_device_functions(op)
548  return op.device
549
550
551def _is_current_explicit_device(device_type):
552  """Check if the current device is explicitly set on the device type specified.
553
554  Arguments:
555      device_type: A string containing `GPU` or `CPU` (case-insensitive).
556
557  Returns:
558      A boolean indicating if the current device scope is explicitly set on the
559      device type.
560
561  Raises:
562      ValueError: If the `device_type` string indicates an unsupported device.
563  """
564  device_type = device_type.upper()
565  if device_type not in ['CPU', 'GPU']:
566    raise ValueError('`device_type` should be either "CPU" or "GPU".')
567  device = _get_current_tf_device()
568  return device is not None and device.device_type == device_type.upper()
569
570
571def _get_available_gpus():
572  """Get a list of available gpu devices (formatted as strings).
573
574  Returns:
575      A list of available GPU devices.
576  """
577  if ops.executing_eagerly_outside_functions():
578    # Returns names of devices directly.
579    return [name for name in context.list_devices() if 'GPU' in name]
580
581  global _LOCAL_DEVICES
582  if _LOCAL_DEVICES is None:
583    _LOCAL_DEVICES = get_session().list_devices()
584  return [x.name for x in _LOCAL_DEVICES if x.device_type == 'GPU']
585
586
587def _has_nchw_support():
588  """Check whether the current scope supports NCHW ops.
589
590  TensorFlow does not support NCHW on CPU. Therefore we check if we are not
591  explicitly put on
592  CPU, and have GPUs available. In this case there will be soft-placing on the
593  GPU device.
594
595  Returns:
596      bool: if the current scope device placement would support nchw
597  """
598  explicitly_on_cpu = _is_current_explicit_device('CPU')
599  gpus_available = bool(_get_available_gpus())
600  return not explicitly_on_cpu and gpus_available
601
602
603# VARIABLE MANIPULATION
604
605
606def _to_tensor(x, dtype):
607  """Convert the input `x` to a tensor of type `dtype`.
608
609  Arguments:
610      x: An object to be converted (numpy array, list, tensors).
611      dtype: The destination type.
612
613  Returns:
614      A tensor.
615  """
616  return ops.convert_to_tensor(x, dtype=dtype)
617
618
619@keras_export('keras.backend.is_sparse')
620def is_sparse(tensor):
621  """Returns whether a tensor is a sparse tensor.
622
623  Arguments:
624      tensor: A tensor instance.
625
626  Returns:
627      A boolean.
628
629  Example:
630  ```python
631      >>> from keras import backend as K
632      >>> a = K.placeholder((2, 2), sparse=False)
633      >>> print(K.is_sparse(a))
634      False
635      >>> b = K.placeholder((2, 2), sparse=True)
636      >>> print(K.is_sparse(b))
637      True
638  ```
639  """
640  return isinstance(tensor, sparse_tensor.SparseTensor)
641
642
643@keras_export('keras.backend.to_dense')
644def to_dense(tensor):
645  """Converts a sparse tensor into a dense tensor and returns it.
646
647  Arguments:
648      tensor: A tensor instance (potentially sparse).
649
650  Returns:
651      A dense tensor.
652
653  Examples:
654  ```python
655      >>> from keras import backend as K
656      >>> b = K.placeholder((2, 2), sparse=True)
657      >>> print(K.is_sparse(b))
658      True
659      >>> c = K.to_dense(b)
660      >>> print(K.is_sparse(c))
661      False
662  ```
663  """
664  if is_sparse(tensor):
665    return sparse_ops.sparse_tensor_to_dense(tensor)
666  else:
667    return tensor
668
669
670name_scope = ops.name_scope
671
672
673@keras_export('keras.backend.variable')
674def variable(value, dtype=None, name=None, constraint=None):
675  """Instantiates a variable and returns it.
676
677  Arguments:
678      value: Numpy array, initial value of the tensor.
679      dtype: Tensor type.
680      name: Optional name string for the tensor.
681      constraint: Optional projection function to be
682          applied to the variable after an optimizer update.
683
684  Returns:
685      A variable instance (with Keras metadata included).
686
687  Examples:
688  ```python
689      >>> import numpy as np
690      >>> from keras import backend as K
691      >>> val = np.array([[1, 2], [3, 4]])
692      >>> kvar = K.variable(value=val, dtype='float64', name='example_var')
693      >>> K.dtype(kvar)
694      'float64'
695      >>> print(kvar)
696      example_var
697      >>> kvar.eval()
698      array([[ 1.,  2.],
699             [ 3.,  4.]])
700  ```
701  """
702  if dtype is None:
703    dtype = floatx()
704  if hasattr(value, 'tocoo'):
705    sparse_coo = value.tocoo()
706    indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(
707        sparse_coo.col, 1)), 1)
708    v = sparse_tensor.SparseTensor(
709        indices=indices, values=sparse_coo.data, dense_shape=sparse_coo.shape)
710    v._keras_shape = sparse_coo.shape
711    return v
712  v = resource_variable_ops.ResourceVariable(
713      value,
714      dtype=dtypes_module.as_dtype(dtype),
715      name=name,
716      constraint=constraint)
717  if isinstance(value, np.ndarray):
718    v._keras_shape = value.shape
719  elif hasattr(value, 'shape'):
720    v._keras_shape = int_shape(value)
721  track_variable(v)
722  return v
723
724
725def track_tf_optimizer(tf_optimizer):
726  """Tracks the given TF optimizer for initialization of its variables."""
727  if context.executing_eagerly():
728    return
729  graph = get_graph()
730  optimizers = _GRAPH_TF_OPTIMIZERS.setdefault(graph, weakref.WeakSet())
731  optimizers.add(tf_optimizer)
732
733
734def track_variable(v):
735  """Tracks the given variable for initialization."""
736  if context.executing_eagerly():
737    return
738  graph = v.graph if hasattr(v, 'graph') else get_graph()
739  if graph not in _GRAPH_VARIABLES:
740    _GRAPH_VARIABLES[graph] = weakref.WeakSet()
741  _GRAPH_VARIABLES[graph].add(v)
742
743
744def _get_variables(graph=None):
745  """Returns variables corresponding to the given graph for initialization."""
746  assert not context.executing_eagerly()
747  variables = _GRAPH_VARIABLES.setdefault(graph, weakref.WeakSet())
748  for opt in _GRAPH_TF_OPTIMIZERS.get(graph, set()):
749    variables.update(opt.optimizer.variables())
750  return variables
751
752
753def _initialize_variables(session):
754  """Utility to initialize uninitialized variables on the fly."""
755  variables = _get_variables(get_graph())
756  candidate_vars = []
757  for v in variables:
758    if not getattr(v, '_keras_initialized', False):
759      candidate_vars.append(v)
760  if candidate_vars:
761    # This step is expensive, so we only run it on variables not already
762    # marked as initialized.
763    is_initialized = session.run(
764        [variables_module.is_variable_initialized(v) for v in candidate_vars])
765    uninitialized_vars = []
766    for flag, v in zip(is_initialized, candidate_vars):
767      if not flag:
768        uninitialized_vars.append(v)
769      v._keras_initialized = True
770    if uninitialized_vars:
771      session.run(variables_module.variables_initializer(uninitialized_vars))
772
773
774@keras_export('keras.backend.constant')
775def constant(value, dtype=None, shape=None, name=None):
776  """Creates a constant tensor.
777
778  Arguments:
779      value: A constant value (or list)
780      dtype: The type of the elements of the resulting tensor.
781      shape: Optional dimensions of resulting tensor.
782      name: Optional name for the tensor.
783
784  Returns:
785      A Constant Tensor.
786  """
787  if dtype is None:
788    dtype = floatx()
789
790  # If the outer context is eager but we are executing under the keras
791  # FuncGraph, we create EagerTensors and use them as constants.
792  if (ops.executing_eagerly_outside_functions() and
793      getattr(get_graph(), 'name', '') == 'keras_graph'):
794    with ops.init_scope():
795      return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
796
797  return constant_op.constant(value, dtype=dtype, shape=shape, name=name)
798
799
800def is_keras_tensor(x):
801  """Returns whether `x` is a Keras tensor.
802
803  A "Keras tensor" is a tensor that was returned by a Keras layer,
804  (`Layer` class) or by `Input`.
805
806  Arguments:
807      x: A candidate tensor.
808
809  Returns:
810      A boolean: Whether the argument is a Keras tensor.
811
812  Raises:
813      ValueError: In case `x` is not a symbolic tensor.
814
815  Examples:
816  ```python
817      >>> import tensorflow as tf
818      >>> import numpy
819      >>> from keras import backend as K
820      >>> from keras.layers import Input, Dense
821      >>> np_var = numpy.array([1, 2])
822      >>> K.is_keras_tensor(np_var) # A numpy array is not a symbolic tensor.
823      ValueError
824      >>> k_var = tf.placeholder('float32', shape=(1,1))
825      >>> K.is_keras_tensor(k_var) # A variable indirectly created outside of
826      keras is not a Keras tensor.
827      False
828      >>> keras_var = K.variable(np_var)
829      >>> K.is_keras_tensor(keras_var)  # A variable created with the keras
830      backend is not a Keras tensor.
831      False
832      >>> keras_placeholder = K.placeholder(shape=(2, 4, 5))
833      >>> K.is_keras_tensor(keras_placeholder)  # A placeholder is not a Keras
834      tensor.
835      False
836      >>> keras_input = Input([10])
837      >>> K.is_keras_tensor(keras_input) # An Input is a Keras tensor.
838      True
839      >>> keras_layer_output = Dense(10)(keras_input)
840      >>> K.is_keras_tensor(keras_layer_output) # Any Keras layer output is a
841      Keras tensor.
842      True
843  ```
844  """
845  if not isinstance(x, (ops.Tensor,
846                        variables_module.Variable,
847                        sparse_tensor.SparseTensor)):
848    raise ValueError('Unexpectedly found an instance of type `' + str(type(x)) +
849                     '`. Expected a symbolic tensor instance.')
850  return hasattr(x, '_keras_history')
851
852
853@keras_export('keras.backend.placeholder')
854def placeholder(shape=None, ndim=None, dtype=None, sparse=False, name=None):
855  """Instantiates a placeholder tensor and returns it.
856
857  Arguments:
858      shape: Shape of the placeholder
859          (integer tuple, may include `None` entries).
860      ndim: Number of axes of the tensor.
861          At least one of {`shape`, `ndim`} must be specified.
862          If both are specified, `shape` is used.
863      dtype: Placeholder type.
864      sparse: Boolean, whether the placeholder should have a sparse type.
865      name: Optional name string for the placeholder.
866
867  Raises:
868      ValueError: If called with eager execution.
869
870  Returns:
871      Tensor instance (with Keras metadata included).
872
873  Examples:
874  ```python
875      >>> from keras import backend as K
876      >>> input_ph = K.placeholder(shape=(2, 4, 5))
877      >>> input_ph
878      <tf.Tensor 'Placeholder_4:0' shape=(2, 4, 5) dtype=float32>
879  ```
880  """
881  if dtype is None:
882    dtype = floatx()
883  if not shape:
884    if ndim:
885      shape = tuple([None for _ in range(ndim)])
886  with get_graph().as_default():
887    if sparse:
888      x = array_ops.sparse_placeholder(dtype, shape=shape, name=name)
889    else:
890      x = array_ops.placeholder(dtype, shape=shape, name=name)
891  return x
892
893
894def is_placeholder(x):
895  """Returns whether `x` is a placeholder.
896
897  Arguments:
898      x: A candidate placeholder.
899
900  Returns:
901      Boolean.
902  """
903  try:
904    return x.op.type == 'Placeholder'
905  except AttributeError:
906    return False
907
908
909@keras_export('keras.backend.shape')
910def shape(x):
911  """Returns the symbolic shape of a tensor or variable.
912
913  Arguments:
914      x: A tensor or variable.
915
916  Returns:
917      A symbolic shape (which is itself a tensor).
918
919  Examples:
920
921  ```python
922      # TensorFlow example
923      >>> from keras import backend as K
924      >>> tf_session = K.get_session()
925      >>> val = np.array([[1, 2], [3, 4]])
926      >>> kvar = K.variable(value=val)
927      >>> input = keras.backend.placeholder(shape=(2, 4, 5))
928      >>> K.shape(kvar)
929      <tf.Tensor 'Shape_8:0' shape=(2,) dtype=int32>
930      >>> K.shape(input)
931      <tf.Tensor 'Shape_9:0' shape=(3,) dtype=int32>
932      # To get integer shape (Instead, you can use K.int_shape(x))
933      >>> K.shape(kvar).eval(session=tf_session)
934      array([2, 2], dtype=int32)
935      >>> K.shape(input).eval(session=tf_session)
936      array([2, 4, 5], dtype=int32)
937  ```
938  """
939  return array_ops.shape(x)
940
941
942@keras_export('keras.backend.int_shape')
943def int_shape(x):
944  """Returns the shape of tensor or variable as a tuple of int or None entries.
945
946  Arguments:
947      x: Tensor or variable.
948
949  Returns:
950      A tuple of integers (or None entries).
951
952  Examples:
953  ```python
954      >>> from keras import backend as K
955      >>> input = K.placeholder(shape=(2, 4, 5))
956      >>> K.int_shape(input)
957      (2, 4, 5)
958      >>> val = np.array([[1, 2], [3, 4]])
959      >>> kvar = K.variable(value=val)
960      >>> K.int_shape(kvar)
961      (2, 2)
962  ```
963  """
964  try:
965    shape = x.shape
966    if not isinstance(shape, tuple):
967      shape = tuple(shape.as_list())
968    return shape
969  except ValueError:
970    return None
971
972
973@keras_export('keras.backend.ndim')
974def ndim(x):
975  """Returns the number of axes in a tensor, as an integer.
976
977  Arguments:
978      x: Tensor or variable.
979
980  Returns:
981      Integer (scalar), number of axes.
982
983  Examples:
984  ```python
985      >>> from keras import backend as K
986      >>> input = K.placeholder(shape=(2, 4, 5))
987      >>> val = np.array([[1, 2], [3, 4]])
988      >>> kvar = K.variable(value=val)
989      >>> K.ndim(input)
990      3
991      >>> K.ndim(kvar)
992      2
993  ```
994  """
995  dims = x.shape._dims
996  if dims is not None:
997    return len(dims)
998  return None
999
1000
1001@keras_export('keras.backend.dtype')
1002def dtype(x):
1003  """Returns the dtype of a Keras tensor or variable, as a string.
1004
1005  Arguments:
1006      x: Tensor or variable.
1007
1008  Returns:
1009      String, dtype of `x`.
1010
1011  Examples:
1012  ```python
1013      >>> from keras import backend as K
1014      >>> K.dtype(K.placeholder(shape=(2,4,5)))
1015      'float32'
1016      >>> K.dtype(K.placeholder(shape=(2,4,5), dtype='float32'))
1017      'float32'
1018      >>> K.dtype(K.placeholder(shape=(2,4,5), dtype='float64'))
1019      'float64'
1020      # Keras variable
1021      >>> kvar = K.variable(np.array([[1, 2], [3, 4]]))
1022      >>> K.dtype(kvar)
1023      'float32'
1024      >>> kvar = K.variable(np.array([[1, 2], [3, 4]]), dtype='float32')
1025      >>> K.dtype(kvar)
1026      'float32'
1027  ```
1028  """
1029  return x.dtype.base_dtype.name
1030
1031
1032@keras_export('keras.backend.eval')
1033def eval(x):
1034  """Evaluates the value of a variable.
1035
1036  Arguments:
1037      x: A variable.
1038
1039  Returns:
1040      A Numpy array.
1041
1042  Examples:
1043  ```python
1044      >>> from keras import backend as K
1045      >>> kvar = K.variable(np.array([[1, 2], [3, 4]]), dtype='float32')
1046      >>> K.eval(kvar)
1047      array([[ 1.,  2.],
1048             [ 3.,  4.]], dtype=float32)
1049  ```
1050  """
1051  return get_value(to_dense(x))
1052
1053
1054@keras_export('keras.backend.zeros')
1055def zeros(shape, dtype=None, name=None):
1056  """Instantiates an all-zeros variable and returns it.
1057
1058  Arguments:
1059      shape: Tuple of integers, shape of returned Keras variable
1060      dtype: String, data type of returned Keras variable
1061      name: String, name of returned Keras variable
1062
1063  Returns:
1064      A variable (including Keras metadata), filled with `0.0`.
1065      Note that if `shape` was symbolic, we cannot return a variable,
1066      and will return a dynamically-shaped tensor instead.
1067
1068  Example:
1069  ```python
1070      >>> from keras import backend as K
1071      >>> kvar = K.zeros((3,4))
1072      >>> K.eval(kvar)
1073      array([[ 0.,  0.,  0.,  0.],
1074             [ 0.,  0.,  0.,  0.],
1075             [ 0.,  0.,  0.,  0.]], dtype=float32)
1076  ```
1077  """
1078  with ops.init_scope():
1079    if dtype is None:
1080      dtype = floatx()
1081    tf_dtype = dtypes_module.as_dtype(dtype)
1082    v = array_ops.zeros(shape=shape, dtype=tf_dtype, name=name)
1083    if py_all(v.shape.as_list()):
1084      return variable(v, dtype=dtype, name=name)
1085    track_variable(v)
1086    return v
1087
1088
1089@keras_export('keras.backend.ones')
1090def ones(shape, dtype=None, name=None):
1091  """Instantiates an all-ones variable and returns it.
1092
1093  Arguments:
1094      shape: Tuple of integers, shape of returned Keras variable.
1095      dtype: String, data type of returned Keras variable.
1096      name: String, name of returned Keras variable.
1097
1098  Returns:
1099      A Keras variable, filled with `1.0`.
1100      Note that if `shape` was symbolic, we cannot return a variable,
1101      and will return a dynamically-shaped tensor instead.
1102
1103  Example:
1104  ```python
1105      >>> from keras import backend as K
1106      >>> kvar = K.ones((3,4))
1107      >>> K.eval(kvar)
1108      array([[ 1.,  1.,  1.,  1.],
1109             [ 1.,  1.,  1.,  1.],
1110             [ 1.,  1.,  1.,  1.]], dtype=float32)
1111  ```
1112  """
1113  with ops.init_scope():
1114    if dtype is None:
1115      dtype = floatx()
1116    tf_dtype = dtypes_module.as_dtype(dtype)
1117    v = array_ops.ones(shape=shape, dtype=tf_dtype, name=name)
1118    if py_all(v.shape.as_list()):
1119      return variable(v, dtype=dtype, name=name)
1120    track_variable(v)
1121    return v
1122
1123
1124@keras_export('keras.backend.eye')
1125def eye(size, dtype=None, name=None):
1126  """Instantiate an identity matrix and returns it.
1127
1128  Arguments:
1129      size: Integer, number of rows/columns.
1130      dtype: String, data type of returned Keras variable.
1131      name: String, name of returned Keras variable.
1132
1133  Returns:
1134      A Keras variable, an identity matrix.
1135
1136  Example:
1137  ```python
1138      >>> from keras import backend as K
1139      >>> kvar = K.eye(3)
1140      >>> K.eval(kvar)
1141      array([[ 1.,  0.,  0.],
1142             [ 0.,  1.,  0.],
1143             [ 0.,  0.,  1.]], dtype=float32)
1144  ```
1145
1146  """
1147  if dtype is None:
1148    dtype = floatx()
1149  tf_dtype = dtypes_module.as_dtype(dtype)
1150  return variable(linalg_ops.eye(size, dtype=tf_dtype), dtype, name)
1151
1152
1153@keras_export('keras.backend.zeros_like')
1154def zeros_like(x, dtype=None, name=None):
1155  """Instantiates an all-zeros variable of the same shape as another tensor.
1156
1157  Arguments:
1158      x: Keras variable or Keras tensor.
1159      dtype: String, dtype of returned Keras variable.
1160           None uses the dtype of x.
1161      name: String, name for the variable to create.
1162
1163  Returns:
1164      A Keras variable with the shape of x filled with zeros.
1165
1166  Example:
1167  ```python
1168      >>> from keras import backend as K
1169      >>> kvar = K.variable(np.random.random((2,3)))
1170      >>> kvar_zeros = K.zeros_like(kvar)
1171      >>> K.eval(kvar_zeros)
1172      array([[ 0.,  0.,  0.],
1173             [ 0.,  0.,  0.]], dtype=float32)
1174  ```
1175  """
1176  return array_ops.zeros_like(x, dtype=dtype, name=name)
1177
1178
1179@keras_export('keras.backend.ones_like')
1180def ones_like(x, dtype=None, name=None):
1181  """Instantiates an all-ones variable of the same shape as another tensor.
1182
1183  Arguments:
1184      x: Keras variable or tensor.
1185      dtype: String, dtype of returned Keras variable.
1186           None uses the dtype of x.
1187      name: String, name for the variable to create.
1188
1189  Returns:
1190      A Keras variable with the shape of x filled with ones.
1191
1192  Example:
1193  ```python
1194      >>> from keras import backend as K
1195      >>> kvar = K.variable(np.random.random((2,3)))
1196      >>> kvar_ones = K.ones_like(kvar)
1197      >>> K.eval(kvar_ones)
1198      array([[ 1.,  1.,  1.],
1199             [ 1.,  1.,  1.]], dtype=float32)
1200  ```
1201  """
1202  return array_ops.ones_like(x, dtype=dtype, name=name)
1203
1204
1205def identity(x, name=None):
1206  """Returns a tensor with the same content as the input tensor.
1207
1208  Arguments:
1209      x: The input tensor.
1210      name: String, name for the variable to create.
1211
1212  Returns:
1213      A tensor of the same shape, type and content.
1214  """
1215  return array_ops.identity(x, name=name)
1216
1217
1218@keras_export('keras.backend.random_uniform_variable')
1219def random_uniform_variable(shape, low, high, dtype=None, name=None, seed=None):
1220  """Instantiates a variable with values drawn from a uniform distribution.
1221
1222  Arguments:
1223      shape: Tuple of integers, shape of returned Keras variable.
1224      low: Float, lower boundary of the output interval.
1225      high: Float, upper boundary of the output interval.
1226      dtype: String, dtype of returned Keras variable.
1227      name: String, name of returned Keras variable.
1228      seed: Integer, random seed.
1229
1230  Returns:
1231      A Keras variable, filled with drawn samples.
1232
1233  Example:
1234  ```python
1235      # TensorFlow example
1236      >>> kvar = K.random_uniform_variable((2,3), 0, 1)
1237      >>> kvar
1238      <tensorflow.python.ops.variables.Variable object at 0x10ab40b10>
1239      >>> K.eval(kvar)
1240      array([[ 0.10940075,  0.10047495,  0.476143  ],
1241             [ 0.66137183,  0.00869417,  0.89220798]], dtype=float32)
1242  ```
1243  """
1244  if dtype is None:
1245    dtype = floatx()
1246  tf_dtype = dtypes_module.as_dtype(dtype)
1247  if seed is None:
1248    # ensure that randomness is conditioned by the Numpy RNG
1249    seed = np.random.randint(10e8)
1250  value = init_ops.random_uniform_initializer(
1251      low, high, dtype=tf_dtype, seed=seed)(shape)
1252  return variable(value, dtype=dtype, name=name)
1253
1254
1255@keras_export('keras.backend.random_normal_variable')
1256def random_normal_variable(shape, mean, scale, dtype=None, name=None,
1257                           seed=None):
1258  """Instantiates a variable with values drawn from a normal distribution.
1259
1260  Arguments:
1261      shape: Tuple of integers, shape of returned Keras variable.
1262      mean: Float, mean of the normal distribution.
1263      scale: Float, standard deviation of the normal distribution.
1264      dtype: String, dtype of returned Keras variable.
1265      name: String, name of returned Keras variable.
1266      seed: Integer, random seed.
1267
1268  Returns:
1269      A Keras variable, filled with drawn samples.
1270
1271  Example:
1272  ```python
1273      # TensorFlow example
1274      >>> kvar = K.random_normal_variable((2,3), 0, 1)
1275      >>> kvar
1276      <tensorflow.python.ops.variables.Variable object at 0x10ab12dd0>
1277      >>> K.eval(kvar)
1278      array([[ 1.19591331,  0.68685907, -0.63814116],
1279             [ 0.92629528,  0.28055015,  1.70484698]], dtype=float32)
1280  ```
1281  """
1282  if dtype is None:
1283    dtype = floatx()
1284  tf_dtype = dtypes_module.as_dtype(dtype)
1285  if seed is None:
1286    # ensure that randomness is conditioned by the Numpy RNG
1287    seed = np.random.randint(10e8)
1288  value = init_ops.random_normal_initializer(
1289      mean, scale, dtype=tf_dtype, seed=seed)(shape)
1290  return variable(value, dtype=dtype, name=name)
1291
1292
1293@keras_export('keras.backend.count_params')
1294def count_params(x):
1295  """Returns the static number of elements in a variable or tensor.
1296
1297  Arguments:
1298      x: Variable or tensor.
1299
1300  Returns:
1301      Integer, the number of scalars in `x`.
1302
1303  Example:
1304  ```python
1305      >>> kvar = K.zeros((2,3))
1306      >>> K.count_params(kvar)
1307      6
1308      >>> K.eval(kvar)
1309      array([[ 0.,  0.,  0.],
1310             [ 0.,  0.,  0.]], dtype=float32)
1311  ```
1312  """
1313  return np.prod(x.shape.as_list())
1314
1315
1316@keras_export('keras.backend.cast')
1317def cast(x, dtype):
1318  """Casts a tensor to a different dtype and returns it.
1319
1320  You can cast a Keras variable but it still returns a Keras tensor.
1321
1322  Arguments:
1323      x: Keras tensor (or variable).
1324      dtype: String, either (`'float16'`, `'float32'`, or `'float64'`).
1325
1326  Returns:
1327      Keras tensor with dtype `dtype`.
1328
1329  Example:
1330  ```python
1331      >>> from keras import backend as K
1332      >>> input = K.placeholder((2, 3), dtype='float32')
1333      >>> input
1334      <tf.Tensor 'Placeholder_2:0' shape=(2, 3) dtype=float32>
1335      # It doesn't work in-place as below.
1336      >>> K.cast(input, dtype='float16')
1337      <tf.Tensor 'Cast_1:0' shape=(2, 3) dtype=float16>
1338      >>> input
1339      <tf.Tensor 'Placeholder_2:0' shape=(2, 3) dtype=float32>
1340      # you need to assign it.
1341      >>> input = K.cast(input, dtype='float16')
1342      >>> input
1343      <tf.Tensor 'Cast_2:0' shape=(2, 3) dtype=float16>
1344  ```
1345  """
1346  return math_ops.cast(x, dtype)
1347
1348
1349# UPDATES OPS
1350
1351
1352@keras_export('keras.backend.update')
1353def update(x, new_x):
1354  return state_ops.assign(x, new_x)
1355
1356
1357@keras_export('keras.backend.update_add')
1358def update_add(x, increment):
1359  """Update the value of `x` by adding `increment`.
1360
1361  Arguments:
1362      x: A Variable.
1363      increment: A tensor of same shape as `x`.
1364
1365  Returns:
1366      The variable `x` updated.
1367  """
1368  return state_ops.assign_add(x, increment)
1369
1370
1371@keras_export('keras.backend.update_sub')
1372def update_sub(x, decrement):
1373  """Update the value of `x` by subtracting `decrement`.
1374
1375  Arguments:
1376      x: A Variable.
1377      decrement: A tensor of same shape as `x`.
1378
1379  Returns:
1380      The variable `x` updated.
1381  """
1382  return state_ops.assign_sub(x, decrement)
1383
1384
1385@keras_export('keras.backend.moving_average_update')
1386def moving_average_update(x, value, momentum):
1387  """Compute the moving average of a variable.
1388
1389  Arguments:
1390      x: A Variable.
1391      value: A tensor with the same shape as `variable`.
1392      momentum: The moving average momentum.
1393
1394  Returns:
1395      An Operation to update the variable.
1396  """
1397  # `training` is higher-up than the Keras backend in the abstraction hierarchy.
1398  # In particular, `training` depends on layers, and thus on Keras.
1399  # moving_averages, being low-level ops, should not be part of the training
1400  # module.
1401  from tensorflow.python.training import moving_averages  # pylint: disable=g-import-not-at-top
1402  return moving_averages.assign_moving_average(
1403      x, value, momentum, zero_debias=True)
1404
1405
1406# LINEAR ALGEBRA
1407
1408
1409@keras_export('keras.backend.dot')
1410def dot(x, y):
1411  """Multiplies 2 tensors (and/or variables) and returns a *tensor*.
1412
1413  When attempting to multiply a nD tensor
1414  with a nD tensor, it reproduces the Theano behavior.
1415  (e.g. `(2, 3) * (4, 3, 5) -> (2, 4, 5)`)
1416
1417  Arguments:
1418      x: Tensor or variable.
1419      y: Tensor or variable.
1420
1421  Returns:
1422      A tensor, dot product of `x` and `y`.
1423
1424  Examples:
1425  ```python
1426      # dot product between tensors
1427      >>> x = K.placeholder(shape=(2, 3))
1428      >>> y = K.placeholder(shape=(3, 4))
1429      >>> xy = K.dot(x, y)
1430      >>> xy
1431      <tf.Tensor 'MatMul_9:0' shape=(2, 4) dtype=float32>
1432  ```
1433
1434  ```python
1435      # dot product between tensors
1436      >>> x = K.placeholder(shape=(32, 28, 3))
1437      >>> y = K.placeholder(shape=(3, 4))
1438      >>> xy = K.dot(x, y)
1439      >>> xy
1440      <tf.Tensor 'MatMul_9:0' shape=(32, 28, 4) dtype=float32>
1441  ```
1442
1443  ```python
1444      # Theano-like behavior example
1445      >>> x = K.random_uniform_variable(shape=(2, 3), low=0, high=1)
1446      >>> y = K.ones((4, 3, 5))
1447      >>> xy = K.dot(x, y)
1448      >>> K.int_shape(xy)
1449      (2, 4, 5)
1450  ```
1451  """
1452  if ndim(x) is not None and (ndim(x) > 2 or ndim(y) > 2):
1453    x_shape = []
1454    for i, s in zip(int_shape(x), array_ops.unstack(array_ops.shape(x))):
1455      if i is not None:
1456        x_shape.append(i)
1457      else:
1458        x_shape.append(s)
1459    x_shape = tuple(x_shape)
1460    y_shape = []
1461    for i, s in zip(int_shape(y), array_ops.unstack(array_ops.shape(y))):
1462      if i is not None:
1463        y_shape.append(i)
1464      else:
1465        y_shape.append(s)
1466    y_shape = tuple(y_shape)
1467    y_permute_dim = list(range(ndim(y)))
1468    y_permute_dim = [y_permute_dim.pop(-2)] + y_permute_dim
1469    xt = array_ops.reshape(x, [-1, x_shape[-1]])
1470    yt = array_ops.reshape(
1471        array_ops.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
1472    return array_ops.reshape(
1473        math_ops.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
1474  if is_sparse(x):
1475    out = sparse_ops.sparse_tensor_dense_matmul(x, y)
1476  else:
1477    out = math_ops.matmul(x, y)
1478  return out
1479
1480
1481@keras_export('keras.backend.batch_dot')
1482def batch_dot(x, y, axes=None):
1483  """Batchwise dot product.
1484
1485  `batch_dot` is used to compute dot product of `x` and `y` when
1486  `x` and `y` are data in batch, i.e. in a shape of
1487  `(batch_size, :)`.
1488  `batch_dot` results in a tensor or variable with less dimensions
1489  than the input. If the number of dimensions is reduced to 1,
1490  we use `expand_dims` to make sure that ndim is at least 2.
1491
1492  Arguments:
1493      x: Keras tensor or variable with `ndim >= 2`.
1494      y: Keras tensor or variable with `ndim >= 2`.
1495      axes: list of (or single) int with target dimensions.
1496          The lengths of `axes[0]` and `axes[1]` should be the same.
1497
1498  Returns:
1499      A tensor with shape equal to the concatenation of `x`'s shape
1500      (less the dimension that was summed over) and `y`'s shape
1501      (less the batch dimension and the dimension that was summed over).
1502      If the final rank is 1, we reshape it to `(batch_size, 1)`.
1503
1504  Examples:
1505      Assume `x = [[1, 2], [3, 4]]` and `y = [[5, 6], [7, 8]]`
1506      `batch_dot(x, y, axes=1) = [[17, 53]]` which is the main diagonal
1507      of `x.dot(y.T)`, although we never have to calculate the off-diagonal
1508      elements.
1509
1510      Shape inference:
1511      Let `x`'s shape be `(100, 20)` and `y`'s shape be `(100, 30, 20)`.
1512      If `axes` is (1, 2), to find the output shape of resultant tensor,
1513          loop through each dimension in `x`'s shape and `y`'s shape:
1514
1515      * `x.shape[0]` : 100 : append to output shape
1516      * `x.shape[1]` : 20 : do not append to output shape,
1517          dimension 1 of `x` has been summed over. (`dot_axes[0]` = 1)
1518      * `y.shape[0]` : 100 : do not append to output shape,
1519          always ignore first dimension of `y`
1520      * `y.shape[1]` : 30 : append to output shape
1521      * `y.shape[2]` : 20 : do not append to output shape,
1522          dimension 2 of `y` has been summed over. (`dot_axes[1]` = 2)
1523      `output_shape` = `(100, 30)`
1524
1525  ```python
1526      >>> x_batch = K.ones(shape=(32, 20, 1))
1527      >>> y_batch = K.ones(shape=(32, 30, 20))
1528      >>> xy_batch_dot = K.batch_dot(x_batch, y_batch, axes=[1, 2])
1529      >>> K.int_shape(xy_batch_dot)
1530      (32, 1, 30)
1531  ```
1532  """
1533  if isinstance(axes, int):
1534    axes = (axes, axes)
1535  x_ndim = ndim(x)
1536  y_ndim = ndim(y)
1537  if axes is None:
1538    # behaves like tf.batch_matmul as default
1539    axes = [x_ndim - 1, y_ndim - 2]
1540  if x_ndim > y_ndim:
1541    diff = x_ndim - y_ndim
1542    y = array_ops.reshape(y,
1543                          array_ops.concat(
1544                              [array_ops.shape(y), [1] * (diff)], axis=0))
1545  elif y_ndim > x_ndim:
1546    diff = y_ndim - x_ndim
1547    x = array_ops.reshape(x,
1548                          array_ops.concat(
1549                              [array_ops.shape(x), [1] * (diff)], axis=0))
1550  else:
1551    diff = 0
1552  if ndim(x) == 2 and ndim(y) == 2:
1553    if axes[0] == axes[1]:
1554      out = math_ops.reduce_sum(math_ops.multiply(x, y), axes[0])
1555    else:
1556      out = math_ops.reduce_sum(
1557          math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
1558  else:
1559    adj_x = None if axes[0] == ndim(x) - 1 else True
1560    adj_y = True if axes[1] == ndim(y) - 1 else None
1561    out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
1562  if diff:
1563    if x_ndim > y_ndim:
1564      idx = x_ndim + y_ndim - 3
1565    else:
1566      idx = x_ndim - 1
1567    out = array_ops.squeeze(out, list(range(idx, idx + diff)))
1568  if ndim(out) == 1:
1569    out = expand_dims(out, 1)
1570  return out
1571
1572
1573@keras_export('keras.backend.transpose')
1574def transpose(x):
1575  """Transposes a tensor and returns it.
1576
1577  Arguments:
1578      x: Tensor or variable.
1579
1580  Returns:
1581      A tensor.
1582
1583  Examples:
1584  ```python
1585      >>> var = K.variable([[1, 2, 3], [4, 5, 6]])
1586      >>> K.eval(var)
1587      array([[ 1.,  2.,  3.],
1588             [ 4.,  5.,  6.]], dtype=float32)
1589      >>> var_transposed = K.transpose(var)
1590      >>> K.eval(var_transposed)
1591      array([[ 1.,  4.],
1592             [ 2.,  5.],
1593             [ 3.,  6.]], dtype=float32)
1594  ```
1595
1596  ```python
1597      >>> input = K.placeholder((2, 3))
1598      >>> input
1599      <tf.Tensor 'Placeholder_11:0' shape=(2, 3) dtype=float32>
1600      >>> input_transposed = K.transpose(input)
1601      >>> input_transposed
1602      <tf.Tensor 'transpose_4:0' shape=(3, 2) dtype=float32>
1603
1604  ```
1605  """
1606  return array_ops.transpose(x)
1607
1608
1609@keras_export('keras.backend.gather')
1610def gather(reference, indices):
1611  """Retrieves the elements of indices `indices` in the tensor `reference`.
1612
1613  Arguments:
1614      reference: A tensor.
1615      indices: An integer tensor of indices.
1616
1617  Returns:
1618      A tensor of same type as `reference`.
1619  """
1620  return array_ops.gather(reference, indices)
1621
1622
1623# ELEMENT-WISE OPERATIONS
1624
1625
1626@keras_export('keras.backend.max')
1627def max(x, axis=None, keepdims=False):
1628  """Maximum value in a tensor.
1629
1630  Arguments:
1631      x: A tensor or variable.
1632      axis: An integer, the axis to find maximum values.
1633      keepdims: A boolean, whether to keep the dimensions or not.
1634          If `keepdims` is `False`, the rank of the tensor is reduced
1635          by 1. If `keepdims` is `True`,
1636          the reduced dimension is retained with length 1.
1637
1638  Returns:
1639      A tensor with maximum values of `x`.
1640  """
1641  return math_ops.reduce_max(x, axis, keepdims)
1642
1643
1644@keras_export('keras.backend.min')
1645def min(x, axis=None, keepdims=False):
1646  """Minimum value in a tensor.
1647
1648  Arguments:
1649      x: A tensor or variable.
1650      axis: An integer, the axis to find minimum values.
1651      keepdims: A boolean, whether to keep the dimensions or not.
1652          If `keepdims` is `False`, the rank of the tensor is reduced
1653          by 1. If `keepdims` is `True`,
1654          the reduced dimension is retained with length 1.
1655
1656  Returns:
1657      A tensor with minimum values of `x`.
1658  """
1659  return math_ops.reduce_min(x, axis, keepdims)
1660
1661
1662@keras_export('keras.backend.sum')
1663def sum(x, axis=None, keepdims=False):
1664  """Sum of the values in a tensor, alongside the specified axis.
1665
1666  Arguments:
1667      x: A tensor or variable.
1668      axis: An integer, the axis to sum over.
1669      keepdims: A boolean, whether to keep the dimensions or not.
1670          If `keepdims` is `False`, the rank of the tensor is reduced
1671          by 1. If `keepdims` is `True`,
1672          the reduced dimension is retained with length 1.
1673
1674  Returns:
1675      A tensor with sum of `x`.
1676  """
1677  return math_ops.reduce_sum(x, axis, keepdims)
1678
1679
1680@keras_export('keras.backend.prod')
1681def prod(x, axis=None, keepdims=False):
1682  """Multiplies the values in a tensor, alongside the specified axis.
1683
1684  Arguments:
1685      x: A tensor or variable.
1686      axis: An integer, the axis to compute the product.
1687      keepdims: A boolean, whether to keep the dimensions or not.
1688          If `keepdims` is `False`, the rank of the tensor is reduced
1689          by 1. If `keepdims` is `True`,
1690          the reduced dimension is retained with length 1.
1691
1692  Returns:
1693      A tensor with the product of elements of `x`.
1694  """
1695  return math_ops.reduce_prod(x, axis, keepdims)
1696
1697
1698@keras_export('keras.backend.cumsum')
1699def cumsum(x, axis=0):
1700  """Cumulative sum of the values in a tensor, alongside the specified axis.
1701
1702  Arguments:
1703      x: A tensor or variable.
1704      axis: An integer, the axis to compute the sum.
1705
1706  Returns:
1707      A tensor of the cumulative sum of values of `x` along `axis`.
1708  """
1709  return math_ops.cumsum(x, axis=axis)
1710
1711
1712@keras_export('keras.backend.cumprod')
1713def cumprod(x, axis=0):
1714  """Cumulative product of the values in a tensor, alongside the specified axis.
1715
1716  Arguments:
1717      x: A tensor or variable.
1718      axis: An integer, the axis to compute the product.
1719
1720  Returns:
1721      A tensor of the cumulative product of values of `x` along `axis`.
1722  """
1723  return math_ops.cumprod(x, axis=axis)
1724
1725
1726@keras_export('keras.backend.var')
1727def var(x, axis=None, keepdims=False):
1728  """Variance of a tensor, alongside the specified axis.
1729
1730  Arguments:
1731      x: A tensor or variable.
1732      axis: An integer, the axis to compute the variance.
1733      keepdims: A boolean, whether to keep the dimensions or not.
1734          If `keepdims` is `False`, the rank of the tensor is reduced
1735          by 1. If `keepdims` is `True`,
1736          the reduced dimension is retained with length 1.
1737
1738  Returns:
1739      A tensor with the variance of elements of `x`.
1740  """
1741  if x.dtype.base_dtype == dtypes_module.bool:
1742    x = math_ops.cast(x, floatx())
1743  return math_ops.reduce_variance(x, axis=axis, keepdims=keepdims)
1744
1745
1746@keras_export('keras.backend.std')
1747def std(x, axis=None, keepdims=False):
1748  """Standard deviation of a tensor, alongside the specified axis.
1749
1750  Arguments:
1751      x: A tensor or variable.
1752      axis: An integer, the axis to compute the standard deviation.
1753      keepdims: A boolean, whether to keep the dimensions or not.
1754          If `keepdims` is `False`, the rank of the tensor is reduced
1755          by 1. If `keepdims` is `True`,
1756          the reduced dimension is retained with length 1.
1757
1758  Returns:
1759      A tensor with the standard deviation of elements of `x`.
1760  """
1761  if x.dtype.base_dtype == dtypes_module.bool:
1762    x = math_ops.cast(x, floatx())
1763  return math_ops.reduce_std(x, axis=axis, keepdims=keepdims)
1764
1765
1766@keras_export('keras.backend.mean')
1767def mean(x, axis=None, keepdims=False):
1768  """Mean of a tensor, alongside the specified axis.
1769
1770  Arguments:
1771      x: A tensor or variable.
1772      axis: A list of integer. Axes to compute the mean.
1773      keepdims: A boolean, whether to keep the dimensions or not.
1774          If `keepdims` is `False`, the rank of the tensor is reduced
1775          by 1 for each entry in `axis`. If `keepdims` is `True`,
1776          the reduced dimensions are retained with length 1.
1777
1778  Returns:
1779      A tensor with the mean of elements of `x`.
1780  """
1781  if x.dtype.base_dtype == dtypes_module.bool:
1782    x = math_ops.cast(x, floatx())
1783  return math_ops.reduce_mean(x, axis, keepdims)
1784
1785
1786@keras_export('keras.backend.any')
1787def any(x, axis=None, keepdims=False):
1788  """Bitwise reduction (logical OR).
1789
1790  Arguments:
1791      x: Tensor or variable.
1792      axis: axis along which to perform the reduction.
1793      keepdims: whether the drop or broadcast the reduction axes.
1794
1795  Returns:
1796      A uint8 tensor (0s and 1s).
1797  """
1798  x = math_ops.cast(x, dtypes_module.bool)
1799  return math_ops.reduce_any(x, axis, keepdims)
1800
1801
1802@keras_export('keras.backend.all')
1803def all(x, axis=None, keepdims=False):
1804  """Bitwise reduction (logical AND).
1805
1806  Arguments:
1807      x: Tensor or variable.
1808      axis: axis along which to perform the reduction.
1809      keepdims: whether the drop or broadcast the reduction axes.
1810
1811  Returns:
1812      A uint8 tensor (0s and 1s).
1813  """
1814  x = math_ops.cast(x, dtypes_module.bool)
1815  return math_ops.reduce_all(x, axis, keepdims)
1816
1817
1818@keras_export('keras.backend.argmax')
1819def argmax(x, axis=-1):
1820  """Returns the index of the maximum value along an axis.
1821
1822  Arguments:
1823      x: Tensor or variable.
1824      axis: axis along which to perform the reduction.
1825
1826  Returns:
1827      A tensor.
1828  """
1829  return math_ops.argmax(x, axis)
1830
1831
1832@keras_export('keras.backend.argmin')
1833def argmin(x, axis=-1):
1834  """Returns the index of the minimum value along an axis.
1835
1836  Arguments:
1837      x: Tensor or variable.
1838      axis: axis along which to perform the reduction.
1839
1840  Returns:
1841      A tensor.
1842  """
1843  return math_ops.argmin(x, axis)
1844
1845
1846@keras_export('keras.backend.square')
1847def square(x):
1848  """Element-wise square.
1849
1850  Arguments:
1851      x: Tensor or variable.
1852
1853  Returns:
1854      A tensor.
1855  """
1856  return math_ops.square(x)
1857
1858
1859@keras_export('keras.backend.abs')
1860def abs(x):
1861  """Element-wise absolute value.
1862
1863  Arguments:
1864      x: Tensor or variable.
1865
1866  Returns:
1867      A tensor.
1868  """
1869  return math_ops.abs(x)
1870
1871
1872@keras_export('keras.backend.sqrt')
1873def sqrt(x):
1874  """Element-wise square root.
1875
1876  Arguments:
1877      x: Tensor or variable.
1878
1879  Returns:
1880      A tensor.
1881  """
1882  zero = _to_tensor(0., x.dtype.base_dtype)
1883  inf = _to_tensor(np.inf, x.dtype.base_dtype)
1884  x = clip_ops.clip_by_value(x, zero, inf)
1885  return math_ops.sqrt(x)
1886
1887
1888@keras_export('keras.backend.exp')
1889def exp(x):
1890  """Element-wise exponential.
1891
1892  Arguments:
1893      x: Tensor or variable.
1894
1895  Returns:
1896      A tensor.
1897  """
1898  return math_ops.exp(x)
1899
1900
1901@keras_export('keras.backend.log')
1902def log(x):
1903  """Element-wise log.
1904
1905  Arguments:
1906      x: Tensor or variable.
1907
1908  Returns:
1909      A tensor.
1910  """
1911  return math_ops.log(x)
1912
1913
1914def logsumexp(x, axis=None, keepdims=False):
1915  """Computes log(sum(exp(elements across dimensions of a tensor))).
1916
1917  This function is more numerically stable than log(sum(exp(x))).
1918  It avoids overflows caused by taking the exp of large inputs and
1919  underflows caused by taking the log of small inputs.
1920
1921  Arguments:
1922      x: A tensor or variable.
1923      axis: An integer, the axis to reduce over.
1924      keepdims: A boolean, whether to keep the dimensions or not.
1925          If `keepdims` is `False`, the rank of the tensor is reduced
1926          by 1. If `keepdims` is `True`, the reduced dimension is
1927          retained with length 1.
1928
1929  Returns:
1930      The reduced tensor.
1931  """
1932  return math_ops.reduce_logsumexp(x, axis, keepdims)
1933
1934
1935@keras_export('keras.backend.round')
1936def round(x):
1937  """Element-wise rounding to the closest integer.
1938
1939  In case of tie, the rounding mode used is "half to even".
1940
1941  Arguments:
1942      x: Tensor or variable.
1943
1944  Returns:
1945      A tensor.
1946  """
1947  return math_ops.round(x)
1948
1949
1950@keras_export('keras.backend.sign')
1951def sign(x):
1952  """Element-wise sign.
1953
1954  Arguments:
1955      x: Tensor or variable.
1956
1957  Returns:
1958      A tensor.
1959  """
1960  return math_ops.sign(x)
1961
1962
1963@keras_export('keras.backend.pow')
1964def pow(x, a):
1965  """Element-wise exponentiation.
1966
1967  Arguments:
1968      x: Tensor or variable.
1969      a: Python integer.
1970
1971  Returns:
1972      A tensor.
1973  """
1974  return math_ops.pow(x, a)
1975
1976
1977@keras_export('keras.backend.clip')
1978def clip(x, min_value, max_value):
1979  """Element-wise value clipping.
1980
1981  Arguments:
1982      x: Tensor or variable.
1983      min_value: Python float or integer.
1984      max_value: Python float or integer.
1985
1986  Returns:
1987      A tensor.
1988  """
1989  if max_value is not None and max_value < min_value:
1990    max_value = min_value
1991  if max_value is None:
1992    max_value = np.inf
1993  min_value = _to_tensor(min_value, x.dtype.base_dtype)
1994  max_value = _to_tensor(max_value, x.dtype.base_dtype)
1995  return clip_ops.clip_by_value(x, min_value, max_value)
1996
1997
1998@keras_export('keras.backend.equal')
1999def equal(x, y):
2000  """Element-wise equality between two tensors.
2001
2002  Arguments:
2003      x: Tensor or variable.
2004      y: Tensor or variable.
2005
2006  Returns:
2007      A bool tensor.
2008  """
2009  return math_ops.equal(x, y)
2010
2011
2012@keras_export('keras.backend.not_equal')
2013def not_equal(x, y):
2014  """Element-wise inequality between two tensors.
2015
2016  Arguments:
2017      x: Tensor or variable.
2018      y: Tensor or variable.
2019
2020  Returns:
2021      A bool tensor.
2022  """
2023  return math_ops.not_equal(x, y)
2024
2025
2026@keras_export('keras.backend.greater')
2027def greater(x, y):
2028  """Element-wise truth value of (x > y).
2029
2030  Arguments:
2031      x: Tensor or variable.
2032      y: Tensor or variable.
2033
2034  Returns:
2035      A bool tensor.
2036  """
2037  return math_ops.greater(x, y)
2038
2039
2040@keras_export('keras.backend.greater_equal')
2041def greater_equal(x, y):
2042  """Element-wise truth value of (x >= y).
2043
2044  Arguments:
2045      x: Tensor or variable.
2046      y: Tensor or variable.
2047
2048  Returns:
2049      A bool tensor.
2050  """
2051  return math_ops.greater_equal(x, y)
2052
2053
2054@keras_export('keras.backend.less')
2055def less(x, y):
2056  """Element-wise truth value of (x < y).
2057
2058  Arguments:
2059      x: Tensor or variable.
2060      y: Tensor or variable.
2061
2062  Returns:
2063      A bool tensor.
2064  """
2065  return math_ops.less(x, y)
2066
2067
2068@keras_export('keras.backend.less_equal')
2069def less_equal(x, y):
2070  """Element-wise truth value of (x <= y).
2071
2072  Arguments:
2073      x: Tensor or variable.
2074      y: Tensor or variable.
2075
2076  Returns:
2077      A bool tensor.
2078  """
2079  return math_ops.less_equal(x, y)
2080
2081
2082@keras_export('keras.backend.maximum')
2083def maximum(x, y):
2084  """Element-wise maximum of two tensors.
2085
2086  Arguments:
2087      x: Tensor or variable.
2088      y: Tensor or variable.
2089
2090  Returns:
2091      A tensor.
2092  """
2093  return math_ops.maximum(x, y)
2094
2095
2096@keras_export('keras.backend.minimum')
2097def minimum(x, y):
2098  """Element-wise minimum of two tensors.
2099
2100  Arguments:
2101      x: Tensor or variable.
2102      y: Tensor or variable.
2103
2104  Returns:
2105      A tensor.
2106  """
2107  return math_ops.minimum(x, y)
2108
2109
2110@keras_export('keras.backend.sin')
2111def sin(x):
2112  """Computes sin of x element-wise.
2113
2114  Arguments:
2115      x: Tensor or variable.
2116
2117  Returns:
2118      A tensor.
2119  """
2120  return math_ops.sin(x)
2121
2122
2123@keras_export('keras.backend.cos')
2124def cos(x):
2125  """Computes cos of x element-wise.
2126
2127  Arguments:
2128      x: Tensor or variable.
2129
2130  Returns:
2131      A tensor.
2132  """
2133  return math_ops.cos(x)
2134
2135
2136def _regular_normalize_batch_in_training(x,
2137                                         gamma,
2138                                         beta,
2139                                         reduction_axes,
2140                                         epsilon=1e-3):
2141  """Non-fused version of `normalize_batch_in_training`.
2142
2143  Arguments:
2144      x: Input tensor or variable.
2145      gamma: Tensor by which to scale the input.
2146      beta: Tensor with which to center the input.
2147      reduction_axes: iterable of integers,
2148          axes over which to normalize.
2149      epsilon: Fuzz factor.
2150
2151  Returns:
2152      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2153  """
2154  mean, var = nn.moments(x, reduction_axes, None, None, False)
2155  normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2156  return normed, mean, var
2157
2158
2159def _broadcast_normalize_batch_in_training(x,
2160                                           gamma,
2161                                           beta,
2162                                           reduction_axes,
2163                                           epsilon=1e-3):
2164  """Non-fused, broadcast version of `normalize_batch_in_training`.
2165
2166  Arguments:
2167      x: Input tensor or variable.
2168      gamma: Tensor by which to scale the input.
2169      beta: Tensor with which to center the input.
2170      reduction_axes: iterable of integers,
2171          axes over which to normalize.
2172      epsilon: Fuzz factor.
2173
2174  Returns:
2175      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2176  """
2177  mean, var = nn.moments(x, reduction_axes, None, None, False)
2178  target_shape = []
2179  for axis in range(ndim(x)):
2180    if axis in reduction_axes:
2181      target_shape.append(1)
2182    else:
2183      target_shape.append(array_ops.shape(x)[axis])
2184  target_shape = array_ops.stack(target_shape)
2185
2186  broadcast_mean = array_ops.reshape(mean, target_shape)
2187  broadcast_var = array_ops.reshape(var, target_shape)
2188  if gamma is None:
2189    broadcast_gamma = None
2190  else:
2191    broadcast_gamma = array_ops.reshape(gamma, target_shape)
2192  if beta is None:
2193    broadcast_beta = None
2194  else:
2195    broadcast_beta = array_ops.reshape(beta, target_shape)
2196
2197  normed = nn.batch_normalization(x, broadcast_mean, broadcast_var,
2198                                  broadcast_beta, broadcast_gamma, epsilon)
2199  return normed, mean, var
2200
2201
2202def _fused_normalize_batch_in_training(x,
2203                                       gamma,
2204                                       beta,
2205                                       reduction_axes,
2206                                       epsilon=1e-3):
2207  """Fused version of `normalize_batch_in_training`.
2208
2209  Arguments:
2210      x: Input tensor or variable.
2211      gamma: Tensor by which to scale the input.
2212      beta: Tensor with which to center the input.
2213      reduction_axes: iterable of integers,
2214          axes over which to normalize.
2215      epsilon: Fuzz factor.
2216
2217  Returns:
2218      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2219  """
2220  if list(reduction_axes) == [0, 1, 2]:
2221    normalization_axis = 3
2222    tf_data_format = 'NHWC'
2223  else:
2224    normalization_axis = 1
2225    tf_data_format = 'NCHW'
2226
2227  if gamma is None:
2228    gamma = constant_op.constant(
2229        1.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2230  if beta is None:
2231    beta = constant_op.constant(
2232        0.0, dtype=x.dtype, shape=[x.shape[normalization_axis]])
2233
2234  return nn.fused_batch_norm(
2235      x, gamma, beta, epsilon=epsilon, data_format=tf_data_format)
2236
2237
2238@keras_export('keras.backend.normalize_batch_in_training')
2239def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
2240  """Computes mean and std for batch then apply batch_normalization on batch.
2241
2242  Arguments:
2243      x: Input tensor or variable.
2244      gamma: Tensor by which to scale the input.
2245      beta: Tensor with which to center the input.
2246      reduction_axes: iterable of integers,
2247          axes over which to normalize.
2248      epsilon: Fuzz factor.
2249
2250  Returns:
2251      A tuple length of 3, `(normalized_tensor, mean, variance)`.
2252  """
2253  if ndim(x) == 4 and list(reduction_axes) in [[0, 1, 2], [0, 2, 3]]:
2254    if not _has_nchw_support() and list(reduction_axes) == [0, 2, 3]:
2255      return _broadcast_normalize_batch_in_training(
2256          x, gamma, beta, reduction_axes, epsilon=epsilon)
2257    return _fused_normalize_batch_in_training(
2258        x, gamma, beta, reduction_axes, epsilon=epsilon)
2259  else:
2260    if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
2261      return _regular_normalize_batch_in_training(
2262          x, gamma, beta, reduction_axes, epsilon=epsilon)
2263    else:
2264      return _broadcast_normalize_batch_in_training(
2265          x, gamma, beta, reduction_axes, epsilon=epsilon)
2266
2267
2268@keras_export('keras.backend.batch_normalization')
2269def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
2270  """Applies batch normalization on x given mean, var, beta and gamma.
2271
2272  I.e. returns:
2273  `output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta`
2274
2275  Arguments:
2276      x: Input tensor or variable.
2277      mean: Mean of batch.
2278      var: Variance of batch.
2279      beta: Tensor with which to center the input.
2280      gamma: Tensor by which to scale the input.
2281      axis: Integer, the axis that should be normalized.
2282          (typically the features axis).
2283      epsilon: Fuzz factor.
2284
2285  Returns:
2286      A tensor.
2287  """
2288  if ndim(x) == 4:
2289    # The CPU implementation of `fused_batch_norm` only supports NHWC
2290    if axis == 1 or axis == -3:
2291      tf_data_format = 'NCHW'
2292    elif axis == 3 or axis == -1:
2293      tf_data_format = 'NHWC'
2294    else:
2295      tf_data_format = None
2296
2297    if (tf_data_format == 'NHWC' or
2298        tf_data_format == 'NCHW' and _has_nchw_support()):
2299      # The mean / var / beta / gamma tensors may be broadcasted
2300      # so they may have extra axes of size 1, which should be squeezed.
2301      if ndim(mean) > 1:
2302        mean = array_ops.reshape(mean, [-1])
2303      if ndim(var) > 1:
2304        var = array_ops.reshape(var, [-1])
2305      if beta is None:
2306        beta = zeros_like(mean)
2307      elif ndim(beta) > 1:
2308        beta = array_ops.reshape(beta, [-1])
2309      if gamma is None:
2310        gamma = ones_like(mean)
2311      elif ndim(gamma) > 1:
2312        gamma = array_ops.reshape(gamma, [-1])
2313    y, _, _ = nn.fused_batch_norm(
2314        x,
2315        gamma,
2316        beta,
2317        epsilon=epsilon,
2318        mean=mean,
2319        variance=var,
2320        data_format=tf_data_format,
2321        is_training=False
2322    )
2323    return y
2324  return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
2325
2326
2327# SHAPE OPERATIONS
2328
2329
2330@keras_export('keras.backend.concatenate')
2331def concatenate(tensors, axis=-1):
2332  """Concatenates a list of tensors alongside the specified axis.
2333
2334  Arguments:
2335      tensors: list of tensors to concatenate.
2336      axis: concatenation axis.
2337
2338  Returns:
2339      A tensor.
2340  """
2341  if axis < 0:
2342    rank = ndim(tensors[0])
2343    if rank:
2344      axis %= rank
2345    else:
2346      axis = 0
2347
2348  if py_all(is_sparse(x) for x in tensors):
2349    return sparse_ops.sparse_concat(axis, tensors)
2350  else:
2351    return array_ops.concat([to_dense(x) for x in tensors], axis)
2352
2353
2354@keras_export('keras.backend.reshape')
2355def reshape(x, shape):
2356  """Reshapes a tensor to the specified shape.
2357
2358  Arguments:
2359      x: Tensor or variable.
2360      shape: Target shape tuple.
2361
2362  Returns:
2363      A tensor.
2364  """
2365  return array_ops.reshape(x, shape)
2366
2367
2368@keras_export('keras.backend.permute_dimensions')
2369def permute_dimensions(x, pattern):
2370  """Permutes axes in a tensor.
2371
2372  Arguments:
2373      x: Tensor or variable.
2374      pattern: A tuple of
2375          dimension indices, e.g. `(0, 2, 1)`.
2376
2377  Returns:
2378      A tensor.
2379  """
2380  return array_ops.transpose(x, perm=pattern)
2381
2382
2383@keras_export('keras.backend.resize_images')
2384def resize_images(x, height_factor, width_factor, data_format,
2385                  interpolation='nearest'):
2386  """Resizes the images contained in a 4D tensor.
2387
2388  Arguments:
2389      x: Tensor or variable to resize.
2390      height_factor: Positive integer.
2391      width_factor: Positive integer.
2392      data_format: One of `"channels_first"`, `"channels_last"`.
2393      interpolation: A string, one of `nearest` or `bilinear`.
2394
2395  Returns:
2396      A tensor.
2397
2398  Raises:
2399      ValueError: in case of incorrect value for
2400        `data_format` or `interpolation`.
2401  """
2402  if data_format == 'channels_first':
2403    rows, cols = 2, 3
2404  elif data_format == 'channels_last':
2405    rows, cols = 1, 2
2406  else:
2407    raise ValueError('Invalid `data_format` argument: %s' % (data_format,))
2408
2409  original_shape = int_shape(x)
2410  new_shape = array_ops.shape(x)[rows:cols + 1]
2411  new_shape *= constant_op.constant(
2412      np.array([height_factor, width_factor], dtype='int32'))
2413
2414  if data_format == 'channels_first':
2415    x = permute_dimensions(x, [0, 2, 3, 1])
2416  if interpolation == 'nearest':
2417    x = image_ops.resize_nearest_neighbor(x, new_shape)
2418  elif interpolation == 'bilinear':
2419    x = image_ops.resize_bilinear(x, new_shape)
2420  else:
2421    raise ValueError('interpolation should be one '
2422                     'of "nearest" or "bilinear".')
2423  if data_format == 'channels_first':
2424    x = permute_dimensions(x, [0, 3, 1, 2])
2425
2426  if original_shape[rows] is None:
2427    new_height = None
2428  else:
2429    new_height = original_shape[rows] * height_factor
2430
2431  if original_shape[cols] is None:
2432    new_width = None
2433  else:
2434    new_width = original_shape[cols] * width_factor
2435
2436  if data_format == 'channels_first':
2437    output_shape = (None, None, new_height, new_width)
2438  else:
2439    output_shape = (None, new_height, new_width, None)
2440  x.set_shape(output_shape)
2441  return x
2442
2443
2444@keras_export('keras.backend.resize_volumes')
2445def resize_volumes(x, depth_factor, height_factor, width_factor, data_format):
2446  """Resizes the volume contained in a 5D tensor.
2447
2448  Arguments:
2449      x: Tensor or variable to resize.
2450      depth_factor: Positive integer.
2451      height_factor: Positive integer.
2452      width_factor: Positive integer.
2453      data_format: One of `"channels_first"`, `"channels_last"`.
2454
2455  Returns:
2456      A tensor.
2457
2458  Raises:
2459      ValueError: if `data_format` is neither
2460          `channels_last` or `channels_first`.
2461  """
2462  if data_format == 'channels_first':
2463    output = repeat_elements(x, depth_factor, axis=2)
2464    output = repeat_elements(output, height_factor, axis=3)
2465    output = repeat_elements(output, width_factor, axis=4)
2466    return output
2467  elif data_format == 'channels_last':
2468    output = repeat_elements(x, depth_factor, axis=1)
2469    output = repeat_elements(output, height_factor, axis=2)
2470    output = repeat_elements(output, width_factor, axis=3)
2471    return output
2472  else:
2473    raise ValueError('Invalid data_format: ' + str(data_format))
2474
2475
2476@keras_export('keras.backend.repeat_elements')
2477def repeat_elements(x, rep, axis):
2478  """Repeats the elements of a tensor along an axis, like `np.repeat`.
2479
2480  If `x` has shape `(s1, s2, s3)` and `axis` is `1`, the output
2481  will have shape `(s1, s2 * rep, s3)`.
2482
2483  Arguments:
2484      x: Tensor or variable.
2485      rep: Python integer, number of times to repeat.
2486      axis: Axis along which to repeat.
2487
2488  Returns:
2489      A tensor.
2490  """
2491  x_shape = x.shape.as_list()
2492  # For static axis
2493  if x_shape[axis] is not None:
2494    # slices along the repeat axis
2495    splits = array_ops.split(value=x,
2496                             num_or_size_splits=x_shape[axis],
2497                             axis=axis)
2498    # repeat each slice the given number of reps
2499    x_rep = [s for s in splits for _ in range(rep)]
2500    return concatenate(x_rep, axis)
2501
2502  # Here we use tf.tile to mimic behavior of np.repeat so that
2503  # we can handle dynamic shapes (that include None).
2504  # To do that, we need an auxiliary axis to repeat elements along
2505  # it and then merge them along the desired axis.
2506
2507  # Repeating
2508  auxiliary_axis = axis + 1
2509  x_shape = array_ops.shape(x)
2510  x_rep = array_ops.expand_dims(x, axis=auxiliary_axis)
2511  reps = np.ones(len(x.shape) + 1)
2512  reps[auxiliary_axis] = rep
2513  x_rep = array_ops.tile(x_rep, reps)
2514
2515  # Merging
2516  reps = np.delete(reps, auxiliary_axis)
2517  reps[axis] = rep
2518  reps = array_ops.constant(reps, dtype='int32')
2519  x_shape *= reps
2520  x_rep = array_ops.reshape(x_rep, x_shape)
2521
2522  # Fix shape representation
2523  x_shape = x.shape.as_list()
2524  x_rep.set_shape(x_shape)
2525  x_rep._keras_shape = tuple(x_shape)
2526  return x_rep
2527
2528
2529@keras_export('keras.backend.repeat')
2530def repeat(x, n):
2531  """Repeats a 2D tensor.
2532
2533  if `x` has shape (samples, dim) and `n` is `2`,
2534  the output will have shape `(samples, 2, dim)`.
2535
2536  Arguments:
2537      x: Tensor or variable.
2538      n: Python integer, number of times to repeat.
2539
2540  Returns:
2541      A tensor.
2542  """
2543  assert ndim(x) == 2
2544  x = array_ops.expand_dims(x, 1)
2545  pattern = array_ops.stack([1, n, 1])
2546  return array_ops.tile(x, pattern)
2547
2548
2549@keras_export('keras.backend.arange')
2550def arange(start, stop=None, step=1, dtype='int32'):
2551  """Creates a 1D tensor containing a sequence of integers.
2552
2553  The function arguments use the same convention as
2554  Theano's arange: if only one argument is provided,
2555  it is in fact the "stop" argument and "start" is 0.
2556
2557  The default type of the returned tensor is `'int32'` to
2558  match TensorFlow's default.
2559
2560  Arguments:
2561      start: Start value.
2562      stop: Stop value.
2563      step: Difference between two successive values.
2564      dtype: Integer dtype to use.
2565
2566  Returns:
2567      An integer tensor.
2568
2569  """
2570  # Match the behavior of numpy and Theano by returning an empty sequence.
2571  if stop is None and start < 0:
2572    start = 0
2573  result = math_ops.range(start, limit=stop, delta=step, name='arange')
2574  if dtype != 'int32':
2575    result = cast(result, dtype)
2576  return result
2577
2578
2579@keras_export('keras.backend.tile')
2580def tile(x, n):
2581  """Creates a tensor by tiling `x` by `n`.
2582
2583  Arguments:
2584      x: A tensor or variable
2585      n: A list of integer. The length must be the same as the number of
2586          dimensions in `x`.
2587
2588  Returns:
2589      A tiled tensor.
2590  """
2591  if isinstance(n, int):
2592    n = [n]
2593  return array_ops.tile(x, n)
2594
2595
2596@keras_export('keras.backend.flatten')
2597def flatten(x):
2598  """Flatten a tensor.
2599
2600  Arguments:
2601      x: A tensor or variable.
2602
2603  Returns:
2604      A tensor, reshaped into 1-D
2605  """
2606  return array_ops.reshape(x, [-1])
2607
2608
2609@keras_export('keras.backend.batch_flatten')
2610def batch_flatten(x):
2611  """Turn a nD tensor into a 2D tensor with same 0th dimension.
2612
2613  In other words, it flattens each data samples of a batch.
2614
2615  Arguments:
2616      x: A tensor or variable.
2617
2618  Returns:
2619      A tensor.
2620  """
2621  x = array_ops.reshape(x, array_ops.stack([-1, prod(shape(x)[1:])]))
2622  return x
2623
2624
2625@keras_export('keras.backend.expand_dims')
2626def expand_dims(x, axis=-1):
2627  """Adds a 1-sized dimension at index "axis".
2628
2629  Arguments:
2630      x: A tensor or variable.
2631      axis: Position where to add a new axis.
2632
2633  Returns:
2634      A tensor with expanded dimensions.
2635  """
2636  return array_ops.expand_dims(x, axis)
2637
2638
2639@keras_export('keras.backend.squeeze')
2640def squeeze(x, axis):
2641  """Removes a 1-dimension from the tensor at index "axis".
2642
2643  Arguments:
2644      x: A tensor or variable.
2645      axis: Axis to drop.
2646
2647  Returns:
2648      A tensor with the same data as `x` but reduced dimensions.
2649  """
2650  return array_ops.squeeze(x, [axis])
2651
2652
2653@keras_export('keras.backend.temporal_padding')
2654def temporal_padding(x, padding=(1, 1)):
2655  """Pads the middle dimension of a 3D tensor.
2656
2657  Arguments:
2658      x: Tensor or variable.
2659      padding: Tuple of 2 integers, how many zeros to
2660          add at the start and end of dim 1.
2661
2662  Returns:
2663      A padded 3D tensor.
2664  """
2665  assert len(padding) == 2
2666  pattern = [[0, 0], [padding[0], padding[1]], [0, 0]]
2667  return array_ops.pad(x, pattern)
2668
2669
2670@keras_export('keras.backend.spatial_2d_padding')
2671def spatial_2d_padding(x, padding=((1, 1), (1, 1)), data_format=None):
2672  """Pads the 2nd and 3rd dimensions of a 4D tensor.
2673
2674  Arguments:
2675      x: Tensor or variable.
2676      padding: Tuple of 2 tuples, padding pattern.
2677      data_format: One of `channels_last` or `channels_first`.
2678
2679  Returns:
2680      A padded 4D tensor.
2681
2682  Raises:
2683      ValueError: if `data_format` is neither
2684          `channels_last` or `channels_first`.
2685  """
2686  assert len(padding) == 2
2687  assert len(padding[0]) == 2
2688  assert len(padding[1]) == 2
2689  if data_format is None:
2690    data_format = image_data_format()
2691  if data_format not in {'channels_first', 'channels_last'}:
2692    raise ValueError('Unknown data_format: ' + str(data_format))
2693
2694  if data_format == 'channels_first':
2695    pattern = [[0, 0], [0, 0], list(padding[0]), list(padding[1])]
2696  else:
2697    pattern = [[0, 0], list(padding[0]), list(padding[1]), [0, 0]]
2698  return array_ops.pad(x, pattern)
2699
2700
2701@keras_export('keras.backend.spatial_3d_padding')
2702def spatial_3d_padding(x, padding=((1, 1), (1, 1), (1, 1)), data_format=None):
2703  """Pads 5D tensor with zeros along the depth, height, width dimensions.
2704
2705  Pads these dimensions with respectively
2706  "padding[0]", "padding[1]" and "padding[2]" zeros left and right.
2707
2708  For 'channels_last' data_format,
2709  the 2nd, 3rd and 4th dimension will be padded.
2710  For 'channels_first' data_format,
2711  the 3rd, 4th and 5th dimension will be padded.
2712
2713  Arguments:
2714      x: Tensor or variable.
2715      padding: Tuple of 3 tuples, padding pattern.
2716      data_format: One of `channels_last` or `channels_first`.
2717
2718  Returns:
2719      A padded 5D tensor.
2720
2721  Raises:
2722      ValueError: if `data_format` is neither
2723          `channels_last` or `channels_first`.
2724
2725  """
2726  assert len(padding) == 3
2727  assert len(padding[0]) == 2
2728  assert len(padding[1]) == 2
2729  assert len(padding[2]) == 2
2730  if data_format is None:
2731    data_format = image_data_format()
2732  if data_format not in {'channels_first', 'channels_last'}:
2733    raise ValueError('Unknown data_format: ' + str(data_format))
2734
2735  if data_format == 'channels_first':
2736    pattern = [[0, 0], [0, 0], [padding[0][0], padding[0][1]],
2737               [padding[1][0], padding[1][1]], [padding[2][0], padding[2][1]]]
2738  else:
2739    pattern = [[0, 0], [padding[0][0], padding[0][1]],
2740               [padding[1][0], padding[1][1]], [padding[2][0],
2741                                                padding[2][1]], [0, 0]]
2742  return array_ops.pad(x, pattern)
2743
2744
2745@keras_export('keras.backend.stack')
2746def stack(x, axis=0):
2747  """Stacks a list of rank `R` tensors into a rank `R+1` tensor.
2748
2749  Arguments:
2750      x: List of tensors.
2751      axis: Axis along which to perform stacking.
2752
2753  Returns:
2754      A tensor.
2755  """
2756  return array_ops.stack(x, axis=axis)
2757
2758
2759@keras_export('keras.backend.one_hot')
2760def one_hot(indices, num_classes):
2761  """Computes the one-hot representation of an integer tensor.
2762
2763  Arguments:
2764      indices: nD integer tensor of shape
2765          `(batch_size, dim1, dim2, ... dim(n-1))`
2766      num_classes: Integer, number of classes to consider.
2767
2768  Returns:
2769      (n + 1)D one hot representation of the input
2770      with shape `(batch_size, dim1, dim2, ... dim(n-1), num_classes)`
2771
2772  Returns:
2773      The one-hot tensor.
2774  """
2775  return array_ops.one_hot(indices, depth=num_classes, axis=-1)
2776
2777
2778@keras_export('keras.backend.reverse')
2779def reverse(x, axes):
2780  """Reverse a tensor along the specified axes.
2781
2782  Arguments:
2783      x: Tensor to reverse.
2784      axes: Integer or iterable of integers.
2785          Axes to reverse.
2786
2787  Returns:
2788      A tensor.
2789  """
2790  if isinstance(axes, int):
2791    axes = [axes]
2792  return array_ops.reverse(x, axes)
2793
2794
2795# VALUE MANIPULATION
2796
2797
2798@keras_export('keras.backend.get_value')
2799def get_value(x):
2800  """Returns the value of a variable.
2801
2802  Arguments:
2803      x: input variable.
2804
2805  Returns:
2806      A Numpy array.
2807
2808  Raises:
2809      RuntimeError: If this method is called inside defun.
2810  """
2811  if context.executing_eagerly():
2812    return x.numpy()
2813  elif not getattr(x, '_in_graph_mode', True):
2814    # This is a variable which was created in an eager context, but is being
2815    # evaluated from a Graph.
2816    with context.eager_mode():
2817      return x.numpy()
2818  elif ops.inside_function():
2819    raise RuntimeError('Cannot get value inside Tensorflow graph function.')
2820  return x.eval(session=get_session((x,)))
2821
2822
2823@keras_export('keras.backend.batch_get_value')
2824def batch_get_value(tensors):
2825  """Returns the value of more than one tensor variable.
2826
2827  Arguments:
2828      tensors: list of ops to run.
2829
2830  Returns:
2831      A list of Numpy arrays.
2832
2833  Raises:
2834      RuntimeError: If this method is called inside defun.
2835  """
2836  if context.executing_eagerly():
2837    return [x.numpy() for x in tensors]
2838  elif ops.inside_function():  # pylint: disable=protected-access
2839    raise RuntimeError('Cannot get value inside Tensorflow graph function.')
2840  if tensors:
2841    return get_session(tensors).run(tensors)
2842  else:
2843    return []
2844
2845
2846@keras_export('keras.backend.set_value')
2847def set_value(x, value):
2848  """Sets the value of a variable, from a Numpy array.
2849
2850  Arguments:
2851      x: Tensor to set to a new value.
2852      value: Value to set the tensor to, as a Numpy array
2853          (of the same shape).
2854  """
2855  value = np.asarray(value, dtype=dtype(x))
2856  if ops.executing_eagerly_outside_functions():
2857    x.assign(value)
2858  else:
2859    with get_graph().as_default():
2860      tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
2861      if hasattr(x, '_assign_placeholder'):
2862        assign_placeholder = x._assign_placeholder
2863        assign_op = x._assign_op
2864      else:
2865        assign_placeholder = array_ops.placeholder(tf_dtype, shape=value.shape)
2866        assign_op = x.assign(assign_placeholder)
2867        x._assign_placeholder = assign_placeholder
2868        x._assign_op = assign_op
2869      get_session().run(assign_op, feed_dict={assign_placeholder: value})
2870
2871
2872@keras_export('keras.backend.batch_set_value')
2873def batch_set_value(tuples):
2874  """Sets the values of many tensor variables at once.
2875
2876  Arguments:
2877      tuples: a list of tuples `(tensor, value)`.
2878          `value` should be a Numpy array.
2879  """
2880  if ops.executing_eagerly_outside_functions():
2881    for x, value in tuples:
2882      x.assign(np.asarray(value, dtype=dtype(x)))
2883  else:
2884    with get_graph().as_default():
2885      if tuples:
2886        assign_ops = []
2887        feed_dict = {}
2888        for x, value in tuples:
2889          value = np.asarray(value, dtype=dtype(x))
2890          tf_dtype = dtypes_module.as_dtype(x.dtype.name.split('_')[0])
2891          if hasattr(x, '_assign_placeholder'):
2892            assign_placeholder = x._assign_placeholder
2893            assign_op = x._assign_op
2894          else:
2895            assign_placeholder = array_ops.placeholder(tf_dtype,
2896                                                       shape=value.shape)
2897            assign_op = x.assign(assign_placeholder)
2898            x._assign_placeholder = assign_placeholder
2899            x._assign_op = assign_op
2900          assign_ops.append(assign_op)
2901          feed_dict[assign_placeholder] = value
2902        get_session().run(assign_ops, feed_dict=feed_dict)
2903
2904
2905@keras_export('keras.backend.print_tensor')
2906def print_tensor(x, message=''):
2907  """Prints `message` and the tensor value when evaluated.
2908
2909  Note that `print_tensor` returns a new tensor identical to `x`
2910  which should be used in the following code. Otherwise the
2911  print operation is not taken into account during evaluation.
2912
2913  Example:
2914
2915  ```python
2916     >>> x = K.print_tensor(x, message="x is: ")
2917  ```
2918
2919  Arguments:
2920      x: Tensor to print.
2921      message: Message to print jointly with the tensor.
2922
2923  Returns:
2924      The same tensor `x`, unchanged.
2925  """
2926  return logging_ops.Print(x, [x], message)
2927
2928
2929# GRAPH MANIPULATION
2930
2931
2932class GraphExecutionFunction(object):
2933  """Runs a computation graph.
2934
2935  It's possible to pass arguments to `tf.Session.run()` via `session_kwargs`.
2936  In particular additional operations via `fetches` argument and additional
2937  tensor substitutions via `feed_dict` arguments. Note that given
2938  substitutions are merged with substitutions from `inputs`. Even though
2939  `feed_dict` is passed once in the constructor (called in `model.compile()`)
2940  we can modify the values in the dictionary. Through this feed_dict we can
2941  provide additional substitutions besides Keras inputs.
2942
2943  Arguments:
2944      inputs: Feed placeholders to the computation graph.
2945      outputs: Output tensors to fetch.
2946      updates: Additional update ops to be run at function call.
2947      name: A name to help users identify what this function does.
2948      session_kwargs: Arguments to `tf.Session.run()`:
2949                      `fetches`, `feed_dict`, `options`, `run_metadata`.
2950  """
2951
2952  def __init__(self, inputs, outputs, updates=None, name=None,
2953               **session_kwargs):
2954    updates = updates or []
2955    if not isinstance(updates, (list, tuple)):
2956      raise TypeError('`updates` in a Keras backend function '
2957                      'should be a list or tuple.')
2958    self.inputs = nest.flatten(inputs)
2959    self._outputs_structure = outputs
2960    self.outputs = cast_variables_to_tensor(nest.flatten(outputs))
2961    # TODO(b/127668432): Consider using autograph to generate these
2962    # dependencies in call.
2963    # Index 0 = total loss or model output for `predict`.
2964    with ops.control_dependencies([self.outputs[0]]):
2965      updates_ops = []
2966      for update in updates:
2967        if isinstance(update, tuple):
2968          p, new_p = update
2969          updates_ops.append(state_ops.assign(p, new_p))
2970        else:
2971          # assumed already an op
2972          updates_ops.append(update)
2973      self.updates_op = control_flow_ops.group(*updates_ops)
2974    self.name = name
2975    # additional tensor substitutions
2976    self.feed_dict = session_kwargs.pop('feed_dict', None)
2977    # additional operations
2978    self.fetches = session_kwargs.pop('fetches', [])
2979    if not isinstance(self.fetches, list):
2980      self.fetches = [self.fetches]
2981    self.run_options = session_kwargs.pop('options', None)
2982    self.run_metadata = session_kwargs.pop('run_metadata', None)
2983    # The main use case of `fetches` being passed to a model is the ability
2984    # to run custom updates
2985    # This requires us to wrap fetches in `identity` ops.
2986    self.fetches = [array_ops.identity(x) for x in self.fetches]
2987    self.session_kwargs = session_kwargs
2988    # This mapping keeps track of the function that should receive the
2989    # output from a fetch in `fetches`: { fetch: function(fetch_output) }
2990    # A Callback can use this to register a function with access to the
2991    # output values for a fetch it added.
2992    self.fetch_callbacks = dict()
2993
2994    if session_kwargs:
2995      raise ValueError('Some keys in session_kwargs are not supported at this '
2996                       'time: %s' % (session_kwargs.keys(),))
2997
2998    self._callable_fn = None
2999    self._feed_arrays = None
3000    self._feed_symbols = None
3001    self._symbol_vals = None
3002    self._fetches = None
3003    self._session = None
3004
3005  def _make_callable(self, feed_arrays, feed_symbols, symbol_vals, session):
3006    """Generates a callable that runs the graph.
3007
3008    Arguments:
3009      feed_arrays: List of input tensors to be fed Numpy arrays at runtime.
3010      feed_symbols: List of input tensors to be fed symbolic tensors at runtime.
3011      symbol_vals: List of symbolic tensors to be fed to `feed_symbols`.
3012      session: Session to use to generate the callable.
3013
3014    Returns:
3015      Function that runs the graph according to the above options.
3016    """
3017    # Prepare callable options.
3018    callable_opts = config_pb2.CallableOptions()
3019    # Handle external-data feed.
3020    for x in feed_arrays:
3021      callable_opts.feed.append(x.name)
3022    if self.feed_dict:
3023      for key in sorted(self.feed_dict.keys()):
3024        callable_opts.feed.append(key.name)
3025    # Handle symbolic feed.
3026    for x, y in zip(feed_symbols, symbol_vals):
3027      connection = callable_opts.tensor_connection.add()
3028      if x.dtype != y.dtype:
3029        y = math_ops.cast(y, dtype=x.dtype)
3030      from_tensor = ops._as_graph_element(y)
3031      if from_tensor is None:
3032        from_tensor = y
3033      connection.from_tensor = from_tensor.name  # Data tensor
3034      connection.to_tensor = x.name  # Placeholder
3035    # Handle fetches.
3036    for x in self.outputs + self.fetches:
3037      callable_opts.fetch.append(x.name)
3038    # Handle updates.
3039    callable_opts.target.append(self.updates_op.name)
3040    # Handle run_options.
3041    if self.run_options:
3042      callable_opts.run_options.CopyFrom(self.run_options)
3043    # Create callable.
3044    callable_fn = session._make_callable_from_options(callable_opts)
3045    # Cache parameters corresponding to the generated callable, so that
3046    # we can detect future mismatches and refresh the callable.
3047    self._callable_fn = callable_fn
3048    self._feed_arrays = feed_arrays
3049    self._feed_symbols = feed_symbols
3050    self._symbol_vals = symbol_vals
3051    self._fetches = list(self.fetches)
3052    self._session = session
3053
3054  def _call_fetch_callbacks(self, fetches_output):
3055    for fetch, output in zip(self._fetches, fetches_output):
3056      if fetch in self.fetch_callbacks:
3057        self.fetch_callbacks[fetch](output)
3058
3059  def __call__(self, inputs):
3060    inputs = nest.flatten(inputs)
3061
3062    session = get_session(inputs)
3063    feed_arrays = []
3064    array_vals = []
3065    feed_symbols = []
3066    symbol_vals = []
3067    for tensor, value in zip(self.inputs, inputs):
3068      if value is None:
3069        continue
3070      if is_sparse(tensor):
3071        sparse_coo = value.tocoo()
3072        indices = np.concatenate((np.expand_dims(sparse_coo.row, 1),
3073                                  np.expand_dims(sparse_coo.col, 1)), 1)
3074        value = (indices, sparse_coo.data, sparse_coo.shape)
3075      if tensor_util.is_tensor(value):
3076        # Case: feeding symbolic tensor.
3077        feed_symbols.append(tensor)
3078        symbol_vals.append(value)
3079      else:
3080        # Case: feeding Numpy array.
3081        feed_arrays.append(tensor)
3082        # We need to do array conversion and type casting at this level, since
3083        # `callable_fn` only supports exact matches.
3084        tensor_type = dtypes_module.as_dtype(tensor.dtype)
3085        array_vals.append(np.asarray(value,
3086                                     dtype=tensor_type.as_numpy_dtype))
3087
3088    if self.feed_dict:
3089      for key in sorted(self.feed_dict.keys()):
3090        array_vals.append(
3091            np.asarray(self.feed_dict[key], dtype=key.dtype.base_dtype.name))
3092
3093    # Refresh callable if anything has changed.
3094    if (self._callable_fn is None or feed_arrays != self._feed_arrays or
3095        symbol_vals != self._symbol_vals or
3096        feed_symbols != self._feed_symbols or self.fetches != self._fetches or
3097        session != self._session):
3098      self._make_callable(feed_arrays, feed_symbols, symbol_vals, session)
3099
3100    fetched = self._callable_fn(*array_vals,
3101                                run_metadata=self.run_metadata)
3102    self._call_fetch_callbacks(fetched[-len(self._fetches):])
3103    return nest.pack_sequence_as(self._outputs_structure,
3104                                 fetched[:len(self.outputs)])
3105
3106
3107class EagerExecutionFunction(object):
3108  """Helper class for constructing a TF graph function from the Keras graph.
3109
3110  Arguments:
3111    inputs: Feed placeholders to the computation graph.
3112    outputs: Output tensors to fetch.
3113    updates: Additional update ops to be run at function call.
3114    name: A name to help users identify what this function does.
3115    session_kwargs: Unsupported.
3116  """
3117
3118  def __init__(self, inputs, outputs, updates=None, name=None):
3119    self.name = name
3120    self._outputs_structure = outputs
3121    inputs = nest.flatten(inputs)
3122    outputs = nest.flatten(outputs)
3123
3124    updates = updates or []
3125    if not isinstance(updates, (list, tuple)):
3126      raise TypeError('`updates` in a Keras backend function '
3127                      'should be a list or tuple.')
3128
3129    if updates and not outputs:
3130      # Edge case; never happens in practice
3131      raise ValueError('Cannot create a Keras backend function with updates'
3132                       ' but no outputs during eager execution.')
3133
3134    graphs = {i.graph for i in nest.flatten([inputs, outputs, updates])
3135              if hasattr(i, 'graph')}
3136    if len(graphs) > 1:
3137      raise ValueError('Cannot create an execution function which is comprised '
3138                       'of elements from multiple graphs.')
3139
3140    source_graph = graphs.pop()
3141    global_graph = get_graph()
3142
3143    updates_ops = []
3144    legacy_update_ops = []
3145    for update in updates:
3146      # For legacy reasons it is allowed to pass an update as a tuple
3147      # `(variable, new_value)` (this maps to an assign op). Otherwise it
3148      # is assumed to already be an op -- we cannot control its execution
3149      # order.
3150      if isinstance(update, tuple):
3151        legacy_update_ops.append(update)
3152      else:
3153        if hasattr(update, 'op'):
3154          update = update.op
3155        updates_ops.append(update)
3156
3157    with _scratch_graph() as exec_graph:
3158      global_graph = get_graph()
3159      if source_graph not in (exec_graph, global_graph):
3160        raise ValueError('Unknown graph. Aborting.')
3161
3162      if source_graph is global_graph and exec_graph is not global_graph:
3163        init_tensors = (
3164            outputs + updates_ops + [p for [p, _] in legacy_update_ops] +
3165            [p_new for [_, p_new] in legacy_update_ops
3166             if isinstance(p_new, ops.Tensor)])
3167        lifted_map = lift_to_graph.lift_to_graph(
3168            init_tensors=init_tensors, graph=exec_graph, sources=inputs,
3169            add_sources=True, handle_captures=True, base_graph=source_graph)
3170
3171        inputs = [lifted_map[i] for i in inputs]
3172        outputs = [lifted_map[i] for i in outputs]
3173        updates_ops = [lifted_map[i] for i in updates_ops]
3174        legacy_update_ops = [(lifted_map[p], lifted_map.get(p_new, p_new))
3175                             for p, p_new in legacy_update_ops]
3176
3177    # Consolidate updates
3178    with exec_graph.as_default():
3179      outputs = cast_variables_to_tensor(outputs)
3180      with ops.control_dependencies(outputs):
3181        for p, p_new in legacy_update_ops:
3182          updates_ops.append(state_ops.assign(p, p_new))
3183
3184      self.inputs, self.outputs = inputs, outputs
3185      with ops.control_dependencies(updates_ops):
3186        self.outputs[0] = array_ops.identity(self.outputs[0])
3187
3188      exec_graph.inputs = self.inputs + list(exec_graph.captures.values())
3189      exec_graph.outputs = self.outputs
3190      graph_fn = eager_function.ConcreteFunction(exec_graph)
3191
3192    graph_fn._num_positional_args = len(self.inputs)
3193    graph_fn._arg_keywords = []
3194    self._graph_fn = graph_fn
3195
3196    # Handle placeholders with default
3197    # (treated as required placeholder by graph functions)
3198    self._placeholder_default_values = {}
3199    with exec_graph.as_default():
3200      for x in self.inputs:
3201        if x.op.type == 'PlaceholderWithDefault':
3202          self._placeholder_default_values[x] = tensor_util.constant_value(
3203              x.op.inputs[0])
3204
3205  def __call__(self, inputs):
3206    inputs = nest.flatten(inputs)
3207    converted_inputs = []
3208    for tensor, value in zip(self.inputs, inputs):
3209      if value is None:
3210        # Assume `value` is a placeholder with default
3211        value = self._placeholder_default_values.get(tensor, None)
3212        if value is None:
3213          raise ValueError(
3214              'You must feed a value for placeholder %s' % (tensor,))
3215      if not isinstance(value, ops.Tensor):
3216        value = ops.convert_to_tensor(value, dtype=tensor.dtype)
3217      if value.dtype != tensor.dtype:
3218        # Temporary workaround due to `convert_to_tensor` not casting floats.
3219        # See b/119637405
3220        value = math_ops.cast(value, tensor.dtype)
3221      converted_inputs.append(value)
3222    outputs = self._graph_fn(*converted_inputs)
3223    return nest.pack_sequence_as(self._outputs_structure,
3224                                 [x.numpy() for x in outputs])
3225
3226
3227@keras_export('keras.backend.function')
3228def function(inputs, outputs, updates=None, name=None, **kwargs):
3229  """Instantiates a Keras function.
3230
3231  Arguments:
3232      inputs: List of placeholder tensors.
3233      outputs: List of output tensors.
3234      updates: List of update ops.
3235      name: String, name of function.
3236      **kwargs: Passed to `tf.Session.run`.
3237
3238  Returns:
3239      Output values as Numpy arrays.
3240
3241  Raises:
3242      ValueError: if invalid kwargs are passed in or if in eager execution.
3243  """
3244  if ops.executing_eagerly_outside_functions():
3245    if kwargs:
3246      raise ValueError('Session keyword arguments are not support during '
3247                       'eager execution. You passed: %s' % (kwargs,))
3248    return EagerExecutionFunction(inputs, outputs, updates=updates, name=name)
3249
3250  if kwargs:
3251    for key in kwargs:
3252      if (key not in tf_inspect.getfullargspec(session_module.Session.run)[0]
3253          and key not in ['inputs', 'outputs', 'updates', 'name']):
3254        msg = ('Invalid argument "%s" passed to K.function with TensorFlow '
3255               'backend') % key
3256        raise ValueError(msg)
3257  return GraphExecutionFunction(inputs, outputs, updates=updates, **kwargs)
3258
3259
3260@keras_export('keras.backend.gradients')
3261def gradients(loss, variables):
3262  """Returns the gradients of `loss` w.r.t. `variables`.
3263
3264  Arguments:
3265      loss: Scalar tensor to minimize.
3266      variables: List of variables.
3267
3268  Returns:
3269      A gradients tensor.
3270  """
3271  return gradients_module.gradients(
3272      loss, variables, colocate_gradients_with_ops=True)
3273
3274
3275@keras_export('keras.backend.stop_gradient')
3276def stop_gradient(variables):
3277  """Returns `variables` but with zero gradient w.r.t. every other variable.
3278
3279  Arguments:
3280      variables: Tensor or list of tensors to consider constant with respect
3281        to any other variable.
3282
3283
3284  Returns:
3285      A single tensor or a list of tensors (depending on the passed argument)
3286      that has no gradient with respect to any other variable.
3287  """
3288  if isinstance(variables, (list, tuple)):
3289    return map(array_ops.stop_gradient, variables)
3290  return array_ops.stop_gradient(variables)
3291
3292
3293# CONTROL FLOW
3294
3295
3296@keras_export('keras.backend.rnn')
3297def rnn(step_function,
3298        inputs,
3299        initial_states,
3300        go_backwards=False,
3301        mask=None,
3302        constants=None,
3303        unroll=False,
3304        input_length=None,
3305        time_major=False,
3306        zero_output_for_mask=False):
3307  """Iterates over the time dimension of a tensor.
3308
3309  Arguments:
3310      step_function: RNN step function.
3311          Args;
3312              input; Tensor with shape `(samples, ...)` (no time dimension),
3313                  representing input for the batch of samples at a certain
3314                  time step.
3315              states; List of tensors.
3316          Returns;
3317              output; Tensor with shape `(samples, output_dim)`
3318                  (no time dimension).
3319              new_states; List of tensors, same length and shapes
3320                  as 'states'. The first state in the list must be the
3321                  output tensor at the previous timestep.
3322      inputs: Tensor of temporal data of shape `(samples, time, ...)`
3323          (at least 3D), or nested tensors, and each of which has shape
3324          `(samples, time, ...)`.
3325      initial_states: Tensor with shape `(samples, state_size)`
3326          (no time dimension), containing the initial values for the states used
3327          in the step function. In the case that state_size is in a nested
3328          shape, the shape of initial_states will also follow the nested
3329          structure.
3330      go_backwards: Boolean. If True, do the iteration over the time
3331          dimension in reverse order and return the reversed sequence.
3332      mask: Binary tensor with shape `(samples, time, 1)`,
3333          with a zero for every element that is masked.
3334      constants: List of constant values passed at each step.
3335      unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
3336      input_length: If specified, assume time dimension is of this length.
3337      time_major: Boolean. If true, the inputs and outputs will be in shape
3338          `(timesteps, batch, ...)`, whereas in the False case, it will be
3339          `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
3340          efficient because it avoids transposes at the beginning and end of the
3341          RNN calculation. However, most TensorFlow data is batch-major, so by
3342          default this function accepts input and emits output in batch-major
3343          form.
3344      zero_output_for_mask: Boolean. If True, the output for masked timestep
3345          will be zeros, whereas in the False case, output from previous
3346          timestep is returned.
3347  Returns:
3348      A tuple, `(last_output, outputs, new_states)`.
3349          last_output: the latest output of the rnn, of shape `(samples, ...)`
3350          outputs: tensor with shape `(samples, time, ...)` where each
3351              entry `outputs[s, t]` is the output of the step function
3352              at time `t` for sample `s`.
3353          new_states: list of tensors, latest states returned by
3354              the step function, of shape `(samples, ...)`.
3355
3356  Raises:
3357      ValueError: if input dimension is less than 3.
3358      ValueError: if `unroll` is `True` but input timestep is not a fixed
3359      number.
3360      ValueError: if `mask` is provided (not `None`) but states is not provided
3361          (`len(states)` == 0).
3362  """
3363
3364  def swap_batch_timestep(input_t):
3365    # Swap the batch and timestep dim for the incoming tensor.
3366    axes = list(range(len(input_t.shape)))
3367    axes[0], axes[1] = 1, 0
3368    return array_ops.transpose(input_t, axes)
3369
3370  if not time_major:
3371    inputs = nest.map_structure(swap_batch_timestep, inputs)
3372
3373  flatted_inputs = nest.flatten(inputs)
3374  time_steps = flatted_inputs[0].shape[0]
3375  batch = flatted_inputs[0].shape[1]
3376  time_steps_t = array_ops.shape(flatted_inputs[0])[0]
3377
3378  for input_ in flatted_inputs:
3379    input_.get_shape().with_rank_at_least(3)
3380
3381  if mask is not None:
3382    if mask.dtype != dtypes_module.bool:
3383      mask = math_ops.cast(mask, dtypes_module.bool)
3384    if len(mask.shape) == 2:
3385      mask = expand_dims(mask)
3386    if not time_major:
3387      mask = swap_batch_timestep(mask)
3388
3389  if constants is None:
3390    constants = []
3391
3392  # tf.where needs its condition tensor to be the same shape as its two
3393  # result tensors, but in our case the condition (mask) tensor is
3394  # (nsamples, 1), and inputs are (nsamples, ndimensions) or even more.
3395  # So we need to broadcast the mask to match the shape of inputs.
3396  # That's what the tile call does, it just repeats the mask along its
3397  # second dimension n times.
3398  def _expand_mask(mask_t, input_t, fixed_dim=1):
3399    assert not nest.is_sequence(mask_t)
3400    assert not nest.is_sequence(input_t)
3401    rank_diff = len(input_t.shape) - len(mask_t.shape)
3402    for _ in range(rank_diff):
3403      mask_t = array_ops.expand_dims(mask_t, -1)
3404    multiples = [1] * fixed_dim + input_t.shape.as_list()[fixed_dim:]
3405    return array_ops.tile(mask_t, multiples)
3406
3407  if unroll:
3408    if not time_steps:
3409      raise ValueError('Unrolling requires a fixed number of timesteps.')
3410    states = tuple(initial_states)
3411    successive_states = []
3412    successive_outputs = []
3413
3414    # Process the input tensors. The input tensor need to be split on the
3415    # time_step dim, and reverse if go_backwards is True. In the case of nested
3416    # input, the input is flattened and then transformed individually.
3417    # The result of this will be a tuple of lists, each of the item in tuple is
3418    # list of the tensor with shape (batch, feature)
3419    def _process_single_input_t(input_t):
3420      input_t = array_ops.unstack(input_t)  # unstack for time_step dim
3421      if go_backwards:
3422        input_t.reverse()
3423      return input_t
3424
3425    if nest.is_sequence(inputs):
3426      processed_input = nest.map_structure(_process_single_input_t, inputs)
3427    else:
3428      processed_input = (_process_single_input_t(inputs),)
3429
3430    def _get_input_tensor(time):
3431      inp = [t_[time] for t_ in processed_input]
3432      return nest.pack_sequence_as(inputs, inp)
3433
3434    if mask is not None:
3435      mask_list = array_ops.unstack(mask)
3436      if go_backwards:
3437        mask_list.reverse()
3438
3439      for i in range(time_steps):
3440        inp = _get_input_tensor(i)
3441        mask_t = mask_list[i]
3442        output, new_states = step_function(inp,
3443                                           tuple(states) + tuple(constants))
3444        tiled_mask_t = _expand_mask(mask_t, output)
3445
3446        if not successive_outputs:
3447          prev_output = zeros_like(output)
3448        else:
3449          prev_output = successive_outputs[-1]
3450
3451        output = array_ops.where(tiled_mask_t, output, prev_output)
3452
3453        return_states = []
3454        for state, new_state in zip(states, new_states):
3455          # (see earlier comment for tile explanation)
3456          tiled_mask_t = _expand_mask(mask_t, new_state)
3457          return_states.append(array_ops.where(tiled_mask_t, new_state, state))
3458        states = return_states
3459        successive_outputs.append(output)
3460        successive_states.append(states)
3461      last_output = successive_outputs[-1]
3462      new_states = successive_states[-1]
3463      outputs = array_ops.stack(successive_outputs)
3464
3465      if zero_output_for_mask:
3466        last_output = array_ops.where(
3467            _expand_mask(mask_list[-1], last_output),
3468            last_output,
3469            zeros_like(last_output))
3470        outputs = array_ops.where(
3471            _expand_mask(mask, outputs, fixed_dim=2),
3472            outputs,
3473            zeros_like(outputs))
3474
3475    else:
3476      for i in range(time_steps):
3477        inp = _get_input_tensor(i)
3478        output, states = step_function(inp, tuple(states) + tuple(constants))
3479        successive_outputs.append(output)
3480        successive_states.append(states)
3481      last_output = successive_outputs[-1]
3482      new_states = successive_states[-1]
3483      outputs = array_ops.stack(successive_outputs)
3484
3485  else:
3486    states = tuple(initial_states)
3487
3488    # Create input tensor array, if the inputs is nested tensors, then it will
3489    # be flattened first, and tensor array will be created one per flattened
3490    # tensor.
3491    input_ta = tuple(
3492        tensor_array_ops.TensorArray(
3493            dtype=inp.dtype,
3494            size=time_steps_t,
3495            tensor_array_name='input_ta_%s' % i)
3496        for i, inp in enumerate(flatted_inputs))
3497    input_ta = tuple(
3498        ta.unstack(input_) if not go_backwards else ta
3499        .unstack(reverse(input_, 0))
3500        for ta, input_ in zip(input_ta, flatted_inputs))
3501
3502    # Get the time(0) input and compute the output for that, the output will be
3503    # used to determine the dtype of output tensor array. Don't read from
3504    # input_ta due to TensorArray clear_after_read default to True.
3505    input_time_zero = nest.pack_sequence_as(inputs,
3506                                            [inp[0] for inp in flatted_inputs])
3507    # output_time_zero is used to determine the cell output shape and its dtype.
3508    # the value is discarded.
3509    output_time_zero, _ = step_function(input_time_zero,
3510                                        initial_states + constants)
3511    output_ta = tuple(
3512        tensor_array_ops.TensorArray(
3513            dtype=out.dtype,
3514            size=time_steps_t,
3515            tensor_array_name='output_ta_%s' % i)
3516        for i, out in enumerate(nest.flatten(output_time_zero)))
3517
3518    time = constant_op.constant(0, dtype='int32', name='time')
3519
3520    while_loop_kwargs = {
3521        'cond': lambda time, *_: time < time_steps_t,
3522        'maximum_iterations': input_length,
3523        'parallel_iterations': 32,
3524        'swap_memory': True,
3525    }
3526
3527    if mask is not None:
3528      if not states:
3529        raise ValueError('No initial states provided! '
3530                         'When using masking in an RNN, you should '
3531                         'provide initial states '
3532                         '(and your step function should return '
3533                         'as its first state at time `t` '
3534                         'the output at time `t-1`).')
3535      if go_backwards:
3536        mask = reverse(mask, 0)
3537
3538      mask_ta = tensor_array_ops.TensorArray(
3539          dtype=dtypes_module.bool,
3540          size=time_steps_t,
3541          tensor_array_name='mask_ta')
3542      mask_ta = mask_ta.unstack(mask)
3543
3544      # Mask for the T output will be base on the output of T - 1. In the case
3545      # T = 0, a zero filled tensor will be used.
3546      flat_zero_output = tuple(array_ops.zeros_like(o)
3547                               for o in nest.flatten(output_time_zero))
3548      def _step(time, output_ta_t, prev_output, *states):
3549        """RNN step function.
3550
3551        Arguments:
3552            time: Current timestep value.
3553            output_ta_t: TensorArray.
3554            prev_output: tuple of outputs from time - 1.
3555            *states: List of states.
3556
3557        Returns:
3558            Tuple: `(time + 1, output_ta_t, output) + tuple(new_states)`
3559        """
3560        current_input = tuple(ta.read(time) for ta in input_ta)
3561        # maybe set shape.
3562        current_input = nest.pack_sequence_as(inputs, current_input)
3563        mask_t = mask_ta.read(time)
3564        output, new_states = step_function(current_input,
3565                                           tuple(states) + tuple(constants))
3566        # mask output
3567        flat_output = nest.flatten(output)
3568        flat_mask_output = (flat_zero_output if zero_output_for_mask
3569                            else nest.flatten(prev_output))
3570        tiled_mask_t = tuple(_expand_mask(mask_t, o) for o in flat_output)
3571        flat_new_output = tuple(
3572            array_ops.where(m, o, zo) for m, o, zo in zip(
3573                tiled_mask_t, flat_output, flat_mask_output))
3574
3575        # mask states
3576        flat_state = nest.flatten(states)
3577        flat_new_state = nest.flatten(new_states)
3578        for state, new_state in zip(flat_state, flat_new_state):
3579          new_state.set_shape(state.shape)
3580        tiled_mask_t = tuple(_expand_mask(mask_t, s) for s in flat_state)
3581        flat_final_state = tuple(
3582            array_ops.where(m, s, ps)
3583            for m, s, ps in zip(tiled_mask_t, flat_new_state, flat_state))
3584        new_states = nest.pack_sequence_as(new_states, flat_final_state)
3585
3586        output_ta_t = tuple(
3587            ta.write(time, out)
3588            for ta, out in zip(output_ta_t, flat_new_output))
3589        return (time + 1, output_ta_t,
3590                tuple(flat_new_output)) + tuple(new_states)
3591
3592      final_outputs = control_flow_ops.while_loop(
3593          body=_step,
3594          loop_vars=(time, output_ta, flat_zero_output) + states,
3595          **while_loop_kwargs)
3596      # Skip final_outputs[2] which is the output for final timestep.
3597      new_states = final_outputs[3:]
3598    else:
3599      def _step(time, output_ta_t, *states):
3600        """RNN step function.
3601
3602        Arguments:
3603            time: Current timestep value.
3604            output_ta_t: TensorArray.
3605            *states: List of states.
3606
3607        Returns:
3608            Tuple: `(time + 1,output_ta_t) + tuple(new_states)`
3609        """
3610        current_input = tuple(ta.read(time) for ta in input_ta)
3611        current_input = nest.pack_sequence_as(inputs, current_input)
3612        output, new_states = step_function(current_input,
3613                                           tuple(states) + tuple(constants))
3614        flat_state = nest.flatten(states)
3615        flat_new_state = nest.flatten(new_states)
3616        for state, new_state in zip(flat_state, flat_new_state):
3617          new_state.set_shape(state.shape)
3618
3619        flat_output = nest.flatten(output)
3620        output_ta_t = tuple(
3621            ta.write(time, out) for ta, out in zip(output_ta_t, flat_output))
3622        new_states = nest.pack_sequence_as(initial_states, flat_new_state)
3623        return (time + 1, output_ta_t) + tuple(new_states)
3624
3625      final_outputs = control_flow_ops.while_loop(
3626          body=_step,
3627          loop_vars=(time, output_ta) + states,
3628          **while_loop_kwargs)
3629      new_states = final_outputs[2:]
3630
3631    output_ta = final_outputs[1]
3632
3633    outputs = tuple(o.stack() for o in output_ta)
3634    last_output = tuple(o[-1] for o in outputs)
3635
3636    outputs = nest.pack_sequence_as(output_time_zero, outputs)
3637    last_output = nest.pack_sequence_as(output_time_zero, last_output)
3638
3639  # static shape inference
3640  def set_shape(output_):
3641    shape = output_.shape.as_list()
3642    shape[0] = time_steps
3643    shape[1] = batch
3644    output_.set_shape(shape)
3645    return output_
3646
3647  outputs = nest.map_structure(set_shape, outputs)
3648
3649  if not time_major:
3650    outputs = nest.map_structure(swap_batch_timestep, outputs)
3651
3652  return last_output, outputs, new_states
3653
3654
3655@keras_export('keras.backend.switch')
3656def switch(condition, then_expression, else_expression):
3657  """Switches between two operations depending on a scalar value.
3658
3659  Note that both `then_expression` and `else_expression`
3660  should be symbolic tensors of the *same shape*.
3661
3662  Arguments:
3663      condition: tensor (`int` or `bool`).
3664      then_expression: either a tensor, or a callable that returns a tensor.
3665      else_expression: either a tensor, or a callable that returns a tensor.
3666
3667  Returns:
3668      The selected tensor.
3669
3670  Raises:
3671      ValueError: If rank of `condition` is greater than rank of expressions.
3672  """
3673  if condition.dtype != dtypes_module.bool:
3674    condition = math_ops.cast(condition, 'bool')
3675  cond_ndim = ndim(condition)
3676  if not cond_ndim:
3677    if not callable(then_expression):
3678
3679      def then_expression_fn():
3680        return then_expression
3681    else:
3682      then_expression_fn = then_expression
3683    if not callable(else_expression):
3684
3685      def else_expression_fn():
3686        return else_expression
3687    else:
3688      else_expression_fn = else_expression
3689    x = control_flow_ops.cond(condition, then_expression_fn, else_expression_fn)
3690  else:
3691    # tf.where needs its condition tensor
3692    # to be the same shape as its two
3693    # result tensors
3694    if callable(then_expression):
3695      then_expression = then_expression()
3696    if callable(else_expression):
3697      else_expression = else_expression()
3698    expr_ndim = ndim(then_expression)
3699    if cond_ndim > expr_ndim:
3700      raise ValueError('Rank of `condition` should be less than or'
3701                       ' equal to rank of `then_expression` and '
3702                       '`else_expression`. ndim(condition)=' + str(cond_ndim) +
3703                       ', ndim(then_expression)'
3704                       '=' + str(expr_ndim))
3705    if cond_ndim > 1:
3706      ndim_diff = expr_ndim - cond_ndim
3707      cond_shape = array_ops.concat(
3708          [array_ops.shape(condition), [1] * ndim_diff], axis=0)
3709      condition = array_ops.reshape(condition, cond_shape)
3710      expr_shape = array_ops.shape(then_expression)
3711      shape_diff = expr_shape - cond_shape
3712      tile_shape = array_ops.where(shape_diff > 0, expr_shape,
3713                                   array_ops.ones_like(expr_shape))
3714      condition = array_ops.tile(condition, tile_shape)
3715    x = array_ops.where(condition, then_expression, else_expression)
3716  return x
3717
3718
3719@keras_export('keras.backend.in_train_phase')
3720def in_train_phase(x, alt, training=None):
3721  """Selects `x` in train phase, and `alt` otherwise.
3722
3723  Note that `alt` should have the *same shape* as `x`.
3724
3725  Arguments:
3726      x: What to return in train phase
3727          (tensor or callable that returns a tensor).
3728      alt: What to return otherwise
3729          (tensor or callable that returns a tensor).
3730      training: Optional scalar tensor
3731          (or Python boolean, or Python integer)
3732          specifying the learning phase.
3733
3734  Returns:
3735      Either `x` or `alt` based on the `training` flag.
3736      the `training` flag defaults to `K.learning_phase()`.
3737  """
3738  if training is None:
3739    training = learning_phase()
3740
3741  if training == 1 or training is True:
3742    if callable(x):
3743      return x()
3744    else:
3745      return x
3746
3747  elif training == 0 or training is False:
3748    if callable(alt):
3749      return alt()
3750    else:
3751      return alt
3752
3753  # else: assume learning phase is a placeholder tensor.
3754  x = switch(training, x, alt)
3755  return x
3756
3757
3758@keras_export('keras.backend.in_test_phase')
3759def in_test_phase(x, alt, training=None):
3760  """Selects `x` in test phase, and `alt` otherwise.
3761
3762  Note that `alt` should have the *same shape* as `x`.
3763
3764  Arguments:
3765      x: What to return in test phase
3766          (tensor or callable that returns a tensor).
3767      alt: What to return otherwise
3768          (tensor or callable that returns a tensor).
3769      training: Optional scalar tensor
3770          (or Python boolean, or Python integer)
3771          specifying the learning phase.
3772
3773  Returns:
3774      Either `x` or `alt` based on `K.learning_phase`.
3775  """
3776  return in_train_phase(alt, x, training=training)
3777
3778
3779# NN OPERATIONS
3780
3781
3782@keras_export('keras.backend.relu')
3783def relu(x, alpha=0., max_value=None, threshold=0):
3784  """Rectified linear unit.
3785
3786  With default values, it returns element-wise `max(x, 0)`.
3787
3788  Otherwise, it follows:
3789  `f(x) = max_value` for `x >= max_value`,
3790  `f(x) = x` for `threshold <= x < max_value`,
3791  `f(x) = alpha * (x - threshold)` otherwise.
3792
3793  Arguments:
3794      x: A tensor or variable.
3795      alpha: A scalar, slope of negative section (default=`0.`).
3796      max_value: float. Saturation threshold.
3797      threshold: float. Threshold value for thresholded activation.
3798
3799  Returns:
3800      A tensor.
3801  """
3802
3803  if alpha != 0.:
3804    if max_value is None and threshold == 0:
3805      return nn.leaky_relu(x, alpha=alpha)
3806
3807    if threshold != 0:
3808      negative_part = nn.relu(-x + threshold)
3809    else:
3810      negative_part = nn.relu(-x)
3811
3812  clip_max = max_value is not None
3813
3814  if threshold != 0:
3815    # computes x for x > threshold else 0
3816    x = x * math_ops.cast(math_ops.greater(x, threshold), floatx())
3817  elif max_value == 6:
3818    # if no threshold, then can use nn.relu6 native TF op for performance
3819    x = nn.relu6(x)
3820    clip_max = False
3821  else:
3822    x = nn.relu(x)
3823
3824  if clip_max:
3825    max_value = _to_tensor(max_value, x.dtype.base_dtype)
3826    zero = _to_tensor(0., x.dtype.base_dtype)
3827    x = clip_ops.clip_by_value(x, zero, max_value)
3828
3829  if alpha != 0.:
3830    alpha = _to_tensor(alpha, x.dtype.base_dtype)
3831    x -= alpha * negative_part
3832  return x
3833
3834
3835@keras_export('keras.backend.elu')
3836def elu(x, alpha=1.):
3837  """Exponential linear unit.
3838
3839  Arguments:
3840      x: A tensor or variable to compute the activation function for.
3841      alpha: A scalar, slope of negative section.
3842
3843  Returns:
3844      A tensor.
3845  """
3846  res = nn.elu(x)
3847  if alpha == 1:
3848    return res
3849  else:
3850    return array_ops.where(x > 0, res, alpha * res)
3851
3852
3853@keras_export('keras.backend.softmax')
3854def softmax(x, axis=-1):
3855  """Softmax of a tensor.
3856
3857  Arguments:
3858      x: A tensor or variable.
3859      axis: The dimension softmax would be performed on.
3860          The default is -1 which indicates the last dimension.
3861
3862  Returns:
3863      A tensor.
3864  """
3865  return nn.softmax(x, axis=axis)
3866
3867
3868@keras_export('keras.backend.softplus')
3869def softplus(x):
3870  """Softplus of a tensor.
3871
3872  Arguments:
3873      x: A tensor or variable.
3874
3875  Returns:
3876      A tensor.
3877  """
3878  return nn.softplus(x)
3879
3880
3881@keras_export('keras.backend.softsign')
3882def softsign(x):
3883  """Softsign of a tensor.
3884
3885  Arguments:
3886      x: A tensor or variable.
3887
3888  Returns:
3889      A tensor.
3890  """
3891  return nn.softsign(x)
3892
3893
3894@keras_export('keras.backend.categorical_crossentropy')
3895def categorical_crossentropy(target, output, from_logits=False, axis=-1):
3896  """Categorical crossentropy between an output tensor and a target tensor.
3897
3898  Arguments:
3899      target: A tensor of the same shape as `output`.
3900      output: A tensor resulting from a softmax
3901          (unless `from_logits` is True, in which
3902          case `output` is expected to be the logits).
3903      from_logits: Boolean, whether `output` is the
3904          result of a softmax, or is a tensor of logits.
3905      axis: Int specifying the channels axis. `axis=-1` corresponds to data
3906          format `channels_last', and `axis=1` corresponds to data format
3907          `channels_first`.
3908
3909  Returns:
3910      Output tensor.
3911
3912  Raises:
3913      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
3914  """
3915  if not from_logits:
3916    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
3917        output.op.type != 'Softmax'):
3918      axis = axis % len(output.shape)
3919      # scale preds so that the class probas of each sample sum to 1
3920      output = output / math_ops.reduce_sum(output, axis, True)
3921      # Compute cross entropy from probabilities.
3922      epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
3923      output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
3924      return -math_ops.reduce_sum(target * math_ops.log(output), axis)
3925    else:
3926      # When softmax activation function is used for output operation, we
3927      # use logits from the softmax function directly to compute loss in order
3928      # to prevent collapsing zero when training.
3929      # See b/117284466
3930      assert len(output.op.inputs) == 1
3931      output = output.op.inputs[0]
3932  return nn.softmax_cross_entropy_with_logits_v2(labels=target, logits=output)
3933
3934
3935@keras_export('keras.backend.sparse_categorical_crossentropy')
3936def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1):
3937  """Categorical crossentropy with integer targets.
3938
3939  Arguments:
3940      target: An integer tensor.
3941      output: A tensor resulting from a softmax
3942          (unless `from_logits` is True, in which
3943          case `output` is expected to be the logits).
3944      from_logits: Boolean, whether `output` is the
3945          result of a softmax, or is a tensor of logits.
3946      axis: Int specifying the channels axis. `axis=-1` corresponds to data
3947          format `channels_last', and `axis=1` corresponds to data format
3948          `channels_first`.
3949
3950  Returns:
3951      Output tensor.
3952
3953  Raises:
3954      ValueError: if `axis` is neither -1 nor one of the axes of `output`.
3955  """
3956  if not from_logits:
3957    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
3958        output.op.type != 'Softmax'):
3959      epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
3960      output = clip_ops.clip_by_value(output, epsilon_, 1 - epsilon_)
3961      output = math_ops.log(output)
3962    else:
3963      # When softmax activation function is used for output operation, we
3964      # use logits from the softmax function directly to compute loss in order
3965      # to prevent collapsing zero when training.
3966      # See b/117284466
3967      assert len(output.op.inputs) == 1
3968      output = output.op.inputs[0]
3969
3970  rank = len(output.shape)
3971  axis = axis % rank
3972  if axis != rank - 1:
3973    permutation = list(range(axis)) + list(range(axis + 1, rank)) + [axis]
3974    output = array_ops.transpose(output, perm=permutation)
3975
3976  output_shape = output.shape
3977  targets = cast(flatten(target), 'int64')
3978  logits = array_ops.reshape(output, [-1, int(output_shape[-1])])
3979  res = nn.sparse_softmax_cross_entropy_with_logits(
3980      labels=targets, logits=logits)
3981  if len(output_shape) >= 3:
3982    # If our output includes timesteps or spatial dimensions we need to reshape
3983    return array_ops.reshape(res, array_ops.shape(output)[:-1])
3984  else:
3985    return res
3986
3987
3988@keras_export('keras.backend.binary_crossentropy')
3989def binary_crossentropy(target, output, from_logits=False):
3990  """Binary crossentropy between an output tensor and a target tensor.
3991
3992  Arguments:
3993      target: A tensor with the same shape as `output`.
3994      output: A tensor.
3995      from_logits: Whether `output` is expected to be a logits tensor.
3996          By default, we consider that `output`
3997          encodes a probability distribution.
3998
3999  Returns:
4000      A tensor.
4001  """
4002  if not from_logits:
4003    if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
4004        output.op.type != 'Sigmoid'):
4005      epsilon_ = _to_tensor(epsilon(), output.dtype.base_dtype)
4006      output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
4007
4008      # Compute cross entropy from probabilities.
4009      bce = target * math_ops.log(output + epsilon())
4010      bce += (1 - target) * math_ops.log(1 - output + epsilon())
4011      return -bce
4012    else:
4013      # When sigmoid activation function is used for output operation, we
4014      # use logits from the sigmoid function directly to compute loss in order
4015      # to prevent collapsing zero when training.
4016      assert len(output.op.inputs) == 1
4017      output = output.op.inputs[0]
4018  return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output)
4019
4020
4021@keras_export('keras.backend.sigmoid')
4022def sigmoid(x):
4023  """Element-wise sigmoid.
4024
4025  Arguments:
4026      x: A tensor or variable.
4027
4028  Returns:
4029      A tensor.
4030  """
4031  return nn.sigmoid(x)
4032
4033
4034@keras_export('keras.backend.hard_sigmoid')
4035def hard_sigmoid(x):
4036  """Segment-wise linear approximation of sigmoid.
4037
4038  Faster than sigmoid.
4039  Returns `0.` if `x < -2.5`, `1.` if `x > 2.5`.
4040  In `-2.5 <= x <= 2.5`, returns `0.2 * x + 0.5`.
4041
4042  Arguments:
4043      x: A tensor or variable.
4044
4045  Returns:
4046      A tensor.
4047  """
4048  x = (0.2 * x) + 0.5
4049  zero = _to_tensor(0., x.dtype.base_dtype)
4050  one = _to_tensor(1., x.dtype.base_dtype)
4051  x = clip_ops.clip_by_value(x, zero, one)
4052  return x
4053
4054
4055@keras_export('keras.backend.tanh')
4056def tanh(x):
4057  """Element-wise tanh.
4058
4059  Arguments:
4060      x: A tensor or variable.
4061
4062  Returns:
4063      A tensor.
4064  """
4065  return nn.tanh(x)
4066
4067
4068@keras_export('keras.backend.dropout')
4069def dropout(x, level, noise_shape=None, seed=None):
4070  """Sets entries in `x` to zero at random, while scaling the entire tensor.
4071
4072  Arguments:
4073      x: tensor
4074      level: fraction of the entries in the tensor
4075          that will be set to 0.
4076      noise_shape: shape for randomly generated keep/drop flags,
4077          must be broadcastable to the shape of `x`
4078      seed: random seed to ensure determinism.
4079
4080  Returns:
4081      A tensor.
4082  """
4083  if seed is None:
4084    seed = np.random.randint(10e6)
4085  return nn.dropout_v2(x, rate=level, noise_shape=noise_shape, seed=seed)
4086
4087
4088@keras_export('keras.backend.l2_normalize')
4089def l2_normalize(x, axis=None):
4090  """Normalizes a tensor wrt the L2 norm alongside the specified axis.
4091
4092  Arguments:
4093      x: Tensor or variable.
4094      axis: axis along which to perform normalization.
4095
4096  Returns:
4097      A tensor.
4098  """
4099  return nn.l2_normalize(x, axis=axis)
4100
4101
4102@keras_export('keras.backend.in_top_k')
4103def in_top_k(predictions, targets, k):
4104  """Returns whether the `targets` are in the top `k` `predictions`.
4105
4106  Arguments:
4107      predictions: A tensor of shape `(batch_size, classes)` and type `float32`.
4108      targets: A 1D tensor of length `batch_size` and type `int32` or `int64`.
4109      k: An `int`, number of top elements to consider.
4110
4111  Returns:
4112      A 1D tensor of length `batch_size` and type `bool`.
4113      `output[i]` is `True` if `predictions[i, targets[i]]` is within top-`k`
4114      values of `predictions[i]`.
4115  """
4116  return nn.in_top_k(predictions, targets, k)
4117
4118
4119# CONVOLUTIONS
4120
4121
4122def _preprocess_conv1d_input(x, data_format):
4123  """Transpose and cast the input before the conv1d.
4124
4125  Arguments:
4126      x: input tensor.
4127      data_format: string, `"channels_last"` or `"channels_first"`.
4128
4129  Returns:
4130      A tensor.
4131  """
4132  tf_data_format = 'NWC'  # to pass TF Conv2dNative operations
4133  if data_format == 'channels_first':
4134    if not _has_nchw_support():
4135      x = array_ops.transpose(x, (0, 2, 1))  # NCW -> NWC
4136    else:
4137      tf_data_format = 'NCW'
4138  return x, tf_data_format
4139
4140
4141def _preprocess_conv2d_input(x, data_format, force_transpose=False):
4142  """Transpose and cast the input before the conv2d.
4143
4144  Arguments:
4145      x: input tensor.
4146      data_format: string, `"channels_last"` or `"channels_first"`.
4147      force_transpose: Boolean. If True, the input will always be transposed
4148          from NCHW to NHWC if `data_format` is `"channels_first"`.
4149          If False, the transposition only occurs on CPU (GPU ops are
4150          assumed to support NCHW).
4151
4152  Returns:
4153      A tensor.
4154  """
4155  tf_data_format = 'NHWC'
4156  if data_format == 'channels_first':
4157    if not _has_nchw_support() or force_transpose:
4158      x = array_ops.transpose(x, (0, 2, 3, 1))  # NCHW -> NHWC
4159    else:
4160      tf_data_format = 'NCHW'
4161  return x, tf_data_format
4162
4163
4164def _preprocess_conv3d_input(x, data_format):
4165  """Transpose and cast the input before the conv3d.
4166
4167  Arguments:
4168      x: input tensor.
4169      data_format: string, `"channels_last"` or `"channels_first"`.
4170
4171  Returns:
4172      A tensor.
4173  """
4174  tf_data_format = 'NDHWC'
4175  if data_format == 'channels_first':
4176    if not _has_nchw_support():
4177      x = array_ops.transpose(x, (0, 2, 3, 4, 1))
4178    else:
4179      tf_data_format = 'NCDHW'
4180  return x, tf_data_format
4181
4182
4183def _preprocess_padding(padding):
4184  """Convert keras' padding to TensorFlow's padding.
4185
4186  Arguments:
4187      padding: string, one of 'same' , 'valid'
4188
4189  Returns:
4190      a string, one of 'SAME', 'VALID'.
4191
4192  Raises:
4193      ValueError: if invalid `padding'`
4194  """
4195  if padding == 'same':
4196    padding = 'SAME'
4197  elif padding == 'valid':
4198    padding = 'VALID'
4199  else:
4200    raise ValueError('Invalid padding: ' + str(padding))
4201  return padding
4202
4203
4204@keras_export('keras.backend.conv1d')
4205def conv1d(x,
4206           kernel,
4207           strides=1,
4208           padding='valid',
4209           data_format=None,
4210           dilation_rate=1):
4211  """1D convolution.
4212
4213  Arguments:
4214      x: Tensor or variable.
4215      kernel: kernel tensor.
4216      strides: stride integer.
4217      padding: string, `"same"`, `"causal"` or `"valid"`.
4218      data_format: string, one of "channels_last", "channels_first".
4219      dilation_rate: integer dilate rate.
4220
4221  Returns:
4222      A tensor, result of 1D convolution.
4223
4224  Raises:
4225      ValueError: if `data_format` is neither `channels_last` or
4226      `channels_first`.
4227  """
4228  if data_format is None:
4229    data_format = image_data_format()
4230  if data_format not in {'channels_first', 'channels_last'}:
4231    raise ValueError('Unknown data_format: ' + str(data_format))
4232
4233  kernel_shape = kernel.shape.as_list()
4234  if padding == 'causal':
4235    # causal (dilated) convolution:
4236    left_pad = dilation_rate * (kernel_shape[0] - 1)
4237    x = temporal_padding(x, (left_pad, 0))
4238    padding = 'valid'
4239  padding = _preprocess_padding(padding)
4240
4241  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
4242  x = nn.convolution(
4243      input=x,
4244      filter=kernel,
4245      dilation_rate=dilation_rate,
4246      strides=strides,
4247      padding=padding,
4248      data_format=tf_data_format)
4249  if data_format == 'channels_first' and tf_data_format == 'NWC':
4250    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
4251  return x
4252
4253
4254@keras_export('keras.backend.conv2d')
4255def conv2d(x,
4256           kernel,
4257           strides=(1, 1),
4258           padding='valid',
4259           data_format=None,
4260           dilation_rate=(1, 1)):
4261  """2D convolution.
4262
4263  Arguments:
4264      x: Tensor or variable.
4265      kernel: kernel tensor.
4266      strides: strides tuple.
4267      padding: string, `"same"` or `"valid"`.
4268      data_format: `"channels_last"` or `"channels_first"`.
4269          Whether to use Theano or TensorFlow data format
4270          for inputs/kernels/outputs.
4271      dilation_rate: tuple of 2 integers.
4272
4273  Returns:
4274      A tensor, result of 2D convolution.
4275
4276  Raises:
4277      ValueError: if `data_format` is neither `channels_last` or
4278      `channels_first`.
4279  """
4280  if data_format is None:
4281    data_format = image_data_format()
4282  if data_format not in {'channels_first', 'channels_last'}:
4283    raise ValueError('Unknown data_format: ' + str(data_format))
4284
4285  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
4286  padding = _preprocess_padding(padding)
4287  x = nn.convolution(
4288      input=x,
4289      filter=kernel,
4290      dilation_rate=dilation_rate,
4291      strides=strides,
4292      padding=padding,
4293      data_format=tf_data_format)
4294  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4295    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4296  return x
4297
4298
4299@keras_export('keras.backend.conv2d_transpose')
4300def conv2d_transpose(x,
4301                     kernel,
4302                     output_shape,
4303                     strides=(1, 1),
4304                     padding='valid',
4305                     data_format=None,
4306                     dilation_rate=(1, 1)):
4307  """2D deconvolution (i.e.
4308
4309  transposed convolution).
4310
4311  Arguments:
4312      x: Tensor or variable.
4313      kernel: kernel tensor.
4314      output_shape: 1D int tensor for the output shape.
4315      strides: strides tuple.
4316      padding: string, `"same"` or `"valid"`.
4317      data_format: string, `"channels_last"` or `"channels_first"`.
4318          Whether to use Theano or TensorFlow/CNTK data format
4319          for inputs/kernels/outputs.
4320      dilation_rate: Tuple of 2 integers.
4321
4322  Returns:
4323      A tensor, result of transposed 2D convolution.
4324
4325  Raises:
4326      ValueError: if `data_format` is neither `channels_last` or
4327      `channels_first`.
4328  """
4329  if data_format is None:
4330    data_format = image_data_format()
4331  if data_format not in {'channels_first', 'channels_last'}:
4332    raise ValueError('Unknown data_format: ' + str(data_format))
4333  if isinstance(output_shape, (tuple, list)):
4334    output_shape = array_ops.stack(output_shape)
4335
4336  # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
4337  if data_format == 'channels_first' and dilation_rate != (1, 1):
4338    force_transpose = True
4339  else:
4340    force_transpose = False
4341
4342  x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
4343
4344  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4345    output_shape = (output_shape[0], output_shape[2], output_shape[3],
4346                    output_shape[1])
4347  if output_shape[0] is None:
4348    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
4349    output_shape = array_ops.stack(list(output_shape))
4350
4351  padding = _preprocess_padding(padding)
4352  if tf_data_format == 'NHWC':
4353    strides = (1,) + strides + (1,)
4354  else:
4355    strides = (1, 1) + strides
4356
4357  if dilation_rate == (1, 1):
4358    x = nn.conv2d_transpose(x, kernel, output_shape, strides,
4359                            padding=padding,
4360                            data_format=tf_data_format)
4361  else:
4362    assert dilation_rate[0] == dilation_rate[1]
4363    x = nn.atrous_conv2d_transpose(
4364        x,
4365        kernel,
4366        output_shape,
4367        rate=dilation_rate[0],
4368        padding=padding)
4369  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4370    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4371  return x
4372
4373
4374def separable_conv1d(x,
4375                     depthwise_kernel,
4376                     pointwise_kernel,
4377                     strides=1,
4378                     padding='valid',
4379                     data_format=None,
4380                     dilation_rate=1):
4381  """1D convolution with separable filters.
4382
4383  Arguments:
4384      x: input tensor
4385      depthwise_kernel: convolution kernel for the depthwise convolution.
4386      pointwise_kernel: kernel for the 1x1 convolution.
4387      strides: stride integer.
4388      padding: string, `"same"` or `"valid"`.
4389      data_format: string, `"channels_last"` or `"channels_first"`.
4390      dilation_rate: integer dilation rate.
4391
4392  Returns:
4393      Output tensor.
4394
4395  Raises:
4396      ValueError: if `data_format` is neither `channels_last` or
4397      `channels_first`.
4398  """
4399  if data_format is None:
4400    data_format = image_data_format()
4401  if data_format not in {'channels_first', 'channels_last'}:
4402    raise ValueError('Unknown data_format: ' + str(data_format))
4403
4404  if isinstance(strides, int):
4405    strides = (strides,)
4406  if isinstance(dilation_rate, int):
4407    dilation_rate = (dilation_rate,)
4408
4409  x, tf_data_format = _preprocess_conv1d_input(x, data_format)
4410  padding = _preprocess_padding(padding)
4411  if not isinstance(strides, tuple):
4412    strides = tuple(strides)
4413  if tf_data_format == 'NWC':
4414    spatial_start_dim = 1
4415    strides = (1,) + strides * 2 + (1,)
4416  else:
4417    spatial_start_dim = 2
4418    strides = (1, 1) + strides * 2
4419  x = array_ops.expand_dims(x, spatial_start_dim)
4420  depthwise_kernel = array_ops.expand_dims(depthwise_kernel, 0)
4421  pointwise_kernel = array_ops.expand_dims(pointwise_kernel, 0)
4422  dilation_rate = (1,) + dilation_rate
4423
4424  x = nn.separable_conv2d(
4425      x,
4426      depthwise_kernel,
4427      pointwise_kernel,
4428      strides=strides,
4429      padding=padding,
4430      rate=dilation_rate,
4431      data_format=tf_data_format)
4432
4433  x = array_ops.squeeze(x, [spatial_start_dim])
4434
4435  if data_format == 'channels_first' and tf_data_format == 'NWC':
4436    x = array_ops.transpose(x, (0, 2, 1))  # NWC -> NCW
4437
4438  return x
4439
4440
4441@keras_export('keras.backend.separable_conv2d')
4442def separable_conv2d(x,
4443                     depthwise_kernel,
4444                     pointwise_kernel,
4445                     strides=(1, 1),
4446                     padding='valid',
4447                     data_format=None,
4448                     dilation_rate=(1, 1)):
4449  """2D convolution with separable filters.
4450
4451  Arguments:
4452      x: input tensor
4453      depthwise_kernel: convolution kernel for the depthwise convolution.
4454      pointwise_kernel: kernel for the 1x1 convolution.
4455      strides: strides tuple (length 2).
4456      padding: string, `"same"` or `"valid"`.
4457      data_format: string, `"channels_last"` or `"channels_first"`.
4458      dilation_rate: tuple of integers,
4459          dilation rates for the separable convolution.
4460
4461  Returns:
4462      Output tensor.
4463
4464  Raises:
4465      ValueError: if `data_format` is neither `channels_last` or
4466      `channels_first`.
4467      ValueError: if `strides` is not a tuple of 2 integers.
4468  """
4469  if data_format is None:
4470    data_format = image_data_format()
4471  if data_format not in {'channels_first', 'channels_last'}:
4472    raise ValueError('Unknown data_format: ' + str(data_format))
4473  if len(strides) != 2:
4474    raise ValueError('`strides` must be a tuple of 2 integers.')
4475
4476  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
4477  padding = _preprocess_padding(padding)
4478  if not isinstance(strides, tuple):
4479    strides = tuple(strides)
4480  if tf_data_format == 'NHWC':
4481    strides = (1,) + strides + (1,)
4482  else:
4483    strides = (1, 1) + strides
4484
4485  x = nn.separable_conv2d(
4486      x,
4487      depthwise_kernel,
4488      pointwise_kernel,
4489      strides=strides,
4490      padding=padding,
4491      rate=dilation_rate,
4492      data_format=tf_data_format)
4493  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4494    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4495  return x
4496
4497
4498def depthwise_conv2d(x,
4499                     depthwise_kernel,
4500                     strides=(1, 1),
4501                     padding='valid',
4502                     data_format=None,
4503                     dilation_rate=(1, 1)):
4504  """2D convolution with separable filters.
4505
4506  Arguments:
4507      x: input tensor
4508      depthwise_kernel: convolution kernel for the depthwise convolution.
4509      strides: strides tuple (length 2).
4510      padding: string, `"same"` or `"valid"`.
4511      data_format: string, `"channels_last"` or `"channels_first"`.
4512      dilation_rate: tuple of integers,
4513          dilation rates for the separable convolution.
4514
4515  Returns:
4516      Output tensor.
4517
4518  Raises:
4519      ValueError: if `data_format` is neither `channels_last` or
4520      `channels_first`.
4521  """
4522  if data_format is None:
4523    data_format = image_data_format()
4524  if data_format not in {'channels_first', 'channels_last'}:
4525    raise ValueError('Unknown data_format: ' + str(data_format))
4526
4527  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
4528  padding = _preprocess_padding(padding)
4529  if tf_data_format == 'NHWC':
4530    strides = (1,) + strides + (1,)
4531  else:
4532    strides = (1, 1) + strides
4533
4534  x = nn.depthwise_conv2d(
4535      x,
4536      depthwise_kernel,
4537      strides=strides,
4538      padding=padding,
4539      rate=dilation_rate,
4540      data_format=tf_data_format)
4541  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4542    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4543  return x
4544
4545
4546@keras_export('keras.backend.conv3d')
4547def conv3d(x,
4548           kernel,
4549           strides=(1, 1, 1),
4550           padding='valid',
4551           data_format=None,
4552           dilation_rate=(1, 1, 1)):
4553  """3D convolution.
4554
4555  Arguments:
4556      x: Tensor or variable.
4557      kernel: kernel tensor.
4558      strides: strides tuple.
4559      padding: string, `"same"` or `"valid"`.
4560      data_format: string, `"channels_last"` or `"channels_first"`.
4561          Whether to use Theano or TensorFlow/CNTK data format
4562          for inputs/kernels/outputs.
4563      dilation_rate: tuple of 3 integers.
4564
4565  Returns:
4566      A tensor, result of 3D convolution.
4567
4568  Raises:
4569      ValueError: if `data_format` is neither `channels_last` or
4570      `channels_first`.
4571  """
4572  if data_format is None:
4573    data_format = image_data_format()
4574  if data_format not in {'channels_first', 'channels_last'}:
4575    raise ValueError('Unknown data_format: ' + str(data_format))
4576
4577  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
4578  padding = _preprocess_padding(padding)
4579  x = nn.convolution(
4580      input=x,
4581      filter=kernel,
4582      dilation_rate=dilation_rate,
4583      strides=strides,
4584      padding=padding,
4585      data_format=tf_data_format)
4586  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
4587    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
4588  return x
4589
4590
4591def conv3d_transpose(x,
4592                     kernel,
4593                     output_shape,
4594                     strides=(1, 1, 1),
4595                     padding='valid',
4596                     data_format=None):
4597  """3D deconvolution (i.e.
4598
4599  transposed convolution).
4600
4601  Arguments:
4602      x: input tensor.
4603      kernel: kernel tensor.
4604      output_shape: 1D int tensor for the output shape.
4605      strides: strides tuple.
4606      padding: string, "same" or "valid".
4607      data_format: string, `"channels_last"` or `"channels_first"`.
4608          Whether to use Theano or TensorFlow/CNTK data format
4609          for inputs/kernels/outputs.
4610
4611  Returns:
4612      A tensor, result of transposed 3D convolution.
4613
4614  Raises:
4615      ValueError: if `data_format` is neither `channels_last` or
4616      `channels_first`.
4617  """
4618  if data_format is None:
4619    data_format = image_data_format()
4620  if data_format not in {'channels_first', 'channels_last'}:
4621    raise ValueError('Unknown data_format: ' + str(data_format))
4622  if isinstance(output_shape, (tuple, list)):
4623    output_shape = array_ops.stack(output_shape)
4624
4625  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
4626
4627  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
4628    output_shape = (output_shape[0], output_shape[2], output_shape[3],
4629                    output_shape[4], output_shape[1])
4630  if output_shape[0] is None:
4631    output_shape = (array_ops.shape(x)[0],) + tuple(output_shape[1:])
4632    output_shape = array_ops.stack(list(output_shape))
4633
4634  padding = _preprocess_padding(padding)
4635  if tf_data_format == 'NDHWC':
4636    strides = (1,) + strides + (1,)
4637  else:
4638    strides = (1, 1) + strides
4639
4640  x = nn.conv3d_transpose(
4641      x,
4642      kernel,
4643      output_shape,
4644      strides,
4645      padding=padding,
4646      data_format=tf_data_format)
4647  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
4648    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
4649  return x
4650
4651
4652@keras_export('keras.backend.pool2d')
4653def pool2d(x,
4654           pool_size,
4655           strides=(1, 1),
4656           padding='valid',
4657           data_format=None,
4658           pool_mode='max'):
4659  """2D Pooling.
4660
4661  Arguments:
4662      x: Tensor or variable.
4663      pool_size: tuple of 2 integers.
4664      strides: tuple of 2 integers.
4665      padding: string, `"same"` or `"valid"`.
4666      data_format: string, `"channels_last"` or `"channels_first"`.
4667      pool_mode: string, `"max"` or `"avg"`.
4668
4669  Returns:
4670      A tensor, result of 2D pooling.
4671
4672  Raises:
4673      ValueError: if `data_format` is neither `"channels_last"` or
4674      `"channels_first"`.
4675      ValueError: if `pool_size` is not a tuple of 2 integers.
4676      ValueError: if `strides` is not a tuple of 2 integers.
4677      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
4678  """
4679  if data_format is None:
4680    data_format = image_data_format()
4681  if data_format not in {'channels_first', 'channels_last'}:
4682    raise ValueError('Unknown data_format: ' + str(data_format))
4683  if len(pool_size) != 2:
4684    raise ValueError('`pool_size` must be a tuple of 2 integers.')
4685  if len(strides) != 2:
4686    raise ValueError('`strides` must be a tuple of 2 integers.')
4687
4688  x, tf_data_format = _preprocess_conv2d_input(x, data_format)
4689  padding = _preprocess_padding(padding)
4690  if tf_data_format == 'NHWC':
4691    strides = (1,) + strides + (1,)
4692    pool_size = (1,) + pool_size + (1,)
4693  else:
4694    strides = (1, 1) + strides
4695    pool_size = (1, 1) + pool_size
4696
4697  if pool_mode == 'max':
4698    x = nn.max_pool(
4699        x, pool_size, strides, padding=padding, data_format=tf_data_format)
4700  elif pool_mode == 'avg':
4701    x = nn.avg_pool(
4702        x, pool_size, strides, padding=padding, data_format=tf_data_format)
4703  else:
4704    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
4705
4706  if data_format == 'channels_first' and tf_data_format == 'NHWC':
4707    x = array_ops.transpose(x, (0, 3, 1, 2))  # NHWC -> NCHW
4708  return x
4709
4710
4711@keras_export('keras.backend.pool3d')
4712def pool3d(x,
4713           pool_size,
4714           strides=(1, 1, 1),
4715           padding='valid',
4716           data_format=None,
4717           pool_mode='max'):
4718  """3D Pooling.
4719
4720  Arguments:
4721      x: Tensor or variable.
4722      pool_size: tuple of 3 integers.
4723      strides: tuple of 3 integers.
4724      padding: string, `"same"` or `"valid"`.
4725      data_format: string, `"channels_last"` or `"channels_first"`.
4726      pool_mode: string, `"max"` or `"avg"`.
4727
4728  Returns:
4729      A tensor, result of 3D pooling.
4730
4731  Raises:
4732      ValueError: if `data_format` is neither `"channels_last"` or
4733      `"channels_first"`.
4734      ValueError: if `pool_mode` is neither `"max"` or `"avg"`.
4735  """
4736  if data_format is None:
4737    data_format = image_data_format()
4738  if data_format not in {'channels_first', 'channels_last'}:
4739    raise ValueError('Unknown data_format: ' + str(data_format))
4740
4741  x, tf_data_format = _preprocess_conv3d_input(x, data_format)
4742  padding = _preprocess_padding(padding)
4743  if tf_data_format == 'NDHWC':
4744    strides = (1,) + strides + (1,)
4745    pool_size = (1,) + pool_size + (1,)
4746  else:
4747    strides = (1, 1) + strides
4748    pool_size = (1, 1) + pool_size
4749
4750  if pool_mode == 'max':
4751    x = nn.max_pool3d(
4752        x, pool_size, strides, padding=padding, data_format=tf_data_format)
4753  elif pool_mode == 'avg':
4754    x = nn.avg_pool3d(
4755        x, pool_size, strides, padding=padding, data_format=tf_data_format)
4756  else:
4757    raise ValueError('Invalid pooling mode: ' + str(pool_mode))
4758
4759  if data_format == 'channels_first' and tf_data_format == 'NDHWC':
4760    x = array_ops.transpose(x, (0, 4, 1, 2, 3))
4761  return x
4762
4763
4764def local_conv(inputs,
4765               kernel,
4766               kernel_size,
4767               strides,
4768               output_shape,
4769               data_format=None):
4770  """Apply N-D convolution with un-shared weights.
4771
4772  Arguments:
4773      inputs: (N+2)-D tensor with shape
4774          (batch_size, channels_in, d_in1, ..., d_inN)
4775          if data_format='channels_first', or
4776          (batch_size, d_in1, ..., d_inN, channels_in)
4777          if data_format='channels_last'.
4778      kernel: the unshared weight for N-D convolution,
4779          with shape (output_items, feature_dim, channels_out), where
4780          feature_dim = np.prod(kernel_size) * channels_in,
4781          output_items = np.prod(output_shape).
4782      kernel_size: a tuple of N integers, specifying the
4783          spatial dimensions of the N-D convolution window.
4784      strides: a tuple of N integers, specifying the strides
4785          of the convolution along the spatial dimensions.
4786      output_shape: a tuple of (d_out1, ..., d_outN) specifying the spatial
4787          dimensionality of the output.
4788      data_format: string, "channels_first" or "channels_last".
4789
4790  Returns:
4791      An (N+2)-D tensor with shape:
4792      (batch_size, channels_out) + output_shape
4793      if data_format='channels_first', or:
4794      (batch_size,) + output_shape + (channels_out,)
4795      if data_format='channels_last'.
4796
4797  Raises:
4798      ValueError: if `data_format` is neither
4799      `channels_last` nor `channels_first`.
4800  """
4801  if data_format is None:
4802    data_format = image_data_format()
4803  if data_format not in {'channels_first', 'channels_last'}:
4804    raise ValueError('Unknown data_format: ' + str(data_format))
4805
4806  kernel_shape = int_shape(kernel)
4807  feature_dim = kernel_shape[1]
4808  channels_out = kernel_shape[-1]
4809  ndims = len(output_shape)
4810  spatial_dimensions = list(range(ndims))
4811
4812  xs = []
4813  output_axes_ticks = [range(axis_max) for axis_max in output_shape]
4814  for position in itertools.product(*output_axes_ticks):
4815    slices = [slice(None)]
4816
4817    if data_format == 'channels_first':
4818      slices.append(slice(None))
4819
4820    slices.extend([slice(position[d] * strides[d],
4821                         position[d] * strides[d] + kernel_size[d])
4822                   for d in spatial_dimensions])
4823
4824    if data_format == 'channels_last':
4825      slices.append(slice(None))
4826
4827    xs.append(reshape(inputs[slices], (1, -1, feature_dim)))
4828
4829  x_aggregate = concatenate(xs, axis=0)
4830  output = batch_dot(x_aggregate, kernel)
4831  output = reshape(output, output_shape + (-1, channels_out))
4832
4833  if data_format == 'channels_first':
4834    permutation = [ndims, ndims + 1] + spatial_dimensions
4835  else:
4836    permutation = [ndims] + spatial_dimensions + [ndims + 1]
4837
4838  return permute_dimensions(output, permutation)
4839
4840
4841@keras_export('keras.backend.local_conv1d')
4842def local_conv1d(inputs, kernel, kernel_size, strides, data_format=None):
4843  """Apply 1D conv with un-shared weights.
4844
4845  Arguments:
4846      inputs: 3D tensor with shape:
4847          (batch_size, steps, input_dim)
4848          if data_format is "channels_last" or
4849          (batch_size, input_dim, steps)
4850          if data_format is "channels_first".
4851      kernel: the unshared weight for convolution,
4852          with shape (output_length, feature_dim, filters).
4853      kernel_size: a tuple of a single integer,
4854          specifying the length of the 1D convolution window.
4855      strides: a tuple of a single integer,
4856          specifying the stride length of the convolution.
4857      data_format: the data format, channels_first or channels_last.
4858
4859  Returns:
4860      A 3d tensor with shape:
4861      (batch_size, output_length, filters)
4862      if data_format='channels_first'
4863      or 3D tensor with shape:
4864      (batch_size, filters, output_length)
4865      if data_format='channels_last'.
4866  """
4867  output_shape = (kernel.shape[0],)
4868  return local_conv(inputs,
4869                    kernel,
4870                    kernel_size,
4871                    strides,
4872                    output_shape,
4873                    data_format)
4874
4875
4876@keras_export('keras.backend.local_conv2d')
4877def local_conv2d(inputs,
4878                 kernel,
4879                 kernel_size,
4880                 strides,
4881                 output_shape,
4882                 data_format=None):
4883  """Apply 2D conv with un-shared weights.
4884
4885  Arguments:
4886      inputs: 4D tensor with shape:
4887          (batch_size, filters, new_rows, new_cols)
4888          if data_format='channels_first'
4889          or 4D tensor with shape:
4890          (batch_size, new_rows, new_cols, filters)
4891          if data_format='channels_last'.
4892      kernel: the unshared weight for convolution,
4893          with shape (output_items, feature_dim, filters).
4894      kernel_size: a tuple of 2 integers, specifying the
4895          width and height of the 2D convolution window.
4896      strides: a tuple of 2 integers, specifying the strides
4897          of the convolution along the width and height.
4898      output_shape: a tuple with (output_row, output_col).
4899      data_format: the data format, channels_first or channels_last.
4900
4901  Returns:
4902      A 4D tensor with shape:
4903      (batch_size, filters, new_rows, new_cols)
4904      if data_format='channels_first'
4905      or 4D tensor with shape:
4906      (batch_size, new_rows, new_cols, filters)
4907      if data_format='channels_last'.
4908  """
4909  return local_conv(inputs,
4910                    kernel,
4911                    kernel_size,
4912                    strides,
4913                    output_shape,
4914                    data_format)
4915
4916
4917@keras_export('keras.backend.bias_add')
4918def bias_add(x, bias, data_format=None):
4919  """Adds a bias vector to a tensor.
4920
4921  Arguments:
4922      x: Tensor or variable.
4923      bias: Bias tensor to add.
4924      data_format: string, `"channels_last"` or `"channels_first"`.
4925
4926  Returns:
4927      Output tensor.
4928
4929  Raises:
4930      ValueError: In one of the two cases below:
4931                  1. invalid `data_format` argument.
4932                  2. invalid bias shape.
4933                     the bias should be either a vector or
4934                     a tensor with ndim(x) - 1 dimension
4935  """
4936  if data_format is None:
4937    data_format = image_data_format()
4938  if data_format not in {'channels_first', 'channels_last'}:
4939    raise ValueError('Unknown data_format: ' + str(data_format))
4940  bias_shape = int_shape(bias)
4941  if len(bias_shape) != 1 and len(bias_shape) != ndim(x) - 1:
4942    raise ValueError(
4943        'Unexpected bias dimensions %d, expect to be 1 or %d dimensions' %
4944        (len(bias_shape), ndim(x)))
4945  # pylint: disable=g-no-augmented-assignment
4946  if ndim(x) == 5:
4947    if data_format == 'channels_first':
4948      if len(bias_shape) == 1:
4949        x = x + reshape(bias, (1, bias_shape[0], 1, 1, 1))
4950      else:
4951        x = x + reshape(bias, (1, bias_shape[3]) + bias_shape[:3])
4952    elif data_format == 'channels_last':
4953      if len(bias_shape) == 1:
4954        x = x + reshape(bias, (1, 1, 1, bias_shape[0]))
4955      else:
4956        x = x + reshape(bias, (1,) + bias_shape)
4957  elif ndim(x) == 4:
4958    if data_format == 'channels_first':
4959      if len(bias_shape) == 1:
4960        if _has_nchw_support():
4961          x = nn.bias_add(x, bias, data_format='NCHW')
4962        else:
4963          x = x + reshape(bias, (1, bias_shape[0], 1, 1))
4964      else:
4965        x = x + reshape(bias, (1, bias_shape[2]) + bias_shape[:2])
4966    elif data_format == 'channels_last':
4967      if len(bias_shape) == 1:
4968        x = nn.bias_add(x, bias, data_format='NHWC')
4969      else:
4970        x = x + reshape(bias, (1,) + bias_shape)
4971  elif ndim(x) == 3:
4972    if data_format == 'channels_first':
4973      if len(bias_shape) == 1:
4974        x = x + reshape(bias, (1, bias_shape[0], 1))
4975      else:
4976        x = x + reshape(bias, (1, bias_shape[1], bias_shape[0]))
4977    elif data_format == 'channels_last':
4978      if len(bias_shape) == 1:
4979        x = x + reshape(bias, (1, 1, bias_shape[0]))
4980      else:
4981        x = x + reshape(bias, (1,) + bias_shape)
4982  else:
4983    x = nn.bias_add(x, bias)
4984  # pylint: enable=g-no-augmented-assignment
4985  return x
4986
4987
4988# RANDOMNESS
4989
4990
4991@keras_export('keras.backend.random_normal')
4992def random_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
4993  """Returns a tensor with normal distribution of values.
4994
4995  Arguments:
4996      shape: A tuple of integers, the shape of tensor to create.
4997      mean: A float, mean of the normal distribution to draw samples.
4998      stddev: A float, standard deviation of the normal distribution
4999          to draw samples.
5000      dtype: String, dtype of returned tensor.
5001      seed: Integer, random seed.
5002
5003  Returns:
5004      A tensor.
5005  """
5006  if dtype is None:
5007    dtype = floatx()
5008  if seed is None:
5009    seed = np.random.randint(10e6)
5010  return random_ops.random_normal(
5011      shape, mean=mean, stddev=stddev, dtype=dtype, seed=seed)
5012
5013
5014@keras_export('keras.backend.random_uniform')
5015def random_uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
5016  """Returns a tensor with uniform distribution of values.
5017
5018  Arguments:
5019      shape: A tuple of integers, the shape of tensor to create.
5020      minval: A float, lower boundary of the uniform distribution
5021          to draw samples.
5022      maxval: A float, upper boundary of the uniform distribution
5023          to draw samples.
5024      dtype: String, dtype of returned tensor.
5025      seed: Integer, random seed.
5026
5027  Returns:
5028      A tensor.
5029  """
5030  if dtype is None:
5031    dtype = floatx()
5032  if seed is None:
5033    seed = np.random.randint(10e6)
5034  return random_ops.random_uniform(
5035      shape, minval=minval, maxval=maxval, dtype=dtype, seed=seed)
5036
5037
5038@keras_export('keras.backend.random_binomial')
5039def random_binomial(shape, p=0.0, dtype=None, seed=None):
5040  """Returns a tensor with random binomial distribution of values.
5041
5042  Arguments:
5043      shape: A tuple of integers, the shape of tensor to create.
5044      p: A float, `0. <= p <= 1`, probability of binomial distribution.
5045      dtype: String, dtype of returned tensor.
5046      seed: Integer, random seed.
5047
5048  Returns:
5049      A tensor.
5050  """
5051  if dtype is None:
5052    dtype = floatx()
5053  if seed is None:
5054    seed = np.random.randint(10e6)
5055  return array_ops.where(
5056      random_ops.random_uniform(shape, dtype=dtype, seed=seed) <= p,
5057      array_ops.ones(shape, dtype=dtype), array_ops.zeros(shape, dtype=dtype))
5058
5059
5060@keras_export('keras.backend.truncated_normal')
5061def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
5062  """Returns a tensor with truncated random normal distribution of values.
5063
5064  The generated values follow a normal distribution
5065  with specified mean and standard deviation,
5066  except that values whose magnitude is more than
5067  two standard deviations from the mean are dropped and re-picked.
5068
5069  Arguments:
5070      shape: A tuple of integers, the shape of tensor to create.
5071      mean: Mean of the values.
5072      stddev: Standard deviation of the values.
5073      dtype: String, dtype of returned tensor.
5074      seed: Integer, random seed.
5075
5076  Returns:
5077      A tensor.
5078  """
5079  if dtype is None:
5080    dtype = floatx()
5081  if seed is None:
5082    seed = np.random.randint(10e6)
5083  return random_ops.truncated_normal(
5084      shape, mean, stddev, dtype=dtype, seed=seed)
5085
5086
5087# CTC
5088# TensorFlow has a native implementation, but it uses sparse tensors
5089# and therefore requires a wrapper for Keras. The functions below convert
5090# dense to sparse tensors and also wraps up the beam search code that is
5091# in TensorFlow's CTC implementation
5092
5093
5094@keras_export('keras.backend.ctc_label_dense_to_sparse')
5095def ctc_label_dense_to_sparse(labels, label_lengths):
5096  """Converts CTC labels from dense to sparse.
5097
5098  Arguments:
5099      labels: dense CTC labels.
5100      label_lengths: length of the labels.
5101
5102  Returns:
5103      A sparse tensor representation of the labels.
5104  """
5105  label_shape = array_ops.shape(labels)
5106  num_batches_tns = array_ops.stack([label_shape[0]])
5107  max_num_labels_tns = array_ops.stack([label_shape[1]])
5108
5109  def range_less_than(_, current_input):
5110    return array_ops.expand_dims(
5111        math_ops.range(label_shape[1]), 0) < array_ops.fill(
5112            max_num_labels_tns, current_input)
5113
5114  init = math_ops.cast(
5115      array_ops.fill([1, label_shape[1]], 0), dtypes_module.bool)
5116  dense_mask = functional_ops.scan(
5117      range_less_than, label_lengths, initializer=init, parallel_iterations=1)
5118  dense_mask = dense_mask[:, 0, :]
5119
5120  label_array = array_ops.reshape(
5121      array_ops.tile(math_ops.range(0, label_shape[1]), num_batches_tns),
5122      label_shape)
5123  label_ind = array_ops.boolean_mask(label_array, dense_mask)
5124
5125  batch_array = array_ops.transpose(
5126      array_ops.reshape(
5127          array_ops.tile(math_ops.range(0, label_shape[0]), max_num_labels_tns),
5128          reverse(label_shape, 0)))
5129  batch_ind = array_ops.boolean_mask(batch_array, dense_mask)
5130  indices = array_ops.transpose(
5131      array_ops.reshape(concatenate([batch_ind, label_ind], axis=0), [2, -1]))
5132
5133  vals_sparse = array_ops.gather_nd(labels, indices)
5134
5135  return sparse_tensor.SparseTensor(
5136      math_ops.cast(indices, dtypes_module.int64), vals_sparse,
5137      math_ops.cast(label_shape, dtypes_module.int64))
5138
5139
5140@keras_export('keras.backend.ctc_batch_cost')
5141def ctc_batch_cost(y_true, y_pred, input_length, label_length):
5142  """Runs CTC loss algorithm on each batch element.
5143
5144  Arguments:
5145      y_true: tensor `(samples, max_string_length)`
5146          containing the truth labels.
5147      y_pred: tensor `(samples, time_steps, num_categories)`
5148          containing the prediction, or output of the softmax.
5149      input_length: tensor `(samples, 1)` containing the sequence length for
5150          each batch item in `y_pred`.
5151      label_length: tensor `(samples, 1)` containing the sequence length for
5152          each batch item in `y_true`.
5153
5154  Returns:
5155      Tensor with shape (samples,1) containing the
5156          CTC loss of each element.
5157  """
5158  label_length = math_ops.cast(
5159      array_ops.squeeze(label_length, axis=-1), dtypes_module.int32)
5160  input_length = math_ops.cast(
5161      array_ops.squeeze(input_length, axis=-1), dtypes_module.int32)
5162  sparse_labels = math_ops.cast(
5163      ctc_label_dense_to_sparse(y_true, label_length), dtypes_module.int32)
5164
5165  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
5166
5167  return array_ops.expand_dims(
5168      ctc.ctc_loss(
5169          inputs=y_pred, labels=sparse_labels, sequence_length=input_length), 1)
5170
5171
5172@keras_export('keras.backend.ctc_decode')
5173def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
5174  """Decodes the output of a softmax.
5175
5176  Can use either greedy search (also known as best path)
5177  or a constrained dictionary search.
5178
5179  Arguments:
5180      y_pred: tensor `(samples, time_steps, num_categories)`
5181          containing the prediction, or output of the softmax.
5182      input_length: tensor `(samples, )` containing the sequence length for
5183          each batch item in `y_pred`.
5184      greedy: perform much faster best-path search if `true`.
5185          This does not use a dictionary.
5186      beam_width: if `greedy` is `false`: a beam search decoder will be used
5187          with a beam of this width.
5188      top_paths: if `greedy` is `false`,
5189          how many of the most probable paths will be returned.
5190
5191  Returns:
5192      Tuple:
5193          List: if `greedy` is `true`, returns a list of one element that
5194              contains the decoded sequence.
5195              If `false`, returns the `top_paths` most probable
5196              decoded sequences.
5197              Important: blank labels are returned as `-1`.
5198          Tensor `(top_paths, )` that contains
5199              the log probability of each decoded sequence.
5200  """
5201  y_pred = math_ops.log(array_ops.transpose(y_pred, perm=[1, 0, 2]) + epsilon())
5202  input_length = math_ops.cast(input_length, dtypes_module.int32)
5203
5204  if greedy:
5205    (decoded, log_prob) = ctc.ctc_greedy_decoder(
5206        inputs=y_pred, sequence_length=input_length)
5207  else:
5208    (decoded, log_prob) = ctc.ctc_beam_search_decoder(
5209        inputs=y_pred,
5210        sequence_length=input_length,
5211        beam_width=beam_width,
5212        top_paths=top_paths)
5213  decoded_dense = [
5214      sparse_ops.sparse_to_dense(
5215          st.indices, st.dense_shape, st.values, default_value=-1)
5216      for st in decoded
5217  ]
5218  return (decoded_dense, log_prob)
5219
5220
5221# HIGH ORDER FUNCTIONS
5222
5223
5224@keras_export('keras.backend.map_fn')
5225def map_fn(fn, elems, name=None, dtype=None):
5226  """Map the function fn over the elements elems and return the outputs.
5227
5228  Arguments:
5229      fn: Callable that will be called upon each element in elems
5230      elems: tensor
5231      name: A string name for the map node in the graph
5232      dtype: Output data type.
5233
5234  Returns:
5235      Tensor with dtype `dtype`.
5236  """
5237  return map_fn_lib.map_fn(fn, elems, name=name, dtype=dtype)
5238
5239
5240@keras_export('keras.backend.foldl')
5241def foldl(fn, elems, initializer=None, name=None):
5242  """Reduce elems using fn to combine them from left to right.
5243
5244  Arguments:
5245      fn: Callable that will be called upon each element in elems and an
5246          accumulator, for instance `lambda acc, x: acc + x`
5247      elems: tensor
5248      initializer: The first value used (`elems[0]` in case of None)
5249      name: A string name for the foldl node in the graph
5250
5251  Returns:
5252      Tensor with same type and shape as `initializer`.
5253  """
5254  return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
5255
5256
5257@keras_export('keras.backend.foldr')
5258def foldr(fn, elems, initializer=None, name=None):
5259  """Reduce elems using fn to combine them from right to left.
5260
5261  Arguments:
5262      fn: Callable that will be called upon each element in elems and an
5263          accumulator, for instance `lambda acc, x: acc + x`
5264      elems: tensor
5265      initializer: The first value used (`elems[-1]` in case of None)
5266      name: A string name for the foldr node in the graph
5267
5268  Returns:
5269      Same type and shape as initializer
5270  """
5271  return functional_ops.foldr(fn, elems, initializer=initializer, name=name)
5272
5273# Load Keras default configuration from config file if present.
5274# Set Keras base dir path given KERAS_HOME env variable, if applicable.
5275# Otherwise either ~/.keras or /tmp.
5276if 'KERAS_HOME' in os.environ:
5277  _keras_dir = os.environ.get('KERAS_HOME')
5278else:
5279  _keras_base_dir = os.path.expanduser('~')
5280  _keras_dir = os.path.join(_keras_base_dir, '.keras')
5281_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
5282if os.path.exists(_config_path):
5283  try:
5284    _config = json.load(open(_config_path))
5285  except ValueError:
5286    _config = {}
5287  _floatx = _config.get('floatx', floatx())
5288  assert _floatx in {'float16', 'float32', 'float64'}
5289  _epsilon = _config.get('epsilon', epsilon())
5290  assert isinstance(_epsilon, float)
5291  _image_data_format = _config.get('image_data_format', image_data_format())
5292  assert _image_data_format in {'channels_last', 'channels_first'}
5293  set_floatx(_floatx)
5294  set_epsilon(_epsilon)
5295  set_image_data_format(_image_data_format)
5296
5297# Save config file.
5298if not os.path.exists(_keras_dir):
5299  try:
5300    os.makedirs(_keras_dir)
5301  except OSError:
5302    # Except permission denied and potential race conditions
5303    # in multi-threaded environments.
5304    pass
5305
5306if not os.path.exists(_config_path):
5307  _config = {
5308      'floatx': floatx(),
5309      'epsilon': epsilon(),
5310      'backend': 'tensorflow',
5311      'image_data_format': image_data_format()
5312  }
5313  try:
5314    with open(_config_path, 'w') as f:
5315      f.write(json.dumps(_config, indent=4))
5316  except IOError:
5317    # Except permission denied.
5318    pass
5319
5320
5321def in_multi_worker_mode():
5322  """Whether we are operating in a Multi-Worker setting."""
5323  tf_config = json.loads(os.environ.get('TF_CONFIG', '{}'))
5324  cluster_spec = server_lib.ClusterSpec(tf_config.get('cluster', {}))
5325  return tf_config and 'master' not in cluster_spec.jobs
5326
5327
5328def configure_and_create_distributed_session(distribution_strategy):
5329  """Configure session config and create a session with it."""
5330
5331  def _create_session(distribution_strategy):
5332    """Create the Distributed Strategy session."""
5333    session_config = get_default_session_config()
5334
5335    # If a session already exists, merge in its config; in the case there is a
5336    # conflict, take values of the existing config.
5337    global _SESSION
5338    if getattr(_SESSION, 'session', None) and _SESSION.session._config:
5339      session_config.MergeFrom(_SESSION.session._config)
5340
5341    if is_tpu_strategy(distribution_strategy):
5342      # TODO(priyag, yuefengz): Remove this workaround when Distribute
5343      # Coordinator is integrated with keras and we can create a session from
5344      # there.
5345      distribution_strategy.configure(session_config)
5346      master = distribution_strategy.extended._tpu_cluster_resolver.master()  # pylint: disable=protected-access
5347      session = session_module.Session(config=session_config, target=master)
5348    else:
5349      worker_context = dc_context.get_current_worker_context()
5350      if worker_context:
5351        dc_session_config = worker_context.session_config
5352        # Merge the default session config to the one from distribute
5353        # coordinator, which is fine for now since they don't have
5354        # conflicting configurations.
5355        dc_session_config.MergeFrom(session_config)
5356        session = session_module.Session(
5357            config=dc_session_config, target=worker_context.master_target)
5358      else:
5359        distribution_strategy.configure(session_config)
5360        session = session_module.Session(config=session_config)
5361
5362    set_session(session)
5363
5364  if in_multi_worker_mode():
5365    dc.run_distribute_coordinator(
5366        _create_session,
5367        distribution_strategy,
5368        mode=dc.CoordinatorMode.INDEPENDENT_WORKER)
5369  else:
5370    _create_session(distribution_strategy)
5371
5372
5373def is_tpu_strategy(strategy):
5374  """We're executing TPU Strategy."""
5375  return strategy is not None and strategy.__class__.__name__ == 'TPUStrategy'
5376
5377
5378def cast_variables_to_tensor(tensors):
5379
5380  def _cast_variables_to_tensor(tensor):
5381    if isinstance(tensor, variables_module.Variable):
5382      return array_ops.identity(tensor)
5383    return tensor
5384
5385  return nest.map_structure(_cast_variables_to_tensor, tensors)
5386