1# Lint as python3 2# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15# ============================================================================== 16"""Structured Tensors.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import re 23from typing import Callable, Dict, List, Sequence, Tuple, Union 24 25import numpy as np 26 27from tensorflow.python.framework import composite_tensor 28from tensorflow.python.framework import constant_op 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_shape 32from tensorflow.python.framework import tensor_spec 33from tensorflow.python.framework import type_spec 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import check_ops 36from tensorflow.python.ops import control_flow_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops.ragged import ragged_factory_ops 39from tensorflow.python.ops.ragged import ragged_tensor 40from tensorflow.python.ops.ragged.row_partition import RowPartition 41from tensorflow.python.util import compat 42from tensorflow.python.util import nest 43 44 45class StructuredTensor(composite_tensor.CompositeTensor): 46 """A multidimensional collection of structures with the same schema. 47 48 A **`StructuredTensor`** is a multi-dimensional collection of ***structures*** 49 with the same ***schema***, where: 50 51 * A ***schema*** is a collection of fields, each of which has a name and type. 52 * A ***structure*** maps each field in the schema to a tensor value (which 53 could be a nested StructuredTensor). 54 55 As an important special case, a 1D `StructuredTensor` encodes a 2D table, 56 where columns are heterogeneous `Tensor`s, and rows are the aligned elements 57 in each of those `Tensor`s. 58 59 Internally, StructuredTensors use a "field-major" encoding: for each leaf 60 field, there is a single tensor that stores the value of that field for all 61 structures in the `StructuredTensor`. 62 63 ### Examples 64 65 >>> # A scalar StructuredTensor describing a single person. 66 >>> s1 = StructuredTensor.from_pyval( 67 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}) 68 >>> s1.shape 69 TensorShape([]) 70 >>> s1["age"] 71 <tf.Tensor: shape=(), dtype=int32, numpy=82> 72 73 >>> # A vector StructuredTensor describing three people. 74 >>> s2 = StructuredTensor.from_pyval([ 75 ... {"age": 12, "nicknames": ["Josaphine"]}, 76 ... {"age": 82, "nicknames": ["Bob", "Bobby"]}, 77 ... {"age": 42, "nicknames": ["Elmo"]}]) 78 >>> s2.shape 79 TensorShape([3]) 80 >>> s2[0]["age"] 81 <tf.Tensor: shape=(), dtype=int32, numpy=12> 82 83 84 ### Field Paths 85 86 A *field path* is a tuple of field names, specifying the path to a nested 87 field. 88 """ 89 90 #============================================================================= 91 # Common Types 92 #============================================================================= 93 # pylint: disable=invalid-name 94 # Field names work as key, and they can be a sequence to refer to the 95 # sub-levels (embedded) StructuredTensor's. 96 FieldName = Union[str, Sequence[str]] 97 98 # Each field may contain one of the following types of Tensors. 99 FieldValue = Union[ops.Tensor, ragged_tensor.RaggedTensor, 'StructuredTensor'] 100 101 # Function that takes a FieldValue as input and returns the transformed 102 # FieldValue. 103 FieldFn = Callable[[FieldValue], FieldValue] 104 105 # pylint: enable=invalid-name 106 107 #============================================================================= 108 # Constructor & Factory Methods 109 #============================================================================= 110 111 def __init__(self, fields, shape, nrows, row_partitions, internal=False): 112 """Private constructor -- use factory methods to create StructuredTensors. 113 114 This constructor builds a `StructuredTensor` from the given attributes, 115 performing minimal validation. 116 117 Args: 118 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 119 `StructuredTensor`. (This dict is not copied, so the caller must ensure 120 that it does not get mutated via leaked references.) 121 shape: `tf.TensorShape` with statically known rank. 122 nrows: scalar integer `tf.Tensor`, or `None` if `shape.rank==0`. 123 row_partitions: tuple of `RowPartition`s, with length `shape.rank-1`. 124 internal: Private key value, required to ensure that this private 125 constructor is *only* called from the factory methods. 126 """ 127 if internal is not _structured_tensor_factory_key: 128 raise ValueError('StructuredTensor constructor is private; please use ' 129 'one of the factory methods instead (e.g., ' 130 'StructuredTensor.from_fields())') 131 assert isinstance(fields, dict), fields 132 assert isinstance(shape, tensor_shape.TensorShape), shape 133 assert nrows is None or isinstance(nrows, ops.Tensor), nrows 134 assert isinstance(row_partitions, tuple), row_partitions 135 self._fields = fields 136 self._shape = shape 137 self._nrows = nrows 138 self._row_partitions = row_partitions 139 140 @classmethod 141 def from_fields(cls, 142 fields, 143 shape=(), 144 nrows=None, 145 row_partitions=None, 146 validate=False): 147 """Creates a `StructuredTensor` from a dictionary of fields. 148 149 Args: 150 fields: A dictionary mapping from string to `Tensor`, `RaggedTensor`, or 151 `StructuredTensor`, providing the values for individual fields in each 152 structure. If `shape.rank > 0`, then every tensor in `fields` must have 153 the same shape in the first `shape.rank` dimensions; and that shape must 154 be compatible with `shape`; and 155 `result[i1...iN][key] = fields[key][i1...iN]` (where `N==shape.rank`). 156 shape: A `TensorShape`: static information about the shape of the 157 `StructuredTensor`. Must have a known `rank`. Defaults to scalar 158 shape (i.e. `rank=0`). 159 nrows: scalar integer tensor containing the number of rows in this 160 `StructuredTensor`. Should only be specified if `shape.rank > 0`. 161 Default value is inferred from the `fields` values. If `fields` is 162 empty, then this must be specified. 163 row_partitions: A list of `RowPartition`s describing the (possibly ragged) 164 shape of this `StructuredTensor`. Should only be specified if 165 `shape.rank > 1`. Default value is inferred from the `fields` values. 166 If `fields` is empty, then this must be specified. 167 validate: If true, then add runtime validation ops that check that the 168 field values all have compatible shapes in the outer `shape.rank` 169 dimensions. 170 171 Returns: 172 A `StructuredTensor`. 173 174 Examples: 175 176 >>> StructuredTensor.from_fields({'x': 1, 'y': [1, 2, 3]}) 177 <StructuredTensor( 178 fields={ 179 "x": tf.Tensor(1, shape=(), dtype=int32), 180 "y": tf.Tensor([1 2 3], shape=(3,), dtype=int32)}, 181 shape=())> 182 183 >>> StructuredTensor.from_fields({'foo': [1, 2], 'bar': [3, 4]}, 184 ... shape=[2]) 185 <StructuredTensor( 186 fields={ 187 "bar": tf.Tensor([3 4], shape=(2,), dtype=int32), 188 "foo": tf.Tensor([1 2], shape=(2,), dtype=int32)}, 189 shape=(2,))> 190 """ 191 shape = tensor_shape.as_shape(shape) 192 rank = shape.rank 193 if rank is None: 194 raise ValueError("StructuredTensor's shape must have known rank.") 195 if not isinstance(fields, dict): 196 raise TypeError('fields must be a dictionary, got %s' % 197 type(fields).__name__) 198 if rank < 2 and row_partitions: 199 raise ValueError('row_partitions must be None or [] if shape.rank<2') 200 if rank == 0 and nrows is not None: 201 raise ValueError('nrows must be None if shape.rank==0') 202 if row_partitions is not None: 203 row_partitions = tuple(row_partitions) 204 if len(row_partitions) != max(0, rank - 1): 205 raise ValueError('len(row_partitions) must be shape.rank-1') 206 elif rank < 2: 207 row_partitions = () 208 209 fields = dict(fields) # Make a private copy. 210 with ops.name_scope(None, 'StructuredTensor', fields.values()): 211 212 # Validate keys and convert field values to tensors. 213 for key, value in fields.items(): 214 if not isinstance(key, str): 215 raise TypeError('Unexpected type for key in `fields`: %r' % key) 216 if not _FIELD_NAME_RE.match(key): 217 raise ValueError('Field name %r is not currently allowed.' % key) 218 fields[key] = _convert_to_structured_field_value(value) 219 220 # Determine dtype for row_partitions and nrows. 221 shape_dtype = _find_shape_dtype(fields, nrows, row_partitions) 222 if nrows is not None: 223 nrows = ops.convert_to_tensor(nrows, shape_dtype) 224 225 # Get the static TensorShape for this StructuredTensor. 226 if rank > 0: 227 for key, value in fields.items(): 228 if not shape.is_compatible_with(value.shape[:rank]): 229 raise ValueError('Field {} has shape {}, which is incompatible ' 230 'with the shape that was specified or inferred ' 231 'from other fields: {}'.format( 232 key, value.shape[:rank], shape)) 233 shape = shape.merge_with(value.shape[:rank]) 234 235 if rank == 1: 236 # Find a consistent value for `nrows`. 237 static_nrows = tensor_shape.dimension_at_index(shape, 0) 238 for value in fields.values(): 239 nrows, static_nrows = _merge_nrows(nrows, static_nrows, value, 240 shape_dtype, validate) 241 if nrows is None: 242 if static_nrows.value is None: 243 raise ValueError('nrows must be specified if rank==1 ' 244 'and `fields` is empty.') 245 else: 246 nrows = constant_op.constant(static_nrows.value, shape_dtype) 247 248 if rank > 1: 249 # Find a consistent list of RowPartitions. 250 for value in fields.values(): 251 row_partitions = _merge_row_partitions(row_partitions, value, rank, 252 shape_dtype, validate) 253 if row_partitions is None: 254 if not shape.is_fully_defined(): 255 raise ValueError('row_partitions must be specified if rank>1 ' 256 'and `fields` is empty.') 257 else: 258 row_partitions = _row_partitions_for_uniform_shape( 259 np.array(shape.as_list(), dtype=shape_dtype.as_numpy_dtype), 260 shape.rank) 261 assert len(row_partitions) == rank - 1 262 nrows = row_partitions[0].nrows() 263 # Update all field values to use the shared RowPartition objects. 264 fields = dict([(k, _replace_row_partitions(v, row_partitions)) 265 for (k, v) in fields.items()]) 266 267 return cls( 268 fields, 269 shape, 270 nrows, 271 row_partitions, 272 internal=_structured_tensor_factory_key) 273 274 def with_updates(self, 275 updates: Dict[FieldName, Union[FieldValue, FieldFn, None]], 276 validate: bool = False) -> 'StructuredTensor': # pylint: disable=bad-whitespace 277 """Creates a new `StructuredTensor` with the updated fields. 278 279 If this `StructuredTensor` is a scalar, and `k` is the `FieldName` being 280 updated and `v` the new value, then: 281 282 ``` 283 result[k] = v # If (k, v) is in updates and v is a FieldValue 284 result[k] = f(self[k]) # If (k, f) is in updates and f is a FieldFn 285 result[k] = self[k] # If k is in self.field_names but not in updates 286 ``` 287 288 If this `StructuredTensor` has rank `N` and shape `[D1...DN]`, then each 289 FieldValue `v` in `updates` must have shape `[D1...DN, ...]`, that is, 290 prefixed with the same shape as the `StructuredTensor`. Then the resulting 291 `StructuredTensor` will have: 292 293 ``` 294 result[i1...iN][k] = v[i1...iN] # (k, v) in updates 295 result[i1...iN][k] = f(self.field_value(k))[i1...iN] # (k, f) in updates 296 result[i1...iN][k] = self[i1...iN][k] # k not in updates 297 ``` 298 299 Note that `result.shape` is always equal to `self.shape` (but the shapes 300 of nested StructuredTensors may be changed if they are updated with new 301 values). 302 303 Args: 304 updates: A dictionary mapping `FieldName` to either a `FieldValue` to be 305 used to update, or a `FieldFn` that will transform the value for the 306 given `FieldName`. `FieldName` can be a string for a direct field, or a 307 sequence of strings to refer to a nested sub-field. `FieldFn` is a 308 function that takes a `FieldValue` as input and should return a 309 `FieldValue`. All other fields are copied over to the new 310 `StructuredTensor`. New `FieldName` can be given (to add new fields), 311 but only to existing `StructuredTensor`, it won't automatically create 312 new nested structures -- but one can create a whole `StructureTensor` 313 sub-structure and set that into an existing structure. If the new value 314 is set to `None`, it is removed. 315 validate: If true, then add runtime validation ops that check that the 316 field values all have compatible shapes in the outer `shape.rank` 317 dimensions. 318 319 Returns: 320 A `StructuredTensor`. 321 322 Raises: 323 `ValueError`: If the any of the `FieldName` keys points to non-existent 324 sub-structures, if parent and child nodes are updated, if shapes 325 change, if a delete update is given for a non-existant field, or if a 326 `FieldFn` transforming function is given for a `FieldName` that doesn't 327 yet exist. 328 329 Examples: 330 331 >>> shoes_us = StructuredTensor.from_pyval([ 332 ... {"age": 12, "nicknames": ["Josaphine"], 333 ... "shoes": {"sizes": [8.0, 7.5, 7.5]}}, 334 ... {"age": 82, "nicknames": ["Bob", "Bobby"], 335 ... "shoes": {"sizes": [11.0, 11.5, 12.0]}}, 336 ... {"age": 42, "nicknames": ["Elmo"], 337 ... "shoes": {"sizes": [9.0, 9.5, 10.0]}}]) 338 >>> def us_to_europe(t): 339 ... return tf.round(t * 2.54 + 17.0) # Rough approximation. 340 >>> shoe_sizes_key = ("shoes", "sizes") 341 >>> shoes_eu = shoes_us.with_updates({shoe_sizes_key: us_to_europe}) 342 >>> shoes_eu.field_value(shoe_sizes_key) 343 <tf.RaggedTensor [[37.0, 36.0, 36.0], [45.0, 46.0, 47.0], 344 [40.0, 41.0, 42.0]]> 345 """ 346 updates_items = [(_normalize_field_name_to_tuple(name), value) 347 for name, value in updates.items()] 348 349 # Sort by keys and check for updates of both parent and child nodes. 350 updates_items = sorted(updates_items) 351 for i in range(1, len(updates_items)): 352 # Parent of a node would precede node in the sorted order. 353 name = updates_items[i][0] # item[0] is the name, item[1] is the value. 354 prev_name = updates_items[i - 1][0] 355 if name[:len(prev_name)] == prev_name: 356 raise ValueError( 357 '`StructuredTensor.with_updates` does not allow both parent and ' 358 'child nodes to be updated: parent={}, child={}. If needed you can ' 359 'update child nodes in the parent update value.'.format( 360 prev_name, name)) 361 return self._with_updates_impl((), updates_items, validate) 362 363 def _with_updates_impl(self, error_prefix: Tuple[str], # pylint: disable=invalid-sequence-index 364 updates: List[Tuple[FieldName, Union[FieldValue, # pylint: disable=invalid-sequence-index 365 FieldFn]]], 366 validate: bool) -> 'StructuredTensor': 367 """Recursive part of `with_updates` implementation.""" 368 # Get current fields. 369 new_fields = dict(self._fields) 370 371 # Convert field name to string with full path for error messages. 372 def name_fullpath(name: Sequence[str]) -> str: 373 return str(error_prefix + (name,)) 374 375 # Apply value if a function or the value itself. 376 def apply_value(name: str, value: Union['FieldValue', 377 'FieldFn']) -> 'FieldValue': 378 if callable(value): 379 # `value` is actually a transforming function. 380 if name not in new_fields: 381 raise ValueError( 382 '`StructuredTensor.with_updates` cannot update the field {} ' 383 'because a transforming function was given, but that field ' 384 'does not already exist.'.format(name_fullpath(name))) 385 value = value(new_fields[name]) 386 return value 387 388 # Merge updates. 389 for name, value in updates: 390 if not name or not name[0]: 391 raise ValueError( 392 '`StructuredTensor.with_updates` does not allow empty names ' 393 '{}.'.format(name_fullpath(name))) 394 395 if len(name) == 1: 396 name = name[0] 397 if value is None: 398 if name not in new_fields: 399 raise ValueError( 400 '`StructuredTensor.with_updates` cannot delete field ' 401 '{} because it is not present.'.format(name_fullpath(name))) 402 new_fields.pop(name) 403 else: 404 new_fields[name] = apply_value(name, value) 405 else: 406 # Recursive 407 prefix = name[0] 408 suffix = name[1:] 409 if prefix not in new_fields: 410 raise ValueError( 411 '`StructuredTensor.with_updates` cannot create new sub-field ' 412 '{} if parent field {} is not set.'.format( 413 error_prefix + tuple(name), name_fullpath(prefix))) 414 current_value = new_fields[prefix] 415 if not isinstance(current_value, StructuredTensor): 416 raise ValueError( 417 '`StructuredTensor.with_updates` cannot create new sub-field ' 418 '{} if parent structure {} is not a `StructuredTensor` that ' 419 'can contain sub-structures -- it is a `{}`.'.format( 420 error_prefix + tuple(name), name_fullpath(prefix), 421 type(current_value))) 422 one_update = [(suffix, value)] 423 424 # Accessing protected member in recursion. 425 # FutureWork: optimize by aggregating the recursions, instead of 426 # calling one at a time. 427 # pylint: disable=protected-access 428 value = current_value._with_updates_impl(error_prefix + (prefix,), 429 one_update, validate) 430 # pylint: enable=protected-access 431 new_fields[prefix] = value 432 433 # TODO(edloper): When validate=True, only validate the modified fields. 434 try: 435 return StructuredTensor.from_fields( 436 new_fields, 437 shape=self.shape, 438 row_partitions=self._row_partitions, 439 nrows=self._nrows, 440 validate=validate) 441 442 except ValueError as e: 443 msg = '`StructuredTensor.with_updates` failed' 444 if error_prefix: 445 msg = '{} for field {}'.format(msg, error_prefix) 446 raise ValueError('{}: {}'.format(msg, e)) 447 448 def _promote_helper(self, source_path, new_parent_path): 449 """Creates a promoted field without adding it to the structure. 450 451 Args: 452 source_path: the source path in the structured tensor. 453 new_parent_path: the new parent path. Must be a prefix of source_path. 454 455 Returns: 456 a composite tensor of source_path promoted. 457 Raises: 458 ValueError: if the shape of the field is unknown and the right strategy 459 cannot be determined. 460 """ 461 current_field = self.field_value(source_path) 462 new_parent_rank = self.field_value(new_parent_path).rank 463 parent_rank = self.field_value(source_path[:-1]).rank 464 if new_parent_rank == parent_rank: 465 return current_field 466 current_field_rank = current_field.shape.rank 467 if current_field_rank is None: 468 raise ValueError('Cannot determine if dimensions should be merged.') 469 inner_dim = min(parent_rank, current_field_rank - 1) 470 if inner_dim <= new_parent_rank: 471 return current_field 472 return _merge_dims_generic(current_field, new_parent_rank, inner_dim) 473 474 def promote(self, source_path, new_name): 475 """Promotes a field, merging dimensions between grandparent and parent. 476 477 >>> d = [ 478 ... {'docs': [{'tokens':[1, 2]}, {'tokens':[3]}]}, 479 ... {'docs': [{'tokens':[7]}]}] 480 >>> st = StructuredTensor.from_pyval(d) 481 >>> st2 =st.promote(('docs','tokens'), 'docs_tokens') 482 >>> st2[0]['docs_tokens'] 483 <tf.Tensor: shape=(3,), dtype=int32, numpy=array([1, 2, 3], dtype=int32)> 484 >>> st2[1]['docs_tokens'] 485 <tf.Tensor: shape=(1,), dtype=int32, numpy=array([7], dtype=int32)> 486 487 Args: 488 source_path: the path of the field or substructure to promote; must have 489 length at least 2. 490 new_name: the name of the new field (must be a string). 491 492 Returns: 493 a modified structured tensor with the new field as a child of the 494 grandparent of the source_path. 495 496 Raises: 497 ValueError: if source_path is not a list or a tuple or has a length 498 less than two, or new_name is not a string, or the rank 499 of source_path is unknown and it is needed. 500 """ 501 if not isinstance(new_name, str): 502 raise ValueError('new_name is not a string') 503 if not isinstance(source_path, (list, tuple)): 504 raise ValueError('source_path must be a list or tuple') 505 506 if len(source_path) < 2: 507 raise ValueError('source_path must have length at least two') 508 509 grandparent_path = source_path[:-2] 510 new_field = self._promote_helper(source_path, grandparent_path) 511 new_path = grandparent_path + (new_name,) 512 return self.with_updates({new_path: new_field}) 513 514 #============================================================================= 515 # Properties 516 #============================================================================= 517 518 @property 519 def rank(self): 520 """The rank of this StructuredTensor. Guaranteed not to be `None`.""" 521 return self._shape.rank 522 523 @property 524 def shape(self): 525 """The static shape of this StructuredTensor. 526 527 The returned `TensorShape` is guaranteed to have a known rank, but the 528 individual dimension sizes may be unknown. 529 530 Returns: 531 `tf.TensorShape` 532 """ 533 return self._shape 534 535 # TODO(edloper): Make this a func instead of a property? Or make nrows 536 # a property instead of a func? Seems like these should be consistent. 537 @property 538 def row_partitions(self): 539 """A tuple of `RowPartition`s defining the shape of this `StructuredTensor`. 540 541 When `self.rank <= 1`, this tuple will be empty. 542 543 When `self.rank > 1`, these `RowPartitions` define the shape of the 544 `StructuredTensor` by describing how a flat (1D) list of structures can be 545 repeatedly partitioned to form a higher-dimensional object. In particular, 546 the flat list is first partitioned into sublists using `row_partitions[-1]`, 547 and then those sublists are further partitioned using `row_partitions[-2]`, 548 etc. The following examples show the row partitions used to describe 549 several different `StructuredTensor`, each of which contains 8 copies of 550 the same structure (`x`): 551 552 >>> x = {'a': 1, 'b': ['foo', 'bar', 'baz']} # shape = [] (scalar) 553 554 >>> s1 = [[x, x, x, x], [x, x, x, x]] # shape = [2, 4] 555 >>> StructuredTensor.from_pyval(s1).row_partitions 556 (tf.RowPartition(row_splits=tf.Tensor([0 4 8], shape=(3,), 557 dtype=int64)),) 558 559 >>> s2 = [[x, x], [x, x], [x, x], [x, x]] # shape = [4, 2] 560 >>> StructuredTensor.from_pyval(s2).row_partitions 561 (tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,), 562 dtype=int64)),) 563 564 >>> s3 = [[x, x, x], [], [x, x, x, x], [x]] # shape = [2, None] 565 >>> StructuredTensor.from_pyval(s3).row_partitions 566 (tf.RowPartition(row_splits=tf.Tensor([0 3 3 7 8], shape=(5,), 567 dtype=int64)),) 568 569 >>> s4 = [[[x, x], [x, x]], [[x, x], [x, x]]] # shape = [2, 2, 2] 570 >>> StructuredTensor.from_pyval(s4).row_partitions 571 (tf.RowPartition(row_splits=tf.Tensor([0 2 4], shape=(3,), dtype=int64)), 572 tf.RowPartition(row_splits=tf.Tensor([0 2 4 6 8], shape=(5,), 573 dtype=int64))) 574 575 576 >>> s5 = [[[x, x], [x]], [[x, x]], [[x, x], [x]]] # shape = [3, None, None] 577 >>> StructuredTensor.from_pyval(s5).row_partitions 578 (tf.RowPartition(row_splits=tf.Tensor([0 2 3 5], shape=(4,), dtype=int64)), 579 tf.RowPartition(row_splits=tf.Tensor([0 2 3 5 7 8], shape=(6,), 580 dtype=int64))) 581 582 Note that shapes for nested fields (such as `x['b']` in the above example) 583 are not considered part of the shape of a `StructuredTensor`, and are not 584 included in `row_partitions`. 585 586 If this `StructuredTensor` has a ragged shape (i.e., if any of the 587 `row_partitions` is not uniform in size), then all fields will be encoded 588 as either `RaggedTensor`s or `StructuredTensor`s with these `RowPartition`s 589 used to define their outermost `self.rank` dimensions. 590 591 Returns: 592 A `tuple` of `RowPartition` objects with length `self.rank - 1` 593 (or `0` if `self.rank < 2`) 594 595 """ 596 return self._row_partitions 597 598 def nrows(self): 599 """The number of rows in this StructuredTensor (if rank>0). 600 601 This means the length of the outer-most dimension of the StructuredTensor. 602 603 Notice that if `self.rank > 1`, then this equals the number of rows 604 of the first row partition. That is, 605 `self.nrows() == self.row_partitions[0].nrows()`. 606 607 Otherwise `self.nrows()` will be the first dimension of the field values. 608 609 Returns: 610 A scalar integer `Tensor` (or `None` if `self.rank == 0`). 611 """ 612 return self._nrows 613 614 def _is_eager(self): 615 """True if all fields are composed of eager tensors.""" 616 tensors = nest.flatten(self, expand_composites=True) 617 return all(isinstance(t, ops.EagerTensor) for t in tensors) 618 619 #============================================================================= 620 # Encoding 621 #============================================================================= 622 623 def field_names(self): 624 """Returns the string field names for this `StructuredTensor`.""" 625 return tuple(self._fields.keys()) 626 627 def field_value(self, field_name): 628 """Returns the tensor value for the specified field or path. 629 630 If `field_name` is a `string`, then it names a field directly owned by this 631 `StructuredTensor`. If this `StructuredTensor` has shape `[D1...DN]`, then 632 the returned tensor will have shape `[D1...DN, V1...VM]`, where the slice 633 `result[d1...dN]` contains the field value for the structure at 634 `self[d1...dN]`. 635 636 If `field_name` is a `tuple` of `string`, then it specifies a path to a 637 field owned by nested `StructuredTensor`. In particular, 638 `struct.field_value((f1, f2, ..., fN))` is equivalent to 639 `struct.field_value(f1).field_value(f2)....field_value(fN)` 640 641 Args: 642 field_name: `string` or `tuple` of `string`: The field whose values should 643 be returned. 644 645 Returns: 646 `Tensor`, `StructuredTensor`, or `RaggedTensor`. 647 648 Raises: 649 KeyError: If the given field_name is not found. 650 """ 651 if isinstance(field_name, (list, tuple)): 652 value = self 653 for f in field_name: 654 if not isinstance(value, StructuredTensor): 655 raise KeyError('Field path {} not found in {}'.format( 656 field_name, self)) 657 value = value.field_value(f) 658 return value 659 return self._fields[field_name] 660 661 #============================================================================= 662 # Operators 663 #============================================================================= 664 665 # TODO(edloper): Add support for ellipsis and/or newaxis? 666 def __getitem__(self, key): 667 """Returns the specified piece of this StructuredTensor. 668 669 * If `struct_tensor` is scalar (i.e., a single structure), then 670 `struct_tensor[f]` returns the value of field `f` (where `f` must be a 671 string). 672 673 * If `struct_tensor` is non-scalar (i.e., a vector or higher-dimensional 674 tensor of structures), `struct_tensor[i]` selects an element or slice of 675 the tensor using standard Python semantics (e.g., negative values index 676 from the end). `i` may have any of the following types: 677 678 * `int` constant 679 * `string` constant 680 * scalar integer `Tensor` 681 * `slice` containing integer constants and/or scalar integer 682 `Tensor`s 683 684 #### Multidimensional indexing 685 686 `StructuredTensor` supports multidimensional indexing. I.e., `key` may be a 687 `tuple` of values, indexing or slicing multiple dimensions at once. For 688 example, if `people` is a vector of structures, each of which has a vector- 689 valued `names` field, then `people[3, 'names', 0]` is equivalent to 690 `people[3]['names'][0]`; and `people[:, 'names', :]` will return a (possibly 691 ragged) matrix of names, with shape `[num_people, num_names_per_person]`. 692 693 Args: 694 key: Indicates which piece of the StructuredTensor to return. 695 Returns: 696 A `Tensor`, `StructuredTensor`, or `RaggedTensor`. 697 """ 698 if isinstance(key, list): 699 key = tuple(key) 700 elif not isinstance(key, tuple): 701 key = (key,) 702 if not key: 703 return self 704 705 if self._shape.rank == 0: 706 return self._scalar_getitem(key) 707 else: 708 return self._tensor_getitem(key) 709 710 def _scalar_getitem(self, key): 711 if (isinstance(key[0], slice) and key[0].start is None and 712 key[0].stop is None and key[0].step is None): 713 fields = dict((field_name, field_value.__getitem__(key[1:])) 714 for (field_name, field_value) in self._fields.items()) 715 return StructuredTensor.from_fields(fields, self._shape) 716 717 elif not isinstance(key[0], compat.bytes_or_text_types): 718 raise ValueError('Key for indexing a StructuredTensor must be a ' 719 "string or a full slice (':')") 720 721 return self._fields[key[0]].__getitem__(key[1:]) 722 723 def _tensor_getitem(self, key): 724 rank = self._shape.rank 725 if len(key) <= rank: 726 new_fields = dict((field_name, field_value.__getitem__(key)) 727 for (field_name, field_value) in self._fields.items()) 728 result_shape = self.shape.as_list() 729 for d, k in enumerate(key): 730 if isinstance(k, slice): 731 if not (k.start is None and k.stop is None and k.step is None): 732 # TODO(edloper): Better static shape analysis here. 733 result_shape[d] = None 734 elif isinstance(k, (int, ops.Tensor)): 735 result_shape[d] = -1 # mark for deletion 736 elif k is None: 737 raise ValueError('Slicing not supported for tf.newaxis') 738 else: 739 # Ellipsis, tf.newaxis: 740 raise ValueError('Slicing not supported for %r' % k) 741 result_shape = [d for d in result_shape if d != -1] 742 return StructuredTensor.from_fields(new_fields, result_shape) 743 744 else: 745 if not isinstance(key[rank], compat.bytes_or_text_types): 746 # TODO(edloper): Also support full slice here? 747 raise ValueError('Key for indexing a StructuredTensor must be a string') 748 return self._fields[key[rank]].__getitem__(key[:rank] + key[rank + 1:]) 749 750 def __repr__(self): 751 fields = sorted(self._fields.items()) 752 fields = ((k, str(v).replace('\n', '\n ')) for k, v in fields) 753 fields = ('"{}": {}'.format(k, v) for k, v in fields) 754 dict_repr = ',\n '.join(fields) 755 return ( 756 '<StructuredTensor(\n' 757 ' fields={\n' 758 ' %s},\n' 759 ' shape=%s)>' % (dict_repr, self._shape)) 760 761 #============================================================================= 762 # Conversion 763 #============================================================================= 764 765 def to_pyval(self): 766 """Returns this StructuredTensor as a nested Python dict or list of dicts. 767 768 Converts this `StructuredTensor` to a nested python value: 769 770 * `StructTensors` with `rank=0` are converted into a dictionary, with an 771 entry for each field. Field names are used as keys and field values are 772 converted to python values. In particular: 773 774 * Scalar Tensor fields are converted to simple values (such as 775 `int` or `float` or `string`) 776 * Non-scalar Tensor fields and RaggedTensor fields are converted to 777 nested lists of simple values. 778 * StructuredTensor fields are converted recursively using `to_pyval`. 779 780 * `StructTensors` with `rank>0` are converted to nested python `list`s, 781 containing one dictionary for each structure (where each structure's 782 dictionary is defined as described above). 783 784 Requires that all fields are Eager tensors. 785 786 >>> StructuredTensor.from_fields( 787 ... {'a': [1, 2, 3]}, [3]).to_pyval() 788 [{'a': 1}, {'a': 2}, {'a': 3}] 789 790 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 791 792 Returns: 793 A nested Python dict or list of dicts. 794 """ 795 if not self._is_eager(): 796 raise ValueError( 797 'StructuredTensor.to_pyval() is only supported in eager mode.') 798 799 # Convert each field value to a nested list. 800 result = {} 801 for (key, value) in self._fields.items(): 802 if isinstance(value, ops.EagerTensor): 803 value = value.numpy() 804 if isinstance(value, np.ndarray): 805 value = value.tolist() 806 elif isinstance(value, ragged_tensor.RaggedTensor): 807 value = value.to_list() 808 elif isinstance(value, StructuredTensor): 809 value = value.to_pyval() 810 # TODO(edloper): Throw an exception if value is an unexpected type. 811 result[key] = value 812 813 # If rank>0, then re-group each value from dict-of-list to list-of-dict. 814 if len(self._shape) > 0: # pylint: disable=g-explicit-length-test 815 if not result: # special-case for StructuredTensors w/ no fields. 816 return _empty_dict_pylist_from_row_partitions(self._row_partitions, 817 self._nrows) 818 return _pyval_field_major_to_node_major( 819 list(result.keys()), list(result.values()), self._shape.rank) 820 else: 821 return result 822 823 @classmethod 824 def from_pyval(cls, pyval, typespec=None): 825 """Constructs a StructuredTensor from a nested Python structure. 826 827 >>> StructuredTensor.from_pyval( 828 ... {'a': [1, 2, 3], 'b': [[4, 5], [6, 7]]}) 829 <StructuredTensor( 830 fields={ 831 "a": tf.Tensor([1 2 3], shape=(3,), dtype=int32), 832 "b": <tf.RaggedTensor [[4, 5], [6, 7]]>}, 833 shape=())> 834 835 Note that `StructuredTensor.from_pyval(pyval).to_pyval() == pyval`. 836 837 Args: 838 pyval: The nested Python structure that should be used to create the new 839 `StructuredTensor`. 840 typespec: A `StructuredTensorSpec` specifying the expected type for each 841 field. If not specified, then all nested dictionaries are turned into 842 StructuredTensors, and all nested lists are turned into Tensors (if 843 rank<2) or RaggedTensors (if rank>=2). 844 845 Returns: 846 A `StructuredTensor`. 847 """ 848 if isinstance(pyval, dict): 849 return cls._from_pydict(pyval, typespec) 850 elif isinstance(pyval, (list, tuple)): 851 keys = set() 852 rank = _pyval_find_struct_keys_and_depth(pyval, keys) 853 if rank is not None: 854 return cls._from_pylist_of_dict(pyval, keys, rank, typespec) 855 else: 856 return cls._from_pylist_of_value(pyval, typespec) 857 else: 858 return cls._from_pyscalar(pyval, typespec) 859 860 @classmethod 861 def _from_pydict(cls, pyval, typespec): 862 """Converts python dictionary `pyval` to a StructuredTensor with rank=0.""" 863 if typespec is None: 864 fields = dict((k, cls.from_pyval(v)) for (k, v) in pyval.items()) 865 else: 866 spec_shape = typespec._shape # pylint: disable=protected-access 867 field_specs = typespec._field_specs # pylint: disable=protected-access 868 if not (isinstance(typespec, StructuredTensorSpec) and 869 spec_shape.rank == 0 and set(pyval) == set(field_specs)): 870 raise ValueError('Value does not match typespec: %r vs %r' % 871 (pyval, typespec)) 872 fields = dict( 873 (k, cls.from_pyval(v, field_specs[k])) for (k, v) in pyval.items()) 874 return StructuredTensor.from_fields(fields=fields, shape=(), validate=False) 875 876 @classmethod 877 def _from_pylist_of_dict(cls, pyval, keys, rank, typespec): 878 """Converts python list `pyval` to a StructuredTensor with rank>1.""" 879 fields = dict((key, []) for key in keys) 880 for child in pyval: 881 _pyval_update_fields(child, fields, 1) 882 if typespec is None: 883 shape = tensor_shape.TensorShape([None] * rank) 884 for (key, target) in fields.items(): 885 fields[key] = cls.from_pyval(target) 886 else: 887 field_specs = typespec._field_specs # pylint: disable=protected-access 888 if ((not isinstance(typespec, StructuredTensorSpec)) or 889 (set(fields) - set(field_specs))): 890 raise ValueError('Value does not match typespec: %r vs %r' % 891 (pyval, typespec)) 892 shape = typespec._shape 893 if shape.rank < rank: 894 raise ValueError('Value does not match typespec (rank mismatch): ' 895 '%r vs %r' % (pyval, typespec)) 896 for (key, spec) in field_specs.items(): 897 fields[key] = cls.from_pyval(fields.get(key, []), spec) 898 return StructuredTensor.from_fields( 899 fields=fields, shape=shape, validate=False) 900 901 @classmethod 902 def _from_pylist_of_value(cls, pyval, typespec): 903 """Converts python list `pyval` to a Tensor or RaggedTensor with rank>1.""" 904 if typespec is None: 905 return ragged_factory_ops.constant(pyval) 906 elif isinstance(typespec, tensor_spec.TensorSpec): 907 result = constant_op.constant(pyval, typespec.dtype) 908 if not typespec.shape.is_compatible_with(result.shape): 909 raise ValueError('Value does not match typespec: %r vs %r' % 910 (typespec, pyval)) 911 return result 912 elif isinstance(typespec, ragged_tensor.RaggedTensorSpec): 913 # pylint: disable=protected-access 914 return ragged_factory_ops.constant( 915 pyval, 916 dtype=typespec._dtype, 917 ragged_rank=typespec._ragged_rank, 918 row_splits_dtype=typespec._row_splits_dtype, 919 inner_shape=typespec._shape[typespec._ragged_rank + 1:]) 920 elif isinstance(typespec, StructuredTensorSpec): 921 empty_rank = _pyval_empty_list_depth(pyval) 922 if empty_rank is None: 923 raise ValueError('Value does not match typespec: %r vs %r' % 924 (typespec, pyval)) 925 else: 926 return cls._from_pylist_of_dict(pyval, set(), empty_rank, typespec) 927 else: 928 raise ValueError('Value does not match typespec: %r vs %r' % 929 (typespec, pyval)) 930 931 @classmethod 932 def _from_pyscalar(cls, pyval, typespec): 933 """Converts python scalar value `pyval` to a Tensor.""" 934 if typespec is None: 935 return constant_op.constant(pyval) 936 else: 937 if not (isinstance(typespec, tensor_spec.TensorSpec) and 938 typespec.shape.rank == 0): 939 raise ValueError('Value does not match typespec: %r vs %r' % 940 (typespec, pyval)) 941 # TODO(edloper): Check that typespec.shape matches. 942 return constant_op.constant(pyval, typespec.dtype) 943 944 #============================================================================= 945 # Transforms 946 #============================================================================= 947 948 # TODO(edloper): Add a 'validate' option here? 949 # TODO(edloper): Unify nomenclature with RaggedTensor. Should RaggedTensor 950 # have a partition_outer_dimension method? 951 def partition_outer_dimension(self, row_partition): 952 """Partitions the outer dimension of this StructuredTensor. 953 954 Returns a new `StructuredTensor` with the same values as `self`, where 955 the outer dimension is partitioned into two (possibly ragged) dimensions. 956 Requires that this StructuredTensor have an outer dimension (i.e., 957 `self.shape.rank > 0`). 958 959 >>> st = StructuredTensor.from_pyval( 960 ... [{'foo': 12}, {'foo': 33}, {'foo': 99}]) 961 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 962 >>> st.partition_outer_dimension(partition) 963 <StructuredTensor( 964 fields={ 965 "foo": <tf.RaggedTensor [[12, 33], [], [99]]>}, 966 shape=(3, None))> 967 968 Args: 969 row_partition: A `RowPartition`. 970 971 Returns: 972 A `StructuredTensor` with rank `values.rank + 1`. 973 """ 974 if not isinstance(row_partition, RowPartition): 975 raise TypeError('row_partition must be a RowPartition.') 976 if self.shape.rank == 0: 977 raise ValueError('Shape %s must have rank at least 1' % self.shape) 978 return _partition_outer_dimension(self, row_partition) 979 980 def merge_dims(self, outer_axis, inner_axis): 981 """Merges outer_axis...inner_axis into a single dimension. 982 983 Returns a copy of this RaggedTensor with the specified range of dimensions 984 flattened into a single dimension, with elements in row-major order. 985 986 >>> st = StructuredTensor.from_pyval( 987 ... [[{'foo': 12}, {'foo': 33}], [], [{'foo': 99}]]) 988 >>> st.merge_dims(0, 1) 989 <StructuredTensor( 990 fields={ 991 "foo": tf.Tensor([12 33 99], shape=(3,), dtype=int32)}, 992 shape=(3,))> 993 994 Args: 995 outer_axis: `int`: The first dimension in the range of dimensions to 996 merge. May be negative (to index from the last dimension). 997 inner_axis: `int`: The last dimension in the range of dimensions to merge. 998 May be negative (to index from the last dimension). 999 1000 Returns: 1001 A copy of this tensor, with the specified dimensions merged into a 1002 single dimension. The shape of the returned tensor will be 1003 `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N` 1004 is the total number of slices in the merged dimensions. 1005 """ 1006 outer_axis = array_ops.get_positive_axis( 1007 outer_axis, 1008 self.shape.rank, 1009 axis_name='outer_axis', 1010 ndims_name='rank(self)') 1011 inner_axis = array_ops.get_positive_axis( 1012 inner_axis, 1013 self.shape.rank, 1014 axis_name='inner_axis', 1015 ndims_name='rank(self)') 1016 if not outer_axis <= inner_axis: 1017 raise ValueError('Expected outer_axis (%d) to be less than or equal to ' 1018 'inner_axis (%d)' % (outer_axis, inner_axis)) 1019 return _merge_dims(self, outer_axis, inner_axis) 1020 1021 #============================================================================= 1022 # Composite Tensor 1023 #============================================================================= 1024 1025 @property 1026 def _type_spec(self): 1027 return StructuredTensorSpec.from_value(self) 1028 1029 1030class StructuredTensorSpec(type_spec.BatchableTypeSpec): 1031 """Type specification for `StructuredTensor`s.""" 1032 1033 __slots__ = ['_shape', '_field_specs'] 1034 1035 def __init__(self, shape, field_specs): 1036 """Build a type specification for a StructuredTensor. 1037 1038 Args: 1039 shape: The shape of the StructuredTensor. shape.rank must not be None. 1040 field_specs: A dictionary mapping from field name to TypeSpec, specifying 1041 the tensor type used to encode each field. These TypeSpecs should 1042 specify the type of the entire field (including outer dimensions which 1043 correspond to `shape`). For example, if `shape=[2, 3]`, and field 'x' 1044 contains an int32 vector of size `10` for each structure, then 1045 `field_specs['x']` should be `tf.TensorSpec([2, 3, 10], tf.int32)`. 1046 """ 1047 shape = tensor_shape.as_shape(shape) 1048 1049 # Perform a few sanity checks on the inputs. 1050 if shape.rank is None: 1051 raise TypeError("StructuredTensor's shape must have known rank.") 1052 if not isinstance(field_specs, dict): 1053 raise TypeError('field_specs must be a dictionary.') 1054 for key, value in field_specs.items(): 1055 if not isinstance(key, str): 1056 raise TypeError('field_specs must be a dictionary with string keys.') 1057 if not isinstance(value, (StructuredTensorSpec, tensor_spec.TensorSpec, 1058 ragged_tensor.RaggedTensorSpec)): 1059 raise TypeError('field_specs must be a dictionary with ' 1060 'TypeSpec values.') 1061 1062 self._shape = shape 1063 self._field_specs = dict(field_specs) 1064 1065 @property 1066 def value_type(self): 1067 return StructuredTensor 1068 1069 def _to_components(self, value): 1070 return value._fields 1071 1072 def _from_components(self, components): 1073 return StructuredTensor.from_fields(components, self._shape, validate=False) 1074 1075 @property 1076 def _component_specs(self): 1077 return self._field_specs 1078 1079 @classmethod 1080 def from_value(cls, value): 1081 field_specs = dict((k, type_spec.type_spec_from_value(v)) 1082 for (k, v) in value._fields.items()) 1083 return cls(value.shape, field_specs) 1084 1085 def _serialize(self): 1086 return (self._shape, self._field_specs) 1087 1088 def _batch(self, batch_size): 1089 # pylint: disable=protected-access 1090 return StructuredTensorSpec( 1091 tensor_shape.TensorShape([batch_size]).concatenate(self._shape), 1092 dict((k, v._batch(batch_size)) for (k, v) in self._field_specs.items())) 1093 1094 def _unbatch(self): 1095 # pylint: disable=protected-access 1096 return StructuredTensorSpec( 1097 self._shape[1:], 1098 dict((k, v._unbatch()) for (k, v) in self._field_specs.items())) 1099 1100 @property 1101 def _flat_tensor_specs(self): 1102 # pylint: disable=protected-access 1103 result = [] 1104 for _, field_spec in sorted(self._field_specs.items(), key=lambda t: t[0]): 1105 result.extend(field_spec._flat_tensor_specs) 1106 return result 1107 1108 def _to_tensor_list(self, value): 1109 return self._to_tensor_list_internal(value, batched=False) 1110 1111 def _to_batched_tensor_list(self, value): 1112 return self._to_tensor_list_internal(value, batched=True) 1113 1114 def _from_compatible_tensor_list(self, tensor_list): 1115 # pylint: disable=protected-access 1116 fields = {} 1117 pos = 0 1118 for field_name, field_spec in sorted( 1119 self._field_specs.items(), key=lambda t: t[0]): 1120 num_tensors_for_field = len(field_spec._flat_tensor_specs) 1121 field_tensors = tensor_list[pos:pos + num_tensors_for_field] 1122 fields[field_name] = field_spec._from_compatible_tensor_list( 1123 field_tensors) 1124 pos += num_tensors_for_field 1125 return StructuredTensor.from_fields(fields, self._shape) 1126 1127 def _to_tensor_list_internal(self, value, batched): 1128 """Returns a dict whose entries are each field's (batched) tensor_list. 1129 1130 If a field is a StructuredTensor, then its entry will be a dict, 1131 recursively. 1132 1133 Args: 1134 value: A StructuredTensor (conforming to `self`). 1135 batched: A boolean. if True, produce `batched_tensor_list` for each field 1136 otherwise produce `tensor_list`. 1137 Returns: 1138 A dict. 1139 """ 1140 result = [] 1141 for field_name, field_spec in sorted( 1142 self._field_specs.items(), key=lambda t: t[0]): 1143 # pylint: disable=protected-access 1144 field_value = value._fields[field_name] 1145 if batched: 1146 result.extend(field_spec._to_batched_tensor_list(field_value)) 1147 else: 1148 result.extend(field_spec._to_tensor_list(field_value)) 1149 1150 return result 1151 1152# Regular expression used to determine whether a string is a valid field name. 1153# Note: we plan to relax (or possibly eliminate) this in the future; you 1154# should not rely on the fact that some field names are currently disallowed. 1155_FIELD_NAME_RE = re.compile('^[a-zA-Z][a-zA-Z0-9_]*$') 1156 1157 1158#============================================================================= 1159# Helper funtions 1160#============================================================================= 1161# TODO(edloper): Move some of these helpers to row_partition.py? 1162 1163 1164def _convert_to_structured_field_value(value): 1165 """Converts `value` to a Tensor, RaggedTensor, or StructuredTensor.""" 1166 if isinstance(value, 1167 (ops.Tensor, ragged_tensor.RaggedTensor, StructuredTensor)): 1168 return value 1169 elif ragged_tensor.is_ragged(value): 1170 return ragged_tensor.convert_to_tensor_or_ragged_tensor(value) 1171 else: 1172 try: 1173 return ops.convert_to_tensor(value) 1174 except (ValueError, TypeError): 1175 raise TypeError('Unexpected type for value in `fields`: %r' % value) 1176 1177 1178def _find_shape_dtype(fields, nrows, row_partitions): 1179 """Return a consistent dtype for fields, nrows, & row_partitions.""" 1180 shape_dtypes = set() 1181 for value in fields.values(): 1182 if isinstance(value, ragged_tensor.RaggedTensor): 1183 shape_dtypes.add(value.row_splits.dtype) 1184 elif isinstance(value, StructuredTensor) and value.rank > 0: 1185 shape_dtypes.add(value.nrows().dtype) 1186 if isinstance(nrows, ops.Tensor): 1187 shape_dtypes.add(nrows.dtype) 1188 if row_partitions is not None: 1189 for partition in row_partitions: 1190 shape_dtypes.add(partition.dtype) 1191 if len(shape_dtypes) > 1: 1192 raise ValueError('field values have incompatible row_partition dtypes.') 1193 elif shape_dtypes: 1194 return shape_dtypes.pop() 1195 else: 1196 return dtypes.int64 1197 1198 1199def _merge_nrows(nrows, static_nrows, value, dtype, validate): 1200 """Merges `nrows` with `nrows(value)`. 1201 1202 Checks that `value` has the expected number of rows (`nrows`), and returns 1203 `nrows`. If `validate` is true, then add validation ops that check that 1204 the `nrows` values match. 1205 1206 Args: 1207 nrows: scalar integer Tensor. 1208 static_nrows: tf.Dimension: static value of nrows, if known. 1209 value: Tensor or RaggedTensor or StructuredTensor 1210 dtype: dtype for `nrows`. 1211 validate: bool -- whether to add validation ops. 1212 1213 Returns: 1214 A tuple `(nrows, static_nrows)`. 1215 """ 1216 static_value_nrows = tensor_shape.dimension_at_index(value.shape, 0) 1217 if isinstance(value, ops.Tensor): 1218 value_nrows = array_ops.shape(value, out_type=dtype)[0] 1219 else: 1220 value_nrows = value.nrows() 1221 if nrows is None: 1222 nrows = value_nrows 1223 elif (static_value_nrows.value is not None and 1224 static_nrows.value is not None): 1225 if not static_value_nrows.is_compatible_with(static_nrows): 1226 raise ValueError('fields have incompatible nrows') 1227 nrows = value_nrows # No need to add an assertion op. 1228 elif validate: 1229 nrows = control_flow_ops.with_dependencies([ 1230 check_ops.assert_equal(nrows, value_nrows, 1231 message='fields have incompatible nrows') 1232 ], nrows) 1233 return nrows, static_nrows.merge_with(static_value_nrows) 1234 1235 1236def _merge_row_partitions(row_partitions, value, rank, dtype, validate): 1237 """Merges `row_partitions` with `row_partitions(value)`.""" 1238 if isinstance(value, ops.Tensor): 1239 value_row_partitions = _row_partitions_for_tensor(value, rank, dtype) 1240 1241 elif isinstance(value, ragged_tensor.RaggedTensor): 1242 value_row_partitions = _row_partitions_for_ragged_tensor(value, rank, dtype) 1243 1244 else: 1245 assert isinstance(value, StructuredTensor), type(value) 1246 value_row_partitions = value.row_partitions[:rank - 1] 1247 1248 assert len(value_row_partitions) == rank - 1 1249 if row_partitions is None: 1250 return tuple(value_row_partitions) 1251 else: 1252 return tuple([ 1253 p1.merge_precomputed_encodings(p2, validate) 1254 for (p1, p2) in zip(row_partitions, value_row_partitions) 1255 ]) 1256 1257 1258def _row_partitions_for_tensor(value, rank, dtype): 1259 """Returns the row partitions for a tf.Tensor.""" 1260 shape = array_ops.shape(value, out_type=dtype) 1261 return _row_partitions_for_uniform_shape(shape, rank) 1262 1263 1264def _row_partitions_for_ragged_tensor(value, rank, dtype): 1265 """Returns the row partitions for a tf.RaggedTensor.""" 1266 assert rank > 1 1267 value_row_partitions = value._nested_row_partitions[:rank - 1] # pylint: disable=protected-access 1268 if len(value_row_partitions) < (rank - 1): 1269 value_row_partitions += _row_partitions_for_tensor( 1270 value.flat_values, rank - len(value_row_partitions), dtype) 1271 assert len(value_row_partitions) == rank - 1 1272 return value_row_partitions 1273 1274 1275def _row_partitions_for_uniform_shape(shape, rank): 1276 """Returns row partitions for the given shape Tensor. 1277 1278 Args: 1279 shape: A vector describing a uniform shape. 1280 rank: The number of dimensions to generate row partitions for 1281 1282 Returns: 1283 A list of (rank-1) `RowPartition`s with uniform row length. 1284 """ 1285 shape_cumprod = math_ops.cumprod(shape[:rank]) 1286 # pylint: disable=g-complex-comprehension 1287 return tuple([ 1288 RowPartition.from_uniform_row_length( 1289 uniform_row_length=shape[i + 1], 1290 nvals=shape_cumprod[i + 1], 1291 nrows=shape_cumprod[i]) for i in range(rank - 1) 1292 ]) 1293 1294 1295def _pyval_field_major_to_node_major(keys, values, depth): 1296 """Regroup each field (k, v) from dict-of-list to list-of-dict. 1297 1298 Given a "field-major" encoding of the StructuredTensor (which maps each key to 1299 a single nested list containing the values for all structs), return a 1300 corresponding "node-major" encoding, consisting of a nested list of dicts. 1301 1302 Args: 1303 keys: The field names (list of string). Must not be empty. 1304 values: The field values (list of python values). Must have the same length 1305 as `keys`. 1306 depth: The list depth at which dictionaries should be created. 1307 1308 Returns: 1309 A nested list of dict, with depth `depth`. 1310 """ 1311 assert keys 1312 if depth == 0: 1313 return dict(zip(keys, values)) 1314 nvals = len(values[0]) 1315 assert all(nvals == len(values[i]) for i in range(1, len(values))) 1316 return [ 1317 _pyval_field_major_to_node_major(keys, value_slice, depth - 1) 1318 for value_slice in zip(*values) 1319 ] 1320 1321 1322def _empty_dict_pylist_from_row_partitions(row_partitions, nrows): 1323 """Returns a python list of empty dicts from the given row partitions. 1324 1325 Args: 1326 row_partitions: The row-partitions describing the ragged shape of the 1327 result. 1328 nrows: The number of rows in the outermost row-partition. (Or if 1329 `len(row_partitions)==0`, then the number of empty dicts to return.) 1330 1331 Returns: 1332 A nested python list whose leaves (if any) are empty python dicts. 1333 """ 1334 if not row_partitions: 1335 return [{} for _ in range(nrows)] 1336 else: 1337 values = _empty_dict_pylist_from_row_partitions( 1338 row_partitions[1:], row_partitions[0].row_splits()[-1]) 1339 splits = row_partitions[0].row_splits() 1340 return [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)] 1341 1342 1343def _pyval_find_struct_keys_and_depth(pyval, keys): 1344 """Finds the keys & depth of nested dictionaries in `pyval`. 1345 1346 Args: 1347 pyval: A nested structure of lists, tuples, and dictionaries. 1348 keys: (output parameter) A set, which will be updated with any keys that are 1349 found in the nested dictionaries. 1350 1351 Returns: 1352 The nesting depth of dictionaries in `pyval`, or `None` if `pyval` does 1353 not contain any dictionaries. 1354 Raises: 1355 ValueError: If dictionaries have inconsistent depth. 1356 """ 1357 if isinstance(pyval, dict): 1358 keys.update(pyval.keys()) 1359 return 0 1360 elif isinstance(pyval, (list, tuple)): 1361 depth = None 1362 for child in pyval: 1363 child_depth = _pyval_find_struct_keys_and_depth(child, keys) 1364 if child_depth is not None: 1365 if depth is None: 1366 depth = child_depth + 1 1367 elif depth != child_depth + 1: 1368 raise ValueError('Inconsistent depth of dictionaries') 1369 return depth 1370 else: 1371 return None 1372 1373 1374def _pyval_update_fields(pyval, fields, depth): 1375 """Append the field values from `pyval` to `fields`. 1376 1377 Args: 1378 pyval: A python `dict`, or nested list/tuple of `dict`, whose value(s) 1379 should be appended to `fields`. 1380 fields: A dictionary mapping string keys to field values. Field values 1381 extracted from `pyval` are appended to this dictionary's values. 1382 depth: The depth at which `pyval` should be appended to the field values. 1383 """ 1384 if not isinstance(pyval, (dict, list, tuple)): 1385 raise ValueError('Expected dict or nested list/tuple of dict') 1386 1387 for (key, target) in fields.items(): 1388 for _ in range(1, depth): 1389 target = target[-1] 1390 target.append(pyval[key] if isinstance(pyval, dict) else []) 1391 1392 if isinstance(pyval, (list, tuple)): 1393 for child in pyval: 1394 _pyval_update_fields(child, fields, depth + 1) 1395 1396 1397def _pyval_empty_list_depth(pyval): 1398 """Find the max depth for nested empty lists. 1399 1400 Args: 1401 pyval: A nested python list. 1402 1403 Returns: 1404 The maximum depth of empty lists in `pyval`, or None if `pyval` contains 1405 anything other than nested empty lists. 1406 """ 1407 if isinstance(pyval, list): 1408 if not pyval: 1409 return 1 1410 depths = [_pyval_empty_list_depth(v) for v in pyval] 1411 if any(depth is None for depth in depths): 1412 return None 1413 else: 1414 return max(depths) + 1 1415 else: 1416 return None 1417 1418 1419def _replace_row_partitions(value, new_partitions): 1420 """Updates `value` to use `new_partitions` as its (outer) row partitions. 1421 1422 This is used to ensure that all fields in a `StructuredTensor` use identical 1423 `RowPartition` objects for the shared dimensions. In particular, 1424 `StructuredTensor.from_fields` first merges all of the row partitions from 1425 any fields, and then replaces the outer row partitions of all fields with 1426 the merged row partitions (using this function). 1427 1428 Args: 1429 value: A `Tensor`, `RaggedTensor`, or `StructuredTensor`. 1430 new_partitions: A list of row-partitions that should be used by `value`. 1431 Must be equivalent to `value`'s current row partitions. 1432 1433 Returns: 1434 A value that is equivalent to `value`, where outer row partitions have been 1435 replaced by `new_partitions`. 1436 """ 1437 if isinstance(value, ops.Tensor) or not new_partitions: 1438 return value 1439 1440 elif isinstance(value, ragged_tensor.RaggedTensor): 1441 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 1442 values=_replace_row_partitions(value.values, new_partitions[1:]), 1443 row_partition=new_partitions[0]) 1444 1445 else: 1446 assert isinstance(value, StructuredTensor) 1447 new_fields = dict((k, _replace_row_partitions(v, new_partitions)) 1448 for (k, v) in value._fields.items()) 1449 return StructuredTensor( 1450 fields=new_fields, 1451 shape=value.shape, 1452 nrows=value.nrows(), 1453 row_partitions=new_partitions + 1454 value.row_partitions[len(new_partitions):], 1455 internal=_structured_tensor_factory_key) 1456 1457 1458def _partition_outer_dimension(value, row_partition): 1459 """Partitions the outer dimension of `value` using `row_partitions`. 1460 1461 Examples: 1462 1463 >>> partition = RowPartition.from_row_lengths([2, 0, 1]) 1464 >>> _partition_outer_dimension(tf.constant([1, 2, 3]), partition) 1465 <tf.RaggedTensor [[1, 2], [], [3]]> 1466 1467 >>> struct_value = StructuredTensor.from_pyval( 1468 ... [{'x': 1}, {'x': 2}, {'x': 3}]) 1469 >>> _partition_outer_dimension(struct_value, partition) 1470 <StructuredTensor( 1471 fields={ 1472 "x": <tf.RaggedTensor [[1, 2], [], [3]]>}, 1473 shape=(3, None))> 1474 1475 Args: 1476 value: Tensor, RaggedTensor, or StructuredTensor 1477 row_partition: RowPartition 1478 1479 Returns: 1480 A value with the same type as `value`, where 1481 `result.rank = value.rank + 1`. 1482 """ 1483 is_ragged = row_partition.uniform_row_length() is None 1484 if isinstance(value, ops.Tensor) and not is_ragged: 1485 new_shape = array_ops.concat( 1486 [[row_partition.nrows(), 1487 row_partition.uniform_row_length()], 1488 array_ops.shape(value, out_type=row_partition.dtype)[1:]], 1489 axis=0) 1490 return array_ops.reshape(value, new_shape) 1491 elif isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 1492 return ragged_tensor.RaggedTensor._from_row_partition( # pylint: disable=protected-access 1493 value, row_partition) 1494 else: 1495 assert isinstance(value, StructuredTensor) 1496 nrows = row_partition.static_nrows 1497 ncols = row_partition.static_uniform_row_length 1498 shape = tensor_shape.TensorShape([nrows, ncols]).concatenate( 1499 value.shape[1:]) 1500 fields = dict((k, _partition_outer_dimension(v, row_partition)) 1501 for (k, v) in value._fields.items()) 1502 return StructuredTensor( 1503 fields, 1504 shape, 1505 row_partition.nrows(), (row_partition,) + value.row_partitions, 1506 internal=_structured_tensor_factory_key) 1507 1508 1509def _merge_dims(value, outer_axis, inner_axis): 1510 """Merges `outer_axis...inner_axis` of `value` into a single dimension.""" 1511 assert outer_axis < inner_axis 1512 if isinstance(value, (ops.Tensor, ragged_tensor.RaggedTensor)): 1513 return ragged_tensor.merge_dims(value, outer_axis, inner_axis) 1514 else: 1515 assert isinstance(value, StructuredTensor) 1516 1517 # Build the new fields. 1518 fields = dict((k, _merge_dims(v, outer_axis, inner_axis)) 1519 for (k, v) in value._fields.items()) 1520 1521 # Build the new shape. 1522 value_shape = value.shape 1523 shape = ( 1524 value_shape[:outer_axis] + 1525 [value_shape[outer_axis:inner_axis].num_elements()] + 1526 value_shape[inner_axis + 1:]) 1527 1528 # Build the new row_partitions & nrows 1529 if outer_axis == 0: 1530 if inner_axis == value.shape.rank - 1: 1531 partitions = () 1532 nrows = value.row_partitions[-1].nvals() 1533 else: 1534 partitions = value.row_partitions[inner_axis:] 1535 nrows = partitions[0].nrows() 1536 else: 1537 # Use tf.gather to merge row_splits from the merged row partitions. 1538 merged_splits = value.row_partitions[outer_axis - 1].row_splits() 1539 for dim in range(outer_axis, inner_axis): 1540 merged_splits = array_ops.gather(value.row_partitions[dim].row_splits(), 1541 merged_splits) 1542 1543 partitions = ( 1544 value.row_partitions[:outer_axis - 1] + 1545 (RowPartition.from_row_splits(merged_splits),) + 1546 value.row_partitions[inner_axis:]) 1547 nrows = partitions[0].nrows() 1548 1549 return StructuredTensor( 1550 fields, 1551 shape, 1552 nrows, 1553 partitions, 1554 internal=_structured_tensor_factory_key) 1555 1556 1557_structured_tensor_factory_key = object() # unique private object 1558 1559 1560def _normalize_field_name_to_tuple(name: 'FieldName') -> Sequence[str]: 1561 """FieldName can be given also as string, this normalizes it to a tuple.""" 1562 if isinstance(name, str): 1563 return (name,) 1564 if isinstance(name, list): 1565 return tuple(name) 1566 assert isinstance(name, tuple) 1567 return name 1568 1569 1570def _merge_dims_generic(source, outer, inner): 1571 """Merges outer_axis...inner_axis into a single dimension. 1572 1573 If outer == inner, this is a NOOP. If inner < outer, then this fials. 1574 If inner >= source.shape.rank, then the behavior is undefined. 1575 1576 Args: 1577 source: a tensor, ragged tensor, or structured tensor. 1578 outer: a python int, indicating the first dimension to compress 1579 (must be nonnegative). 1580 inner: a python int, indicating the first dimension to keep (of the tail) 1581 (must be nonnegative). 1582 1583 Returns: 1584 source with outer_axis...inner_axis merged into a single dimension. 1585 1586 """ 1587 if isinstance(source, StructuredTensor): 1588 return source.merge_dims(outer, inner) 1589 else: 1590 return ragged_tensor.merge_dims(source, outer, inner) 1591