1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""## Functions for working with arbitrarily nested sequences of elements.
17
18This module can perform operations on nested structures. A nested structure is a
19Python sequence, tuple (including `namedtuple`), or dict that can contain
20further sequences, tuples, and dicts.
21
22attr.s decorated classes (http://www.attrs.org) are also supported, in the
23same way as `namedtuple`.
24
25The utilities here assume (and do not check) that the nested structures form a
26'tree', i.e., no references in the structure of the input of these functions
27should be recursive.
28
29Example structures: `((3, 4), 5, (6, 7, (9, 10), 8))`, `(np.array(0),
30  (np.array([3, 4]), tf.constant([3, 4])))`
31"""
32
33from __future__ import absolute_import
34from __future__ import division
35from __future__ import print_function
36
37import collections as _collections
38
39import six as _six
40
41from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
42from tensorflow.python.util.tf_export import tf_export
43
44
45_SHALLOW_TREE_HAS_INVALID_KEYS = (
46    "The shallow_tree's keys are not a subset of the input_tree's keys. The "
47    "shallow_tree has the following keys that are not in the input_tree: {}.")
48
49_STRUCTURES_HAVE_MISMATCHING_TYPES = (
50    "The two structures don't have the same sequence type. Input structure has "
51    "type {shallow_type}, while shallow structure has type {input_type}.")
52
53_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE = (
54    "The input_tree has fewer elements than the input_tree. Input structure "
55    "has length {input_size}, while shallow structure has length "
56    "{shallow_size}.")
57
58_IF_SHALLOW_IS_SEQ_INPUT_MUST_BE_SEQ = (
59    "If shallow structure is a sequence, input must also be a sequence. "
60    "Input has type: {}.")
61
62
63def _get_attrs_items(obj):
64  """Returns a list of (name, value) pairs from an attrs instance.
65
66  The list will be sorted by name.
67
68  Args:
69    obj: an object.
70
71  Returns:
72    A list of (attr_name, attr_value) pairs, sorted by attr_name.
73  """
74  attrs = getattr(obj.__class__, "__attrs_attrs__")
75  attr_names = sorted([a.name for a in attrs])
76  return [(attr_name, getattr(obj, attr_name)) for attr_name in attr_names]
77
78
79def _sorted(dict_):
80  """Returns a sorted list of the dict keys, with error if keys not sortable."""
81  try:
82    return sorted(dict_)
83  except TypeError:
84    raise TypeError("nest only supports dicts with sortable keys.")
85
86
87def _is_namedtuple(instance, strict=False):
88  """Returns True iff `instance` is a `namedtuple`.
89
90  Args:
91    instance: An instance of a Python object.
92    strict: If True, `instance` is considered to be a `namedtuple` only if
93        it is a "plain" namedtuple. For instance, a class inheriting
94        from a `namedtuple` will be considered to be a `namedtuple`
95        iff `strict=False`.
96
97  Returns:
98    True if `instance` is a `namedtuple`.
99  """
100  return _pywrap_tensorflow.IsNamedtuple(instance, strict)
101
102
103# See the swig file (util.i) for documentation.
104_is_mapping = _pywrap_tensorflow.IsMapping
105_is_attrs = _pywrap_tensorflow.IsAttrs
106_is_composite_tensor = _pywrap_tensorflow.IsCompositeTensor
107
108
109def _sequence_like(instance, args):
110  """Converts the sequence `args` to the same type as `instance`.
111
112  Args:
113    instance: an instance of `tuple`, `list`, `namedtuple`, `dict`, or
114        `collections.OrderedDict`.
115    args: elements to be converted to the `instance` type.
116
117  Returns:
118    `args` with the type of `instance`.
119  """
120  if _is_mapping(instance):
121    # Pack dictionaries in a deterministic order by sorting the keys.
122    # Notice this means that we ignore the original order of `OrderedDict`
123    # instances. This is intentional, to avoid potential bugs caused by mixing
124    # ordered and plain dicts (e.g., flattening a dict but using a
125    # corresponding `OrderedDict` to pack it back).
126    result = dict(zip(_sorted(instance), args))
127    return type(instance)((key, result[key]) for key in instance)
128  elif _is_namedtuple(instance) or _is_attrs(instance):
129    return type(instance)(*args)
130  elif _is_composite_tensor(instance):
131    return instance._from_components(args)  # pylint: disable=protected-access
132  else:
133    # Not a namedtuple
134    return type(instance)(args)
135
136
137def _yield_value(iterable):
138  for _, v in _yield_sorted_items(iterable):
139    yield v
140
141
142def _yield_sorted_items(iterable):
143  """Yield (key, value) pairs for `iterable` in a deterministic order.
144
145  For Sequences, the key will be an int, the array index of a value.
146  For Mappings, the key will be the dictionary key.
147  For objects (e.g. namedtuples), the key will be the attribute name.
148
149  In all cases, the keys will be iterated in sorted order.
150
151  Args:
152    iterable: an iterable.
153
154  Yields:
155    The iterable's (key, value) pairs, in order of sorted keys.
156  """
157  if isinstance(iterable, _collections.Mapping):
158    # Iterate through dictionaries in a deterministic order by sorting the
159    # keys. Notice this means that we ignore the original order of `OrderedDict`
160    # instances. This is intentional, to avoid potential bugs caused by mixing
161    # ordered and plain dicts (e.g., flattening a dict but using a
162    # corresponding `OrderedDict` to pack it back).
163    for key in _sorted(iterable):
164      yield key, iterable[key]
165  elif _is_attrs(iterable):
166    for item in _get_attrs_items(iterable):
167      yield item
168  elif _is_namedtuple(iterable):
169    for field in iterable._fields:
170      yield field, getattr(iterable, field)
171  elif _is_composite_tensor(iterable):
172    for item in enumerate(iterable._to_components()):  # pylint: disable=protected-access
173      yield item
174  else:
175    for item in enumerate(iterable):
176      yield item
177
178
179# See the swig file (util.i) for documentation.
180is_sequence = _pywrap_tensorflow.IsSequence
181
182
183# See the swig file (util.i) for documentation.
184is_sequence_or_composite = _pywrap_tensorflow.IsSequenceOrComposite
185
186
187@tf_export("nest.is_nested")
188def is_nested(seq):
189  """Returns true if its input is a collections.Sequence (except strings).
190
191  Args:
192    seq: an input sequence.
193
194  Returns:
195    True if the sequence is a not a string and is a collections.Sequence or a
196    dict.
197  """
198  return is_sequence(seq)
199
200
201@tf_export("nest.flatten")
202def flatten(structure, expand_composites=False):
203  """Returns a flat list from a given nested structure.
204
205  If nest is not a sequence, tuple, or dict, then returns a single-element list:
206  [nest].
207
208  In the case of dict instances, the sequence consists of the values, sorted by
209  key to ensure deterministic behavior. This is true also for OrderedDict
210  instances: their sequence order is ignored, the sorting order of keys is used
211  instead. The same convention is followed in pack_sequence_as. This correctly
212  repacks dicts and OrderedDicts after they have been flattened, and also allows
213  flattening an OrderedDict and then repacking it back using a corresponding
214  plain dict, or vice-versa. Dictionaries with non-sortable keys cannot be
215  flattened.
216
217  Users must not modify any collections used in nest while this function is
218  running.
219
220  Args:
221    structure: an arbitrarily nested structure or a scalar object. Note, numpy
222      arrays are considered scalars.
223    expand_composites: If true, then composite tensors such as tf.SparseTensor
224       and tf.RaggedTensor are expanded into their component tensors.
225
226  Returns:
227    A Python list, the flattened version of the input.
228
229  Raises:
230    TypeError: The nest is or contains a dict with non-sortable keys.
231  """
232  return _pywrap_tensorflow.Flatten(structure, expand_composites)
233
234
235# See the swig file (util.i) for documentation.
236_same_namedtuples = _pywrap_tensorflow.SameNamedtuples
237
238
239class _DotString(object):
240
241  def __str__(self):
242    return "."
243
244  def __repr__(self):
245    return "."
246
247
248_DOT = _DotString()
249
250
251@tf_export("nest.assert_same_structure")
252def assert_same_structure(nest1, nest2, check_types=True,
253                          expand_composites=False):
254  """Asserts that two structures are nested in the same way.
255
256  Note that namedtuples with identical name and fields are always considered
257  to have the same shallow structure (even with `check_types=True`).
258  For instance, this code will print `True`:
259
260  ```python
261  def nt(a, b):
262    return collections.namedtuple('foo', 'a b')(a, b)
263  print(assert_same_structure(nt(0, 1), nt(2, 3)))
264  ```
265
266  Args:
267    nest1: an arbitrarily nested structure.
268    nest2: an arbitrarily nested structure.
269    check_types: if `True` (default) types of sequences are checked as well,
270        including the keys of dictionaries. If set to `False`, for example a
271        list and a tuple of objects will look the same if they have the same
272        size. Note that namedtuples with identical name and fields are always
273        considered to have the same shallow structure. Two types will also be
274        considered the same if they are both list subtypes (which allows "list"
275        and "_ListWrapper" from trackable dependency tracking to compare
276        equal).
277    expand_composites: If true, then composite tensors such as `tf.SparseTensor`
278        and `tf.RaggedTensor` are expanded into their component tensors.
279
280  Raises:
281    ValueError: If the two structures do not have the same number of elements or
282      if the two structures are not nested in the same way.
283    TypeError: If the two structures differ in the type of sequence in any of
284      their substructures. Only possible if `check_types` is `True`.
285  """
286  try:
287    _pywrap_tensorflow.AssertSameStructure(nest1, nest2, check_types,
288                                           expand_composites)
289  except (ValueError, TypeError) as e:
290    str1 = str(map_structure(lambda _: _DOT, nest1))
291    str2 = str(map_structure(lambda _: _DOT, nest2))
292    raise type(e)("%s\n"
293                  "Entire first structure:\n%s\n"
294                  "Entire second structure:\n%s"
295                  % (str(e), str1, str2))
296
297
298def flatten_dict_items(dictionary):
299  """Returns a dictionary with flattened keys and values.
300
301  This function flattens the keys and values of a dictionary, which can be
302  arbitrarily nested structures, and returns the flattened version of such
303  structures:
304
305  ```python
306  example_dictionary = {(4, 5, (6, 8)): ("a", "b", ("c", "d"))}
307  result = {4: "a", 5: "b", 6: "c", 8: "d"}
308  flatten_dict_items(example_dictionary) == result
309  ```
310
311  The input dictionary must satisfy two properties:
312
313  1. Its keys and values should have the same exact nested structure.
314  2. The set of all flattened keys of the dictionary must not contain repeated
315     keys.
316
317  Args:
318    dictionary: the dictionary to zip
319
320  Returns:
321    The zipped dictionary.
322
323  Raises:
324    TypeError: If the input is not a dictionary.
325    ValueError: If any key and value do not have the same structure layout, or
326    if keys are not unique.
327  """
328  if not isinstance(dictionary, (dict, _collections.Mapping)):
329    raise TypeError("input must be a dictionary")
330  flat_dictionary = {}
331  for i, v in _six.iteritems(dictionary):
332    if not is_sequence(i):
333      if i in flat_dictionary:
334        raise ValueError(
335            "Could not flatten dictionary: key %s is not unique." % i)
336      flat_dictionary[i] = v
337    else:
338      flat_i = flatten(i)
339      flat_v = flatten(v)
340      if len(flat_i) != len(flat_v):
341        raise ValueError(
342            "Could not flatten dictionary. Key had %d elements, but value had "
343            "%d elements. Key: %s, value: %s."
344            % (len(flat_i), len(flat_v), flat_i, flat_v))
345      for new_i, new_v in zip(flat_i, flat_v):
346        if new_i in flat_dictionary:
347          raise ValueError(
348              "Could not flatten dictionary: key %s is not unique."
349              % (new_i))
350        flat_dictionary[new_i] = new_v
351  return flat_dictionary
352
353
354def _packed_nest_with_indices(structure, flat, index, is_seq):
355  """Helper function for pack_sequence_as.
356
357  Args:
358    structure: Substructure (list / tuple / dict) to mimic.
359    flat: Flattened values to output substructure for.
360    index: Index at which to start reading from flat.
361    is_seq: Function used to test if a value should be treated as a sequence.
362
363  Returns:
364    The tuple (new_index, child), where:
365      * new_index - the updated index into `flat` having processed `structure`.
366      * packed - the subset of `flat` corresponding to `structure`,
367                 having started at `index`, and packed into the same nested
368                 format.
369
370  Raises:
371    ValueError: if `structure` contains more elements than `flat`
372      (assuming indexing starts from `index`).
373  """
374  packed = []
375  for s in _yield_value(structure):
376    if is_seq(s):
377      new_index, child = _packed_nest_with_indices(s, flat, index, is_seq)
378      packed.append(_sequence_like(s, child))
379      index = new_index
380    else:
381      packed.append(flat[index])
382      index += 1
383  return index, packed
384
385
386@tf_export("nest.pack_sequence_as")
387def pack_sequence_as(structure, flat_sequence, expand_composites=False):
388  """Returns a given flattened sequence packed into a given structure.
389
390  If `structure` is a scalar, `flat_sequence` must be a single-element list;
391  in this case the return value is `flat_sequence[0]`.
392
393  If `structure` is or contains a dict instance, the keys will be sorted to
394  pack the flat sequence in deterministic order. This is true also for
395  `OrderedDict` instances: their sequence order is ignored, the sorting order of
396  keys is used instead. The same convention is followed in `flatten`.
397  This correctly repacks dicts and `OrderedDict`s after they have been
398  flattened, and also allows flattening an `OrderedDict` and then repacking it
399  back using a corresponding plain dict, or vice-versa.
400  Dictionaries with non-sortable keys cannot be flattened.
401
402  Args:
403    structure: Nested structure, whose structure is given by nested lists,
404        tuples, and dicts. Note: numpy arrays and strings are considered
405        scalars.
406    flat_sequence: flat sequence to pack.
407    expand_composites: If true, then composite tensors such as `tf.SparseTensor`
408        and `tf.RaggedTensor` are expanded into their component tensors.
409
410  Returns:
411    packed: `flat_sequence` converted to have the same recursive structure as
412      `structure`.
413
414  Raises:
415    ValueError: If `flat_sequence` and `structure` have different
416      element counts.
417    TypeError: `structure` is or contains a dict with non-sortable keys.
418  """
419  is_seq = is_sequence_or_composite if expand_composites else is_sequence
420  if not is_seq(flat_sequence):
421    raise TypeError("flat_sequence must be a sequence")
422
423  if not is_seq(structure):
424    if len(flat_sequence) != 1:
425      raise ValueError("Structure is a scalar but len(flat_sequence) == %d > 1"
426                       % len(flat_sequence))
427    return flat_sequence[0]
428
429  try:
430    final_index, packed = _packed_nest_with_indices(structure, flat_sequence,
431                                                    0, is_seq)
432    if final_index < len(flat_sequence):
433      raise IndexError
434  except IndexError:
435    flat_structure = flatten(structure)
436    if len(flat_structure) != len(flat_sequence):
437      raise ValueError(
438          "Could not pack sequence. Structure had %d elements, but "
439          "flat_sequence had %d elements.  Structure: %s, flat_sequence: %s." %
440          (len(flat_structure), len(flat_sequence), structure, flat_sequence))
441  return _sequence_like(structure, packed)
442
443
444@tf_export("nest.map_structure")
445def map_structure(func, *structure, **kwargs):
446  """Applies `func` to each entry in `structure` and returns a new structure.
447
448  Applies `func(x[0], x[1], ...)` where x[i] is an entry in
449  `structure[i]`.  All structures in `structure` must have the same arity,
450  and the return value will contain results with the same structure layout.
451
452  Args:
453    func: A callable that accepts as many arguments as there are structures.
454    *structure: scalar, or tuple or list of constructed scalars and/or other
455      tuples/lists, or scalars.  Note: numpy arrays are considered as scalars.
456    **kwargs: Valid keyword args are:
457
458      * `check_types`: If set to `True` (default) the types of
459        iterables within the structures have to be same (e.g.
460        `map_structure(func, [1], (1,))` raises a `TypeError`
461        exception). To allow this set this argument to `False`.
462        Note that namedtuples with identical name and fields are always
463        considered to have the same shallow structure.
464      * `expand_composites`: If set to `True`, then composite tensors such
465        as `tf.SparseTensor` and `tf.RaggedTensor` are expanded into their
466        component tensors.  If `False` (the default), then composite tensors
467        are not expanded.
468
469  Returns:
470    A new structure with the same arity as `structure`, whose values correspond
471    to `func(x[0], x[1], ...)` where `x[i]` is a value in the corresponding
472    location in `structure[i]`. If there are different sequence types and
473    `check_types` is `False` the sequence types of the first structure will be
474    used.
475
476  Raises:
477    TypeError: If `func` is not callable or if the structures do not match
478      each other by depth tree.
479    ValueError: If no structure is provided or if the structures do not match
480      each other by type.
481    ValueError: If wrong keyword arguments are provided.
482  """
483  if not callable(func):
484    raise TypeError("func must be callable, got: %s" % func)
485
486  if not structure:
487    raise ValueError("Must provide at least one structure")
488
489  check_types = True
490  expand_composites = False
491  if kwargs:
492    check_types = kwargs.pop("check_types", check_types)
493    expand_composites = kwargs.pop("expand_composites", expand_composites)
494    if kwargs:
495      raise ValueError("Only valid keyword arguments are check_types "
496                       "and expand_composites")
497
498  for other in structure[1:]:
499    assert_same_structure(structure[0], other, check_types=check_types,
500                          expand_composites=expand_composites)
501
502  flat_structure = [flatten(s, expand_composites) for s in structure]
503  entries = zip(*flat_structure)
504
505  return pack_sequence_as(
506      structure[0], [func(*x) for x in entries],
507      expand_composites=expand_composites)
508
509
510def map_structure_with_paths(func, *structure, **kwargs):
511  """Applies `func` to each entry in `structure` and returns a new structure.
512
513  Applies `func(path, x[0], x[1], ..., **kwargs)` where x[i] is an entry in
514  `structure[i]` and `path` is the common path to x[i] in the structures.  All
515  structures in `structure` must have the same arity, and the return value will
516  contain the results with the same structure layout. Special kwarg
517  `check_types` determines whether the types of iterables within the structure
518  must be the same-- see **kwargs definition below.
519
520  Args:
521    func: A callable with the signature func(path, *values, **kwargs) that is
522      evaluated on the leaves of the structure.
523    *structure: A variable number of compatible structures to process.
524    **kwargs: Optional kwargs to be passed through to func. Special kwarg
525      `check_types` is not passed to func, but instead determines whether the
526      types of iterables within the structures have to be same (e.g.,
527      `map_structure(func, [1], (1,))` raises a `TypeError` exception). By
528      default, the types must match. To allow iteration over structures of
529      different types (but common arity), set this kwarg to `False`.
530
531  Returns:
532    A structure of the same form as the input structures whose leaves are the
533    result of evaluating func on corresponding leaves of the input structures.
534
535  Raises:
536    TypeError: If `func` is not callable or if the structures do not match
537      each other by depth tree.
538    TypeError: If `check_types` is not `False` and the two structures differ in
539      the type of sequence in any of their substructures.
540    ValueError: If no structures are provided.
541  """
542  def wrapper_func(tuple_path, *inputs, **kwargs):
543    string_path = "/".join(str(s) for s in tuple_path)
544    return func(string_path, *inputs, **kwargs)
545
546  return map_structure_with_tuple_paths_up_to(structure[0],
547                                              wrapper_func,
548                                              *structure,
549                                              **kwargs)
550
551
552def map_structure_with_tuple_paths(func, *structure, **kwargs):
553  """Applies `func` to each entry in `structure` and returns a new structure.
554
555  Applies `func(tuple_path, x[0], x[1], ..., **kwargs)` where `x[i]` is an entry
556  in `structure[i]` and `tuple_path` is a tuple of indices and/or dictionary
557  keys (as returned by `nest.yield_flat_paths`), which uniquely specifies the
558  common path to x[i] in the structures. All structures in `structure` must have
559  the same arity, and the return value will contain the results in the same
560  structure. Special kwarg `check_types` determines whether the types of
561  iterables within the structure must be the same-- see **kwargs definition
562  below.
563
564  Args:
565    func: A callable with the signature `func(tuple_path, *values, **kwargs)`
566      that is evaluated on the leaves of the structure.
567    *structure: A variable number of compatible structures to process.
568    **kwargs: Optional kwargs to be passed through to func. Special kwarg
569      `check_types` is not passed to func, but instead determines whether the
570      types of iterables within the structures have to be same (e.g.
571      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
572      this set this argument to `False`.
573
574  Returns:
575    A structure of the same form as the input structures whose leaves are the
576    result of evaluating func on corresponding leaves of the input structures.
577
578  Raises:
579    TypeError: If `func` is not callable or if the structures do not match
580      each other by depth tree.
581    TypeError: If `check_types` is not `False` and the two structures differ in
582      the type of sequence in any of their substructures.
583    ValueError: If no structures are provided.
584  """
585  return map_structure_with_tuple_paths_up_to(structure[0],
586                                              func,
587                                              *structure,
588                                              **kwargs)
589
590
591def _yield_flat_up_to(shallow_tree, input_tree, path=()):
592  """Yields (path, value) pairs of input_tree flattened up to shallow_tree.
593
594  Args:
595    shallow_tree: Nested structure. Traverse no further than its leaf nodes.
596    input_tree: Nested structure. Return the paths and values from this tree.
597      Must have the same upper structure as shallow_tree.
598    path: Tuple. Optional argument, only used when recursing. The path from the
599      root of the original shallow_tree, down to the root of the shallow_tree
600      arg of this recursive call.
601
602  Yields:
603    Pairs of (path, value), where path the tuple path of a leaf node in
604    shallow_tree, and value is the value of the corresponding node in
605    input_tree.
606  """
607  if (isinstance(shallow_tree, _six.string_types) or
608      not any([isinstance(shallow_tree, _collections.Sequence),
609               isinstance(shallow_tree, _collections.Mapping),
610               _is_namedtuple(shallow_tree),
611               _is_attrs(shallow_tree)])):
612    yield (path, input_tree)
613  else:
614    input_tree = dict(_yield_sorted_items(input_tree))
615    for shallow_key, shallow_subtree in _yield_sorted_items(shallow_tree):
616      subpath = path + (shallow_key,)
617      input_subtree = input_tree[shallow_key]
618      for leaf_path, leaf_value in _yield_flat_up_to(shallow_subtree,
619                                                     input_subtree,
620                                                     path=subpath):
621        yield (leaf_path, leaf_value)
622
623
624def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
625  """Asserts that `shallow_tree` is a shallow structure of `input_tree`.
626
627  That is, this function tests if the `input_tree` structure can be created from
628  the `shallow_tree` structure by replacing its leaf nodes with deeper
629  tree structures.
630
631  Examples:
632
633  The following code will raise an exception:
634  ```python
635    shallow_tree = {"a": "A", "b": "B"}
636    input_tree = {"a": 1, "c": 2}
637    assert_shallow_structure(shallow_tree, input_tree)
638  ```
639
640  The following code will not raise an exception:
641  ```python
642    shallow_tree = ["a", "b"]
643    input_tree = ["c", ["d", "e"], "f"]
644    assert_shallow_structure(shallow_tree, input_tree)
645  ```
646
647  Args:
648    shallow_tree: an arbitrarily nested structure.
649    input_tree: an arbitrarily nested structure.
650    check_types: if `True` (default) the sequence types of `shallow_tree` and
651      `input_tree` have to be the same. Note that even with check_types==True,
652      this function will consider two different namedtuple classes with the same
653      name and _fields attribute to be the same class.
654
655  Raises:
656    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
657    TypeError: If the sequence types of `shallow_tree` are different from
658      `input_tree`. Only raised if `check_types` is `True`.
659    ValueError: If the sequence lengths of `shallow_tree` are different from
660      `input_tree`.
661  """
662  if is_sequence(shallow_tree):
663    if not is_sequence(input_tree):
664      raise TypeError(
665          "If shallow structure is a sequence, input must also be a sequence. "
666          "Input has type: %s." % type(input_tree))
667
668    if check_types and not isinstance(input_tree, type(shallow_tree)):
669      # Duck-typing means that nest should be fine with two different
670      # namedtuples with identical name and fields.
671      shallow_is_namedtuple = _is_namedtuple(shallow_tree, False)
672      input_is_namedtuple = _is_namedtuple(input_tree, False)
673      if shallow_is_namedtuple and input_is_namedtuple:
674        if not _same_namedtuples(shallow_tree, input_tree):
675          raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
676              input_type=type(input_tree),
677              shallow_type=type(shallow_tree)))
678
679      elif not (isinstance(shallow_tree, _collections.Mapping)
680                and isinstance(input_tree, _collections.Mapping)):
681        raise TypeError(_STRUCTURES_HAVE_MISMATCHING_TYPES.format(
682            input_type=type(input_tree),
683            shallow_type=type(shallow_tree)))
684
685    if len(input_tree) < len(shallow_tree):
686      raise ValueError(_INPUT_TREE_SMALLER_THAN_SHALLOW_TREE.format(
687          input_size=len(input_tree),
688          shallow_size=len(shallow_tree)))
689
690    if isinstance(shallow_tree, _collections.Mapping):
691      absent_keys = set(shallow_tree) - set(input_tree)
692      if absent_keys:
693        raise ValueError(_SHALLOW_TREE_HAS_INVALID_KEYS
694                         .format(sorted(absent_keys)))
695
696    for shallow_branch, input_branch in zip(_yield_value(shallow_tree),
697                                            _yield_value(input_tree)):
698      assert_shallow_structure(shallow_branch, input_branch,
699                               check_types=check_types)
700
701
702def flatten_up_to(shallow_tree, input_tree, check_types=True):
703  """Flattens `input_tree` up to `shallow_tree`.
704
705  Any further depth in structure in `input_tree` is retained as elements in the
706  partially flatten output.
707
708  If `shallow_tree` and `input_tree` are not sequences, this returns a
709  single-element list: `[input_tree]`.
710
711  Use Case:
712
713  Sometimes we may wish to partially flatten a nested sequence, retaining some
714  of the nested structure. We achieve this by specifying a shallow structure,
715  `shallow_tree`, we wish to flatten up to.
716
717  The input, `input_tree`, can be thought of as having the same structure layout
718  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
719
720  Examples:
721
722  ```python
723  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
724  shallow_tree = [[True, True], [False, True]]
725
726  flattened_input_tree = flatten_up_to(shallow_tree, input_tree)
727  flattened_shallow_tree = flatten_up_to(shallow_tree, shallow_tree)
728
729  # Output is:
730  # [[2, 2], [3, 3], [4, 9], [5, 5]]
731  # [True, True, False, True]
732  ```
733
734  ```python
735  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
736  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
737
738  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
739  input_tree_flattened = flatten(input_tree)
740
741  # Output is:
742  # [('a', 1), ('b', 2), ('c', 3), ('d', 4)]
743  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
744  ```
745
746  Non-Sequence Edge Cases:
747
748  ```python
749  flatten_up_to(0, 0)  # Output: [0]
750  flatten_up_to(0, [0, 1, 2])  # Output: [[0, 1, 2]]
751  flatten_up_to([0, 1, 2], 0)  # Output: TypeError
752  flatten_up_to([0, 1, 2], [0, 1, 2])  # Output: [0, 1, 2]
753  ```
754
755  Args:
756    shallow_tree: a possibly pruned structure of input_tree.
757    input_tree: an arbitrarily nested structure or a scalar object.
758      Note, numpy arrays are considered scalars.
759    check_types: bool. If True, check that each node in shallow_tree has the
760      same type as the corresponding node in input_tree.
761
762  Returns:
763    A Python list, the partially flattened version of `input_tree` according to
764    the structure of `shallow_tree`.
765
766  Raises:
767    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
768    TypeError: If the sequence types of `shallow_tree` are different from
769      `input_tree`.
770    ValueError: If the sequence lengths of `shallow_tree` are different from
771      `input_tree`.
772  """
773  assert_shallow_structure(shallow_tree, input_tree, check_types)
774  # Discard paths returned by _yield_flat_up_to.
775  return list(v for _, v in _yield_flat_up_to(shallow_tree, input_tree))
776
777
778def flatten_with_tuple_paths_up_to(shallow_tree, input_tree, check_types=True):
779  """Flattens `input_tree` up to `shallow_tree`.
780
781  Any further depth in structure in `input_tree` is retained as elements in the
782  partially flattened output.
783
784  Returns a list of (path, value) pairs, where value a leaf node in the
785  flattened tree, and path is the tuple path of that leaf in input_tree.
786
787  If `shallow_tree` and `input_tree` are not sequences, this returns a
788  single-element list: `[((), input_tree)]`.
789
790  Use Case:
791
792  Sometimes we may wish to partially flatten a nested sequence, retaining some
793  of the nested structure. We achieve this by specifying a shallow structure,
794  `shallow_tree`, we wish to flatten up to.
795
796  The input, `input_tree`, can be thought of as having the same structure layout
797  as `shallow_tree`, but with leaf nodes that are themselves tree structures.
798
799  Examples:
800
801  ```python
802  input_tree = [[[2, 2], [3, 3]], [[4, 9], [5, 5]]]
803  shallow_tree = [[True, True], [False, True]]
804
805  flattened_input_tree = flatten_with_tuple_paths_up_to(shallow_tree,
806                                                        input_tree)
807  flattened_shallow_tree = flatten_with_tuple_paths_up_to(shallow_tree,
808                                                          shallow_tree)
809
810  # Output is:
811  # [((0, 0), [2, 2]),
812  #  ((0, 1), [3, 3]),
813  #  ((1, 0), [4, 9]),
814  #  ((1, 1), [5, 5])]
815  #
816  # [((0, 0), True),
817  #  ((0, 1), True),
818  #  ((1, 0), False),
819  #  ((1, 1), True)]
820  ```
821
822  ```python
823  input_tree = [[('a', 1), [('b', 2), [('c', 3), [('d', 4)]]]]]
824  shallow_tree = [['level_1', ['level_2', ['level_3', ['level_4']]]]]
825
826  input_tree_flattened_as_shallow_tree = flatten_up_to(shallow_tree, input_tree)
827  input_tree_flattened = flatten(input_tree)
828
829  # Output is:
830  # [((0, 0), ('a', 1)),
831  #  ((0, 1, 0), ('b', 2)),
832  #  ((0, 1, 1, 0), ('c', 3)),
833  #  ((0, 1, 1, 1), ('d', 4))]
834  # ['a', 1, 'b', 2, 'c', 3, 'd', 4]
835  ```
836
837  Non-Sequence Edge Cases:
838
839  ```python
840  flatten_with_tuple_paths_up_to(0, 0)  # Output: [(), 0]
841
842  flatten_with_tuple_paths_up_to(0, [0, 1, 2])  # Output: [(), [0, 1, 2]]
843
844  flatten_with_tuple_paths_up_to([0, 1, 2], 0)  # Output: TypeError
845
846  flatten_with_tuple_paths_up_to([0, 1, 2], [0, 1, 2])
847  # Output: [((0,) 0), ((1,), 1), ((2,), 2)]
848  ```
849
850  Args:
851    shallow_tree: a possibly pruned structure of input_tree.
852    input_tree: an arbitrarily nested structure or a scalar object.
853      Note, numpy arrays are considered scalars.
854    check_types: bool. If True, check that each node in shallow_tree has the
855      same type as the corresponding node in input_tree.
856
857  Returns:
858    A Python list, the partially flattened version of `input_tree` according to
859    the structure of `shallow_tree`.
860
861  Raises:
862    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
863    TypeError: If the sequence types of `shallow_tree` are different from
864      `input_tree`.
865    ValueError: If the sequence lengths of `shallow_tree` are different from
866      `input_tree`.
867  """
868  assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
869  return list(_yield_flat_up_to(shallow_tree, input_tree))
870
871
872def map_structure_up_to(shallow_tree, func, *inputs, **kwargs):
873  """Applies a function or op to a number of partially flattened inputs.
874
875  The `inputs` are flattened up to `shallow_tree` before being mapped.
876
877  Use Case:
878
879  Sometimes we wish to apply a function to a partially flattened
880  sequence (for example when the function itself takes sequence inputs). We
881  achieve this by specifying a shallow structure, `shallow_tree` we wish to
882  flatten up to.
883
884  The `inputs`, can be thought of as having the same structure layout as
885  `shallow_tree`, but with leaf nodes that are themselves tree structures.
886
887  This function therefore will return something with the same base structure as
888  `shallow_tree`.
889
890  Examples:
891
892  ```python
893  shallow_tree = [None, None]
894  inp_val = [1, 2, 3]
895  out = map_structure_up_to(shallow_tree, lambda x: 2 * x, inp_val)
896
897  # Output is: [2, 4]
898  ```
899
900  ```python
901  ab_tuple = collections.namedtuple("ab_tuple", "a, b")
902  op_tuple = collections.namedtuple("op_tuple", "add, mul")
903  inp_val = ab_tuple(a=2, b=3)
904  inp_ops = ab_tuple(a=op_tuple(add=1, mul=2), b=op_tuple(add=2, mul=3))
905  out = map_structure_up_to(inp_val, lambda val, ops: (val + ops.add) * ops.mul,
906                            inp_val, inp_ops)
907
908  # Output is: ab_tuple(a=6, b=15)
909  ```
910
911  ```python
912  data_list = [[2, 4, 6, 8], [[1, 3, 5, 7, 9], [3, 5, 7]]]
913  name_list = ['evens', ['odds', 'primes']]
914  out = map_structure_up_to(
915      name_list,
916      lambda name, sec: "first_{}_{}".format(len(sec), name),
917      name_list, data_list)
918
919  # Output is: ['first_4_evens', ['first_5_odds', 'first_3_primes']]
920  ```
921
922  Args:
923    shallow_tree: a shallow tree, common to all the inputs.
924    func: callable which will be applied to each input individually.
925    *inputs: arbitrarily nested combination of objects that are compatible with
926        shallow_tree. The function `func` is applied to corresponding
927        partially flattened elements of each input, so the function must support
928        arity of `len(inputs)`.
929    **kwargs: kwargs to feed to func(). Special kwarg
930      `check_types` is not passed to func, but instead determines whether the
931      types of iterables within the structures have to be same (e.g.
932      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
933      this set this argument to `False`.
934
935  Raises:
936    TypeError: If `shallow_tree` is a sequence but `input_tree` is not.
937    TypeError: If the sequence types of `shallow_tree` are different from
938      `input_tree`.
939    ValueError: If the sequence lengths of `shallow_tree` are different from
940      `input_tree`.
941
942  Returns:
943    result of repeatedly applying `func`, with the same structure layout as
944    `shallow_tree`.
945  """
946  return map_structure_with_tuple_paths_up_to(
947      shallow_tree,
948      lambda _, *values: func(*values),  # Discards the path arg.
949      *inputs,
950      **kwargs)
951
952
953def map_structure_with_tuple_paths_up_to(shallow_tree, func, *inputs, **kwargs):
954  """Applies a function or op to a number of partially flattened inputs.
955
956  Like map_structure_up_to(), except that the 'func' argument takes a path
957  tuple as its first argument, followed by the corresponding values from
958  *inputs.
959
960  Example:
961
962  lowercase = {'a': 'a', 'b': ('b0', 'b1')}
963  uppercase = {'a': 'A', 'b': ('B0', 'B1')}
964
965  def print_path_and_values(path, *values):
966    print("path: {}, values: {}".format(path, values))
967
968  shallow_tree = {'a': None}
969  map_structure_with_tuple_paths_up_to(shallow_tree,
970                                       print_path_and_values,
971                                       lowercase,
972                                       uppercase)
973  >>> path: ('a',), values: ('a', 'A')
974  >>> path: ('b', 0), values: ('b0', 'B0')
975  >>> path: ('b', 1), values: ('b1', 'B1')
976
977  shallow_tree = {'b': None}
978  map_structure_with_tuple_paths_up_to(shallow_tree,
979                                       print_path_and_values,
980                                       lowercase,
981                                       uppercase,
982                                       check_types=False)
983  >>> path: ('b', 1), values: (('bo', 'b1'), ('B0', 'B1'))
984
985  shallow_tree = {'a': None, 'b': {1: None}}
986  map_structure_with_tuple_paths_up_to(shallow_tree,
987                                       print_path_and_values,
988                                       lowercase,
989                                       uppercase,
990                                       check_types=False)
991  >>> path: ('a',), values: ('a', 'A')
992  >>> path: ('b', 1), values: ('b1', B1')
993
994  Args:
995    shallow_tree: a shallow tree, common to all the inputs.
996    func: callable that takes args (path, inputs_0_value, ... , inputs_N_value),
997      where path is a tuple path to a leaf node in shallow_tree, and
998      inputs_i_value is the corresponding value from inputs[i].
999    *inputs: nested structures that are all structurally compatible with
1000        shallow_tree.
1001    **kwargs: kwargs to feed to func(). Special kwarg
1002      `check_types` is not passed to func, but instead determines whether the
1003      types of iterables within the structures have to be same (e.g.
1004      `map_structure(func, [1], (1,))` raises a `TypeError` exception). To allow
1005      this set this argument to `False`.
1006
1007  Raises:
1008    TypeError: If `shallow_tree` is a sequence but one of `*inputs` is not.
1009    TypeError: If the sequence types of `shallow_tree` are different from
1010      `input_tree`.
1011    ValueError: If the sequence lengths of `shallow_tree` are different from
1012      `input_tree`.
1013
1014  Returns:
1015    Result of repeatedly applying `func`. Has the same structure layout as
1016    `shallow_tree`.
1017  """
1018  if not inputs:
1019    raise ValueError("Cannot map over no sequences")
1020
1021  check_types = kwargs.pop("check_types", True)
1022
1023  for input_tree in inputs:
1024    assert_shallow_structure(shallow_tree, input_tree, check_types=check_types)
1025
1026  # Flatten each input separately, apply the function to corresponding elements,
1027  # then repack based on the structure of the first input.
1028  flat_value_lists = [flatten_up_to(shallow_tree, input_tree, check_types)
1029                      for input_tree in inputs]
1030  flat_path_list = [path for path, _
1031                    in _yield_flat_up_to(shallow_tree, inputs[0])]
1032  results = [func(*args, **kwargs) for args in zip(flat_path_list,
1033                                                   *flat_value_lists)]
1034  return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
1035
1036
1037def get_traverse_shallow_structure(traverse_fn, structure):
1038  """Generates a shallow structure from a `traverse_fn` and `structure`.
1039
1040  `traverse_fn` must accept any possible subtree of `structure` and return
1041  a depth=1 structure containing `True` or `False` values, describing which
1042  of the top-level subtrees may be traversed.  It may also
1043  return scalar `True` or `False` "traversal is OK / not OK for all subtrees."
1044
1045  Examples are available in the unit tests (nest_test.py).
1046
1047  Args:
1048    traverse_fn: Function taking a substructure and returning either a scalar
1049      `bool` (whether to traverse that substructure or not) or a depth=1
1050      shallow structure of the same type, describing which parts of the
1051      substructure to traverse.
1052    structure: The structure to traverse.
1053
1054  Returns:
1055    A shallow structure containing python bools, which can be passed to
1056    `map_structure_up_to` and `flatten_up_to`.
1057
1058  Raises:
1059    TypeError: if `traverse_fn` returns a sequence for a non-sequence input,
1060      or a structure with depth higher than 1 for a sequence input,
1061      or if any leaf values in the returned structure or scalar are not type
1062      `bool`.
1063  """
1064  to_traverse = traverse_fn(structure)
1065  if not is_sequence(structure):
1066    if not isinstance(to_traverse, bool):
1067      raise TypeError("traverse_fn returned structure: %s for non-structure: %s"
1068                      % (to_traverse, structure))
1069    return to_traverse
1070  level_traverse = []
1071  if isinstance(to_traverse, bool):
1072    if not to_traverse:
1073      # Do not traverse this substructure at all.  Exit early.
1074      return False
1075    else:
1076      # Traverse the entire substructure.
1077      for branch in _yield_value(structure):
1078        level_traverse.append(
1079            get_traverse_shallow_structure(traverse_fn, branch))
1080  elif not is_sequence(to_traverse):
1081    raise TypeError("traverse_fn returned a non-bool scalar: %s for input: %s"
1082                    % (to_traverse, structure))
1083  else:
1084    # Traverse some subset of this substructure.
1085    assert_shallow_structure(to_traverse, structure)
1086    for t, branch in zip(_yield_value(to_traverse), _yield_value(structure)):
1087      if not isinstance(t, bool):
1088        raise TypeError(
1089            "traverse_fn didn't return a depth=1 structure of bools.  saw: %s "
1090            " for structure: %s" % (to_traverse, structure))
1091      if t:
1092        level_traverse.append(
1093            get_traverse_shallow_structure(traverse_fn, branch))
1094      else:
1095        level_traverse.append(False)
1096  return _sequence_like(structure, level_traverse)
1097
1098
1099def yield_flat_paths(nest):
1100  """Yields paths for some nested structure.
1101
1102  Paths are lists of objects which can be str-converted, which may include
1103  integers or other types which are used as indices in a dict.
1104
1105  The flat list will be in the corresponding order as if you called
1106  `snt.nest.flatten` on the structure. This is handy for naming Tensors such
1107  the TF scope structure matches the tuple structure.
1108
1109  E.g. if we have a tuple `value = Foo(a=3, b=Bar(c=23, d=42))`
1110
1111  ```shell
1112  >>> nest.flatten(value)
1113  [3, 23, 42]
1114  >>> list(nest.yield_flat_paths(value))
1115  [('a',), ('b', 'c'), ('b', 'd')]
1116  ```
1117
1118  ```shell
1119  >>> list(nest.yield_flat_paths({'a': [3]}))
1120  [('a', 0)]
1121  >>> list(nest.yield_flat_paths({'a': 3}))
1122  [('a',)]
1123  ```
1124
1125  Args:
1126    nest: the value to produce a flattened paths list for.
1127
1128  Yields:
1129    Tuples containing index or key values which form the path to a specific
1130      leaf value in the nested structure.
1131  """
1132  for k, _ in _yield_flat_up_to(nest, nest):
1133    yield k
1134
1135
1136def flatten_with_joined_string_paths(structure, separator="/"):
1137  """Returns a list of (string path, data element) tuples.
1138
1139  The order of tuples produced matches that of `nest.flatten`. This allows you
1140  to flatten a nested structure while keeping information about where in the
1141  structure each data element was located. See `nest.yield_flat_paths`
1142  for more information.
1143
1144  Args:
1145    structure: the nested structure to flatten.
1146    separator: string to separate levels of hierarchy in the results, defaults
1147      to '/'.
1148
1149  Returns:
1150    A list of (string, data element) tuples.
1151  """
1152  flat_paths = yield_flat_paths(structure)
1153  def stringify_and_join(path_elements):
1154    return separator.join(str(path_element) for path_element in path_elements)
1155  flat_string_paths = [stringify_and_join(path) for path in flat_paths]
1156  return list(zip(flat_string_paths, flatten(structure)))
1157
1158
1159def flatten_with_tuple_paths(structure):
1160  """Returns a list of `(tuple_path, leaf_element)` tuples.
1161
1162  The order of pairs produced matches that of `nest.flatten`. This allows you
1163  to flatten a nested structure while keeping information about where in the
1164  structure each data element was located. See `nest.yield_flat_paths`
1165  for more information about tuple paths.
1166
1167  Args:
1168    structure: the nested structure to flatten.
1169
1170  Returns:
1171    A list of `(tuple_path, leaf_element)` tuples. Each `tuple_path` is a tuple
1172    of indices and/or dictionary keys that uniquely specify the path to
1173    `leaf_element` within `structure`.
1174  """
1175  return list(zip(yield_flat_paths(structure), flatten(structure)))
1176
1177
1178_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
1179_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence)
1180