1# Lint as python3
2# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15# ==============================================================================
16"""Structured Tensors."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import re
23from typing import Callable, Dict, List, Sequence, Tuple, Union
24
25import numpy as np
26
27from tensorflow.python.framework import composite_tensor
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import type_spec
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import check_ops
36from tensorflow.python.ops import control_flow_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops.ragged import ragged_factory_ops
39from tensorflow.python.ops.ragged import ragged_tensor
40from tensorflow.python.ops.ragged.row_partition import RowPartition
41from tensorflow.python.util import compat
42from tensorflow.python.util import nest
43
44
45class StructuredTensor(composite_tensor.CompositeTensor):
46  """A multidimensional collection of structures with the same schema.
47
48  A **`StructuredTensor`** is a multi-dimensional collection of ***structures***
49  with the same ***schema***, where:
50
51  * A ***schema*** is a collection of fields, each of which has a name and type.
52  * A ***structure*** maps each field in the schema to a tensor value (which
53    could be a nested StructuredTensor).
54
55  As an important special case, a 1D `StructuredTensor` encodes a 2D table,
56  where columns are heterogeneous `Tensor`s, and rows are the aligned elements
57  in each of those `Tensor`s.
58
59  Internally, StructuredTensors use a "field-major" encoding: for each leaf
60  field, there is a single tensor that stores the value of that field for all
61  structures in the `StructuredTensor`.
62
63  ### Examples
64
65  >>> # A scalar StructuredTensor describing a single person.
66  >>> s1 = StructuredTensor.from_pyval(
67  ...     {"age": 82, "nicknames": ["Bob", "Bobby"]})
68  >>> s1.shape
69  TensorShape([])
70  >>> s1["age"]
71  <tf.Tensor: shape=(), dtype=int32, numpy=82>
72
73  >>> # A vector StructuredTensor describing three people.
74  >>> s2 = StructuredTensor.from_pyval([
75  ...     {"age": 12, "nicknames": ["Josaphine"]},
76  ...     {"age": 82, "nicknames": ["Bob", "Bobby"]},
77  ...     {"age": 42, "nicknames": ["Elmo"]}])
78  >>> s2.shape
79  TensorShape([3])
80  >>> s2[0]["age"]
81  <tf.Tensor: shape=(), dtype=int32, numpy=12>
82
83
84  ### Field Paths
85
86  A *field path* is a tuple of field names, specifying the path to a nested
87  field.
88  """
89
90  #=============================================================================
91  # Common Types
92  #=============================================================================
93  # pylint: disable=invalid-name
94  # Field names work as key, and they can be a sequence to refer to the
95  # sub-levels (embedded) StructuredTensor's.
96  FieldName = Union[str, Sequence[str]]
97
98  # Each field may contain one of the following types of Tensors.
99  FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor']
100
101  # Function that takes a FieldValue as input and returns the transformed
102  # FieldValue.
103  FieldFn = Callable[[FieldValue], FieldValue]
104
105  # pylint: enable=invalid-name
106
107  #=============================================================================
108  # Constructor & Factory Methods
109  #=============================================================================
110
111  def __init__(self, fields, shape, nrows, row_partitions, internal=False):
112    """Private constructor -- use factory methods to create StructuredTensors.
113
114    This constructor builds a `StructuredTensor` from the given attributes,
115    performing minimal validation.
116
117    Args:
118      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
119        `StructuredTensor`.  (This dict is not copied, so the caller must ensure
120        that it does not get mutated via leaked references.)
121      shape: `tf.TensorShape` with statically known rank.
122      nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`.
123      row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`.
124      internal: Private key value, required to ensure that this private
125        constructor is *only* called from the factory methods.
126    """
127    if internal is not _structured_tensor_factory_key:
128      raise ValueError('StructuredTensor constructor is private; please use '
129                       'one of the factory methods instead (e.g., '
130                       'StructuredTensor.from_fields())')
131    assert isinstance(fields, dict), fields
132    assert isinstance(shape, tensor_shape.TensorShape), shape
133    assert nrows is None or isinstance(nrows, ops.Tensor), nrows
134    assert isinstance(row_partitions, tuple), row_partitions
135    self._fields = fields
136    self._shape = shape
137    self._nrows = nrows
138    self._row_partitions = row_partitions
139
140  @classmethod
141  def from_fields(cls,
142                  fields,
143                  shape=(),
144                  nrows=None,
145                  row_partitions=None,
146                  validate=False):
147    """Creates a `StructuredTensor` from a dictionary of fields.
148
149    Args:
150      fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or
151        `StructuredTensor`, providing the values for individual fields in each
152        structure.  If `shape.rank > 0`, then every tensor in `fields` must have
153        the same shape in the first `shape.rank` dimensions; and that shape must
154        be compatible with `shape`; and
155        `result[i1...iN][key] = fields[key][i1...iN]` (where `N==shape.rank`).
156      shape: A `TensorShape`: static information about the shape of the
157        `StructuredTensor`.  Must have a known `rank`.  Defaults to scalar
158        shape (i.e. `rank=0`).
159      nrows: scalar integer tensor containing the number of rows in this
160        `StructuredTensor`.  Should only be specified if `shape.rank > 0`.
161        Default value is inferred from the `fields` values.  If `fields` is
162        empty, then this must be specified.
163      row_partitions: A list of `RowPartition`s describing the (possibly ragged)
164        shape of this `StructuredTensor`.  Should only be specified if
165        `shape.rank > 1`.  Default value is inferred from the `fields` values.
166        If `fields` is empty, then this must be specified.
167      validate: If true, then add runtime validation ops that check that the
168        field values all have compatible shapes in the outer `shape.rank`
169        dimensions.
170
171    Returns:
172      A `StructuredTensor`.
173
174    Examples:
175
176      >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]})
177      <StructuredTensor(
178        fields={
179          "x": tf.Tensor(1, shape=(), dtype=int32),
180          "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)},
181        shape=())>
182
183      >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]},
184      ...                              shape=[2])
185      <StructuredTensor(
186        fields={
187          "bar": tf.Tensor([3 4], shape=(2,), dtype=int32),
188          "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)},
189        shape=(2,))>
190    """
191    shape = tensor_shape.as_shape(shape)
192    rank = shape.rank
193    if rank is None:
194      raise ValueError("StructuredTensor's shape must have known rank.")
195    if not isinstance(fields, dict):
196      raise TypeError('fields must be a dictionary, got %s' %
197                      type(fields).__name__)
198    if rank < 2 and row_partitions:
199      raise ValueError('row_partitions must be None or [] if shape.rank<2')
200    if rank == 0 and nrows is not None:
201      raise ValueError('nrows must be None if shape.rank==0')
202    if row_partitions is not None:
203      row_partitions = tuple(row_partitions)
204      if len(row_partitions) != max(0, rank - 1):
205        raise ValueError('len(row_partitions) must be shape.rank-1')
206    elif rank < 2:
207      row_partitions = ()
208
209    fields = dict(fields)  # Make a private copy.
210    with ops.name_scope(None, 'StructuredTensor', fields.values()):
211
212      # Validate keys and convert field values to tensors.
213      for key, value in fields.items():
214        if not isinstance(key, str):
215          raise TypeError('Unexpected type for key in `fields`: %r' % key)
216        if not _FIELD_NAME_RE.match(key):
217          raise ValueError('Field name %r is not currently allowed.' % key)
218        fields[key] = _convert_to_structured_field_value(value)
219
220      # Determine dtype for row_partitions and nrows.
221      shape_dtype = _find_shape_dtype(fields, nrows, row_partitions)
222      if nrows is not None:
223        nrows = ops.convert_to_tensor(nrows, shape_dtype)
224
225      # Get the static TensorShape for this StructuredTensor.
226      if rank > 0:
227        for key, value in fields.items():
228          if not shape.is_compatible_with(value.shape[:rank]):
229            raise ValueError('Field {} has shape {}, which is incompatible '
230                             'with the shape that was specified or inferred '
231                             'from other fields: {}'.format(
232                                 key, value.shape[:rank], shape))
233          shape = shape.merge_with(value.shape[:rank])
234
235      if rank == 1:
236        # Find a consistent value for `nrows`.
237        static_nrows = tensor_shape.dimension_at_index(shape, 0)
238        for value in fields.values():
239          nrows, static_nrows = _merge_nrows(nrows, static_nrows, value,
240                                             shape_dtype, validate)
241        if nrows is None:
242          if static_nrows.value is None:
243            raise ValueError('nrows must be specified if rank==1 '
244                             'and `fields` is empty.')
245          else:
246            nrows = constant_op.constant(static_nrows.value, shape_dtype)
247
248      if rank > 1:
249        # Find a consistent list of RowPartitions.
250        for value in fields.values():
251          row_partitions = _merge_row_partitions(row_partitions, value, rank,
252                                                 shape_dtype, validate)
253        if row_partitions is None:
254          if not shape.is_fully_defined():
255            raise ValueError('row_partitions must be specified if rank>1 '
256                             'and `fields` is empty.')
257          else:
258            row_partitions = _row_partitions_for_uniform_shape(
259                np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype),
260                shape.rank)
261        assert len(row_partitions) == rank - 1
262        nrows = row_partitions[0].nrows()
263        # Update all field values to use the shared RowPartition objects.
264        fields = dict([(k, _replace_row_partitions(v, row_partitions))
265                       for (k, v) in fields.items()])
266
267    return cls(
268        fields,
269        shape,
270        nrows,
271        row_partitions,
272        internal=_structured_tensor_factory_key)
273
274  def with_updates(self,
275                   updates: Dict[FieldName, Union[FieldValue, FieldFn, None]],
276                   validate: bool = False) -> 'StructuredTensor':    # pylint: disable=bad-whitespace
277    """Creates a new `StructuredTensor` with the updated fields.
278
279    If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being
280    updated and `v` the new value, then:
281
282    ```
283    result[k] = v              # If (k, v) is in updates and v is a FieldValue
284    result[k] = f(self[k])     # If (k, f) is in updates and f is a FieldFn
285    result[k] = self[k]        # If k is in self.field_names but not in updates
286    ```
287
288    If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each
289    FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is,
290    prefixed with the same shape as the `StructuredTensor`. Then the resulting
291    `StructuredTensor` will have:
292
293    ```
294    result[i1...iN][k] = v[i1...iN]                        # (k, v) in updates
295    result[i1...iN][k] = f(self.field_value(k))[i1...iN]   # (k, f) in updates
296    result[i1...iN][k] = self[i1...iN][k]                  # k not in updates
297    ```
298
299    Note that `result.shape` is always equal to `self.shape` (but the shapes
300    of nested StructuredTensors may be changed if they are updated with new
301    values).
302
303    Args:
304      updates: A dictionary mapping `FieldName` to either a `FieldValue` to be
305        used to update, or a `FieldFn` that will transform the value for the
306        given `FieldName`. `FieldName` can be a string for a direct field, or a
307        sequence of strings to refer to a nested sub-field. `FieldFn` is a
308        function that takes a `FieldValue` as input and should return a
309        `FieldValue`. All other fields are copied over to the new
310        `StructuredTensor`. New `FieldName` can be given (to add new fields),
311        but only to existing `StructuredTensor`, it won't automatically create
312        new nested structures -- but one can create a whole `StructureTensor`
313        sub-structure and set that into an existing structure. If the new value
314        is set to `None`, it is removed.
315      validate: If true, then add runtime validation ops that check that the
316        field values all have compatible shapes in the outer `shape.rank`
317        dimensions.
318
319    Returns:
320      A `StructuredTensor`.
321
322    Raises:
323      `ValueError`: If the any of the `FieldName` keys points to non-existent
324        sub-structures, if parent and child nodes are updated, if shapes
325        change, if a delete update is given for a non-existant field, or if a
326        `FieldFn` transforming function is given for a `FieldName` that doesn't
327        yet exist.
328
329    Examples:
330
331    >>> shoes_us = StructuredTensor.from_pyval([
332    ...    {"age": 12, "nicknames": ["Josaphine"],
333    ...       "shoes": {"sizes": [8.0, 7.5, 7.5]}},
334    ...    {"age": 82, "nicknames": ["Bob", "Bobby"],
335    ...        "shoes": {"sizes": [11.0, 11.5, 12.0]}},
336    ...    {"age": 42, "nicknames": ["Elmo"],
337    ...        "shoes": {"sizes": [9.0, 9.5, 10.0]}}])
338    >>> def us_to_europe(t):
339    ...   return tf.round(t * 2.54 + 17.0)  # Rough approximation.
340    >>> shoe_sizes_key = ("shoes", "sizes")
341    >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe})
342    >>> shoes_eu.field_value(shoe_sizes_key)
343    <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0],
344    [40.0, 41.0, 42.0]]>
345    """
346    updates_items = [(_normalize_field_name_to_tuple(name), value)
347                     for name, value in updates.items()]
348
349    # Sort by keys and check for updates of both parent and child nodes.
350    updates_items = sorted(updates_items)
351    for i in range(1, len(updates_items)):
352      # Parent of a node would precede node in the sorted order.
353      name = updates_items[i][0]  # item[0] is the name, item[1] is the value.
354      prev_name = updates_items[i - 1][0]
355      if name[:len(prev_name)] == prev_name:
356        raise ValueError(
357            '`StructuredTensor.with_updates` does not allow both parent and '
358            'child nodes to be updated: parent={}, child={}. If needed you can '
359            'update child nodes in the parent update value.'.format(
360                prev_name, name))
361    return self._with_updates_impl((), updates_items, validate)
362
363  def _with_updates_impl(self, error_prefix: Tuple[str],  # pylint: disable=invalid-sequence-index
364                         updates: List[Tuple[FieldName, Union[FieldValue,  # pylint: disable=invalid-sequence-index
365                                                              FieldFn]]],
366                         validate: bool) -> 'StructuredTensor':
367    """Recursive part of `with_updates` implementation."""
368    # Get current fields.
369    new_fields = dict(self._fields)
370
371    # Convert field name to string with full path for error messages.
372    def name_fullpath(name: Sequence[str]) -> str:
373      return str(error_prefix + (name,))
374
375    # Apply value if a function or the value itself.
376    def apply_value(name: str, value: Union['FieldValue',
377                                            'FieldFn']) -> 'FieldValue':
378      if callable(value):
379        # `value` is actually a transforming function.
380        if name not in new_fields:
381          raise ValueError(
382              '`StructuredTensor.with_updates` cannot update the field {} '
383              'because a transforming function was given, but that field '
384              'does not already exist.'.format(name_fullpath(name)))
385        value = value(new_fields[name])
386      return value
387
388    # Merge updates.
389    for name, value in updates:
390      if not name or not name[0]:
391        raise ValueError(
392            '`StructuredTensor.with_updates` does not allow empty names '
393            '{}.'.format(name_fullpath(name)))
394
395      if len(name) == 1:
396        name = name[0]
397        if value is None:
398          if name not in new_fields:
399            raise ValueError(
400                '`StructuredTensor.with_updates` cannot delete field '
401                '{} because it is not present.'.format(name_fullpath(name)))
402          new_fields.pop(name)
403        else:
404          new_fields[name] = apply_value(name, value)
405      else:
406        # Recursive
407        prefix = name[0]
408        suffix = name[1:]
409        if prefix not in new_fields:
410          raise ValueError(
411              '`StructuredTensor.with_updates` cannot create new sub-field '
412              '{} if parent field {} is not set.'.format(
413                  error_prefix + tuple(name), name_fullpath(prefix)))
414        current_value = new_fields[prefix]
415        if not isinstance(current_value, StructuredTensor):
416          raise ValueError(
417              '`StructuredTensor.with_updates` cannot create new sub-field '
418              '{} if parent structure {} is not a `StructuredTensor` that '
419              'can contain sub-structures -- it is a `{}`.'.format(
420                  error_prefix + tuple(name), name_fullpath(prefix),
421                  type(current_value)))
422        one_update = [(suffix, value)]
423
424        # Accessing protected member in recursion.
425        # FutureWork: optimize by aggregating the recursions, instead of
426        #   calling one at a time.
427        # pylint: disable=protected-access
428        value = current_value._with_updates_impl(error_prefix + (prefix,),
429                                                 one_update, validate)
430        # pylint: enable=protected-access
431        new_fields[prefix] = value
432
433    # TODO(edloper): When validate=True, only validate the modified fields.
434    try:
435      return StructuredTensor.from_fields(
436          new_fields,
437          shape=self.shape,
438          row_partitions=self._row_partitions,
439          nrows=self._nrows,
440          validate=validate)
441
442    except ValueError as e:
443      msg = '`StructuredTensor.with_updates` failed'
444      if error_prefix:
445        msg = '{} for field {}'.format(msg, error_prefix)
446      raise ValueError('{}: {}'.format(msg, e))
447
448  def _promote_helper(self, source_path, new_parent_path):
449    """Creates a promoted field without adding it to the structure.
450
451    Args:
452      source_path: the source path in the structured tensor.
453      new_parent_path: the new parent path. Must be a prefix of source_path.
454
455    Returns:
456      a composite tensor of source_path promoted.
457    Raises:
458      ValueError: if the shape of the field is unknown and the right strategy
459      cannot be determined.
460    """
461    current_field = self.field_value(source_path)
462    new_parent_rank = self.field_value(new_parent_path).rank
463    parent_rank = self.field_value(source_path[:-1]).rank
464    if new_parent_rank == parent_rank:
465      return current_field
466    current_field_rank = current_field.shape.rank
467    if current_field_rank is None:
468      raise ValueError('Cannot determine if dimensions should be merged.')
469    inner_dim = min(parent_rank, current_field_rank - 1)
470    if inner_dim <= new_parent_rank:
471      return current_field
472    return _merge_dims_generic(current_field, new_parent_rank, inner_dim)
473
474  def promote(self, source_path, new_name):
475    """Promotes a field, merging dimensions between grandparent and parent.
476
477    >>> d = [
478    ...  {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]},
479    ...  {'docs': [{'tokens':[7]}]}]
480    >>> st = StructuredTensor.from_pyval(d)
481    >>> st2 =st.promote(('docs','tokens'), 'docs_tokens')
482    >>> st2[0]['docs_tokens']
483    <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)>
484    >>> st2[1]['docs_tokens']
485    <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)>
486
487    Args:
488      source_path: the path of the field or substructure to promote; must have
489        length at least 2.
490      new_name: the name of the new field (must be a string).
491
492    Returns:
493      a modified structured tensor with the new field as a child of the
494      grandparent of the source_path.
495
496    Raises:
497      ValueError: if source_path is not a list or a tuple or has a length
498        less than two, or new_name is not a string, or the rank
499        of source_path is unknown and it is needed.
500    """
501    if not isinstance(new_name, str):
502      raise ValueError('new_name is not a string')
503    if not isinstance(source_path, (list, tuple)):
504      raise ValueError('source_path must be a list or tuple')
505
506    if len(source_path) < 2:
507      raise ValueError('source_path must have length at least two')
508
509    grandparent_path = source_path[:-2]
510    new_field = self._promote_helper(source_path, grandparent_path)
511    new_path = grandparent_path + (new_name,)
512    return self.with_updates({new_path: new_field})
513
514  #=============================================================================
515  # Properties
516  #=============================================================================
517
518  @property
519  def rank(self):
520    """The rank of this StructuredTensor.  Guaranteed not to be `None`."""
521    return self._shape.rank
522
523  @property
524  def shape(self):
525    """The static shape of this StructuredTensor.
526
527    The returned `TensorShape` is guaranteed to have a known rank, but the
528    individual dimension sizes may be unknown.
529
530    Returns:
531      `tf.TensorShape`
532    """
533    return self._shape
534
535  # TODO(edloper): Make this a func instead of a property?  Or make nrows
536  # a property instead of a func?  Seems like these should be consistent.
537  @property
538  def row_partitions(self):
539    """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`.
540
541    When `self.rank <= 1`, this tuple will be empty.
542
543    When `self.rank > 1`, these `RowPartitions` define the shape of the
544    `StructuredTensor` by describing how a flat (1D) list of structures can be
545    repeatedly partitioned to form a higher-dimensional object.  In particular,
546    the flat list is first partitioned into sublists using `row_partitions[-1]`,
547    and then those sublists are further partitioned using `row_partitions[-2]`,
548    etc.  The following examples show the row partitions used to describe
549    several different `StructuredTensor`, each of which contains 8 copies of
550    the same structure (`x`):
551
552    >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']}       # shape = [] (scalar)
553
554    >>> s1 = [[x, x, x, x], [x, x, x, x]]              # shape = [2, 4]
555    >>> StructuredTensor.from_pyval(s1).row_partitions
556    (tf.RowPartition(row_splits=tf.Tensor([0 4 8], shape=(3,),
557                                          dtype=int64)),)
558
559    >>> s2 = [[x, x], [x, x], [x, x], [x, x]]          # shape = [4, 2]
560    >>> StructuredTensor.from_pyval(s2).row_partitions
561    (tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,),
562                                          dtype=int64)),)
563
564    >>> s3 = [[x, x, x], [], [x, x, x, x], [x]]        # shape = [2, None]
565    >>> StructuredTensor.from_pyval(s3).row_partitions
566    (tf.RowPartition(row_splits=tf.Tensor([0 3 3 7 8], shape=(5,),
567                                          dtype=int64)),)
568
569    >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]]      # shape = [2, 2, 2]
570    >>> StructuredTensor.from_pyval(s4).row_partitions
571    (tf.RowPartition(row_splits=tf.Tensor([0 2 4], shape=(3,), dtype=int64)),
572     tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,),
573                                          dtype=int64)))
574
575
576    >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]]  # shape = [3, None, None]
577    >>> StructuredTensor.from_pyval(s5).row_partitions
578    (tf.RowPartition(row_splits=tf.Tensor([0 2 3 5], shape=(4,), dtype=int64)),
579     tf.RowPartition(row_splits=tf.Tensor([0 2 3 5 7 8], shape=(6,),
580                                          dtype=int64)))
581
582    Note that shapes for nested fields (such as `x['b']` in the above example)
583    are not considered part of the shape of a `StructuredTensor`, and are not
584    included in `row_partitions`.
585
586    If this `StructuredTensor` has a ragged shape (i.e., if any of the
587    `row_partitions` is not uniform in size), then all fields will be encoded
588    as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s
589    used to define their outermost `self.rank` dimensions.
590
591    Returns:
592      A `tuple` of `RowPartition` objects with length `self.rank - 1`
593      (or `0` if `self.rank < 2`)
594
595    """
596    return self._row_partitions
597
598  def nrows(self):
599    """The number of rows in this StructuredTensor (if rank>0).
600
601    This means the length of the outer-most dimension of the StructuredTensor.
602
603    Notice that if `self.rank > 1`, then this equals the number of rows
604    of the first row partition. That is,
605    `self.nrows() == self.row_partitions[0].nrows()`.
606
607    Otherwise `self.nrows()` will be the first dimension of the field values.
608
609    Returns:
610      A scalar integer `Tensor` (or `None` if `self.rank == 0`).
611    """
612    return self._nrows
613
614  def _is_eager(self):
615    """True if all fields are composed of eager tensors."""
616    tensors = nest.flatten(self, expand_composites=True)
617    return all(isinstance(t, ops.EagerTensor) for t in tensors)
618
619  #=============================================================================
620  # Encoding
621  #=============================================================================
622
623  def field_names(self):
624    """Returns the string field names for this `StructuredTensor`."""
625    return tuple(self._fields.keys())
626
627  def field_value(self, field_name):
628    """Returns the tensor value for the specified field or path.
629
630    If `field_name` is a `string`, then it names a field directly owned by this
631    `StructuredTensor`.  If this `StructuredTensor` has shape `[D1...DN]`, then
632    the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice
633    `result[d1...dN]` contains the field value for the structure at
634    `self[d1...dN]`.
635
636    If `field_name` is a `tuple` of `string`, then it specifies a path to a
637    field owned by nested `StructuredTensor`.  In particular,
638    `struct.field_value((f1, f2, ..., fN))` is equivalent to
639    `struct.field_value(f1).field_value(f2)....field_value(fN)`
640
641    Args:
642      field_name: `string` or `tuple` of `string`: The field whose values should
643        be returned.
644
645    Returns:
646      `Tensor`, `StructuredTensor`, or `RaggedTensor`.
647
648    Raises:
649      KeyError: If the given field_name is not found.
650    """
651    if isinstance(field_name, (list, tuple)):
652      value = self
653      for f in field_name:
654        if not isinstance(value, StructuredTensor):
655          raise KeyError('Field path {} not found in {}'.format(
656              field_name, self))
657        value = value.field_value(f)
658      return value
659    return self._fields[field_name]
660
661  #=============================================================================
662  # Operators
663  #=============================================================================
664
665  # TODO(edloper): Add support for ellipsis and/or newaxis?
666  def __getitem__(self, key):
667    """Returns the specified piece of this StructuredTensor.
668
669    * If `struct_tensor` is scalar (i.e., a single structure), then
670      `struct_tensor[f]` returns the value of field `f` (where `f` must be a
671      string).
672
673    * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional
674      tensor of structures), `struct_tensor[i]` selects an element or slice of
675      the tensor using standard Python semantics (e.g., negative values index
676      from the end).  `i` may have any of the following types:
677
678      * `int` constant
679      * `string` constant
680      * scalar integer `Tensor`
681      * `slice` containing integer constants and/or scalar integer
682        `Tensor`s
683
684    #### Multidimensional indexing
685
686    `StructuredTensor` supports multidimensional indexing.  I.e., `key` may be a
687    `tuple` of values, indexing or slicing multiple dimensions at once.  For
688    example, if `people` is a vector of structures, each of which has a vector-
689    valued `names` field, then `people[3, 'names', 0]` is equivalent to
690    `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly
691    ragged) matrix of names, with shape `[num_people, num_names_per_person]`.
692
693    Args:
694      key: Indicates which piece of the StructuredTensor to return.
695    Returns:
696      A `Tensor`, `StructuredTensor`, or `RaggedTensor`.
697    """
698    if isinstance(key, list):
699      key = tuple(key)
700    elif not isinstance(key, tuple):
701      key = (key,)
702    if not key:
703      return self
704
705    if self._shape.rank == 0:
706      return self._scalar_getitem(key)
707    else:
708      return self._tensor_getitem(key)
709
710  def _scalar_getitem(self, key):
711    if (isinstance(key[0], slice) and key[0].start is None and
712        key[0].stop is None and key[0].step is None):
713      fields = dict((field_name, field_value.__getitem__(key[1:]))
714                    for (field_name, field_value) in self._fields.items())
715      return StructuredTensor.from_fields(fields, self._shape)
716
717    elif not isinstance(key[0], compat.bytes_or_text_types):
718      raise ValueError('Key for indexing a StructuredTensor must be a '
719                       "string or a full slice (':')")
720
721    return self._fields[key[0]].__getitem__(key[1:])
722
723  def _tensor_getitem(self, key):
724    rank = self._shape.rank
725    if len(key) <= rank:
726      new_fields = dict((field_name, field_value.__getitem__(key))
727                        for (field_name, field_value) in self._fields.items())
728      result_shape = self.shape.as_list()
729      for d, k in enumerate(key):
730        if isinstance(k, slice):
731          if not (k.start is None and k.stop is None and k.step is None):
732            # TODO(edloper): Better static shape analysis here.
733            result_shape[d] = None
734        elif isinstance(k, (int, ops.Tensor)):
735          result_shape[d] = -1  # mark for deletion
736        elif k is None:
737          raise ValueError('Slicing not supported for tf.newaxis')
738        else:
739          # Ellipsis, tf.newaxis:
740          raise ValueError('Slicing not supported for %r' % k)
741      result_shape = [d for d in result_shape if d != -1]
742      return StructuredTensor.from_fields(new_fields, result_shape)
743
744    else:
745      if not isinstance(key[rank], compat.bytes_or_text_types):
746        # TODO(edloper): Also support full slice here?
747        raise ValueError('Key for indexing a StructuredTensor must be a string')
748      return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:])
749
750  def __repr__(self):
751    fields = sorted(self._fields.items())
752    fields = ((k, str(v).replace('\n', '\n            ')) for k, v in fields)
753    fields = ('"{}": {}'.format(k, v) for k, v in fields)
754    dict_repr = ',\n        '.join(fields)
755    return (
756        '<StructuredTensor(\n'
757        '    fields={\n'
758        '        %s},\n'
759        '    shape=%s)>' % (dict_repr, self._shape))
760
761  #=============================================================================
762  # Conversion
763  #=============================================================================
764
765  def to_pyval(self):
766    """Returns this StructuredTensor as a nested Python dict or list of dicts.
767
768    Converts this `StructuredTensor` to a nested python value:
769
770    * `StructTensors` with `rank=0` are converted into a dictionary, with an
771      entry for each field.  Field names are used as keys and field values are
772      converted to python values.  In particular:
773
774      * Scalar Tensor fields are converted to simple values (such as
775        `int` or `float` or `string`)
776      * Non-scalar Tensor fields and RaggedTensor fields are converted to
777        nested lists of simple values.
778      * StructuredTensor fields are converted recursively using `to_pyval`.
779
780    * `StructTensors` with `rank>0` are converted to nested python `list`s,
781      containing one dictionary for each structure (where each structure's
782      dictionary is defined as described above).
783
784    Requires that all fields are Eager tensors.
785
786    >>> StructuredTensor.from_fields(
787    ...     {'a': [1, 2, 3]}, [3]).to_pyval()
788    [{'a': 1}, {'a': 2}, {'a': 3}]
789
790    Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
791
792    Returns:
793      A nested Python dict or list of dicts.
794    """
795    if not self._is_eager():
796      raise ValueError(
797          'StructuredTensor.to_pyval() is only supported in eager mode.')
798
799    # Convert each field value to a nested list.
800    result = {}
801    for (key, value) in self._fields.items():
802      if isinstance(value, ops.EagerTensor):
803        value = value.numpy()
804      if isinstance(value, np.ndarray):
805        value = value.tolist()
806      elif isinstance(value, ragged_tensor.RaggedTensor):
807        value = value.to_list()
808      elif isinstance(value, StructuredTensor):
809        value = value.to_pyval()
810      # TODO(edloper): Throw an exception if value is an unexpected type.
811      result[key] = value
812
813    # If rank>0, then re-group each value from dict-of-list to list-of-dict.
814    if len(self._shape) > 0:  # pylint: disable=g-explicit-length-test
815      if not result:  # special-case for StructuredTensors w/ no fields.
816        return _empty_dict_pylist_from_row_partitions(self._row_partitions,
817                                                      self._nrows)
818      return _pyval_field_major_to_node_major(
819          list(result.keys()), list(result.values()), self._shape.rank)
820    else:
821      return result
822
823  @classmethod
824  def from_pyval(cls, pyval, typespec=None):
825    """Constructs a StructuredTensor from a nested Python structure.
826
827    >>> StructuredTensor.from_pyval(
828    ...     {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]})
829    <StructuredTensor(
830        fields={
831          "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32),
832          "b": <tf.RaggedTensor [[4, 5], [6, 7]]>},
833        shape=())>
834
835    Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`.
836
837    Args:
838      pyval: The nested Python structure that should be used to create the new
839        `StructuredTensor`.
840      typespec: A `StructuredTensorSpec` specifying the expected type for each
841        field. If not specified, then all nested dictionaries are turned into
842        StructuredTensors, and all nested lists are turned into Tensors (if
843        rank<2) or RaggedTensors (if rank>=2).
844
845    Returns:
846      A `StructuredTensor`.
847    """
848    if isinstance(pyval, dict):
849      return cls._from_pydict(pyval, typespec)
850    elif isinstance(pyval, (list, tuple)):
851      keys = set()
852      rank = _pyval_find_struct_keys_and_depth(pyval, keys)
853      if rank is not None:
854        return cls._from_pylist_of_dict(pyval, keys, rank, typespec)
855      else:
856        return cls._from_pylist_of_value(pyval, typespec)
857    else:
858      return cls._from_pyscalar(pyval, typespec)
859
860  @classmethod
861  def _from_pydict(cls, pyval, typespec):
862    """Converts python dictionary `pyval` to a StructuredTensor with rank=0."""
863    if typespec is None:
864      fields = dict((k, cls.from_pyval(v)) for (k, v) in pyval.items())
865    else:
866      spec_shape = typespec._shape  # pylint: disable=protected-access
867      field_specs = typespec._field_specs  # pylint: disable=protected-access
868      if not (isinstance(typespec, StructuredTensorSpec) and
869              spec_shape.rank == 0 and set(pyval) == set(field_specs)):
870        raise ValueError('Value does not match typespec: %r vs %r' %
871                         (pyval, typespec))
872      fields = dict(
873          (k, cls.from_pyval(v, field_specs[k])) for (k, v) in pyval.items())
874    return StructuredTensor.from_fields(fields=fields, shape=(), validate=False)
875
876  @classmethod
877  def _from_pylist_of_dict(cls, pyval, keys, rank, typespec):
878    """Converts python list `pyval` to a StructuredTensor with rank>1."""
879    fields = dict((key, []) for key in keys)
880    for child in pyval:
881      _pyval_update_fields(child, fields, 1)
882    if typespec is None:
883      shape = tensor_shape.TensorShape([None] * rank)
884      for (key, target) in fields.items():
885        fields[key] = cls.from_pyval(target)
886    else:
887      field_specs = typespec._field_specs  # pylint: disable=protected-access
888      if ((not isinstance(typespec, StructuredTensorSpec)) or
889          (set(fields) - set(field_specs))):
890        raise ValueError('Value does not match typespec: %r vs %r' %
891                         (pyval, typespec))
892      shape = typespec._shape
893      if shape.rank < rank:
894        raise ValueError('Value does not match typespec (rank mismatch): '
895                         '%r vs %r' % (pyval, typespec))
896      for (key, spec) in field_specs.items():
897        fields[key] = cls.from_pyval(fields.get(key, []), spec)
898    return StructuredTensor.from_fields(
899        fields=fields, shape=shape, validate=False)
900
901  @classmethod
902  def _from_pylist_of_value(cls, pyval, typespec):
903    """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1."""
904    if typespec is None:
905      return ragged_factory_ops.constant(pyval)
906    elif isinstance(typespec, tensor_spec.TensorSpec):
907      result = constant_op.constant(pyval, typespec.dtype)
908      if not typespec.shape.is_compatible_with(result.shape):
909        raise ValueError('Value does not match typespec: %r vs %r' %
910                         (typespec, pyval))
911      return result
912    elif isinstance(typespec, ragged_tensor.RaggedTensorSpec):
913      # pylint: disable=protected-access
914      return ragged_factory_ops.constant(
915          pyval,
916          dtype=typespec._dtype,
917          ragged_rank=typespec._ragged_rank,
918          row_splits_dtype=typespec._row_splits_dtype,
919          inner_shape=typespec._shape[typespec._ragged_rank + 1:])
920    elif isinstance(typespec, StructuredTensorSpec):
921      empty_rank = _pyval_empty_list_depth(pyval)
922      if empty_rank is None:
923        raise ValueError('Value does not match typespec: %r vs %r' %
924                         (typespec, pyval))
925      else:
926        return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec)
927    else:
928      raise ValueError('Value does not match typespec: %r vs %r' %
929                       (typespec, pyval))
930
931  @classmethod
932  def _from_pyscalar(cls, pyval, typespec):
933    """Converts python scalar value `pyval` to a Tensor."""
934    if typespec is None:
935      return constant_op.constant(pyval)
936    else:
937      if not (isinstance(typespec, tensor_spec.TensorSpec) and
938              typespec.shape.rank == 0):
939        raise ValueError('Value does not match typespec: %r vs %r' %
940                         (typespec, pyval))
941      # TODO(edloper): Check that typespec.shape matches.
942      return constant_op.constant(pyval, typespec.dtype)
943
944  #=============================================================================
945  # Transforms
946  #=============================================================================
947
948  # TODO(edloper): Add a 'validate' option here?
949  # TODO(edloper): Unify nomenclature with RaggedTensor.  Should RaggedTensor
950  # have a partition_outer_dimension method?
951  def partition_outer_dimension(self, row_partition):
952    """Partitions the outer dimension of this StructuredTensor.
953
954    Returns a new `StructuredTensor` with the same values as `self`, where
955    the outer dimension is partitioned into two (possibly ragged) dimensions.
956    Requires that this StructuredTensor have an outer dimension (i.e.,
957    `self.shape.rank > 0`).
958
959    >>> st = StructuredTensor.from_pyval(
960    ...     [{'foo': 12}, {'foo': 33}, {'foo': 99}])
961    >>> partition = RowPartition.from_row_lengths([2, 0, 1])
962    >>> st.partition_outer_dimension(partition)
963    <StructuredTensor(
964      fields={
965        "foo": <tf.RaggedTensor [[12, 33], [], [99]]>},
966      shape=(3, None))>
967
968    Args:
969      row_partition: A `RowPartition`.
970
971    Returns:
972      A `StructuredTensor` with rank `values.rank + 1`.
973    """
974    if not isinstance(row_partition, RowPartition):
975      raise TypeError('row_partition must be a RowPartition.')
976    if self.shape.rank == 0:
977      raise ValueError('Shape %s must have rank at least 1' % self.shape)
978    return _partition_outer_dimension(self, row_partition)
979
980  def merge_dims(self, outer_axis, inner_axis):
981    """Merges outer_axis...inner_axis into a single dimension.
982
983    Returns a copy of this RaggedTensor with the specified range of dimensions
984    flattened into a single dimension, with elements in row-major order.
985
986    >>> st = StructuredTensor.from_pyval(
987    ...     [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]])
988    >>> st.merge_dims(0, 1)
989    <StructuredTensor(
990      fields={
991        "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)},
992      shape=(3,))>
993
994    Args:
995      outer_axis: `int`: The first dimension in the range of dimensions to
996        merge. May be negative (to index from the last dimension).
997      inner_axis: `int`: The last dimension in the range of dimensions to merge.
998        May be negative (to index from the last dimension).
999
1000    Returns:
1001      A copy of this tensor, with the specified dimensions merged into a
1002      single dimension.  The shape of the returned tensor will be
1003      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1004      is the total number of slices in the merged dimensions.
1005    """
1006    outer_axis = array_ops.get_positive_axis(
1007        outer_axis,
1008        self.shape.rank,
1009        axis_name='outer_axis',
1010        ndims_name='rank(self)')
1011    inner_axis = array_ops.get_positive_axis(
1012        inner_axis,
1013        self.shape.rank,
1014        axis_name='inner_axis',
1015        ndims_name='rank(self)')
1016    if not outer_axis <= inner_axis:
1017      raise ValueError('Expected outer_axis (%d) to be less than or equal to '
1018                       'inner_axis (%d)' % (outer_axis, inner_axis))
1019    return _merge_dims(self, outer_axis, inner_axis)
1020
1021  #=============================================================================
1022  # Composite Tensor
1023  #=============================================================================
1024
1025  @property
1026  def _type_spec(self):
1027    return StructuredTensorSpec.from_value(self)
1028
1029
1030class StructuredTensorSpec(type_spec.BatchableTypeSpec):
1031  """Type specification for `StructuredTensor`s."""
1032
1033  __slots__ = ['_shape', '_field_specs']
1034
1035  def __init__(self, shape, field_specs):
1036    """Build a type specification for a StructuredTensor.
1037
1038    Args:
1039      shape: The shape of the StructuredTensor.  shape.rank must not be None.
1040      field_specs: A dictionary mapping from field name to TypeSpec, specifying
1041        the tensor type used to encode each field. These TypeSpecs should
1042        specify the type of the entire field (including outer dimensions which
1043        correspond to `shape`).  For example, if `shape=[2, 3]`, and field 'x'
1044        contains an int32 vector of size `10` for each structure, then
1045        `field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`.
1046    """
1047    shape = tensor_shape.as_shape(shape)
1048
1049    # Perform a few sanity checks on the inputs.
1050    if shape.rank is None:
1051      raise TypeError("StructuredTensor's shape must have known rank.")
1052    if not isinstance(field_specs, dict):
1053      raise TypeError('field_specs must be a dictionary.')
1054    for key, value in field_specs.items():
1055      if not isinstance(key, str):
1056        raise TypeError('field_specs must be a dictionary with string keys.')
1057      if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec,
1058                                ragged_tensor.RaggedTensorSpec)):
1059        raise TypeError('field_specs must be a dictionary with '
1060                        'TypeSpec values.')
1061
1062    self._shape = shape
1063    self._field_specs = dict(field_specs)
1064
1065  @property
1066  def value_type(self):
1067    return StructuredTensor
1068
1069  def _to_components(self, value):
1070    return value._fields
1071
1072  def _from_components(self, components):
1073    return StructuredTensor.from_fields(components, self._shape, validate=False)
1074
1075  @property
1076  def _component_specs(self):
1077    return self._field_specs
1078
1079  @classmethod
1080  def from_value(cls, value):
1081    field_specs = dict((k, type_spec.type_spec_from_value(v))
1082                       for (k, v) in value._fields.items())
1083    return cls(value.shape, field_specs)
1084
1085  def _serialize(self):
1086    return (self._shape, self._field_specs)
1087
1088  def _batch(self, batch_size):
1089    # pylint: disable=protected-access
1090    return StructuredTensorSpec(
1091        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
1092        dict((k, v._batch(batch_size)) for (k, v) in self._field_specs.items()))
1093
1094  def _unbatch(self):
1095    # pylint: disable=protected-access
1096    return StructuredTensorSpec(
1097        self._shape[1:],
1098        dict((k, v._unbatch()) for (k, v) in self._field_specs.items()))
1099
1100  @property
1101  def _flat_tensor_specs(self):
1102    # pylint: disable=protected-access
1103    result = []
1104    for _, field_spec in sorted(self._field_specs.items(), key=lambda t: t[0]):
1105      result.extend(field_spec._flat_tensor_specs)
1106    return result
1107
1108  def _to_tensor_list(self, value):
1109    return self._to_tensor_list_internal(value, batched=False)
1110
1111  def _to_batched_tensor_list(self, value):
1112    return self._to_tensor_list_internal(value, batched=True)
1113
1114  def _from_compatible_tensor_list(self, tensor_list):
1115    # pylint: disable=protected-access
1116    fields = {}
1117    pos = 0
1118    for field_name, field_spec in sorted(
1119        self._field_specs.items(), key=lambda t: t[0]):
1120      num_tensors_for_field = len(field_spec._flat_tensor_specs)
1121      field_tensors = tensor_list[pos:pos + num_tensors_for_field]
1122      fields[field_name] = field_spec._from_compatible_tensor_list(
1123          field_tensors)
1124      pos += num_tensors_for_field
1125    return StructuredTensor.from_fields(fields, self._shape)
1126
1127  def _to_tensor_list_internal(self, value, batched):
1128    """Returns a dict whose entries are each field's (batched) tensor_list.
1129
1130    If a field is a StructuredTensor, then its entry will be a dict,
1131    recursively.
1132
1133    Args:
1134      value: A StructuredTensor (conforming to `self`).
1135      batched: A boolean. if True, produce `batched_tensor_list` for each field
1136        otherwise produce `tensor_list`.
1137    Returns:
1138      A dict.
1139    """
1140    result = []
1141    for field_name, field_spec in sorted(
1142        self._field_specs.items(), key=lambda t: t[0]):
1143      # pylint: disable=protected-access
1144      field_value = value._fields[field_name]
1145      if batched:
1146        result.extend(field_spec._to_batched_tensor_list(field_value))
1147      else:
1148        result.extend(field_spec._to_tensor_list(field_value))
1149
1150    return result
1151
1152# Regular expression used to determine whether a string is a valid field name.
1153# Note: we plan to relax (or possibly eliminate) this in the future; you
1154# should not rely on the fact that some field names are currently disallowed.
1155_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$')
1156
1157
1158#=============================================================================
1159# Helper funtions
1160#=============================================================================
1161# TODO(edloper): Move some of these helpers to row_partition.py?
1162
1163
1164def _convert_to_structured_field_value(value):
1165  """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor."""
1166  if isinstance(value,
1167                (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)):
1168    return value
1169  elif ragged_tensor.is_ragged(value):
1170    return ragged_tensor.convert_to_tensor_or_ragged_tensor(value)
1171  else:
1172    try:
1173      return ops.convert_to_tensor(value)
1174    except (ValueError, TypeError):
1175      raise TypeError('Unexpected type for value in `fields`: %r' % value)
1176
1177
1178def _find_shape_dtype(fields, nrows, row_partitions):
1179  """Return a consistent dtype for fields, nrows, & row_partitions."""
1180  shape_dtypes = set()
1181  for value in fields.values():
1182    if isinstance(value, ragged_tensor.RaggedTensor):
1183      shape_dtypes.add(value.row_splits.dtype)
1184    elif isinstance(value, StructuredTensor) and value.rank > 0:
1185      shape_dtypes.add(value.nrows().dtype)
1186  if isinstance(nrows, ops.Tensor):
1187    shape_dtypes.add(nrows.dtype)
1188  if row_partitions is not None:
1189    for partition in row_partitions:
1190      shape_dtypes.add(partition.dtype)
1191  if len(shape_dtypes) > 1:
1192    raise ValueError('field values have incompatible row_partition dtypes.')
1193  elif shape_dtypes:
1194    return shape_dtypes.pop()
1195  else:
1196    return dtypes.int64
1197
1198
1199def _merge_nrows(nrows, static_nrows, value, dtype, validate):
1200  """Merges `nrows` with `nrows(value)`.
1201
1202  Checks that `value` has the expected number of rows (`nrows`), and returns
1203  `nrows`.  If `validate` is true, then add validation ops that check that
1204  the `nrows` values match.
1205
1206  Args:
1207    nrows: scalar integer Tensor.
1208    static_nrows: tf.Dimension: static value of nrows, if known.
1209    value: Tensor or RaggedTensor or StructuredTensor
1210    dtype: dtype for `nrows`.
1211    validate: bool -- whether to add validation ops.
1212
1213  Returns:
1214    A tuple `(nrows, static_nrows)`.
1215  """
1216  static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0)
1217  if isinstance(value, ops.Tensor):
1218    value_nrows = array_ops.shape(value, out_type=dtype)[0]
1219  else:
1220    value_nrows = value.nrows()
1221  if nrows is None:
1222    nrows = value_nrows
1223  elif (static_value_nrows.value is not None and
1224        static_nrows.value is not None):
1225    if not static_value_nrows.is_compatible_with(static_nrows):
1226      raise ValueError('fields have incompatible nrows')
1227    nrows = value_nrows  # No need to add an assertion op.
1228  elif validate:
1229    nrows = control_flow_ops.with_dependencies([
1230        check_ops.assert_equal(nrows, value_nrows,
1231                               message='fields have incompatible nrows')
1232    ], nrows)
1233  return nrows, static_nrows.merge_with(static_value_nrows)
1234
1235
1236def _merge_row_partitions(row_partitions, value, rank, dtype, validate):
1237  """Merges `row_partitions` with `row_partitions(value)`."""
1238  if isinstance(value, ops.Tensor):
1239    value_row_partitions = _row_partitions_for_tensor(value, rank, dtype)
1240
1241  elif isinstance(value, ragged_tensor.RaggedTensor):
1242    value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype)
1243
1244  else:
1245    assert isinstance(value, StructuredTensor), type(value)
1246    value_row_partitions = value.row_partitions[:rank - 1]
1247
1248  assert len(value_row_partitions) == rank - 1
1249  if row_partitions is None:
1250    return tuple(value_row_partitions)
1251  else:
1252    return tuple([
1253        p1.merge_precomputed_encodings(p2, validate)
1254        for (p1, p2) in zip(row_partitions, value_row_partitions)
1255    ])
1256
1257
1258def _row_partitions_for_tensor(value, rank, dtype):
1259  """Returns the row partitions for a tf.Tensor."""
1260  shape = array_ops.shape(value, out_type=dtype)
1261  return _row_partitions_for_uniform_shape(shape, rank)
1262
1263
1264def _row_partitions_for_ragged_tensor(value, rank, dtype):
1265  """Returns the row partitions for a tf.RaggedTensor."""
1266  assert rank > 1
1267  value_row_partitions = value._nested_row_partitions[:rank - 1]  # pylint: disable=protected-access
1268  if len(value_row_partitions) < (rank - 1):
1269    value_row_partitions += _row_partitions_for_tensor(
1270        value.flat_values, rank - len(value_row_partitions), dtype)
1271  assert len(value_row_partitions) == rank - 1
1272  return value_row_partitions
1273
1274
1275def _row_partitions_for_uniform_shape(shape, rank):
1276  """Returns row partitions for the given shape Tensor.
1277
1278  Args:
1279    shape: A vector describing a uniform shape.
1280    rank: The number of dimensions to generate row partitions for
1281
1282  Returns:
1283    A list of (rank-1) `RowPartition`s with uniform row length.
1284  """
1285  shape_cumprod = math_ops.cumprod(shape[:rank])
1286  # pylint: disable=g-complex-comprehension
1287  return tuple([
1288      RowPartition.from_uniform_row_length(
1289          uniform_row_length=shape[i + 1],
1290          nvals=shape_cumprod[i + 1],
1291          nrows=shape_cumprod[i]) for i in range(rank - 1)
1292  ])
1293
1294
1295def _pyval_field_major_to_node_major(keys, values, depth):
1296  """Regroup each field (k, v) from dict-of-list to list-of-dict.
1297
1298  Given a "field-major" encoding of the StructuredTensor (which maps each key to
1299  a single nested list containing the values for all structs), return a
1300  corresponding "node-major" encoding, consisting of a nested list of dicts.
1301
1302  Args:
1303    keys: The field names (list of string).  Must not be empty.
1304    values: The field values (list of python values).  Must have the same length
1305      as `keys`.
1306    depth: The list depth at which dictionaries should be created.
1307
1308  Returns:
1309    A nested list of dict, with depth `depth`.
1310  """
1311  assert keys
1312  if depth == 0:
1313    return dict(zip(keys, values))
1314  nvals = len(values[0])
1315  assert all(nvals == len(values[i]) for i in range(1, len(values)))
1316  return [
1317      _pyval_field_major_to_node_major(keys, value_slice, depth - 1)
1318      for value_slice in zip(*values)
1319  ]
1320
1321
1322def _empty_dict_pylist_from_row_partitions(row_partitions, nrows):
1323  """Returns a python list of empty dicts from the given row partitions.
1324
1325  Args:
1326    row_partitions: The row-partitions describing the ragged shape of the
1327      result.
1328    nrows: The number of rows in the outermost row-partition.  (Or if
1329      `len(row_partitions)==0`, then the number of empty dicts to return.)
1330
1331  Returns:
1332    A nested python list whose leaves (if any) are empty python dicts.
1333  """
1334  if not row_partitions:
1335    return [{} for _ in range(nrows)]
1336  else:
1337    values = _empty_dict_pylist_from_row_partitions(
1338        row_partitions[1:], row_partitions[0].row_splits()[-1])
1339    splits = row_partitions[0].row_splits()
1340    return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
1341
1342
1343def _pyval_find_struct_keys_and_depth(pyval, keys):
1344  """Finds the keys & depth of nested dictionaries in `pyval`.
1345
1346  Args:
1347    pyval: A nested structure of lists, tuples, and dictionaries.
1348    keys: (output parameter) A set, which will be updated with any keys that are
1349      found in the nested dictionaries.
1350
1351  Returns:
1352    The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does
1353    not contain any dictionaries.
1354  Raises:
1355    ValueError: If dictionaries have inconsistent depth.
1356  """
1357  if isinstance(pyval, dict):
1358    keys.update(pyval.keys())
1359    return 0
1360  elif isinstance(pyval, (list, tuple)):
1361    depth = None
1362    for child in pyval:
1363      child_depth = _pyval_find_struct_keys_and_depth(child, keys)
1364      if child_depth is not None:
1365        if depth is None:
1366          depth = child_depth + 1
1367        elif depth != child_depth + 1:
1368          raise ValueError('Inconsistent depth of dictionaries')
1369    return depth
1370  else:
1371    return None
1372
1373
1374def _pyval_update_fields(pyval, fields, depth):
1375  """Append the field values from `pyval` to `fields`.
1376
1377  Args:
1378    pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s)
1379      should be appended to `fields`.
1380    fields: A dictionary mapping string keys to field values.  Field values
1381      extracted from `pyval` are appended to this dictionary's values.
1382    depth: The depth at which `pyval` should be appended to the field values.
1383  """
1384  if not isinstance(pyval, (dict, list, tuple)):
1385    raise ValueError('Expected dict or nested list/tuple of dict')
1386
1387  for (key, target) in fields.items():
1388    for _ in range(1, depth):
1389      target = target[-1]
1390    target.append(pyval[key] if isinstance(pyval, dict) else [])
1391
1392  if isinstance(pyval, (list, tuple)):
1393    for child in pyval:
1394      _pyval_update_fields(child, fields, depth + 1)
1395
1396
1397def _pyval_empty_list_depth(pyval):
1398  """Find the max depth for nested empty lists.
1399
1400  Args:
1401    pyval: A nested python list.
1402
1403  Returns:
1404    The maximum depth of empty lists in `pyval`, or None if `pyval` contains
1405    anything other than nested empty lists.
1406  """
1407  if isinstance(pyval, list):
1408    if not pyval:
1409      return 1
1410    depths = [_pyval_empty_list_depth(v) for v in pyval]
1411    if any(depth is None for depth in depths):
1412      return None
1413    else:
1414      return max(depths) + 1
1415  else:
1416    return None
1417
1418
1419def _replace_row_partitions(value, new_partitions):
1420  """Updates `value` to use `new_partitions` as its (outer) row partitions.
1421
1422  This is used to ensure that all fields in a `StructuredTensor` use identical
1423  `RowPartition` objects for the shared dimensions.  In particular,
1424  `StructuredTensor.from_fields` first merges all of the row partitions from
1425  any fields, and then replaces the outer row partitions of all fields with
1426  the merged row partitions (using this function).
1427
1428  Args:
1429    value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`.
1430    new_partitions: A list of row-partitions that should be used by `value`.
1431      Must be equivalent to `value`'s current row partitions.
1432
1433  Returns:
1434    A value that is equivalent to `value`, where outer row partitions have been
1435    replaced by `new_partitions`.
1436  """
1437  if isinstance(value, ops.Tensor) or not new_partitions:
1438    return value
1439
1440  elif isinstance(value, ragged_tensor.RaggedTensor):
1441    return ragged_tensor.RaggedTensor._from_row_partition(  # pylint: disable=protected-access
1442        values=_replace_row_partitions(value.values, new_partitions[1:]),
1443        row_partition=new_partitions[0])
1444
1445  else:
1446    assert isinstance(value, StructuredTensor)
1447    new_fields = dict((k, _replace_row_partitions(v, new_partitions))
1448                      for (k, v) in value._fields.items())
1449    return StructuredTensor(
1450        fields=new_fields,
1451        shape=value.shape,
1452        nrows=value.nrows(),
1453        row_partitions=new_partitions +
1454        value.row_partitions[len(new_partitions):],
1455        internal=_structured_tensor_factory_key)
1456
1457
1458def _partition_outer_dimension(value, row_partition):
1459  """Partitions the outer dimension of `value` using `row_partitions`.
1460
1461  Examples:
1462
1463    >>> partition = RowPartition.from_row_lengths([2, 0, 1])
1464    >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition)
1465    <tf.RaggedTensor [[1, 2], [], [3]]>
1466
1467    >>> struct_value = StructuredTensor.from_pyval(
1468    ...     [{'x': 1}, {'x': 2}, {'x': 3}])
1469    >>> _partition_outer_dimension(struct_value, partition)
1470    <StructuredTensor(
1471      fields={
1472        "x": <tf.RaggedTensor [[1, 2], [], [3]]>},
1473      shape=(3, None))>
1474
1475  Args:
1476    value: Tensor, RaggedTensor, or StructuredTensor
1477    row_partition: RowPartition
1478
1479  Returns:
1480    A value with the same type as `value`, where
1481    `result.rank = value.rank + 1`.
1482  """
1483  is_ragged = row_partition.uniform_row_length() is None
1484  if isinstance(value, ops.Tensor) and not is_ragged:
1485    new_shape = array_ops.concat(
1486        [[row_partition.nrows(),
1487          row_partition.uniform_row_length()],
1488         array_ops.shape(value, out_type=row_partition.dtype)[1:]],
1489        axis=0)
1490    return array_ops.reshape(value, new_shape)
1491  elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1492    return ragged_tensor.RaggedTensor._from_row_partition(  # pylint: disable=protected-access
1493        value, row_partition)
1494  else:
1495    assert isinstance(value, StructuredTensor)
1496    nrows = row_partition.static_nrows
1497    ncols = row_partition.static_uniform_row_length
1498    shape = tensor_shape.TensorShape([nrows, ncols]).concatenate(
1499        value.shape[1:])
1500    fields = dict((k, _partition_outer_dimension(v, row_partition))
1501                  for (k, v) in value._fields.items())
1502    return StructuredTensor(
1503        fields,
1504        shape,
1505        row_partition.nrows(), (row_partition,) + value.row_partitions,
1506        internal=_structured_tensor_factory_key)
1507
1508
1509def _merge_dims(value, outer_axis, inner_axis):
1510  """Merges `outer_axis...inner_axis` of `value` into a single dimension."""
1511  assert outer_axis < inner_axis
1512  if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)):
1513    return ragged_tensor.merge_dims(value, outer_axis, inner_axis)
1514  else:
1515    assert isinstance(value, StructuredTensor)
1516
1517    # Build the new fields.
1518    fields = dict((k, _merge_dims(v, outer_axis, inner_axis))
1519                  for (k, v) in value._fields.items())
1520
1521    # Build the new shape.
1522    value_shape = value.shape
1523    shape = (
1524        value_shape[:outer_axis] +
1525        [value_shape[outer_axis:inner_axis].num_elements()] +
1526        value_shape[inner_axis + 1:])
1527
1528    # Build the new row_partitions & nrows
1529    if outer_axis == 0:
1530      if inner_axis == value.shape.rank - 1:
1531        partitions = ()
1532        nrows = value.row_partitions[-1].nvals()
1533      else:
1534        partitions = value.row_partitions[inner_axis:]
1535        nrows = partitions[0].nrows()
1536    else:
1537      # Use tf.gather to merge row_splits from the merged row partitions.
1538      merged_splits = value.row_partitions[outer_axis - 1].row_splits()
1539      for dim in range(outer_axis, inner_axis):
1540        merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(),
1541                                         merged_splits)
1542
1543      partitions = (
1544          value.row_partitions[:outer_axis - 1] +
1545          (RowPartition.from_row_splits(merged_splits),) +
1546          value.row_partitions[inner_axis:])
1547      nrows = partitions[0].nrows()
1548
1549    return StructuredTensor(
1550        fields,
1551        shape,
1552        nrows,
1553        partitions,
1554        internal=_structured_tensor_factory_key)
1555
1556
1557_structured_tensor_factory_key = object()  # unique private object
1558
1559
1560def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]:
1561  """FieldName can be given also as string, this normalizes it to a tuple."""
1562  if isinstance(name, str):
1563    return (name,)
1564  if isinstance(name, list):
1565    return tuple(name)
1566  assert isinstance(name, tuple)
1567  return name
1568
1569
1570def _merge_dims_generic(source, outer, inner):
1571  """Merges outer_axis...inner_axis into a single dimension.
1572
1573  If outer == inner, this is a NOOP. If inner < outer, then this fials.
1574  If inner >= source.shape.rank, then the behavior is undefined.
1575
1576  Args:
1577    source: a tensor, ragged tensor, or structured tensor.
1578    outer: a python int, indicating the first dimension to compress
1579      (must be nonnegative).
1580    inner: a python int, indicating the first dimension to keep (of the tail)
1581      (must be nonnegative).
1582
1583  Returns:
1584    source with outer_axis...inner_axis merged into a single dimension.
1585
1586  """
1587  if isinstance(source, StructuredTensor):
1588    return source.merge_dims(outer, inner)
1589  else:
1590    return ragged_tensor.merge_dims(source, outer, inner)
1591