1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""## Functions for working with arbitrarily nested sequences of elements.
17
18This module can perform operations on nested structures. A nested structure is a
19Python collection that can contain further collections as well as other objects
20called atoms. Note that numpy arrays are considered atoms.
21
22nest recognizes the following types of collections:
23  1.tuple
24  2.namedtuple
25  3.dict
26  4.orderedDict
27  5.MutableMapping
28  6.attr.s
29
30attr.s decorated classes (http://www.attrs.org) are also supported, in the
31same way as `namedtuple`.
32
33The utilities here assume (and do not check) that the nested structures form a
34'tree', i.e., no references in the structure of the input of these functions
35should be recursive.
36
37Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
38  (np.array([3, 4]), tf.constant([3, 4])))`
39"""
40
41from __future__ import absolute_import
42from __future__ import division
43from __future__ import print_function
44
45import collections as _collections
46
47import six as _six
48import wrapt as _wrapt
49
50from tensorflow.python.platform import tf_logging
51from tensorflow.python.util import _pywrap_nest
52from tensorflow.python.util import _pywrap_utils
53from tensorflow.python.util.compat import collections_abc as _collections_abc
54from tensorflow.python.util.tf_export import tf_export
55
56
57_SHALLOW_TREE_HAS_INVALID_KEYS = (
58    "The shallow_tree's keys are not a subset of the input_tree's keys. The "
59    "shallow_tree has the following keys that are not in the input_tree: {}.")
60
61_STRUCTURES_HAVE_MISMATCHING_TYPES = (
62    "The two structures don't have the same sequence type. Input structure has "
63    "type {input_type}, while shallow structure has type {shallow_type}.")
64
65_STRUCTURES_HAVE_MISMATCHING_LENGTHS = (
66    "The two structures don't have the same sequence length. Input "
67    "structure has length {input_length}, while shallow structure has length "
68    "{shallow_length}."
69)
70
71_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
72    "The input_tree has fewer elements than the shallow_tree. Input structure "
73    "has length {input_size}, while shallow structure has length "
74    "{shallow_size}.")
75
76_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
77    "If shallow structure is a sequence, input must also be a sequence. "
78    "Input has type: {}.")
79
80
81def _get_attrs_items(obj):
82  """Returns a list of (name, value) pairs from an attrs instance.
83
84  The list will be sorted by name.
85
86  Args:
87    obj: an object.
88
89  Returns:
90    A list of (attr_name, attr_value) pairs, sorted by attr_name.
91  """
92  attrs = getattr(obj.__class__, "__attrs_attrs__")
93  attr_names = (a.name for a in attrs)
94  return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
95
96
97def _sorted(dict_):
98  """Returns a sorted list of the dict keys, with error if keys not sortable."""
99  try:
100    return sorted(dict_.keys())
101  except TypeError:
102    raise TypeError("nest only supports dicts with sortable keys.")
103
104
105def _is_namedtuple(instance, strict=False):
106  """Returns True iff `instance` is a `namedtuple`.
107
108  Args:
109    instance: An instance of a Python object.
110    strict: If True, `instance` is considered to be a `namedtuple` only if
111        it is a "plain" namedtuple. For instance, a class inheriting
112        from a `namedtuple` will be considered to be a `namedtuple`
113        iff `strict=False`.
114
115  Returns:
116    True if `instance` is a `namedtuple`.
117  """
118  return _pywrap_utils.IsNamedtuple(instance, strict)
119
120
121# See the swig file (util.i) for documentation.
122_is_mapping_view = _pywrap_utils.IsMappingView
123_is_attrs = _pywrap_utils.IsAttrs
124_is_composite_tensor = _pywrap_utils.IsCompositeTensor
125_is_type_spec = _pywrap_utils.IsTypeSpec
126_is_mutable_mapping = _pywrap_utils.IsMutableMapping
127_is_mapping = _pywrap_utils.IsMapping
128
129
130@tf_export("__internal__.nest.is_attrs", v1=[])
131def is_attrs(obj):
132  """Returns a true if its input is an instance of an attr.s decorated class."""
133  return _is_attrs(obj)
134
135
136@tf_export("__internal__.nest.is_mapping", v1=[])
137def is_mapping(obj):
138  """Returns a true if its input is a collections.Mapping."""
139  return _is_mapping(obj)
140
141
142@tf_export("__internal__.nest.sequence_like", v1=[])
143def _sequence_like(instance, args):
144  """Converts the sequence `args` to the same type as `instance`.
145
146  Args:
147    instance: an instance of `tuple`, `list`, `namedtuple`, `dict`,
148        `collections.OrderedDict`, or `composite_tensor.Composite_Tensor`
149        or `type_spec.TypeSpec`.
150    args: elements to be converted to the `instance` type.
151
152  Returns:
153    `args` with the type of `instance`.
154  """
155  if _is_mutable_mapping(instance):
156    # Pack dictionaries in a deterministic order by sorting the keys.
157    # Notice this means that we ignore the original order of `OrderedDict`
158    # instances. This is intentional, to avoid potential bugs caused by mixing
159    # ordered and plain dicts (e.g., flattening a dict but using a
160    # corresponding `OrderedDict` to pack it back).
161    result = dict(zip(_sorted(instance), args))
162    instance_type = type(instance)
163    if instance_type == _collections.defaultdict:
164      d = _collections.defaultdict(instance.default_factory)
165    else:
166      d = instance_type()
167    for key in instance:
168      d[key] = result[key]
169    return d
170  elif _is_mapping(instance):
171    result = dict(zip(_sorted(instance), args))
172    instance_type = type(instance)
173    tf_logging.log_first_n(
174        tf_logging.WARN, "Mapping types may not work well with tf.nest. Prefer"
175        " using MutableMapping for {}".format(instance_type), 1)
176    try:
177      return instance_type((key, result[key]) for key in instance)
178    except TypeError as err:
179      raise TypeError("Error creating an object of type {} like {}. Note that "
180                      "it must accept a single positional argument "
181                      "representing an iterable of key-value pairs, in "
182                      "addition to self. Cause: {}".format(
183                          type(instance), instance, err))
184  elif _is_mapping_view(instance):
185    # We can't directly construct mapping views, so we create a list instead
186    return list(args)
187  elif _is_namedtuple(instance) or _is_attrs(instance):
188    if isinstance(instance, _wrapt.ObjectProxy):
189      instance_type = type(instance.__wrapped__)
190    else:
191      instance_type = type(instance)
192    return instance_type(*args)
193  elif _is_composite_tensor(instance):
194    assert len(args) == 1
195    spec = instance._type_spec  # pylint: disable=protected-access
196    return spec._from_components(args[0])  # pylint: disable=protected-access
197  elif _is_type_spec(instance):
198    # Pack a CompositeTensor's components according to a TypeSpec.
199    assert len(args) == 1
200    return instance._from_components(args[0])  # pylint: disable=protected-access
201  elif isinstance(instance, _six.moves.range):
202    return _sequence_like(list(instance), args)
203  elif isinstance(instance, _wrapt.ObjectProxy):
204    # For object proxies, first create the underlying type and then re-wrap it
205    # in the proxy type.
206    return type(instance)(_sequence_like(instance.__wrapped__, args))
207  else:
208    # Not a namedtuple
209    return type(instance)(args)
210
211
212def _yield_value(iterable):
213  for _, v in _yield_sorted_items(iterable):
214    yield v
215
216
217def _yield_sorted_items(iterable):
218  """Yield (key, value) pairs for `iterable` in a deterministic order.
219
220  For Sequences, the key will be an int, the array index of a value.
221  For Mappings, the key will be the dictionary key.
222  For objects (e.g. namedtuples), the key will be the attribute name.
223
224  In all cases, the keys will be iterated in sorted order.
225
226  Args:
227    iterable: an iterable.
228
229  Yields:
230    The iterable's (key, value) pairs, in order of sorted keys.
231  """
232  # Ordered to check common structure types (list, tuple, dict) first.
233  if isinstance(iterable, list):
234    for item in enumerate(iterable):
235      yield item
236  # namedtuples handled separately to avoid expensive namedtuple check.
237  elif type(iterable) == tuple:  # pylint: disable=unidiomatic-typecheck
238    for item in enumerate(iterable):
239      yield item
240  elif isinstance(iterable, (dict, _collections_abc.Mapping)):
241    # Iterate through dictionaries in a deterministic order by sorting the
242    # keys. Notice this means that we ignore the original order of `OrderedDict`
243    # instances. This is intentional, to avoid potential bugs caused by mixing
244    # ordered and plain dicts (e.g., flattening a dict but using a
245    # corresponding `OrderedDict` to pack it back).
246    for key in _sorted(iterable):
247      yield key, iterable[key]
248  elif _is_attrs(iterable):
249    for item in _get_attrs_items(iterable):
250      yield item
251  elif _is_namedtuple(iterable):
252    for field in iterable._fields:
253      yield field, getattr(iterable, field)
254  elif _is_composite_tensor(iterable):
255    type_spec = iterable._type_spec  # pylint: disable=protected-access
256    yield type_spec.value_type.__name__, type_spec._to_components(iterable)  # pylint: disable=protected-access
257  elif _is_type_spec(iterable):
258    # Note: to allow CompositeTensors and their TypeSpecs to have matching
259    # structures, we need to use the same key string here.
260    yield iterable.value_type.__name__, iterable._component_specs  # pylint: disable=protected-access
261  else:
262    for item in enumerate(iterable):
263      yield item
264
265
266# See the swig file (util.i) for documentation.
267is_sequence = _pywrap_utils.IsSequence
268
269
270# See the swig file (util.i) for documentation.
271is_sequence_or_composite = _pywrap_utils.IsSequenceOrComposite
272
273
274@tf_export("nest.is_nested")
275def is_nested(seq):
276  """Returns true if its input is a collections.abc.Sequence (except strings).
277
278    >>> tf.nest.is_nested("1234")
279    False
280
281    >>> tf.nest.is_nested([1, 3, [4, 5]])
282    True
283
284    >>> tf.nest.is_nested(((7, 8), (5, 6)))
285    True
286
287    >>> tf.nest.is_nested([])
288    True
289
290    >>> tf.nest.is_nested({"a": 1, "b": 2})
291    True
292
293    >>> tf.nest.is_nested({"a": 1, "b": 2}.keys())
294    True
295
296    >>> tf.nest.is_nested({"a": 1, "b": 2}.values())
297    True
298
299    >>> tf.nest.is_nested({"a": 1, "b": 2}.items())
300    True
301
302    >>> tf.nest.is_nested(set([1, 2]))
303    False
304
305    >>> ones = tf.ones([2, 3])
306    >>> tf.nest.is_nested(ones)
307    False
308
309  Args:
310    seq: an input sequence.
311
312  Returns:
313    True if the sequence is a not a string and is a collections.abc.Sequence
314    or a dict.
315  """
316  return is_sequence(seq)
317
318
319@tf_export("nest.flatten")
320def flatten(structure, expand_composites=False):
321  """Returns a flat list from a given nested structure.
322
323  If nest is not a structure , tuple (or a namedtuple), dict, or an attrs class,
324  then returns a single-element list:
325    [nest].
326
327  This is the inverse of the `nest.pack_sequence_as` method that takes in a
328  flattened list and re-packs it into the nested structure.
329
330  In the case of dict instances, the sequence consists of the values, sorted by
331  key to ensure deterministic behavior. This is true also for OrderedDict
332  instances: their sequence order is ignored, the sorting order of keys is used
333  instead. The same convention is followed in `nest.pack_sequence_as`. This
334  correctly repacks dicts and OrderedDicts after they have been flattened, and
335  also allows flattening an OrderedDict and then repacking it back using a
336  corresponding plain dict, or vice-versa. Dictionaries with non-sortable keys
337  cannot be flattened.
338
339  Users must not modify any collections used in nest while this function is
340  running.
341
342  Examples:
343
344  1. Python dict (ordered by key):
345
346    >>> dict = { "key3": "value3", "key1": "value1", "key2": "value2" }
347    >>> tf.nest.flatten(dict)
348    ['value1', 'value2', 'value3']
349
350  2. For a nested python tuple:
351
352    >>> tuple = ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
353    >>> tf.nest.flatten(tuple)
354        [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
355
356  3. For a nested dictionary of dictionaries:
357
358    >>> dict = { "key3": {"c": (1.0, 2.0), "a": (3.0)},
359    ... "key1": {"m": "val1", "g": "val2"} }
360    >>> tf.nest.flatten(dict)
361    ['val2', 'val1', 3.0, 1.0, 2.0]
362
363  4. Numpy array (will not flatten):
364
365    >>> array = np.array([[1, 2], [3, 4]])
366    >>> tf.nest.flatten(array)
367        [array([[1, 2],
368                [3, 4]])]
369
370  5. `tf.Tensor` (will not flatten):
371
372    >>> tensor = tf.constant([[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]])
373    >>> tf.nest.flatten(tensor)
374        [<tf.Tensor: shape=(3, 3), dtype=float32, numpy=
375          array([[1., 2., 3.],
376                 [4., 5., 6.],
377                 [7., 8., 9.]], dtype=float32)>]
378
379  6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
380  of a flattened list of 'values' and a list of 'row_splits' which indicate how
381  to chop up the flattened list into different rows. For more details on
382  `tf.RaggedTensor`, please visit
383  https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
384
385  with `expand_composites=False`, we just return the RaggedTensor as is.
386
387    >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
388    >>> tf.nest.flatten(tensor, expand_composites=False)
389    [<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2]]>]
390
391  with `expand_composites=True`, we return the component Tensors that make up
392  the RaggedTensor representation (the values and row_splits tensors)
393
394    >>> tensor = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2]])
395    >>> tf.nest.flatten(tensor, expand_composites=True)
396    [<tf.Tensor: shape=(7,), dtype=int32, numpy=array([3, 1, 4, 1, 5, 9, 2],
397                                                      dtype=int32)>,
398     <tf.Tensor: shape=(4,), dtype=int64, numpy=array([0, 4, 4, 7])>]
399
400  Args:
401    structure: an arbitrarily nested structure. Note, numpy arrays are
402      considered atoms and are not flattened.
403    expand_composites: If true, then composite tensors such as
404      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
405      component tensors.
406
407  Returns:
408    A Python list, the flattened version of the input.
409
410  Raises:
411    TypeError: The nest is or contains a dict with non-sortable keys.
412  """
413  if structure is None:
414    return [None]
415  expand_composites = bool(expand_composites)
416  return _pywrap_utils.Flatten(structure, expand_composites)
417
418
419# See the swig file (util.i) for documentation.
420_same_namedtuples = _pywrap_utils.SameNamedtuples
421
422
423class _DotString(object):
424
425  __slots__ = []
426
427  def __str__(self):
428    return "."
429
430  def __repr__(self):
431    return "."
432
433
434_DOT = _DotString()
435
436
437@tf_export("nest.assert_same_structure")
438def assert_same_structure(nest1, nest2, check_types=True,
439                          expand_composites=False):
440  """Asserts that two structures are nested in the same way.
441
442  Note the method does not check the types of data inside the structures.
443
444  Examples:
445
446  * These scalar vs. scalar comparisons will pass:
447
448    >>> tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
449    >>> tf.nest.assert_same_structure("abc", np.array([1, 2]))
450
451  * These sequence vs. sequence comparisons will pass:
452
453    >>> structure1 = (((1, 2), 3), 4, (5, 6))
454    >>> structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
455    >>> structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
456    >>> tf.nest.assert_same_structure(structure1, structure2)
457    >>> tf.nest.assert_same_structure(structure1, structure3, check_types=False)
458
459    >>> import collections
460    >>> tf.nest.assert_same_structure(
461    ...     collections.namedtuple("bar", "a b")(1, 2),
462    ...     collections.namedtuple("foo", "a b")(2, 3),
463    ...     check_types=False)
464
465    >>> tf.nest.assert_same_structure(
466    ...     collections.namedtuple("bar", "a b")(1, 2),
467    ...     { "a": 1, "b": 2 },
468    ...     check_types=False)
469
470    >>> tf.nest.assert_same_structure(
471    ...     { "a": 1, "b": 2, "c": 3 },
472    ...     { "c": 6, "b": 5, "a": 4 })
473
474    >>> ragged_tensor1 = tf.RaggedTensor.from_row_splits(
475    ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
476    ...       row_splits=[0, 4, 4, 7, 8, 8])
477    >>> ragged_tensor2 = tf.RaggedTensor.from_row_splits(
478    ...       values=[3, 1, 4],
479    ...       row_splits=[0, 3])
480    >>> tf.nest.assert_same_structure(
481    ...       ragged_tensor1,
482    ...       ragged_tensor2,
483    ...       expand_composites=True)
484
485  * These examples will raise exceptions:
486
487    >>> tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
488    Traceback (most recent call last):
489    ...
490    ValueError: The two structures don't have the same nested structure
491
492    >>> tf.nest.assert_same_structure(
493    ...       collections.namedtuple('bar', 'a b')(1, 2),
494    ...       collections.namedtuple('foo', 'a b')(2, 3))
495    Traceback (most recent call last):
496    ...
497    TypeError: The two structures don't have the same nested structure
498
499  Args:
500    nest1: an arbitrarily nested structure.
501    nest2: an arbitrarily nested structure.
502    check_types: if `True` (default) types of sequences are checked as well,
503      including the keys of dictionaries. If set to `False`, for example a
504      list and a tuple of objects will look the same if they have the same
505      size. Note that namedtuples with identical name and fields are always
506      considered to have the same shallow structure. Two types will also be
507      considered the same if they are both list subtypes (which allows "list"
508      and "_ListWrapper" from trackable dependency tracking to compare
509      equal).
510    expand_composites: If true, then composite tensors such as
511      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
512      component tensors.
513
514  Raises:
515    ValueError: If the two structures do not have the same number of elements or
516      if the two structures are not nested in the same way.
517    TypeError: If the two structures differ in the type of sequence in any of
518      their substructures. Only possible if `check_types` is `True`.
519  """
520  # Convert to bool explicitly as otherwise pybind will not be able# to handle
521  # type mismatch message correctly. See GitHub issue 42329 for details.
522  check_types = bool(check_types)
523  expand_composites = bool(expand_composites)
524  try:
525    _pywrap_utils.AssertSameStructure(nest1, nest2, check_types,
526                                      expand_composites)
527  except (ValueError, TypeError) as e:
528    str1 = str(map_structure(lambda _: _DOT, nest1))
529    str2 = str(map_structure(lambda _: _DOT, nest2))
530    raise type(e)("%s\n"
531                  "Entire first structure:\n%s\n"
532                  "Entire second structure:\n%s"
533                  % (str(e), str1, str2))
534
535
536def flatten_dict_items(dictionary):
537  """Returns a dictionary with flattened keys and values.
538
539  This function flattens the keys and values of a dictionary, which can be
540  arbitrarily nested structures, and returns the flattened version of such
541  structures:
542
543  ```python
544  example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
545  result = {4: "a", 5: "b", 6: "c", 8: "d"}
546  flatten_dict_items(example_dictionary) == result
547  ```
548
549  The input dictionary must satisfy two properties:
550
551  1. Its keys and values should have the same exact nested structure.
552  2. The set of all flattened keys of the dictionary must not contain repeated
553     keys.
554
555  Args:
556    dictionary: the dictionary to zip
557
558  Returns:
559    The zipped dictionary.
560
561  Raises:
562    TypeError: If the input is not a dictionary.
563    ValueError: If any key and value do not have the same structure layout, or
564    if keys are not unique.
565  """
566  return _pywrap_nest.FlattenDictItems(dictionary)
567
568
569def _packed_nest_with_indices(structure, flat, index, is_seq, sequence_fn=None):
570  """Helper function for pack_sequence_as.
571
572  Args:
573    structure: Substructure (list / tuple / dict) to mimic.
574    flat: Flattened values to output substructure for.
575    index: Index at which to start reading from flat.
576    is_seq: Function used to test if a value should be treated as a sequence.
577    sequence_fn: Function used to generate a new sequence instance.
578
579  Returns:
580    The tuple (new_index, child), where:
581      * new_index - the updated index into `flat` having processed `structure`.
582      * packed - the subset of `flat` corresponding to `structure`,
583                 having started at `index`, and packed into the same nested
584                 format.
585
586  Raises:
587    ValueError: if `structure` contains more elements than `flat`
588      (assuming indexing starts from `index`).
589  """
590  packed = []
591  sequence_fn = sequence_fn or _sequence_like
592  for s in _yield_value(structure):
593    if is_seq(s):
594      new_index, child = _packed_nest_with_indices(s, flat, index, is_seq,
595                                                   sequence_fn)
596      packed.append(sequence_fn(s, child))
597      index = new_index
598    else:
599      packed.append(flat[index])
600      index += 1
601  return index, packed
602
603
604def _pack_sequence_as(structure, flat_sequence, expand_composites,
605                      sequence_fn=None):
606  """Implements sequence packing, with the option to alter the structure."""
607  is_seq = is_sequence_or_composite if expand_composites else is_sequence
608  sequence_fn = sequence_fn or _sequence_like
609  def truncate(value, length):
610    value_str = str(value)
611    return value_str[:length] + (value_str[length:] and "...")
612
613  if not is_seq(flat_sequence):
614    raise TypeError(
615        "Attempted to pack value:\n  {}\ninto a sequence, but found "
616        "incompatible type `{}` instead."
617        .format(truncate(flat_sequence, 100), type(flat_sequence)))
618
619  if not is_seq(structure):
620    if len(flat_sequence) != 1:
621      raise ValueError(
622          "The target structure is of type `{}`\n  {}\nHowever the input "
623          "structure is a sequence ({}) of length {}.\n  {}\nnest cannot "
624          "guarantee that it is safe to map one to the other.".format(
625              type(structure), truncate(structure, 100), type(flat_sequence),
626              len(flat_sequence), truncate(flat_sequence, 100)))
627    return flat_sequence[0]
628
629  try:
630    final_index, packed = _packed_nest_with_indices(structure, flat_sequence,
631                                                    0, is_seq, sequence_fn)
632    if final_index < len(flat_sequence):
633      raise IndexError
634  except IndexError:
635    flat_structure = flatten(structure)
636    if len(flat_structure) != len(flat_sequence):
637      raise ValueError(
638          "Could not pack sequence. Structure had %d elements, but "
639          "flat_sequence had %d elements.  Structure: %s, flat_sequence: %s." %
640          (len(flat_structure), len(flat_sequence), structure, flat_sequence))
641  return sequence_fn(structure, packed)
642
643
644@tf_export("nest.pack_sequence_as")
645def pack_sequence_as(structure, flat_sequence, expand_composites=False):
646  """Returns a given flattened sequence packed into a given structure.
647
648  If `structure` is a scalar, `flat_sequence` must be a single-element list;
649  in this case the return value is `flat_sequence[0]`.
650
651  If `structure` is or contains a dict instance, the keys will be sorted to
652  pack the flat sequence in deterministic order. This is true also for
653  `OrderedDict` instances: their sequence order is ignored, the sorting order of
654  keys is used instead. The same convention is followed in `flatten`.
655  This correctly repacks dicts and `OrderedDict`s after they have been
656  flattened, and also allows flattening an `OrderedDict` and then repacking it
657  back using a corresponding plain dict, or vice-versa.
658  Dictionaries with non-sortable keys cannot be flattened.
659
660  Examples:
661
662  1. Python dict:
663
664    >>> structure = { "key3": "", "key1": "", "key2": "" }
665    >>> flat_sequence = ["value1", "value2", "value3"]
666    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
667    {'key3': 'value3', 'key1': 'value1', 'key2': 'value2'}
668
669  2. For a nested python tuple:
670
671    >>> structure = (('a','b'), ('c','d','e'), 'f')
672    >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
673    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
674    ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
675
676  3. For a nested dictionary of dictionaries:
677
678    >>> structure = { "key3": {"c": ('alpha', 'beta'), "a": ('gamma')},
679    ...               "key1": {"e": "val1", "d": "val2"} }
680    >>> flat_sequence = ['val2', 'val1', 3.0, 1.0, 2.0]
681    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
682    {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
683
684  4. Numpy array (considered a scalar):
685
686    >>> structure = ['a']
687    >>> flat_sequence = [np.array([[1, 2], [3, 4]])]
688    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
689    [array([[1, 2],
690           [3, 4]])]
691
692  5. tf.Tensor (considered a scalar):
693
694    >>> structure = ['a']
695    >>> flat_sequence = [tf.constant([[1., 2., 3.], [4., 5., 6.]])]
696    >>> tf.nest.pack_sequence_as(structure, flat_sequence)
697    [<tf.Tensor: shape=(2, 3), dtype=float32,
698     numpy= array([[1., 2., 3.], [4., 5., 6.]], dtype=float32)>]
699
700  6. `tf.RaggedTensor`: This is a composite tensor thats representation consists
701  of a flattened list of 'values' and a list of 'row_splits' which indicate how
702  to chop up the flattened list into different rows. For more details on
703  `tf.RaggedTensor`, please visit
704  https://www.tensorflow.org/api_docs/python/tf/RaggedTensor.
705
706  With `expand_composites=False`, we treat RaggedTensor as a scalar.
707
708    >>> structure = { "foo": tf.ragged.constant([[1, 2], [3]]),
709    ...               "bar": tf.constant([[5]]) }
710    >>> flat_sequence = [ "one", "two" ]
711    >>> tf.nest.pack_sequence_as(structure, flat_sequence,
712    ... expand_composites=False)
713    {'foo': 'two', 'bar': 'one'}
714
715  With `expand_composites=True`, we expect that the flattened input contains
716  the tensors making up the ragged tensor i.e. the values and row_splits
717  tensors.
718
719    >>> structure = { "foo": tf.ragged.constant([[1., 2.], [3.]]),
720    ...               "bar": tf.constant([[5.]]) }
721    >>> tensors = tf.nest.flatten(structure, expand_composites=True)
722    >>> print(tensors)
723    [<tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
724     dtype=float32)>,
725     <tf.Tensor: shape=(3,), dtype=float32, numpy=array([1., 2., 3.],
726     dtype=float32)>,
727     <tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 2, 3])>]
728    >>> verified_tensors = [tf.debugging.check_numerics(t, 'invalid tensor: ')
729    ...                     if t.dtype==tf.float32 else t
730    ...                     for t in tensors]
731    >>> tf.nest.pack_sequence_as(structure, verified_tensors,
732    ...                          expand_composites=True)
733    {'foo': <tf.RaggedTensor [[1.0, 2.0], [3.0]]>,
734     'bar': <tf.Tensor: shape=(1, 1), dtype=float32, numpy=array([[5.]],
735     dtype=float32)>}
736
737  Args:
738    structure: Nested structure, whose structure is given by nested lists,
739      tuples, and dicts. Note: numpy arrays and strings are considered
740      scalars.
741    flat_sequence: flat sequence to pack.
742    expand_composites: If true, then composite tensors such as
743      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
744      component tensors.
745
746  Returns:
747    packed: `flat_sequence` converted to have the same recursive structure as
748      `structure`.
749
750  Raises:
751    ValueError: If `flat_sequence` and `structure` have different
752      element counts.
753    TypeError: `structure` is or contains a dict with non-sortable keys.
754  """
755  return _pack_sequence_as(structure, flat_sequence, expand_composites)
756
757
758@tf_export("nest.map_structure")
759def map_structure(func, *structure, **kwargs):
760  """Applies `func` to each entry in `structure` and returns a new structure.
761
762  Applies `func(x[0], x[1], ...)` where x[i] is an entry in
763  `structure[i]`.  All structures in `structure` must have the same arity,
764  and the return value will contain results with the same structure layout.
765
766  Examples:
767
768  * A single Python dict:
769
770  >>> a = {"hello": 24, "world": 76}
771  >>> tf.nest.map_structure(lambda p: p * 2, a)
772  {'hello': 48, 'world': 152}
773
774  * Multiple Python dictionaries:
775
776  >>> d1 = {"hello": 24, "world": 76}
777  >>> d2 = {"hello": 36, "world": 14}
778  >>> tf.nest.map_structure(lambda p1, p2: p1 + p2, d1, d2)
779  {'hello': 60, 'world': 90}
780
781  * A single Python list:
782
783  >>> a = [24, 76, "ab"]
784  >>> tf.nest.map_structure(lambda p: p * 2, a)
785  [48, 152, 'abab']
786
787  * Scalars:
788
789  >>> tf.nest.map_structure(lambda x, y: x + y, 3, 4)
790  7
791
792  * Empty structures:
793
794  >>> tf.nest.map_structure(lambda x: x + 1, ())
795  ()
796
797  *. Check the types of iterables:
798
799  >>> s1 = (((1, 2), 3), 4, (5, 6))
800  >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
801  >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list)
802  Traceback (most recent call last):
803  ...
804  TypeError: The two structures don't have the same nested structure
805
806  * Type check is set to False:
807
808  >>> s1 = (((1, 2), 3), 4, (5, 6))
809  >>> s1_list = [[[1, 2], 3], 4, [5, 6]]
810  >>> tf.nest.map_structure(lambda x, y: None, s1, s1_list, check_types=False)
811  (((None, None), None), None, (None, None))
812
813  Args:
814    func: A callable that accepts as many arguments as there are structures.
815    *structure: scalar, or tuple or dict or list of constructed scalars and/or
816      other tuples/lists, or scalars.  Note: numpy arrays are considered as
817      scalars.
818    **kwargs: Valid keyword args are:
819
820      * `check_types`: If set to `True` (default) the types of
821        iterables within the structures have to be same (e.g.
822        `map_structure(func, [1], (1,))` raises a `TypeError`
823        exception). To allow this set this argument to `False`.
824        Note that namedtuples with identical name and fields are always
825        considered to have the same shallow structure.
826      * `expand_composites`: If set to `True`, then composite tensors such
827        as `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into
828        their component tensors.  If `False` (the default), then composite
829        tensors are not expanded.
830
831  Returns:
832    A new structure with the same arity as `structure`, whose values correspond
833    to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
834    location in `structure[i]`. If there are different sequence types and
835    `check_types` is `False` the sequence types of the first structure will be
836    used.
837
838  Raises:
839    TypeError: If `func` is not callable or if the structures do not match
840      each other by depth tree.
841    ValueError: If no structure is provided or if the structures do not match
842      each other by type.
843    ValueError: If wrong keyword arguments are provided.
844  """
845  if not callable(func):
846    raise TypeError("func must be callable, got: %s" % func)
847
848  if not structure:
849    raise ValueError("Must provide at least one structure")
850
851  check_types = kwargs.pop("check_types", True)
852  expand_composites = kwargs.pop("expand_composites", False)
853
854  if kwargs:
855    raise ValueError(
856        "Only valid keyword arguments are `check_types` and "
857        "`expand_composites`, not: `%s`" % ("`, `".join(kwargs.keys())))
858
859  for other in structure[1:]:
860    assert_same_structure(structure[0], other, check_types=check_types,
861                          expand_composites=expand_composites)
862
863  flat_structure = (flatten(s, expand_composites) for s in structure)
864  entries = zip(*flat_structure)
865
866  return pack_sequence_as(
867      structure[0], [func(*x) for x in entries],
868      expand_composites=expand_composites)
869
870
871def map_structure_with_paths(func, *structure, **kwargs):
872  """Applies `func` to each entry in `structure` and returns a new structure.
873
874  Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
875  `structure[i]` and `path` is the common path to x[i] in the structures.  All
876  structures in `structure` must have the same arity, and the return value will
877  contain the results with the same structure layout. Special kwarg
878  `check_types` determines whether the types of iterables within the structure
879  must be the same-- see **kwargs definition below.
880
881  Args:
882    func: A callable with the signature func(path, *values, **kwargs) that is
883      evaluated on the leaves of the structure.
884    *structure: A variable number of compatible structures to process.
885    **kwargs: Optional kwargs to be passed through to func. Special kwarg
886      `check_types` is not passed to func, but instead determines whether the
887      types of iterables within the structures have to be same (e.g.,
888      `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
889      default, the types must match. To allow iteration over structures of
890      different types (but common arity), set this kwarg to `False`.
891
892  Returns:
893    A structure of the same form as the input structures whose leaves are the
894    result of evaluating func on corresponding leaves of the input structures.
895
896  Raises:
897    TypeError: If `func` is not callable or if the structures do not match
898      each other by depth tree.
899    TypeError: If `check_types` is not `False` and the two structures differ in
900      the type of sequence in any of their substructures.
901    ValueError: If no structures are provided.
902  """
903  def wrapper_func(tuple_path, *inputs, **kwargs):
904    string_path = "/".join(str(s) for s in tuple_path)
905    return func(string_path, *inputs, **kwargs)
906
907  return map_structure_with_tuple_paths_up_to(structure[0],
908                                              wrapper_func,
909                                              *structure,
910                                              **kwargs)
911
912
913def map_structure_with_tuple_paths(func, *structure, **kwargs):
914  """Applies `func` to each entry in `structure` and returns a new structure.
915
916  Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry
917  in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary
918  keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the
919  common path to x[i] in the structures. All structures in `structure` must have
920  the same arity, and the return value will contain the results in the same
921  structure. Special kwarg `check_types` determines whether the types of
922  iterables within the structure must be the same-- see **kwargs definition
923  below.
924
925  Args:
926    func: A callable with the signature `func(tuple_path, *values, **kwargs)`
927      that is evaluated on the leaves of the structure.
928    *structure: A variable number of compatible structures to process.
929    **kwargs: Optional kwargs to be passed through to func. Special kwarg
930      `check_types` is not passed to func, but instead determines whether the
931      types of iterables within the structures have to be same (e.g.
932      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
933      this set this argument to `False`.
934
935  Returns:
936    A structure of the same form as the input structures whose leaves are the
937    result of evaluating func on corresponding leaves of the input structures.
938
939  Raises:
940    TypeError: If `func` is not callable or if the structures do not match
941      each other by depth tree.
942    TypeError: If `check_types` is not `False` and the two structures differ in
943      the type of sequence in any of their substructures.
944    ValueError: If no structures are provided.
945  """
946  return map_structure_with_tuple_paths_up_to(structure[0],
947                                              func,
948                                              *structure,
949                                              **kwargs)
950
951
952def _yield_flat_up_to(shallow_tree, input_tree, is_seq, path=()):
953  """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
954
955  Args:
956    shallow_tree: Nested structure. Traverse no further than its leaf nodes.
957    input_tree: Nested structure. Return the paths and values from this tree.
958      Must have the same upper structure as shallow_tree.
959    is_seq: Function used to test if a value should be treated as a sequence.
960    path: Tuple. Optional argument, only used when recursing. The path from the
961      root of the original shallow_tree, down to the root of the shallow_tree
962      arg of this recursive call.
963
964  Yields:
965    Pairs of (path, value), where path the tuple path of a leaf node in
966    shallow_tree, and value is the value of the corresponding node in
967    input_tree.
968  """
969  if not is_seq(shallow_tree):
970    yield (path, input_tree)
971  else:
972    input_tree = dict(_yield_sorted_items(input_tree))
973    for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
974      subpath = path + (shallow_key,)
975      input_subtree = input_tree[shallow_key]
976      for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
977                                                     input_subtree, is_seq,
978                                                     path=subpath):
979        yield (leaf_path, leaf_value)
980
981
982def assert_shallow_structure(shallow_tree,
983                             input_tree,
984                             check_types=True,
985                             expand_composites=False):
986  """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
987
988  That is, this function tests if the `input_tree` structure can be created from
989  the `shallow_tree` structure by replacing its leaf nodes with deeper
990  tree structures.
991
992  Examples:
993
994  The following code will raise an exception:
995  ```python
996    shallow_tree = {"a": "A", "b": "B"}
997    input_tree = {"a": 1, "c": 2}
998    assert_shallow_structure(shallow_tree, input_tree)
999  ```
1000
1001  The following code will raise an exception:
1002  ```python
1003    shallow_tree = ["a", "b"]
1004    input_tree = ["c", ["d", "e"], "f"]
1005    assert_shallow_structure(shallow_tree, input_tree)
1006  ```
1007
1008  Args:
1009    shallow_tree: an arbitrarily nested structure.
1010    input_tree: an arbitrarily nested structure.
1011    check_types: if `True` (default) the sequence types of `shallow_tree` and
1012      `input_tree` have to be the same. Note that even with check_types==True,
1013      this function will consider two different namedtuple classes with the same
1014      name and _fields attribute to be the same class.
1015    expand_composites: If true, then composite tensors such as
1016      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1017      component tensors.
1018  Raises:
1019    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1020    TypeError: If the sequence types of `shallow_tree` are different from
1021      `input_tree`. Only raised if `check_types` is `True`.
1022    ValueError: If the sequence lengths of `shallow_tree` are different from
1023      `input_tree`.
1024  """
1025  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1026  if is_seq(shallow_tree):
1027    if not is_seq(input_tree):
1028      raise TypeError(
1029          "If shallow structure is a sequence, input must also be a sequence. "
1030          "Input has type: %s." % type(input_tree))
1031
1032    if isinstance(shallow_tree, _wrapt.ObjectProxy):
1033      shallow_type = type(shallow_tree.__wrapped__)
1034    else:
1035      shallow_type = type(shallow_tree)
1036
1037    if check_types and not isinstance(input_tree, shallow_type):
1038      # Duck-typing means that nest should be fine with two different
1039      # namedtuples with identical name and fields.
1040      shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
1041      input_is_namedtuple = _is_namedtuple(input_tree, False)
1042      if shallow_is_namedtuple and input_is_namedtuple:
1043        if not _same_namedtuples(shallow_tree, input_tree):
1044          raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1045              input_type=type(input_tree),
1046              shallow_type=type(shallow_tree)))
1047
1048      elif isinstance(shallow_tree, list) and isinstance(input_tree, list):
1049        # List subclasses are considered the same,
1050        # e.g. python list vs. _ListWrapper.
1051        pass
1052
1053      elif ((_is_composite_tensor(shallow_tree) or
1054             _is_composite_tensor(input_tree)) and
1055            (_is_type_spec(shallow_tree) or _is_type_spec(input_tree))):
1056        pass  # Compatibility will be checked below.
1057
1058      elif not (isinstance(shallow_tree, _collections_abc.Mapping) and
1059                isinstance(input_tree, _collections_abc.Mapping)):
1060        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1061            input_type=type(input_tree),
1062            shallow_type=type(shallow_tree)))
1063
1064    if _is_composite_tensor(shallow_tree) or _is_composite_tensor(input_tree):
1065      if not (
1066          (_is_composite_tensor(input_tree) or _is_type_spec(input_tree)) and
1067          (_is_composite_tensor(shallow_tree) or _is_type_spec(shallow_tree))):
1068        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
1069            input_type=type(input_tree),
1070            shallow_type=type(shallow_tree)))
1071      type_spec_1 = (shallow_tree if _is_type_spec(shallow_tree) else
1072                     shallow_tree._type_spec)  # pylint: disable=protected-access
1073      type_spec_2 = (input_tree if _is_type_spec(input_tree) else
1074                     input_tree._type_spec)  # pylint: disable=protected-access
1075      try:
1076        _ = type_spec_1.most_specific_compatible_type(type_spec_2)
1077      except (TypeError, ValueError) as e:
1078        raise ValueError(
1079            "Incompatible CompositeTensor TypeSpecs: %s vs. %s -- %s" %
1080            (type_spec_1, type_spec_2, e))
1081
1082    elif _is_type_spec(shallow_tree):
1083      if not _is_type_spec(input_tree):
1084        raise TypeError("If shallow structure is a TypeSpec, input must also "
1085                        "be a TypeSpec.  Input has type: %s."
1086                        % type(input_tree))
1087    else:
1088      if len(input_tree) != len(shallow_tree):
1089        raise ValueError(
1090            _STRUCTURES_HAVE_MISMATCHING_LENGTHS.format(
1091                input_length=len(input_tree), shallow_length=len(shallow_tree)))
1092      elif len(input_tree) < len(shallow_tree):
1093        raise ValueError(
1094            _INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
1095                input_size=len(input_tree), shallow_size=len(shallow_tree)))
1096
1097    if isinstance(shallow_tree, _collections_abc.Mapping):
1098      absent_keys = set(shallow_tree) - set(input_tree)
1099      if absent_keys:
1100        raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
1101                         .format(sorted(absent_keys)))
1102
1103    for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
1104                                            _yield_value(input_tree)):
1105      assert_shallow_structure(shallow_branch, input_branch,
1106                               check_types=check_types,
1107                               expand_composites=expand_composites)
1108
1109
1110@tf_export("__internal__.nest.flatten_up_to", v1=[])
1111def flatten_up_to(shallow_tree, input_tree, check_types=True,
1112                  expand_composites=False):
1113  """Flattens `input_tree` up to `shallow_tree`.
1114
1115  Any further depth in structure in `input_tree` is retained as elements in the
1116  partially flatten output.
1117
1118  If `shallow_tree` and `input_tree` are not sequences, this returns a
1119  single-element list: `[input_tree]`.
1120
1121  Use Case:
1122
1123  Sometimes we may wish to partially flatten a nested sequence, retaining some
1124  of the nested structure. We achieve this by specifying a shallow structure,
1125  `shallow_tree`, we wish to flatten up to.
1126
1127  The input, `input_tree`, can be thought of as having the same structure layout
1128  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1129
1130  Examples:
1131
1132  ```python
1133  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1134  shallow_tree = [[True, True], [False, True]]
1135
1136  flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
1137  flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
1138
1139  # Output is:
1140  # [[2, 2], [3, 3], [4, 9], [5, 5]]
1141  # [True, True, False, True]
1142  ```
1143
1144  ```python
1145  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1146  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1147
1148  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1149  input_tree_flattened = flatten(input_tree)
1150
1151  # Output is:
1152  # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
1153  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1154  ```
1155
1156  Non-Sequence Edge Cases:
1157
1158  ```python
1159  flatten_up_to(0, 0)  # Output: [0]
1160  flatten_up_to(0, [0, 1, 2])  # Output: [[0, 1, 2]]
1161  flatten_up_to([0, 1, 2], 0)  # Output: TypeError
1162  flatten_up_to([0, 1, 2], [0, 1, 2])  # Output: [0, 1, 2]
1163  ```
1164
1165  Args:
1166    shallow_tree: a possibly pruned structure of input_tree.
1167    input_tree: an arbitrarily nested structure or a scalar object.
1168      Note, numpy arrays are considered scalars.
1169    check_types: bool. If True, check that each node in shallow_tree has the
1170      same type as the corresponding node in input_tree.
1171    expand_composites: If true, then composite tensors such as
1172      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1173      component tensors.
1174
1175  Returns:
1176    A Python list, the partially flattened version of `input_tree` according to
1177    the structure of `shallow_tree`.
1178
1179  Raises:
1180    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1181    TypeError: If the sequence types of `shallow_tree` are different from
1182      `input_tree`.
1183    ValueError: If the sequence lengths of `shallow_tree` are different from
1184      `input_tree`.
1185  """
1186  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1187  assert_shallow_structure(shallow_tree,
1188                           input_tree,
1189                           check_types=check_types,
1190                           expand_composites=expand_composites)
1191  # Discard paths returned by _yield_flat_up_to.
1192  return [v for _, v in _yield_flat_up_to(shallow_tree, input_tree, is_seq)]
1193
1194
1195def flatten_with_tuple_paths_up_to(shallow_tree,
1196                                   input_tree,
1197                                   check_types=True,
1198                                   expand_composites=False):
1199  """Flattens `input_tree` up to `shallow_tree`.
1200
1201  Any further depth in structure in `input_tree` is retained as elements in the
1202  partially flattened output.
1203
1204  Returns a list of (path, value) pairs, where value a leaf node in the
1205  flattened tree, and path is the tuple path of that leaf in input_tree.
1206
1207  If `shallow_tree` and `input_tree` are not sequences, this returns a
1208  single-element list: `[((), input_tree)]`.
1209
1210  Use Case:
1211
1212  Sometimes we may wish to partially flatten a nested sequence, retaining some
1213  of the nested structure. We achieve this by specifying a shallow structure,
1214  `shallow_tree`, we wish to flatten up to.
1215
1216  The input, `input_tree`, can be thought of as having the same structure layout
1217  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
1218
1219  Examples:
1220
1221  ```python
1222  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
1223  shallow_tree = [[True, True], [False, True]]
1224
1225  flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree,
1226                                                        input_tree)
1227  flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree,
1228                                                          shallow_tree)
1229
1230  # Output is:
1231  # [((0, 0), [2, 2]),
1232  #  ((0, 1), [3, 3]),
1233  #  ((1, 0), [4, 9]),
1234  #  ((1, 1), [5, 5])]
1235  #
1236  # [((0, 0), True),
1237  #  ((0, 1), True),
1238  #  ((1, 0), False),
1239  #  ((1, 1), True)]
1240  ```
1241
1242  ```python
1243  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
1244  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
1245
1246  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
1247  input_tree_flattened = flatten(input_tree)
1248
1249  # Output is:
1250  # [((0, 0), ('a', 1)),
1251  #  ((0, 1, 0), ('b', 2)),
1252  #  ((0, 1, 1, 0), ('c', 3)),
1253  #  ((0, 1, 1, 1), ('d', 4))]
1254  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
1255  ```
1256
1257  Non-Sequence Edge Cases:
1258
1259  ```python
1260  flatten_with_tuple_paths_up_to(0, 0)  # Output: [(), 0]
1261
1262  flatten_with_tuple_paths_up_to(0, [0, 1, 2])  # Output: [(), [0, 1, 2]]
1263
1264  flatten_with_tuple_paths_up_to([0, 1, 2], 0)  # Output: TypeError
1265
1266  flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2])
1267  # Output: [((0,) 0), ((1,), 1), ((2,), 2)]
1268  ```
1269
1270  Args:
1271    shallow_tree: a possibly pruned structure of input_tree.
1272    input_tree: an arbitrarily nested structure or a scalar object.
1273      Note, numpy arrays are considered scalars.
1274    check_types: bool. If True, check that each node in shallow_tree has the
1275      same type as the corresponding node in input_tree.
1276    expand_composites: If true, then composite tensors such as
1277      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1278      component tensors.
1279
1280  Returns:
1281    A Python list, the partially flattened version of `input_tree` according to
1282    the structure of `shallow_tree`.
1283
1284  Raises:
1285    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1286    TypeError: If the sequence types of `shallow_tree` are different from
1287      `input_tree`.
1288    ValueError: If the sequence lengths of `shallow_tree` are different from
1289      `input_tree`.
1290  """
1291  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1292  assert_shallow_structure(shallow_tree,
1293                           input_tree,
1294                           check_types=check_types,
1295                           expand_composites=expand_composites)
1296  return list(_yield_flat_up_to(shallow_tree, input_tree, is_seq))
1297
1298
1299@tf_export("__internal__.nest.map_structure_up_to", v1=[])
1300def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
1301  """Applies a function or op to a number of partially flattened inputs.
1302
1303  The `inputs` are flattened up to `shallow_tree` before being mapped.
1304
1305  Use Case:
1306
1307  Sometimes we wish to apply a function to a partially flattened
1308  sequence (for example when the function itself takes sequence inputs). We
1309  achieve this by specifying a shallow structure, `shallow_tree` we wish to
1310  flatten up to.
1311
1312  The `inputs`, can be thought of as having the same structure layout as
1313  `shallow_tree`, but with leaf nodes that are themselves tree structures.
1314
1315  This function therefore will return something with the same base structure as
1316  `shallow_tree`.
1317
1318  Examples:
1319
1320  ```python
1321  shallow_tree = [None, None]
1322  inp_val = [1, 2, 3]
1323  out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val)
1324
1325  # Output is: [2, 4]
1326  ```
1327
1328  ```python
1329  ab_tuple = collections.namedtuple("ab_tuple", "a, b")
1330  op_tuple = collections.namedtuple("op_tuple", "add, mul")
1331  inp_val = ab_tuple(a=2, b=3)
1332  inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
1333  out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
1334                            inp_val, inp_ops)
1335
1336  # Output is: ab_tuple(a=6, b=15)
1337  ```
1338
1339  ```python
1340  data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
1341  name_list = ['evens', ['odds', 'primes']]
1342  out = map_structure_up_to(
1343      name_list,
1344      lambda name, sec: "first_{}_{}".format(len(sec), name),
1345      name_list, data_list)
1346
1347  # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
1348  ```
1349
1350  Args:
1351    shallow_tree: a shallow tree, common to all the inputs.
1352    func: callable which will be applied to each input individually.
1353    *inputs: arbitrarily nested combination of objects that are compatible with
1354        shallow_tree. The function `func` is applied to corresponding
1355        partially flattened elements of each input, so the function must support
1356        arity of `len(inputs)`.
1357    **kwargs: kwargs to feed to func(). Special kwarg
1358      `check_types` is not passed to func, but instead determines whether the
1359      types of iterables within the structures have to be same (e.g.
1360      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1361      this set this argument to `False`.
1362
1363  Raises:
1364    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
1365    TypeError: If the sequence types of `shallow_tree` are different from
1366      `input_tree`.
1367    ValueError: If the sequence lengths of `shallow_tree` are different from
1368      `input_tree`.
1369
1370  Returns:
1371    result of repeatedly applying `func`, with the same structure layout as
1372    `shallow_tree`.
1373  """
1374  return map_structure_with_tuple_paths_up_to(
1375      shallow_tree,
1376      lambda _, *values: func(*values),  # Discards the path arg.
1377      *inputs,
1378      **kwargs)
1379
1380
1381def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
1382  """Applies a function or op to a number of partially flattened inputs.
1383
1384  Like map_structure_up_to(), except that the 'func' argument takes a path
1385  tuple as its first argument, followed by the corresponding values from
1386  *inputs.
1387
1388  Example:
1389
1390  ```python
1391  lowercase = {'a': 'a', 'b': ('b0', 'b1')}
1392  uppercase = {'a': 'A', 'b': ('B0', 'B1')}
1393
1394  def print_path_and_values(path, *values):
1395    print("path: {}, values: {}".format(path, values))
1396
1397  shallow_tree = {'a': None}
1398  map_structure_with_tuple_paths_up_to(shallow_tree,
1399                                       print_path_and_values,
1400                                       lowercase,
1401                                       uppercase)
1402  path: ('a',), values: ('a', 'A')
1403  path: ('b', 0), values: ('b0', 'B0')
1404  path: ('b', 1), values: ('b1', 'B1')
1405
1406  shallow_tree = {'b': None}
1407  map_structure_with_tuple_paths_up_to(shallow_tree,
1408                                       print_path_and_values,
1409                                       lowercase,
1410                                       uppercase,
1411                                       check_types=False)
1412  path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1'))
1413
1414  shallow_tree = {'a': None, 'b': {1: None}}
1415  map_structure_with_tuple_paths_up_to(shallow_tree,
1416                                       print_path_and_values,
1417                                       lowercase,
1418                                       uppercase,
1419                                       check_types=False)
1420  path: ('a',), values: ('a', 'A')
1421  path: ('b', 1), values: ('b1', B1')
1422  ```
1423
1424  Args:
1425    shallow_tree: a shallow tree, common to all the inputs.
1426    func: callable that takes args (path, inputs_0_value, ... , inputs_N_value),
1427      where path is a tuple path to a leaf node in shallow_tree, and
1428      inputs_i_value is the corresponding value from inputs[i].
1429    *inputs: nested structures that are all structurally compatible with
1430        shallow_tree.
1431    **kwargs: kwargs to feed to func(). Special kwarg
1432      `check_types` is not passed to func, but instead determines whether the
1433      types of iterables within the structures have to be same (e.g.
1434      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1435      this set this argument to `False`.
1436
1437  Raises:
1438    TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
1439    TypeError: If the sequence types of `shallow_tree` are different from
1440      `input_tree`.
1441    ValueError: If the sequence lengths of `shallow_tree` are different from
1442      `input_tree`.
1443
1444  Returns:
1445    Result of repeatedly applying `func`. Has the same structure layout as
1446    `shallow_tree`.
1447  """
1448  if not inputs:
1449    raise ValueError("Cannot map over no sequences")
1450
1451  check_types = kwargs.pop("check_types", True)
1452  expand_composites = kwargs.pop("expand_composites", False)
1453  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1454
1455  for input_tree in inputs:
1456    assert_shallow_structure(
1457        shallow_tree,
1458        input_tree,
1459        check_types=check_types,
1460        expand_composites=expand_composites)
1461
1462  # Flatten each input separately, apply the function to corresponding elements,
1463  # then repack based on the structure of the first input.
1464  flat_value_gen = (
1465      flatten_up_to(  # pylint: disable=g-complex-comprehension
1466          shallow_tree,
1467          input_tree,
1468          check_types,
1469          expand_composites=expand_composites) for input_tree in inputs)
1470  flat_path_gen = (
1471      path for path, _ in _yield_flat_up_to(shallow_tree, inputs[0], is_seq))
1472  results = [
1473      func(*args, **kwargs) for args in zip(flat_path_gen, *flat_value_gen)
1474  ]
1475  return pack_sequence_as(structure=shallow_tree, flat_sequence=results,
1476                          expand_composites=expand_composites)
1477
1478
1479@tf_export("__internal__.nest.get_traverse_shallow_structure", v1=[])
1480def get_traverse_shallow_structure(traverse_fn, structure,
1481                                   expand_composites=False):
1482  """Generates a shallow structure from a `traverse_fn` and `structure`.
1483
1484  `traverse_fn` must accept any possible subtree of `structure` and return
1485  a depth=1 structure containing `True` or `False` values, describing which
1486  of the top-level subtrees may be traversed.  It may also
1487  return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
1488
1489  Examples are available in the unit tests (nest_test.py).
1490
1491  Args:
1492    traverse_fn: Function taking a substructure and returning either a scalar
1493      `bool` (whether to traverse that substructure or not) or a depth=1
1494      shallow structure of the same type, describing which parts of the
1495      substructure to traverse.
1496    structure: The structure to traverse.
1497    expand_composites: If true, then composite tensors such as
1498      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1499      component tensors.
1500
1501  Returns:
1502    A shallow structure containing python bools, which can be passed to
1503    `map_structure_up_to` and `flatten_up_to`.
1504
1505  Raises:
1506    TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
1507      or a structure with depth higher than 1 for a sequence input,
1508      or if any leaf values in the returned structure or scalar are not type
1509      `bool`.
1510  """
1511  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1512  to_traverse = traverse_fn(structure)
1513  if not is_seq(structure):
1514    if not isinstance(to_traverse, bool):
1515      raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
1516                      % (to_traverse, structure))
1517    return to_traverse
1518  level_traverse = []
1519  if isinstance(to_traverse, bool):
1520    if not to_traverse:
1521      # Do not traverse this substructure at all.  Exit early.
1522      return False
1523    else:
1524      # Traverse the entire substructure.
1525      for branch in _yield_value(structure):
1526        level_traverse.append(
1527            get_traverse_shallow_structure(traverse_fn, branch,
1528                                           expand_composites=expand_composites))
1529  elif not is_seq(to_traverse):
1530    raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
1531                    % (to_traverse, structure))
1532  else:
1533    # Traverse some subset of this substructure.
1534    assert_shallow_structure(to_traverse, structure,
1535                             expand_composites=expand_composites)
1536    for t, branch in zip(_yield_value(to_traverse),
1537                         _yield_value(structure)):
1538      if not isinstance(t, bool):
1539        raise TypeError(
1540            "traverse_fn didn't return a depth=1 structure of bools.  saw: %s "
1541            " for structure: %s" % (to_traverse, structure))
1542      if t:
1543        level_traverse.append(
1544            get_traverse_shallow_structure(traverse_fn, branch))
1545      else:
1546        level_traverse.append(False)
1547  return _sequence_like(structure, level_traverse)
1548
1549
1550@tf_export("__internal__.nest.yield_flat_paths", v1=[])
1551def yield_flat_paths(nest, expand_composites=False):
1552  """Yields paths for some nested structure.
1553
1554  Paths are lists of objects which can be str-converted, which may include
1555  integers or other types which are used as indices in a dict.
1556
1557  The flat list will be in the corresponding order as if you called
1558  `nest.flatten` on the structure. This is handy for naming Tensors such
1559  the TF scope structure matches the tuple structure.
1560
1561  E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
1562
1563  ```shell
1564  nest.flatten(value)
1565  [3, 23, 42]
1566  list(nest.yield_flat_paths(value))
1567  [('a',), ('b', 'c'), ('b', 'd')]
1568  ```
1569
1570  ```shell
1571  list(nest.yield_flat_paths({'a': [3]}))
1572  [('a', 0)]
1573  list(nest.yield_flat_paths({'a': 3}))
1574  [('a',)]
1575  ```
1576
1577  Args:
1578    nest: the value to produce a flattened paths list for.
1579    expand_composites: If true, then composite tensors such as
1580      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1581      component tensors.
1582
1583  Yields:
1584    Tuples containing index or key values which form the path to a specific
1585    leaf value in the nested structure.
1586  """
1587  is_seq = is_sequence_or_composite if expand_composites else is_sequence
1588  for k, _ in _yield_flat_up_to(nest, nest, is_seq):
1589    yield k
1590
1591
1592def flatten_with_joined_string_paths(structure, separator="/",
1593                                     expand_composites=False):
1594  """Returns a list of (string path, data element) tuples.
1595
1596  The order of tuples produced matches that of `nest.flatten`. This allows you
1597  to flatten a nested structure while keeping information about where in the
1598  structure each data element was located. See `nest.yield_flat_paths`
1599  for more information.
1600
1601  Args:
1602    structure: the nested structure to flatten.
1603    separator: string to separate levels of hierarchy in the results, defaults
1604      to '/'.
1605    expand_composites: If true, then composite tensors such as
1606      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1607      component tensors.
1608
1609  Returns:
1610    A list of (string, data element) tuples.
1611  """
1612  flat_paths = yield_flat_paths(structure, expand_composites=expand_composites)
1613  def stringify_and_join(path_elements):
1614    return separator.join(str(path_element) for path_element in path_elements)
1615
1616  flat_string_paths = (stringify_and_join(path) for path in flat_paths)
1617  return list(zip(flat_string_paths,
1618                  flatten(structure, expand_composites=expand_composites)))
1619
1620
1621def flatten_with_tuple_paths(structure, expand_composites=False):
1622  """Returns a list of `(tuple_path, leaf_element)` tuples.
1623
1624  The order of pairs produced matches that of `nest.flatten`. This allows you
1625  to flatten a nested structure while keeping information about where in the
1626  structure each data element was located. See `nest.yield_flat_paths`
1627  for more information about tuple paths.
1628
1629  Args:
1630    structure: the nested structure to flatten.
1631    expand_composites: If true, then composite tensors such as
1632      `tf.sparse.SparseTensor` and `tf.RaggedTensor` are expanded into their
1633      component tensors.
1634
1635  Returns:
1636    A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple
1637    of indices and/or dictionary keys that uniquely specify the path to
1638    `leaf_element` within `structure`.
1639  """
1640  return list(zip(yield_flat_paths(structure,
1641                                   expand_composites=expand_composites),
1642                  flatten(structure, expand_composites=expand_composites)))
1643
1644
1645@tf_export("__internal__.nest.list_to_tuple", v1=[])
1646def list_to_tuple(structure):
1647  """Replace all lists with tuples.
1648
1649  The fork of nest that tf.data uses treats lists as single elements, while
1650  tf.nest treats them as structures to recurse into. Keras has chosen to adopt
1651  the latter convention, and must therefore deeply replace all lists with tuples
1652  before passing structures to Dataset.from_generator.
1653
1654  Args:
1655    structure: A nested structure to be remapped.
1656
1657  Returns:
1658    structure mapped to replace all lists with tuples.
1659  """
1660  def sequence_fn(instance, args):
1661    if isinstance(instance, list):
1662      return tuple(args)
1663    return _sequence_like(instance, args)
1664
1665  return _pack_sequence_as(structure, flatten(structure), False,
1666                           sequence_fn=sequence_fn)
1667
1668
1669_pywrap_utils.RegisterType("Mapping", _collections_abc.Mapping)
1670_pywrap_utils.RegisterType("MutableMapping", _collections_abc.MutableMapping)
1671_pywrap_utils.RegisterType("Sequence", _collections_abc.Sequence)
1672_pywrap_utils.RegisterType("MappingView", _collections_abc.MappingView)
1673_pywrap_utils.RegisterType("ObjectProxy", _wrapt.ObjectProxy)
1674