1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Contains the `Node` class."""
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import copy
23import json
24import numpy as np
25
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_util
28from tensorflow.python.keras import backend
29from tensorflow.python.keras.engine import base_layer_utils
30from tensorflow.python.keras.engine import keras_tensor
31from tensorflow.python.keras.saving.saved_model import json_utils
32from tensorflow.python.keras.utils import tf_utils
33from tensorflow.python.util import nest
34
35_CONSTANT_VALUE = '_CONSTANT_VALUE'
36
37
38class Node(object):
39  """A `Node` describes the connectivity between two layers.
40
41  Each time a layer is connected to some new input,
42  a node is added to `layer._inbound_nodes`.
43  Each time the output of a layer is used by another layer,
44  a node is added to `layer._outbound_nodes`.
45
46  Args:
47      layer: The Layer for the Layer.__call__ this node represents.
48      call_args: The positional arguments the Layer was called with.
49      call_kwargs: The keyword arguments the Layer was called with.
50      outputs: The outputs of the Layer.__call__
51  """
52
53  def __init__(self,
54               layer,
55               call_args=None,
56               call_kwargs=None,
57               outputs=None):
58    call_args = [] if call_args is None else call_args
59    call_kwargs = {} if call_kwargs is None else call_kwargs
60    outputs = [] if outputs is None else outputs
61
62    self.layer = layer
63    self.is_input = not call_args and not call_kwargs
64
65    # These arguments are user-provided. Copy the structures here so that
66    # future user modifications do not affect the node's metadata.
67    # We copy using map_structure rather than python's shallow or deep copy,
68    # because the args can be data structures (so shallow copy is
69    # insufficient), but individual values might not support copy.copy
70    # or be too expensive to deep copy.
71    call_args = nest.map_structure(lambda t: t, call_args)
72    call_kwargs = nest.map_structure(lambda t: t, call_kwargs)
73    self.outputs = nest.map_structure(lambda t: t, outputs)
74    self.call_args = call_args
75    self.call_kwargs = call_kwargs
76
77    # Cached for performance.
78    self._flat_arguments = nest.flatten((self.call_args, self.call_kwargs))
79    # Used to avoid expensive `nest` operations in the most common case.
80    self._single_positional_tensor_passed = (not self.call_kwargs and len(
81        self.call_args) == 1 and tensor_util.is_tf_type(self.call_args[0]))
82
83    if not keras_tensor.keras_tensors_enabled():
84      # Create TensorFlowOpLayers if needed.
85      for obj in self._flat_arguments:
86        if (isinstance(obj, ops.Tensor) and
87            base_layer_utils.needs_keras_history(
88                obj, ignore_call_context=True)):
89          base_layer_utils.create_keras_history(obj)
90
91    self._keras_inputs = []
92    self._keras_inputs_ids_and_indices = []
93    for i, ele in enumerate(self._flat_arguments):
94      if is_keras_tensor(ele):
95        self._keras_inputs.append(ele)
96        kt_id = str(id(ele))
97        kt_index = i
98        self._keras_inputs_ids_and_indices.append((kt_id, kt_index))
99
100    # Wire up Node to Layers.
101    self.layer._inbound_nodes.append(self)
102    for kt in self.keras_inputs:
103      inbound_layer = kt._keras_history.layer
104      if inbound_layer is not None:  # `None` for `Input` tensors.
105        inbound_layer._outbound_nodes.append(self)
106
107    # Set metadata on outputs.
108    node_index = len(self.layer._inbound_nodes) - 1
109    for i, tensor in enumerate(nest.flatten(outputs)):
110      tensor._keras_history = KerasHistory(
111          layer=layer, node_index=node_index, tensor_index=i)
112
113    # Cached for performance.
114    self.flat_input_ids = [str(id(t)) for t in self._keras_inputs]
115    self.flat_output_ids = [str(id(t)) for t in nest.flatten(self.outputs)]
116
117  @property
118  def keras_inputs(self):
119    """Tensors input to this node that can be traced back to a `keras.Input`."""
120    return self._keras_inputs
121
122  @property
123  def parent_nodes(self):
124    """Returns all the `Node`s whose output this node immediately depends on."""
125    node_deps = []
126    for kt in self.keras_inputs:
127      layer = kt._keras_history.layer
128      node_index = kt._keras_history.node_index
129      if layer is not None:  # `None` for `Input` tensors.
130        node_deps.append(layer._inbound_nodes[node_index])
131    return node_deps
132
133  def iterate_inbound(self):
134    """Yields tuples representing the data inbound from other nodes.
135
136    Yields:
137      tuples like: (inbound_layer, node_index, tensor_index, tensor).
138    """
139    for kt in self.keras_inputs:
140      keras_history = kt._keras_history
141      layer = keras_history.layer
142      node_index = keras_history.node_index
143      tensor_index = keras_history.tensor_index
144      yield layer, node_index, tensor_index, kt
145
146  def map_arguments(self, tensor_dict):
147    """Maps Keras Tensors to computed Tensors using `tensor_dict`."""
148    if self._single_positional_tensor_passed:
149      # Performance optimization for most common case.
150      kt_id, _ = self._keras_inputs_ids_and_indices[0]
151      return (tensor_dict[kt_id].pop(),), {}
152    else:
153      flat_arguments = copy.copy(self._flat_arguments)
154      for kt_id, kt_index in self._keras_inputs_ids_and_indices:
155        flat_arguments[kt_index] = tensor_dict[kt_id].pop()
156
157      args, kwargs = nest.pack_sequence_as((self.call_args, self.call_kwargs),
158                                           flat_arguments)
159      return args, kwargs
160
161  def serialize(self, make_node_key, node_conversion_map):
162    """Serializes `Node` for Functional API's `get_config`."""
163    # Serialization still special-cases first argument.
164    args, kwargs = self.call_args, self.call_kwargs
165    inputs, args, kwargs = self.layer._split_out_first_arg(args, kwargs)
166
167    # Treat everything other than first argument as a kwarg.
168    arguments = dict(zip(self.layer._call_fn_args[1:], args))
169    arguments.update(kwargs)
170    kwargs = arguments
171
172    def _serialize_keras_tensor(t):
173      """Serializes a single Tensor passed to `call`."""
174      if hasattr(t, '_keras_history'):
175        kh = t._keras_history
176        node_index = kh.node_index
177        node_key = make_node_key(kh.layer.name, node_index)
178        new_node_index = node_conversion_map.get(node_key, 0)
179        return [kh.layer.name, new_node_index, kh.tensor_index]
180
181      if isinstance(t, np.ndarray):
182        return t.tolist()
183
184      if isinstance(t, ops.Tensor):
185        return backend.get_value(t).tolist()
186
187      return t
188
189    kwargs = nest.map_structure(_serialize_keras_tensor, kwargs)
190    try:
191      json.dumps(kwargs, default=json_utils.get_json_type)
192    except TypeError:
193      kwarg_types = nest.map_structure(type, kwargs)
194      raise TypeError('Layer ' + self.layer.name +
195                      ' was passed non-JSON-serializable arguments. ' +
196                      'Arguments had types: ' +
197                      str(kwarg_types) + '. They cannot be serialized out '
198                      'when saving the model.')
199
200    # `kwargs` is added to each Tensor in the first arg. This should be
201    # changed in a future version of the serialization format.
202    def serialize_first_arg_tensor(t):
203      if is_keras_tensor(t):
204        kh = t._keras_history
205        node_index = kh.node_index
206        node_key = make_node_key(kh.layer.name, node_index)
207        new_node_index = node_conversion_map.get(node_key, 0)
208        data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
209      else:
210        # If an element in the first call argument did not originate as a
211        # keras tensor and is a constant value, we save it using the format
212        # ['_CONSTANT_VALUE', -1, serializaed_tensor_or_python_constant]
213        # (potentially including serialized kwargs in an optional 4th argument
214        data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
215      return tf_utils.ListWrapper(data)
216
217    data = nest.map_structure(serialize_first_arg_tensor, inputs)
218    if (not nest.is_nested(data) and
219        not self.layer._preserve_input_structure_in_config):
220      data = [data]
221    data = tf_utils.convert_inner_node_data(data)
222    return data
223
224  #############################################################
225  # Properties for Backwards compatibility.
226  # These only check the first input argument
227  # As nodes are internal, they may be removed in the future.
228  #############################################################
229
230  @property
231  def input_tensors(self):
232    if self.is_input:
233      return [self.outputs]  # Used in `Layer.input`.
234    return self.call_args[0]
235
236  @property
237  def output_tensors(self):
238    if self.is_input:
239      return [self.outputs]  # Used in `Layer.input`.
240    return self.outputs
241
242  @property
243  def input_shapes(self):
244    input_shapes = nest.map_structure(backend.int_shape, self.input_tensors)
245    if len(input_shapes) == 1 and not self.is_input:
246      return input_shapes[0]
247    return input_shapes
248
249  @property
250  def output_shapes(self):
251    return nest.map_structure(backend.int_shape, self.output_tensors)
252
253  @property
254  def outbound_layer(self):
255    return self.layer
256
257  @property
258  def inbound_layers(self):
259    if self.is_input:
260      return []
261    inbound_layers = nest.map_structure(lambda t: t._keras_history.layer,
262                                        self.call_args[0])
263    return inbound_layers
264
265
266class KerasHistory(
267    collections.namedtuple('KerasHistory',
268                           ['layer', 'node_index', 'tensor_index'])):
269  """Tracks the Layer call that created a Tensor, for Keras Graph Networks.
270
271  During construction of Keras Graph Networks, this metadata is added to
272  each Tensor produced as the output of a Layer, starting with an
273  `InputLayer`. This allows Keras to track how each Tensor was produced, and
274  this information is later retraced by the `keras.engine.Network` class to
275  reconstruct the Keras Graph Network.
276
277  Attributes:
278    layer: The Layer that produced the Tensor.
279    node_index: The specific call to the Layer that produced this Tensor. Layers
280      can be called multiple times in order to share weights. A new node is
281      created every time a Layer is called.
282    tensor_index: The output index for this Tensor. Always zero if the Layer
283      that produced this Tensor only has one output. Nested structures of
284      Tensors are deterministically assigned an index via `nest.flatten`.
285  """
286  # Added to maintain memory and performance characteristics of `namedtuple`
287  # while subclassing.
288  __slots__ = ()
289
290
291def is_keras_tensor(obj):
292  return hasattr(obj, '_keras_history')
293