1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to layer/model functionality.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import functools
23import weakref
24
25import numpy as np
26import six
27
28from tensorflow.python.util import nest
29from tensorflow.python.util.tf_export import keras_export
30
31
32@keras_export('keras.utils.get_source_inputs')
33def get_source_inputs(tensor, layer=None, node_index=None):
34  """Returns the list of input tensors necessary to compute `tensor`.
35
36  Output will always be a list of tensors
37  (potentially with 1 element).
38
39  Args:
40      tensor: The tensor to start from.
41      layer: Origin layer of the tensor. Will be
42          determined via tensor._keras_history if not provided.
43      node_index: Origin node index of the tensor.
44
45  Returns:
46      List of input tensors.
47  """
48  if not hasattr(tensor, '_keras_history'):
49    return tensor
50
51  if layer is None or node_index:
52    layer, node_index, _ = tensor._keras_history
53  if not layer._inbound_nodes:
54    return [tensor]
55  else:
56    node = layer._inbound_nodes[node_index]
57    if node.is_input:
58      # Reached an Input layer, stop recursion.
59      return nest.flatten(node.input_tensors)
60    else:
61      source_tensors = []
62      for layer, node_index, _, tensor in node.iterate_inbound():
63        previous_sources = get_source_inputs(tensor, layer, node_index)
64        # Avoid input redundancy.
65        for x in previous_sources:
66          if all(x is not t for t in source_tensors):
67            source_tensors.append(x)
68      return source_tensors
69
70
71def validate_string_arg(input_data,
72                        allowable_strings,
73                        layer_name,
74                        arg_name,
75                        allow_none=False,
76                        allow_callables=False):
77  """Validates the correctness of a string-based arg."""
78  if allow_none and input_data is None:
79    return
80  elif allow_callables and callable(input_data):
81    return
82  elif isinstance(input_data,
83                  six.string_types) and input_data in allowable_strings:
84    return
85  else:
86    allowed_args = '`None`, ' if allow_none else ''
87    allowed_args += 'a `Callable`, ' if allow_callables else ''
88    allowed_args += 'or one of the following values: %s' % (allowable_strings,)
89    raise ValueError(('The %s argument of layer %s received an invalid '
90                      'value %s. Allowed values are: %s.') %
91                     (arg_name, layer_name, input_data, allowed_args))
92
93
94def count_params(weights):
95  """Count the total number of scalars composing the weights.
96
97  Args:
98      weights: An iterable containing the weights on which to compute params
99
100  Returns:
101      The total number of scalars composing the weights
102  """
103  unique_weights = {id(w): w for w in weights}.values()
104  weight_shapes = [w.shape.as_list() for w in unique_weights]
105  standardized_weight_shapes = [
106      [0 if w_i is None else w_i for w_i in w] for w in weight_shapes
107  ]
108  return int(sum(np.prod(p) for p in standardized_weight_shapes))
109
110
111def print_summary(model, line_length=None, positions=None, print_fn=None):
112  """Prints a summary of a model.
113
114  Args:
115      model: Keras model instance.
116      line_length: Total length of printed lines
117          (e.g. set this to adapt the display to different
118          terminal window sizes).
119      positions: Relative or absolute positions of log elements in each line.
120          If not provided, defaults to `[.33, .55, .67, 1.]`.
121      print_fn: Print function to use.
122          It will be called on each line of the summary.
123          You can set it to a custom function
124          in order to capture the string summary.
125          It defaults to `print` (prints to stdout).
126  """
127  if print_fn is None:
128    print_fn = print
129
130  if model.__class__.__name__ == 'Sequential':
131    sequential_like = True
132  elif not model._is_graph_network:
133    # We treat subclassed models as a simple sequence of layers, for logging
134    # purposes.
135    sequential_like = True
136  else:
137    sequential_like = True
138    nodes_by_depth = model._nodes_by_depth.values()
139    nodes = []
140    for v in nodes_by_depth:
141      if (len(v) > 1) or (len(v) == 1 and
142                          len(nest.flatten(v[0].keras_inputs)) > 1):
143        # if the model has multiple nodes
144        # or if the nodes have multiple inbound_layers
145        # the model is no longer sequential
146        sequential_like = False
147        break
148      nodes += v
149    if sequential_like:
150      # search for shared layers
151      for layer in model.layers:
152        flag = False
153        for node in layer._inbound_nodes:
154          if node in nodes:
155            if flag:
156              sequential_like = False
157              break
158            else:
159              flag = True
160        if not sequential_like:
161          break
162
163  if sequential_like:
164    line_length = line_length or 65
165    positions = positions or [.45, .85, 1.]
166    if positions[-1] <= 1:
167      positions = [int(line_length * p) for p in positions]
168    # header names for the different log elements
169    to_display = ['Layer (type)', 'Output Shape', 'Param #']
170  else:
171    line_length = line_length or 98
172    positions = positions or [.33, .55, .67, 1.]
173    if positions[-1] <= 1:
174      positions = [int(line_length * p) for p in positions]
175    # header names for the different log elements
176    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
177    relevant_nodes = []
178    for v in model._nodes_by_depth.values():
179      relevant_nodes += v
180
181  def print_row(fields, positions):
182    line = ''
183    for i in range(len(fields)):
184      if i > 0:
185        line = line[:-1] + ' '
186      line += str(fields[i])
187      line = line[:positions[i]]
188      line += ' ' * (positions[i] - len(line))
189    print_fn(line)
190
191  print_fn('Model: "{}"'.format(model.name))
192  print_fn('_' * line_length)
193  print_row(to_display, positions)
194  print_fn('=' * line_length)
195
196  def print_layer_summary(layer):
197    """Prints a summary for a single layer.
198
199    Args:
200        layer: target layer.
201    """
202    try:
203      output_shape = layer.output_shape
204    except AttributeError:
205      output_shape = 'multiple'
206    except RuntimeError:  # output_shape unknown in Eager mode.
207      output_shape = '?'
208    name = layer.name
209    cls_name = layer.__class__.__name__
210    if not layer.built and not getattr(layer, '_is_graph_network', False):
211      # If a subclassed model has a layer that is not called in Model.call, the
212      # layer will not be built and we cannot call layer.count_params().
213      params = '0 (unused)'
214    else:
215      params = layer.count_params()
216    fields = [name + ' (' + cls_name + ')', output_shape, params]
217    print_row(fields, positions)
218
219  def print_layer_summary_with_connections(layer):
220    """Prints a summary for a single layer (including topological connections).
221
222    Args:
223        layer: target layer.
224    """
225    try:
226      output_shape = layer.output_shape
227    except AttributeError:
228      output_shape = 'multiple'
229    connections = []
230    for node in layer._inbound_nodes:
231      if relevant_nodes and node not in relevant_nodes:
232        # node is not part of the current network
233        continue
234
235      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
236        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
237                                               tensor_index))
238
239    name = layer.name
240    cls_name = layer.__class__.__name__
241    if not connections:
242      first_connection = ''
243    else:
244      first_connection = connections[0]
245    fields = [
246        name + ' (' + cls_name + ')', output_shape,
247        layer.count_params(), first_connection
248    ]
249    print_row(fields, positions)
250    if len(connections) > 1:
251      for i in range(1, len(connections)):
252        fields = ['', '', '', connections[i]]
253        print_row(fields, positions)
254
255  layers = model.layers
256  for i in range(len(layers)):
257    if sequential_like:
258      print_layer_summary(layers[i])
259    else:
260      print_layer_summary_with_connections(layers[i])
261    if i == len(layers) - 1:
262      print_fn('=' * line_length)
263    else:
264      print_fn('_' * line_length)
265
266  if hasattr(model, '_collected_trainable_weights'):
267    trainable_count = count_params(model._collected_trainable_weights)
268  else:
269    trainable_count = count_params(model.trainable_weights)
270
271  non_trainable_count = count_params(model.non_trainable_weights)
272
273  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
274  print_fn('Trainable params: {:,}'.format(trainable_count))
275  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
276  print_fn('_' * line_length)
277
278
279def convert_dense_weights_data_format(dense,
280                                      previous_feature_map_shape,
281                                      target_data_format='channels_first'):
282  """Utility useful when changing a convnet's `data_format`.
283
284  When porting the weights of a convnet from one data format to the other,
285  if the convnet includes a `Flatten` layer
286  (applied to the last convolutional feature map)
287  followed by a `Dense` layer, the weights of that `Dense` layer
288  should be updated to reflect the new dimension ordering.
289
290  Args:
291      dense: The target `Dense` layer.
292      previous_feature_map_shape: A shape tuple of 3 integers,
293          e.g. `(512, 7, 7)`. The shape of the convolutional
294          feature map right before the `Flatten` layer that
295          came before the target `Dense` layer.
296      target_data_format: One of "channels_last", "channels_first".
297          Set it "channels_last"
298          if converting a "channels_first" model to "channels_last",
299          or reciprocally.
300  """
301  assert target_data_format in {'channels_last', 'channels_first'}
302  kernel, bias = dense.get_weights()
303  for i in range(kernel.shape[1]):
304    if target_data_format == 'channels_first':
305      c, h, w = previous_feature_map_shape
306      original_fm_shape = (h, w, c)
307      ki = kernel[:, i].reshape(original_fm_shape)
308      ki = np.transpose(ki, (2, 0, 1))  # last -> first
309    else:
310      h, w, c = previous_feature_map_shape
311      original_fm_shape = (c, h, w)
312      ki = kernel[:, i].reshape(original_fm_shape)
313      ki = np.transpose(ki, (1, 2, 0))  # first -> last
314    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
315  dense.set_weights([kernel, bias])
316
317
318def is_builtin_layer(layer):
319  if not getattr(layer, '_keras_api_names', None):
320    return False
321
322  # Subclasses of `Layer` that are not exported inherit the export name
323  # of the base layer class.
324  return (layer._keras_api_names != ('keras.layers.Layer',) and
325          layer._keras_api_names_v1 != ('keras.layers.Layer',))
326
327
328def cached_per_instance(f):
329  """Lightweight decorator for caching lazily constructed properties.
330
331  When to use:
332  This decorator provides simple caching with minimal overhead. It is designed
333  for properties which are expensive to compute and static over the life of a
334  class instance, and provides no mechanism for cache invalidation. Thus it is
335  best suited for lazily exposing derived properties of other static data.
336
337  For classes with custom getattr / setattr behavior (such as trackable
338  objects), storing cache results as object attributes is not performant.
339  Instead, a specialized cache can significantly reduce property lookup
340  overhead. (While still allowing the decorated property to be lazily computed.)
341  Consider the following class:
342
343  ```
344  class MyClass(object):
345    def __setattr__(self, key, value):
346      # Some expensive class specific code
347      # ...
348      # ...
349
350      super(MyClass, self).__setattr__(key, value)
351
352    @property
353    def thing(self):
354      # `thing` is expensive to compute (and may not even be requested), so we
355      # want to lazily compute it and then cache it.
356      output = getattr(self, '_thing', None)
357      if output is None:
358        self._thing = output = compute_thing(self)
359      return output
360  ```
361
362  It's also worth noting that ANY overriding of __setattr__, even something as
363  simple as:
364  ```
365    def __setattr__(self, key, value):
366      super(MyClass, self).__setattr__(key, value)
367  ```
368
369  Slows down attribute assignment by nearly 10x.
370
371  By contrast, replacing the definition of `thing` with the following sidesteps
372  the expensive __setattr__ altogether:
373
374  '''
375  @property
376  @tracking.cached_per_instance
377  def thing(self):
378    # `thing` is expensive to compute (and may not even be requested), so we
379    # want to lazily compute it and then cache it.
380    return compute_thing(self)
381  '''
382
383  Performance:
384  The overhead for this decorator is ~0.4 us / call. A much lower overhead
385  implementation (~0.085 us / call) can be achieved by using a custom dict type:
386
387  ```
388  def dict_based_cache(f):
389    class Cache(dict):
390      __slots__ = ()
391      def __missing__(self, key):
392        self[key] = output = f(key)
393        return output
394
395    return property(Cache().__getitem__)
396  ```
397
398  However, that implementation holds class instances as keys, and as a result
399  blocks garbage collection. (And modifying it to use weakref's as keys raises
400  the lookup overhead to ~0.4 us) As a result, the WeakKeyDictionary
401  implementation below turns out to be more prudent.
402
403  Args:
404    f: The function to cache.
405
406  Returns:
407    f decorated with simple caching behavior.
408  """
409
410  cache = weakref.WeakKeyDictionary()
411
412  @functools.wraps(f)
413  def wrapped(item):
414    output = cache.get(item)
415    if output is None:
416      cache[item] = output = f(item)
417    return output
418
419  wrapped.cache = cache
420  return wrapped
421
422
423def filter_empty_layer_containers(layer_list):
424  """Filter out empty Layer-like containers and uniquify."""
425  # TODO(b/130381733): Make this an attribute in base_layer.Layer.
426  existing = set()
427  to_visit = layer_list[::-1]
428  while to_visit:
429    obj = to_visit.pop()
430    if id(obj) in existing:
431      continue
432    existing.add(id(obj))
433    if hasattr(obj, '_is_layer') and not isinstance(obj, type):
434      yield obj
435    else:
436      sub_layers = getattr(obj, 'layers', None) or []
437
438      # Trackable data structures will not show up in ".layers" lists, but
439      # the layers they contain will.
440      to_visit.extend(sub_layers[::-1])
441