1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14#============================================================================== 15"""Data Flow Operations.""" 16# pylint: disable=g-bad-name 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import hashlib 22import threading 23 24import six 25 26from tensorflow.python.eager import context 27from tensorflow.python.framework import dtypes as _dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import random_seed 30from tensorflow.python.framework import tensor_shape 31from tensorflow.python.framework import tensor_util 32from tensorflow.python.lib.io import python_io 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import control_flow_ops 35from tensorflow.python.ops import gen_data_flow_ops 36from tensorflow.python.ops import math_ops 37from tensorflow.python.ops import resource_variable_ops 38# go/tf-wildcard-import 39# pylint: disable=wildcard-import 40from tensorflow.python.ops.gen_data_flow_ops import * 41from tensorflow.python.util import deprecation 42from tensorflow.python.util.compat import collections_abc 43from tensorflow.python.util.tf_export import tf_export 44 45# pylint: enable=wildcard-import 46 47 48def _as_type_list(dtypes): 49 """Convert dtypes to a list of types.""" 50 assert dtypes is not None 51 if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)): 52 # We have a single type. 53 return [dtypes] 54 else: 55 # We have a list or tuple of types. 56 return list(dtypes) 57 58 59def _as_shape_list(shapes, 60 dtypes, 61 unknown_dim_allowed=False, 62 unknown_rank_allowed=False): 63 """Convert shapes to a list of tuples of int (or None).""" 64 del dtypes 65 if unknown_dim_allowed: 66 if (not isinstance(shapes, collections_abc.Sequence) or not shapes or 67 any(shape is None or isinstance(shape, int) for shape in shapes)): 68 raise ValueError( 69 "When providing partial shapes, a list of shapes must be provided.") 70 if shapes is None: 71 return None 72 if isinstance(shapes, tensor_shape.TensorShape): 73 shapes = [shapes] 74 if not isinstance(shapes, (tuple, list)): 75 raise TypeError( 76 "shapes must be a TensorShape or a list or tuple of TensorShapes.") 77 if all(shape is None or isinstance(shape, int) for shape in shapes): 78 # We have a single shape. 79 shapes = [shapes] 80 shapes = [tensor_shape.as_shape(shape) for shape in shapes] 81 if not unknown_dim_allowed: 82 if any(not shape.is_fully_defined() for shape in shapes): 83 raise ValueError("All shapes must be fully defined: %s" % shapes) 84 if not unknown_rank_allowed: 85 if any(shape.dims is None for shape in shapes): 86 raise ValueError("All shapes must have a defined rank: %s" % shapes) 87 88 return shapes 89 90 91def _as_name_list(names, dtypes): 92 if names is None: 93 return None 94 if not isinstance(names, (list, tuple)): 95 names = [names] 96 if len(names) != len(dtypes): 97 raise ValueError("List of names must have the same length as the list " 98 "of dtypes") 99 return list(names) 100 101 102def _shape_common(s1, s2): 103 """The greatest lower bound (ordered by specificity) TensorShape.""" 104 s1 = tensor_shape.TensorShape(s1) 105 s2 = tensor_shape.TensorShape(s2) 106 if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims: 107 return tensor_shape.unknown_shape() 108 d = [ 109 d1 if d1 is not None and d1 == d2 else None 110 for (d1, d2) in zip(s1.as_list(), s2.as_list()) 111 ] 112 return tensor_shape.TensorShape(d) 113 114 115# pylint: disable=protected-access 116@tf_export("queue.QueueBase", 117 v1=["queue.QueueBase", "io.QueueBase", "QueueBase"]) 118@deprecation.deprecated_endpoints(["io.QueueBase", "QueueBase"]) 119class QueueBase(object): 120 """Base class for queue implementations. 121 122 A queue is a TensorFlow data structure that stores tensors across 123 multiple steps, and exposes operations that enqueue and dequeue 124 tensors. 125 126 Each queue element is a tuple of one or more tensors, where each 127 tuple component has a static dtype, and may have a static shape. The 128 queue implementations support versions of enqueue and dequeue that 129 handle single elements, versions that support enqueuing and 130 dequeuing a batch of elements at once. 131 132 See `tf.queue.FIFOQueue` and 133 `tf.queue.RandomShuffleQueue` for concrete 134 implementations of this class, and instructions on how to create 135 them. 136 """ 137 138 def __init__(self, dtypes, shapes, names, queue_ref): 139 """Constructs a queue object from a queue reference. 140 141 The two optional lists, `shapes` and `names`, must be of the same length 142 as `dtypes` if provided. The values at a given index `i` indicate the 143 shape and name to use for the corresponding queue component in `dtypes`. 144 145 Args: 146 dtypes: A list of types. The length of dtypes must equal the number 147 of tensors in each element. 148 shapes: Constraints on the shapes of tensors in an element: 149 A list of shape tuples or None. This list is the same length 150 as dtypes. If the shape of any tensors in the element are constrained, 151 all must be; shapes can be None if the shapes should not be constrained. 152 names: Optional list of names. If provided, the `enqueue()` and 153 `dequeue()` methods will use dictionaries with these names as keys. 154 Must be None or a list or tuple of the same length as `dtypes`. 155 queue_ref: The queue reference, i.e. the output of the queue op. 156 157 Raises: 158 ValueError: If one of the arguments is invalid. 159 """ 160 self._dtypes = dtypes 161 if shapes is not None: 162 if len(shapes) != len(dtypes): 163 raise ValueError("Queue shapes must have the same length as dtypes") 164 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 165 else: 166 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 167 if names is not None: 168 if len(names) != len(dtypes): 169 raise ValueError("Queue names must have the same length as dtypes") 170 self._names = names 171 else: 172 self._names = None 173 self._queue_ref = queue_ref 174 if isinstance(queue_ref, ops.EagerTensor): 175 if context.context().scope_name: 176 self._name = context.context().scope_name 177 else: 178 self._name = "Empty" 179 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 180 queue_ref, None) 181 else: 182 self._name = self._queue_ref.op.name.split("/")[-1] 183 184 @staticmethod 185 def from_list(index, queues): 186 """Create a queue using the queue reference from `queues[index]`. 187 188 Args: 189 index: An integer scalar tensor that determines the input that gets 190 selected. 191 queues: A list of `QueueBase` objects. 192 193 Returns: 194 A `QueueBase` object. 195 196 Raises: 197 TypeError: When `queues` is not a list of `QueueBase` objects, 198 or when the data types of `queues` are not all the same. 199 """ 200 if ((not queues) or (not isinstance(queues, list)) or 201 (not all(isinstance(x, QueueBase) for x in queues))): 202 raise TypeError("A list of queues expected") 203 204 dtypes = queues[0].dtypes 205 if not all(dtypes == q.dtypes for q in queues[1:]): 206 raise TypeError("Queues do not have matching component dtypes.") 207 208 names = queues[0].names 209 if not all(names == q.names for q in queues[1:]): 210 raise TypeError("Queues do not have matching component names.") 211 212 queue_shapes = [q.shapes for q in queues] 213 reduced_shapes = [ 214 six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes) 215 ] 216 217 queue_refs = array_ops.stack([x.queue_ref for x in queues]) 218 selected_queue = array_ops.gather(queue_refs, index) 219 return QueueBase( 220 dtypes=dtypes, 221 shapes=reduced_shapes, 222 names=names, 223 queue_ref=selected_queue) 224 225 @property 226 def queue_ref(self): 227 """The underlying queue reference.""" 228 return self._queue_ref 229 230 @property 231 def name(self): 232 """The name of the underlying queue.""" 233 if context.executing_eagerly(): 234 return self._name 235 return self._queue_ref.op.name 236 237 @property 238 def dtypes(self): 239 """The list of dtypes for each component of a queue element.""" 240 return self._dtypes 241 242 @property 243 def shapes(self): 244 """The list of shapes for each component of a queue element.""" 245 return self._shapes 246 247 @property 248 def names(self): 249 """The list of names for each component of a queue element.""" 250 return self._names 251 252 def _check_enqueue_dtypes(self, vals): 253 """Validate and convert `vals` to a list of `Tensor`s. 254 255 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 256 dictionary with tensor values. 257 258 If it is a dictionary, the queue must have been constructed with a 259 `names` attribute and the dictionary keys must match the queue names. 260 If the queue was constructed with a `names` attribute, `vals` must 261 be a dictionary. 262 263 Args: 264 vals: A tensor, a list or tuple of tensors, or a dictionary.. 265 266 Returns: 267 A list of `Tensor` objects. 268 269 Raises: 270 ValueError: If `vals` is invalid. 271 """ 272 if isinstance(vals, dict): 273 if not self._names: 274 raise ValueError("Queue must have names to enqueue a dictionary") 275 if sorted(self._names, key=str) != sorted(vals.keys(), key=str): 276 raise ValueError("Keys in dictionary to enqueue do not match " 277 "names of Queue. Dictionary: (%s), Queue: (%s)" % 278 (sorted(vals.keys()), sorted(self._names))) 279 # The order of values in `self._names` indicates the order in which the 280 # tensors in the dictionary `vals` must be listed. 281 vals = [vals[k] for k in self._names] 282 else: 283 if self._names: 284 raise ValueError("You must enqueue a dictionary in a Queue with names") 285 if not isinstance(vals, (list, tuple)): 286 vals = [vals] 287 288 tensors = [] 289 for i, (val, dtype) in enumerate(zip(vals, self._dtypes)): 290 tensors.append( 291 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 292 293 return tensors 294 295 def _scope_vals(self, vals): 296 """Return a list of values to pass to `name_scope()`. 297 298 Args: 299 vals: A tensor, a list or tuple of tensors, or a dictionary. 300 301 Returns: 302 The values in vals as a list. 303 """ 304 if isinstance(vals, (list, tuple)): 305 return vals 306 elif isinstance(vals, dict): 307 return vals.values() 308 else: 309 return [vals] 310 311 def enqueue(self, vals, name=None): 312 """Enqueues one element to this queue. 313 314 If the queue is full when this operation executes, it will block 315 until the element has been enqueued. 316 317 At runtime, this operation may raise an error if the queue is 318 `tf.QueueBase.close` before or during its execution. If the 319 queue is closed before this operation runs, 320 `tf.errors.CancelledError` will be raised. If this operation is 321 blocked, and either (i) the queue is closed by a close operation 322 with `cancel_pending_enqueues=True`, or (ii) the session is 323 `tf.Session.close`, 324 `tf.errors.CancelledError` will be raised. 325 326 Args: 327 vals: A tensor, a list or tuple of tensors, or a dictionary containing 328 the values to enqueue. 329 name: A name for the operation (optional). 330 331 Returns: 332 The operation that enqueues a new tuple of tensors to the queue. 333 """ 334 with ops.name_scope(name, "%s_enqueue" % self._name, 335 self._scope_vals(vals)) as scope: 336 vals = self._check_enqueue_dtypes(vals) 337 338 # NOTE(mrry): Not using a shape function because we need access to 339 # the `QueueBase` object. 340 for val, shape in zip(vals, self._shapes): 341 val.get_shape().assert_is_compatible_with(shape) 342 343 if self._queue_ref.dtype == _dtypes.resource: 344 return gen_data_flow_ops.queue_enqueue_v2( 345 self._queue_ref, vals, name=scope) 346 else: 347 return gen_data_flow_ops.queue_enqueue( 348 self._queue_ref, vals, name=scope) 349 350 def enqueue_many(self, vals, name=None): 351 """Enqueues zero or more elements to this queue. 352 353 This operation slices each component tensor along the 0th dimension to 354 make multiple queue elements. All of the tensors in `vals` must have the 355 same size in the 0th dimension. 356 357 If the queue is full when this operation executes, it will block 358 until all of the elements have been enqueued. 359 360 At runtime, this operation may raise an error if the queue is 361 `tf.QueueBase.close` before or during its execution. If the 362 queue is closed before this operation runs, 363 `tf.errors.CancelledError` will be raised. If this operation is 364 blocked, and either (i) the queue is closed by a close operation 365 with `cancel_pending_enqueues=True`, or (ii) the session is 366 `tf.Session.close`, 367 `tf.errors.CancelledError` will be raised. 368 369 Args: 370 vals: A tensor, a list or tuple of tensors, or a dictionary 371 from which the queue elements are taken. 372 name: A name for the operation (optional). 373 374 Returns: 375 The operation that enqueues a batch of tuples of tensors to the queue. 376 """ 377 with ops.name_scope(name, "%s_EnqueueMany" % self._name, 378 self._scope_vals(vals)) as scope: 379 vals = self._check_enqueue_dtypes(vals) 380 381 # NOTE(mrry): Not using a shape function because we need access to 382 # the `QueueBase` object. 383 # NOTE(fchollet): the code that follow is verbose because it needs to be 384 # compatible with both TF v1 TensorShape behavior and TF v2 behavior. 385 batch_dim = tensor_shape.dimension_value( 386 vals[0].get_shape().with_rank_at_least(1)[0]) 387 batch_dim = tensor_shape.Dimension(batch_dim) 388 for val, shape in zip(vals, self._shapes): 389 val_batch_dim = tensor_shape.dimension_value( 390 val.get_shape().with_rank_at_least(1)[0]) 391 val_batch_dim = tensor_shape.Dimension(val_batch_dim) 392 batch_dim = batch_dim.merge_with(val_batch_dim) 393 val.get_shape()[1:].assert_is_compatible_with(shape) 394 395 return gen_data_flow_ops.queue_enqueue_many_v2( 396 self._queue_ref, vals, name=scope) 397 398 def _dequeue_return_value(self, tensors): 399 """Return the value to return from a dequeue op. 400 401 If the queue has names, return a dictionary with the 402 names as keys. Otherwise return either a single tensor 403 or a list of tensors depending on the length of `tensors`. 404 405 Args: 406 tensors: List of tensors from the dequeue op. 407 408 Returns: 409 A single tensor, a list of tensors, or a dictionary 410 of tensors. 411 """ 412 if self._names: 413 # The returned values in `tensors` are in the same order as 414 # the names in `self._names`. 415 return {n: tensors[i] for i, n in enumerate(self._names)} 416 elif len(tensors) == 1: 417 return tensors[0] 418 else: 419 return tensors 420 421 def dequeue(self, name=None): 422 """Dequeues one element from this queue. 423 424 If the queue is empty when this operation executes, it will block 425 until there is an element to dequeue. 426 427 At runtime, this operation may raise an error if the queue is 428 `tf.QueueBase.close` before or during its execution. If the 429 queue is closed, the queue is empty, and there are no pending 430 enqueue operations that can fulfill this request, 431 `tf.errors.OutOfRangeError` will be raised. If the session is 432 `tf.Session.close`, 433 `tf.errors.CancelledError` will be raised. 434 435 Args: 436 name: A name for the operation (optional). 437 438 Returns: 439 The tuple of tensors that was dequeued. 440 """ 441 if name is None: 442 name = "%s_Dequeue" % self._name 443 if self._queue_ref.dtype == _dtypes.resource: 444 ret = gen_data_flow_ops.queue_dequeue_v2( 445 self._queue_ref, self._dtypes, name=name) 446 else: 447 ret = gen_data_flow_ops.queue_dequeue( 448 self._queue_ref, self._dtypes, name=name) 449 450 # NOTE(mrry): Not using a shape function because we need access to 451 # the `QueueBase` object. 452 if not context.executing_eagerly(): 453 op = ret[0].op 454 for output, shape in zip(op.values(), self._shapes): 455 output.set_shape(shape) 456 457 return self._dequeue_return_value(ret) 458 459 def dequeue_many(self, n, name=None): 460 """Dequeues and concatenates `n` elements from this queue. 461 462 This operation concatenates queue-element component tensors along 463 the 0th dimension to make a single component tensor. All of the 464 components in the dequeued tuple will have size `n` in the 0th dimension. 465 466 If the queue is closed and there are less than `n` elements left, then an 467 `OutOfRange` exception is raised. 468 469 At runtime, this operation may raise an error if the queue is 470 `tf.QueueBase.close` before or during its execution. If the 471 queue is closed, the queue contains fewer than `n` elements, and 472 there are no pending enqueue operations that can fulfill this 473 request, `tf.errors.OutOfRangeError` will be raised. If the 474 session is `tf.Session.close`, 475 `tf.errors.CancelledError` will be raised. 476 477 Args: 478 n: A scalar `Tensor` containing the number of elements to dequeue. 479 name: A name for the operation (optional). 480 481 Returns: 482 The list of concatenated tensors that was dequeued. 483 """ 484 if name is None: 485 name = "%s_DequeueMany" % self._name 486 487 ret = gen_data_flow_ops.queue_dequeue_many_v2( 488 self._queue_ref, n=n, component_types=self._dtypes, name=name) 489 490 # NOTE(mrry): Not using a shape function because we need access to 491 # the Queue object. 492 if not context.executing_eagerly(): 493 op = ret[0].op 494 batch_dim = tensor_shape.Dimension( 495 tensor_util.constant_value(op.inputs[1])) 496 for output, shape in zip(op.values(), self._shapes): 497 output.set_shape( 498 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 499 500 return self._dequeue_return_value(ret) 501 502 def dequeue_up_to(self, n, name=None): 503 """Dequeues and concatenates `n` elements from this queue. 504 505 **Note** This operation is not supported by all queues. If a queue does not 506 support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised. 507 508 This operation concatenates queue-element component tensors along 509 the 0th dimension to make a single component tensor. If the queue 510 has not been closed, all of the components in the dequeued tuple 511 will have size `n` in the 0th dimension. 512 513 If the queue is closed and there are more than `0` but fewer than 514 `n` elements remaining, then instead of raising a 515 `tf.errors.OutOfRangeError` like `tf.QueueBase.dequeue_many`, 516 less than `n` elements are returned immediately. If the queue is 517 closed and there are `0` elements left in the queue, then a 518 `tf.errors.OutOfRangeError` is raised just like in `dequeue_many`. 519 Otherwise the behavior is identical to `dequeue_many`. 520 521 Args: 522 n: A scalar `Tensor` containing the number of elements to dequeue. 523 name: A name for the operation (optional). 524 525 Returns: 526 The tuple of concatenated tensors that was dequeued. 527 """ 528 if name is None: 529 name = "%s_DequeueUpTo" % self._name 530 531 ret = gen_data_flow_ops.queue_dequeue_up_to_v2( 532 self._queue_ref, n=n, component_types=self._dtypes, name=name) 533 534 # NOTE(mrry): Not using a shape function because we need access to 535 # the Queue object. 536 if not context.executing_eagerly(): 537 op = ret[0].op 538 for output, shape in zip(op.values(), self._shapes): 539 output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape)) 540 541 return self._dequeue_return_value(ret) 542 543 def close(self, cancel_pending_enqueues=False, name=None): 544 """Closes this queue. 545 546 This operation signals that no more elements will be enqueued in 547 the given queue. Subsequent `enqueue` and `enqueue_many` 548 operations will fail. Subsequent `dequeue` and `dequeue_many` 549 operations will continue to succeed if sufficient elements remain 550 in the queue. Subsequently dequeue and dequeue_many operations 551 that would otherwise block waiting for more elements (if close 552 hadn't been called) will now fail immediately. 553 554 If `cancel_pending_enqueues` is `True`, all pending requests will also 555 be canceled. 556 557 Args: 558 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 559 `False` (described above). 560 name: A name for the operation (optional). 561 562 Returns: 563 The operation that closes the queue. 564 """ 565 if name is None: 566 name = "%s_Close" % self._name 567 if self._queue_ref.dtype == _dtypes.resource: 568 return gen_data_flow_ops.queue_close_v2( 569 self._queue_ref, 570 cancel_pending_enqueues=cancel_pending_enqueues, 571 name=name) 572 else: 573 return gen_data_flow_ops.queue_close( 574 self._queue_ref, 575 cancel_pending_enqueues=cancel_pending_enqueues, 576 name=name) 577 578 def is_closed(self, name=None): 579 """Returns true if queue is closed. 580 581 This operation returns true if the queue is closed and false if the queue 582 is open. 583 584 Args: 585 name: A name for the operation (optional). 586 587 Returns: 588 True if the queue is closed and false if the queue is open. 589 """ 590 if name is None: 591 name = "%s_Is_Closed" % self._name 592 if self._queue_ref.dtype == _dtypes.resource: 593 return gen_data_flow_ops.queue_is_closed_v2(self._queue_ref, name=name) 594 else: 595 return gen_data_flow_ops.queue_is_closed_(self._queue_ref, name=name) 596 597 def size(self, name=None): 598 """Compute the number of elements in this queue. 599 600 Args: 601 name: A name for the operation (optional). 602 603 Returns: 604 A scalar tensor containing the number of elements in this queue. 605 """ 606 if name is None: 607 name = "%s_Size" % self._name 608 if self._queue_ref.dtype == _dtypes.resource: 609 return gen_data_flow_ops.queue_size_v2(self._queue_ref, name=name) 610 else: 611 return gen_data_flow_ops.queue_size(self._queue_ref, name=name) 612 613def _shared_name(shared_name): 614 if context.executing_eagerly(): 615 return str(ops.uid()) 616 return shared_name 617 618 619@tf_export( 620 "queue.RandomShuffleQueue", 621 v1=["queue.RandomShuffleQueue", 622 "io.RandomShuffleQueue", "RandomShuffleQueue"]) 623@deprecation.deprecated_endpoints( 624 ["io.RandomShuffleQueue", "RandomShuffleQueue"]) 625class RandomShuffleQueue(QueueBase): 626 """A queue implementation that dequeues elements in a random order. 627 628 See `tf.queue.QueueBase` for a description of the methods on 629 this class. 630 """ 631 632 def __init__(self, 633 capacity, 634 min_after_dequeue, 635 dtypes, 636 shapes=None, 637 names=None, 638 seed=None, 639 shared_name=None, 640 name="random_shuffle_queue"): 641 """Create a queue that dequeues elements in a random order. 642 643 A `RandomShuffleQueue` has bounded capacity; supports multiple 644 concurrent producers and consumers; and provides exactly-once 645 delivery. 646 647 A `RandomShuffleQueue` holds a list of up to `capacity` 648 elements. Each element is a fixed-length tuple of tensors whose 649 dtypes are described by `dtypes`, and whose shapes are optionally 650 described by the `shapes` argument. 651 652 If the `shapes` argument is specified, each component of a queue 653 element must have the respective fixed shape. If it is 654 unspecified, different queue elements may have different shapes, 655 but the use of `dequeue_many` is disallowed. 656 657 The `min_after_dequeue` argument allows the caller to specify a 658 minimum number of elements that will remain in the queue after a 659 `dequeue` or `dequeue_many` operation completes, to ensure a 660 minimum level of mixing of elements. This invariant is maintained 661 by blocking those operations until sufficient elements have been 662 enqueued. The `min_after_dequeue` argument is ignored after the 663 queue has been closed. 664 665 Args: 666 capacity: An integer. The upper bound on the number of elements 667 that may be stored in this queue. 668 min_after_dequeue: An integer (described above). 669 dtypes: A list of `DType` objects. The length of `dtypes` must equal 670 the number of tensors in each queue element. 671 shapes: (Optional.) A list of fully-defined `TensorShape` objects 672 with the same length as `dtypes`, or `None`. 673 names: (Optional.) A list of string naming the components in the queue 674 with the same length as `dtypes`, or `None`. If specified the dequeue 675 methods return a dictionary with the names as keys. 676 seed: A Python integer. Used to create a random seed. See 677 `tf.compat.v1.set_random_seed` 678 for behavior. 679 shared_name: (Optional.) If non-empty, this queue will be shared under 680 the given name across multiple sessions. 681 name: Optional name for the queue operation. 682 """ 683 dtypes = _as_type_list(dtypes) 684 shapes = _as_shape_list(shapes, dtypes) 685 names = _as_name_list(names, dtypes) 686 seed1, seed2 = random_seed.get_seed(seed) 687 if seed1 is None and seed2 is None: 688 seed1, seed2 = 0, 0 689 elif seed is None and shared_name is not None: 690 # This means that graph seed is provided but op seed is not provided. 691 # If shared_name is also provided, make seed2 depend only on the graph 692 # seed and shared_name. (seed2 from get_seed() is generally dependent on 693 # the id of the last op created.) 694 string = (str(seed1) + shared_name).encode("utf-8") 695 seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF 696 queue_ref = gen_data_flow_ops.random_shuffle_queue_v2( 697 component_types=dtypes, 698 shapes=shapes, 699 capacity=capacity, 700 min_after_dequeue=min_after_dequeue, 701 seed=seed1, 702 seed2=seed2, 703 shared_name=_shared_name(shared_name), 704 name=name) 705 706 super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref) 707 708 709@tf_export("queue.FIFOQueue", v1=["queue.FIFOQueue", "FIFOQueue"]) 710@deprecation.deprecated_endpoints("FIFOQueue") 711class FIFOQueue(QueueBase): 712 """A queue implementation that dequeues elements in first-in first-out order. 713 714 See `tf.queue.QueueBase` for a description of the methods on 715 this class. 716 """ 717 718 def __init__(self, 719 capacity, 720 dtypes, 721 shapes=None, 722 names=None, 723 shared_name=None, 724 name="fifo_queue"): 725 """Creates a queue that dequeues elements in a first-in first-out order. 726 727 A `FIFOQueue` has bounded capacity; supports multiple concurrent 728 producers and consumers; and provides exactly-once delivery. 729 730 A `FIFOQueue` holds a list of up to `capacity` elements. Each 731 element is a fixed-length tuple of tensors whose dtypes are 732 described by `dtypes`, and whose shapes are optionally described 733 by the `shapes` argument. 734 735 If the `shapes` argument is specified, each component of a queue 736 element must have the respective fixed shape. If it is 737 unspecified, different queue elements may have different shapes, 738 but the use of `dequeue_many` is disallowed. 739 740 Args: 741 capacity: An integer. The upper bound on the number of elements 742 that may be stored in this queue. 743 dtypes: A list of `DType` objects. The length of `dtypes` must equal 744 the number of tensors in each queue element. 745 shapes: (Optional.) A list of fully-defined `TensorShape` objects 746 with the same length as `dtypes`, or `None`. 747 names: (Optional.) A list of string naming the components in the queue 748 with the same length as `dtypes`, or `None`. If specified the dequeue 749 methods return a dictionary with the names as keys. 750 shared_name: (Optional.) If non-empty, this queue will be shared under 751 the given name across multiple sessions. 752 name: Optional name for the queue operation. 753 """ 754 dtypes = _as_type_list(dtypes) 755 shapes = _as_shape_list(shapes, dtypes) 756 names = _as_name_list(names, dtypes) 757 with ops.init_scope(), ops.device("CPU"): 758 queue_ref = gen_data_flow_ops.fifo_queue_v2( 759 component_types=dtypes, 760 shapes=shapes, 761 capacity=capacity, 762 shared_name=_shared_name(shared_name), 763 name=name) 764 765 super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 766 767 768# TODO(allenl): If GPU-compatible queues turn out to be useful, we should 769# implement GPU kernels for EnqueueMany and DequeueMany so we can make the 770# public FIFOQueue GPU-compatible and remove this internal version. 771class GPUCompatibleFIFOQueue(QueueBase): 772 """A queue implementation that dequeues elements in first-in first-out order. 773 774 GPUCompatibleFIFOQueue is like FIFOQueue, but the queue resource may be placed 775 either on a CPU or on a GPU. It is not cross-device: enqueues and dequeues 776 will be colocated with the queue resource. GPUCompatibleFIFOQueue only 777 supports enqueue and dequeue at the moment, not enqueue_many or dequeue_many. 778 779 See `tf.queue.QueueBase` for a description of the methods on this class. 780 """ 781 782 def __init__(self, 783 capacity, 784 dtypes, 785 shapes=None, 786 names=None, 787 shared_name=None, 788 name="fifo_queue"): 789 """Creates a queue that dequeues elements in a first-in first-out order. 790 791 A `FIFOQueue` has bounded capacity; supports multiple concurrent 792 producers and consumers; and provides exactly-once delivery. 793 794 A `FIFOQueue` holds a list of up to `capacity` elements. Each 795 element is a fixed-length tuple of tensors whose dtypes are 796 described by `dtypes`, and whose shapes are optionally described 797 by the `shapes` argument. 798 799 If the `shapes` argument is specified, each component of a queue 800 element must have the respective fixed shape. If it is 801 unspecified, different queue elements may have different shapes, 802 but the use of `dequeue_many` is disallowed. 803 804 Args: 805 capacity: An integer. The upper bound on the number of elements 806 that may be stored in this queue. 807 dtypes: A list of `DType` objects. The length of `dtypes` must equal 808 the number of tensors in each queue element. 809 shapes: (Optional.) A list of fully-defined `TensorShape` objects 810 with the same length as `dtypes`, or `None`. 811 names: (Optional.) A list of string naming the components in the queue 812 with the same length as `dtypes`, or `None`. If specified the dequeue 813 methods return a dictionary with the names as keys. 814 shared_name: (Optional.) If non-empty, this queue will be shared under 815 the given name across multiple sessions. 816 name: Optional name for the queue operation. 817 """ 818 dtypes = _as_type_list(dtypes) 819 shapes = _as_shape_list(shapes, dtypes) 820 names = _as_name_list(names, dtypes) 821 with ops.init_scope(): 822 queue_ref = gen_data_flow_ops.fifo_queue_v2( 823 component_types=dtypes, 824 shapes=shapes, 825 capacity=capacity, 826 shared_name=_shared_name(shared_name), 827 name=name) 828 829 super(GPUCompatibleFIFOQueue, self).__init__( 830 dtypes, shapes, names, queue_ref) 831 832 def enqueue_many(self, vals, name=None): 833 """enqueue_many is not supported on GPUCompatibleFIFOQueue.""" 834 raise NotImplementedError( 835 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, " 836 "only enqueue and dequeue.") 837 838 def dequeue_many(self, n, name=None): 839 """dequeue_many is not supported on GPUCompatibleFIFOQueue.""" 840 raise NotImplementedError( 841 "GPUCompatibleFIFOQueue does not support enqueue_many or dequeue_many, " 842 "only enqueue and dequeue.") 843 844 845@tf_export( 846 "queue.PaddingFIFOQueue", 847 v1=["queue.PaddingFIFOQueue", "io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 848@deprecation.deprecated_endpoints(["io.PaddingFIFOQueue", "PaddingFIFOQueue"]) 849class PaddingFIFOQueue(QueueBase): 850 """A FIFOQueue that supports batching variable-sized tensors by padding. 851 852 A `PaddingFIFOQueue` may contain components with dynamic shape, while also 853 supporting `dequeue_many`. See the constructor for more details. 854 855 See `tf.queue.QueueBase` for a description of the methods on 856 this class. 857 """ 858 859 def __init__(self, 860 capacity, 861 dtypes, 862 shapes, 863 names=None, 864 shared_name=None, 865 name="padding_fifo_queue"): 866 """Creates a queue that dequeues elements in a first-in first-out order. 867 868 A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent 869 producers and consumers; and provides exactly-once delivery. 870 871 A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each 872 element is a fixed-length tuple of tensors whose dtypes are 873 described by `dtypes`, and whose shapes are described by the `shapes` 874 argument. 875 876 The `shapes` argument must be specified; each component of a queue 877 element must have the respective shape. Shapes of fixed 878 rank but variable size are allowed by setting any shape dimension to None. 879 In this case, the inputs' shape may vary along the given dimension, and 880 `dequeue_many` will pad the given dimension with zeros up to the maximum 881 shape of all elements in the given batch. 882 883 Args: 884 capacity: An integer. The upper bound on the number of elements 885 that may be stored in this queue. 886 dtypes: A list of `DType` objects. The length of `dtypes` must equal 887 the number of tensors in each queue element. 888 shapes: A list of `TensorShape` objects, with the same length as 889 `dtypes`. Any dimension in the `TensorShape` containing value 890 `None` is dynamic and allows values to be enqueued with 891 variable size in that dimension. 892 names: (Optional.) A list of string naming the components in the queue 893 with the same length as `dtypes`, or `None`. If specified the dequeue 894 methods return a dictionary with the names as keys. 895 shared_name: (Optional.) If non-empty, this queue will be shared under 896 the given name across multiple sessions. 897 name: Optional name for the queue operation. 898 899 Raises: 900 ValueError: If shapes is not a list of shapes, or the lengths of dtypes 901 and shapes do not match, or if names is specified and the lengths of 902 dtypes and names do not match. 903 """ 904 dtypes = _as_type_list(dtypes) 905 shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True) 906 names = _as_name_list(names, dtypes) 907 if len(dtypes) != len(shapes): 908 raise ValueError("Shapes must be provided for all components, " 909 "but received %d dtypes and %d shapes." % (len(dtypes), 910 len(shapes))) 911 912 queue_ref = gen_data_flow_ops.padding_fifo_queue_v2( 913 component_types=dtypes, 914 shapes=shapes, 915 capacity=capacity, 916 shared_name=_shared_name(shared_name), 917 name=name) 918 919 super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref) 920 921 922@tf_export("queue.PriorityQueue", 923 v1=["queue.PriorityQueue", "io.PriorityQueue", "PriorityQueue"]) 924@deprecation.deprecated_endpoints(["io.PriorityQueue", "PriorityQueue"]) 925class PriorityQueue(QueueBase): 926 """A queue implementation that dequeues elements in prioritized order. 927 928 See `tf.queue.QueueBase` for a description of the methods on 929 this class. 930 """ 931 932 def __init__(self, 933 capacity, 934 types, 935 shapes=None, 936 names=None, 937 shared_name=None, 938 name="priority_queue"): 939 """Creates a queue that dequeues elements in a first-in first-out order. 940 941 A `PriorityQueue` has bounded capacity; supports multiple concurrent 942 producers and consumers; and provides exactly-once delivery. 943 944 A `PriorityQueue` holds a list of up to `capacity` elements. Each 945 element is a fixed-length tuple of tensors whose dtypes are 946 described by `types`, and whose shapes are optionally described 947 by the `shapes` argument. 948 949 If the `shapes` argument is specified, each component of a queue 950 element must have the respective fixed shape. If it is 951 unspecified, different queue elements may have different shapes, 952 but the use of `dequeue_many` is disallowed. 953 954 Enqueues and Dequeues to the `PriorityQueue` must include an additional 955 tuple entry at the beginning: the `priority`. The priority must be 956 an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`). 957 958 Args: 959 capacity: An integer. The upper bound on the number of elements 960 that may be stored in this queue. 961 types: A list of `DType` objects. The length of `types` must equal 962 the number of tensors in each queue element, except the first priority 963 element. The first tensor in each element is the priority, 964 which must be type int64. 965 shapes: (Optional.) A list of fully-defined `TensorShape` objects, 966 with the same length as `types`, or `None`. 967 names: (Optional.) A list of strings naming the components in the queue 968 with the same length as `dtypes`, or `None`. If specified, the dequeue 969 methods return a dictionary with the names as keys. 970 shared_name: (Optional.) If non-empty, this queue will be shared under 971 the given name across multiple sessions. 972 name: Optional name for the queue operation. 973 """ 974 types = _as_type_list(types) 975 shapes = _as_shape_list(shapes, types) 976 977 queue_ref = gen_data_flow_ops.priority_queue_v2( 978 component_types=types, 979 shapes=shapes, 980 capacity=capacity, 981 shared_name=_shared_name(shared_name), 982 name=name) 983 984 priority_dtypes = [_dtypes.int64] + types 985 priority_shapes = [()] + shapes if shapes else shapes 986 987 super(PriorityQueue, self).__init__(priority_dtypes, priority_shapes, names, 988 queue_ref) 989 990 991# TODO(josh11b): class BatchQueue(QueueBase): 992 993 994class Barrier(object): 995 """Represents a key-value map that persists across graph executions.""" 996 997 def __init__(self, types, shapes=None, shared_name=None, name="barrier"): 998 """Creates a barrier that persists across different graph executions. 999 1000 A barrier represents a key-value map, where each key is a string, and 1001 each value is a tuple of tensors. 1002 1003 At runtime, the barrier contains 'complete' and 'incomplete' 1004 elements. A complete element has defined tensors for all 1005 components of its value tuple, and may be accessed using 1006 take_many. An incomplete element has some undefined components in 1007 its value tuple, and may be updated using insert_many. 1008 1009 The barrier call `take_many` outputs values in a particular order. 1010 First, it only outputs completed values. Second, the order in which 1011 completed values are returned matches the order in which their very 1012 first component was inserted into the barrier. So, for example, for this 1013 sequence of insertions and removals: 1014 1015 barrier = Barrier((tf.string, tf.int32), shapes=((), ())) 1016 barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run() 1017 barrier.insert_many(1, keys=["k1"], values=[1]).run() 1018 barrier.insert_many(0, keys=["k3"], values=["c"]).run() 1019 barrier.insert_many(1, keys=["k3"], values=[3]).run() 1020 barrier.insert_many(1, keys=["k2"], values=[2]).run() 1021 1022 (indices, keys, values) = barrier.take_many(2) 1023 (indices_val, keys_val, values0_val, values1_val) = 1024 session.run([indices, keys, values[0], values[1]]) 1025 1026 The output will be (up to permutation of "k1" and "k2"): 1027 1028 indices_val == (-2**63, -2**63) 1029 keys_val == ("k1", "k2") 1030 values0_val == ("a", "b") 1031 values1_val == (1, 2) 1032 1033 Note the key "k2" was inserted into the barrier before "k3". Even though 1034 "k3" was completed first, both are complete by the time 1035 take_many is called. As a result, "k2" is prioritized and "k1" and "k2" 1036 are returned first. "k3" remains in the barrier until the next execution 1037 of `take_many`. Since "k1" and "k2" had their first insertions into 1038 the barrier together, their indices are the same (-2**63). The index 1039 of "k3" will be -2**63 + 1, because it was the next new inserted key. 1040 1041 Args: 1042 types: A single dtype or a tuple of dtypes, corresponding to the 1043 dtypes of the tensor elements that comprise a value in this barrier. 1044 shapes: Optional. Constraints on the shapes of tensors in the values: 1045 a single tensor shape tuple; a tuple of tensor shape tuples 1046 for each barrier-element tuple component; or None if the shape should 1047 not be constrained. 1048 shared_name: Optional. If non-empty, this barrier will be shared under 1049 the given name across multiple sessions. 1050 name: Optional name for the barrier op. 1051 1052 Raises: 1053 ValueError: If one of the `shapes` indicate no elements. 1054 """ 1055 self._types = _as_type_list(types) 1056 1057 if shapes is not None: 1058 shapes = _as_shape_list(shapes, self._types) 1059 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 1060 for i, shape in enumerate(self._shapes): 1061 if shape.num_elements() == 0: 1062 raise ValueError("Empty tensors are not supported, but received " 1063 "shape '%s' at index %d" % (shape, i)) 1064 else: 1065 self._shapes = [tensor_shape.unknown_shape() for _ in self._types] 1066 1067 self._barrier_ref = gen_data_flow_ops.barrier( 1068 component_types=self._types, 1069 shapes=self._shapes, 1070 shared_name=shared_name, 1071 name=name) 1072 if context.executing_eagerly(): 1073 self._name = context.context().scope_name 1074 else: 1075 self._name = self._barrier_ref.op.name.split("/")[-1] 1076 1077 @property 1078 def barrier_ref(self): 1079 """Get the underlying barrier reference.""" 1080 return self._barrier_ref 1081 1082 @property 1083 def name(self): 1084 """The name of the underlying barrier.""" 1085 if context.executing_eagerly(): 1086 return self._name 1087 return self._barrier_ref.op.name 1088 1089 def insert_many(self, component_index, keys, values, name=None): 1090 """For each key, assigns the respective value to the specified component. 1091 1092 This operation updates each element at component_index. 1093 1094 Args: 1095 component_index: The component of the value that is being assigned. 1096 keys: A vector of keys, with length n. 1097 values: An any-dimensional tensor of values, which are associated with the 1098 respective keys. The first dimension must have length n. 1099 name: Optional name for the op. 1100 1101 Returns: 1102 The operation that performs the insertion. 1103 Raises: 1104 InvalidArgumentsError: If inserting keys and values without elements. 1105 """ 1106 if name is None: 1107 name = "%s_BarrierInsertMany" % self._name 1108 return gen_data_flow_ops.barrier_insert_many( 1109 self._barrier_ref, keys, values, component_index, name=name) 1110 1111 def take_many(self, 1112 num_elements, 1113 allow_small_batch=False, 1114 timeout=None, 1115 name=None): 1116 """Takes the given number of completed elements from this barrier. 1117 1118 This operation concatenates completed-element component tensors along 1119 the 0th dimension to make a single component tensor. 1120 1121 If barrier has no completed elements, this operation will block 1122 until there are 'num_elements' elements to take. 1123 1124 TODO(b/25743580): the semantics of `allow_small_batch` are experimental 1125 and may be extended to other cases in the future. 1126 1127 TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking 1128 already when the barrier is closed, it will block for ever. Fix this 1129 by using asynchronous operations. 1130 1131 Args: 1132 num_elements: The number of elements to take. 1133 allow_small_batch: If the barrier is closed, don't block if there are less 1134 completed elements than requested, but instead return all available 1135 completed elements. 1136 timeout: This specifies the number of milliseconds to block 1137 before returning with DEADLINE_EXCEEDED. (This option is not 1138 supported yet.) 1139 name: A name for the operation (optional). 1140 1141 Returns: 1142 A tuple of (index, key, value_list). 1143 "index" is a int64 tensor of length num_elements containing the 1144 index of the insert_many call for which the very first component of 1145 the given element was inserted into the Barrier, starting with 1146 the value -2**63. Note, this value is different from the 1147 index of the insert_many call for which the element was completed. 1148 "key" is a string tensor of length num_elements containing the keys. 1149 "value_list" is a tuple of tensors, each one with size num_elements 1150 in the 0th dimension for each component in the barrier's values. 1151 1152 """ 1153 if name is None: 1154 name = "%s_BarrierTakeMany" % self._name 1155 ret = gen_data_flow_ops.barrier_take_many( 1156 self._barrier_ref, 1157 num_elements, 1158 self._types, 1159 allow_small_batch, 1160 timeout, 1161 name=name) 1162 1163 # NOTE(mrry): Not using a shape function because we need access to 1164 # the Barrier object. 1165 if not context.executing_eagerly(): 1166 op = ret[0].op 1167 if allow_small_batch: 1168 batch_dim = None 1169 else: 1170 batch_dim = tensor_shape.Dimension( 1171 tensor_util.constant_value(op.inputs[1])) 1172 op.outputs[0].set_shape(tensor_shape.TensorShape([batch_dim])) # indices 1173 op.outputs[1].set_shape(tensor_shape.TensorShape([batch_dim])) # keys 1174 for output, shape in zip(op.outputs[2:], self._shapes): # value_list 1175 output.set_shape( 1176 tensor_shape.TensorShape([batch_dim]).concatenate(shape)) 1177 1178 return ret 1179 1180 def close(self, cancel_pending_enqueues=False, name=None): 1181 """Closes this barrier. 1182 1183 This operation signals that no more new key values will be inserted in the 1184 given barrier. Subsequent InsertMany operations with new keys will fail. 1185 InsertMany operations that just complement already existing keys with other 1186 components, will continue to succeed. Subsequent TakeMany operations will 1187 continue to succeed if sufficient elements remain in the barrier. Subsequent 1188 TakeMany operations that would block will fail immediately. 1189 1190 If `cancel_pending_enqueues` is `True`, all pending requests to the 1191 underlying queue will also be canceled, and completing of already 1192 started values is also not acceptable anymore. 1193 1194 Args: 1195 cancel_pending_enqueues: (Optional.) A boolean, defaulting to 1196 `False` (described above). 1197 name: Optional name for the op. 1198 1199 Returns: 1200 The operation that closes the barrier. 1201 """ 1202 if name is None: 1203 name = "%s_BarrierClose" % self._name 1204 return gen_data_flow_ops.barrier_close( 1205 self._barrier_ref, 1206 cancel_pending_enqueues=cancel_pending_enqueues, 1207 name=name) 1208 1209 def ready_size(self, name=None): 1210 """Compute the number of complete elements in the given barrier. 1211 1212 Args: 1213 name: A name for the operation (optional). 1214 1215 Returns: 1216 A single-element tensor containing the number of complete elements in the 1217 given barrier. 1218 """ 1219 if name is None: 1220 name = "%s_BarrierReadySize" % self._name 1221 return gen_data_flow_ops.barrier_ready_size(self._barrier_ref, name=name) 1222 1223 def incomplete_size(self, name=None): 1224 """Compute the number of incomplete elements in the given barrier. 1225 1226 Args: 1227 name: A name for the operation (optional). 1228 1229 Returns: 1230 A single-element tensor containing the number of incomplete elements in 1231 the given barrier. 1232 """ 1233 if name is None: 1234 name = "%s_BarrierIncompleteSize" % self._name 1235 return gen_data_flow_ops.barrier_incomplete_size( 1236 self._barrier_ref, name=name) 1237 1238 1239@tf_export(v1=["ConditionalAccumulatorBase"]) 1240class ConditionalAccumulatorBase(object): 1241 """A conditional accumulator for aggregating gradients. 1242 1243 Up-to-date gradients (i.e., time step at which gradient was computed is 1244 equal to the accumulator's time step) are added to the accumulator. 1245 1246 Extraction of the average gradient is blocked until the required number of 1247 gradients has been accumulated. 1248 """ 1249 1250 def __init__(self, dtype, shape, accumulator_ref): 1251 """Creates a new ConditionalAccumulator. 1252 1253 Args: 1254 dtype: Datatype of the accumulated gradients. 1255 shape: Shape of the accumulated gradients. 1256 accumulator_ref: A handle to the conditional accumulator, created by sub- 1257 classes 1258 """ 1259 self._dtype = dtype 1260 if shape is not None: 1261 self._shape = tensor_shape.TensorShape(shape) 1262 else: 1263 self._shape = tensor_shape.unknown_shape() 1264 self._accumulator_ref = accumulator_ref 1265 if context.executing_eagerly(): 1266 self._name = context.context().scope_name 1267 else: 1268 self._name = self._accumulator_ref.op.name.split("/")[-1] 1269 1270 @property 1271 def accumulator_ref(self): 1272 """The underlying accumulator reference.""" 1273 return self._accumulator_ref 1274 1275 @property 1276 def name(self): 1277 """The name of the underlying accumulator.""" 1278 return self._name 1279 1280 @property 1281 def dtype(self): 1282 """The datatype of the gradients accumulated by this accumulator.""" 1283 return self._dtype 1284 1285 def num_accumulated(self, name=None): 1286 """Number of gradients that have currently been aggregated in accumulator. 1287 1288 Args: 1289 name: Optional name for the operation. 1290 1291 Returns: 1292 Number of accumulated gradients currently in accumulator. 1293 """ 1294 if name is None: 1295 name = "%s_NumAccumulated" % self._name 1296 1297 return gen_data_flow_ops.resource_accumulator_num_accumulated( 1298 self._accumulator_ref, name=name) 1299 1300 def set_global_step(self, new_global_step, name=None): 1301 """Sets the global time step of the accumulator. 1302 1303 The operation logs a warning if we attempt to set to a time step that is 1304 lower than the accumulator's own time step. 1305 1306 Args: 1307 new_global_step: Value of new time step. Can be a variable or a constant 1308 name: Optional name for the operation. 1309 1310 Returns: 1311 Operation that sets the accumulator's time step. 1312 """ 1313 return gen_data_flow_ops.resource_accumulator_set_global_step( 1314 self._accumulator_ref, 1315 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 1316 name=name) 1317 1318 1319@tf_export(v1=["ConditionalAccumulator"]) 1320class ConditionalAccumulator(ConditionalAccumulatorBase): 1321 """A conditional accumulator for aggregating gradients. 1322 1323 Up-to-date gradients (i.e., time step at which gradient was computed is 1324 equal to the accumulator's time step) are added to the accumulator. 1325 1326 Extraction of the average gradient is blocked until the required number of 1327 gradients has been accumulated. 1328 """ 1329 1330 def __init__(self, 1331 dtype, 1332 shape=None, 1333 shared_name=None, 1334 name="conditional_accumulator", 1335 reduction_type="MEAN"): 1336 """Creates a new ConditionalAccumulator. 1337 1338 Args: 1339 dtype: Datatype of the accumulated gradients. 1340 shape: Shape of the accumulated gradients. 1341 shared_name: Optional. If non-empty, this accumulator will be shared under 1342 the given name across multiple sessions. 1343 name: Optional name for the accumulator. 1344 reduction_type: Reduction type to use when taking the gradient. 1345 """ 1346 accumulator_ref = gen_data_flow_ops.resource_conditional_accumulator( 1347 dtype=dtype, 1348 shape=shape, 1349 shared_name=shared_name, 1350 name=name, 1351 reduction_type=reduction_type) 1352 if context.executing_eagerly(): 1353 self._resource_deleter = resource_variable_ops.EagerResourceDeleter( 1354 handle=accumulator_ref, handle_device=context.context().device_name) 1355 1356 super(ConditionalAccumulator, self).__init__(dtype, shape, accumulator_ref) 1357 1358 def apply_grad(self, grad, local_step=0, name=None): 1359 """Attempts to apply a gradient to the accumulator. 1360 1361 The attempt is silently dropped if the gradient is stale, i.e., local_step 1362 is less than the accumulator's global time step. 1363 1364 Args: 1365 grad: The gradient tensor to be applied. 1366 local_step: Time step at which the gradient was computed. 1367 name: Optional name for the operation. 1368 1369 Returns: 1370 The operation that (conditionally) applies a gradient to the accumulator. 1371 1372 Raises: 1373 ValueError: If grad is of the wrong shape 1374 """ 1375 grad = ops.convert_to_tensor(grad, self._dtype) 1376 grad.get_shape().assert_is_compatible_with(self._shape) 1377 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 1378 1379 return gen_data_flow_ops.resource_accumulator_apply_gradient( 1380 self._accumulator_ref, local_step=local_step, gradient=grad, name=name) 1381 1382 def take_grad(self, num_required, name=None): 1383 """Attempts to extract the average gradient from the accumulator. 1384 1385 The operation blocks until sufficient number of gradients have been 1386 successfully applied to the accumulator. 1387 1388 Once successful, the following actions are also triggered: 1389 1390 - Counter of accumulated gradients is reset to 0. 1391 - Aggregated gradient is reset to 0 tensor. 1392 - Accumulator's internal time step is incremented by 1. 1393 1394 Args: 1395 num_required: Number of gradients that needs to have been aggregated 1396 name: Optional name for the operation 1397 1398 Returns: 1399 A tensor holding the value of the average gradient. 1400 1401 Raises: 1402 InvalidArgumentError: If num_required < 1 1403 """ 1404 out = gen_data_flow_ops.resource_accumulator_take_gradient( 1405 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1406 out.set_shape(self._shape) 1407 return out 1408 1409 1410@tf_export( 1411 v1=["sparse.SparseConditionalAccumulator", "SparseConditionalAccumulator"]) 1412class SparseConditionalAccumulator(ConditionalAccumulatorBase): 1413 """A conditional accumulator for aggregating sparse gradients. 1414 1415 Sparse gradients are represented by `IndexedSlices`. 1416 1417 Up-to-date gradients (i.e., time step at which gradient was computed is 1418 equal to the accumulator's time step) are added to the accumulator. 1419 1420 Extraction of the average gradient is blocked until the required number of 1421 gradients has been accumulated. 1422 1423 Args: 1424 dtype: Datatype of the accumulated gradients. 1425 shape: Shape of the accumulated gradients. 1426 shared_name: Optional. If non-empty, this accumulator will be shared under 1427 the given name across multiple sessions. 1428 name: Optional name for the accumulator. 1429 reduction_type: Reduction type to use when taking the gradient. 1430 """ 1431 1432 def __init__(self, 1433 dtype, 1434 shape=None, 1435 shared_name=None, 1436 name="sparse_conditional_accumulator", 1437 reduction_type="MEAN"): 1438 accumulator_ref = gen_data_flow_ops.sparse_conditional_accumulator( 1439 dtype=dtype, 1440 shape=shape, 1441 shared_name=shared_name, 1442 name=name, 1443 reduction_type=reduction_type) 1444 super(SparseConditionalAccumulator, self).__init__(dtype, shape, 1445 accumulator_ref) 1446 1447 def apply_indexed_slices_grad(self, grad, local_step=0, name=None): 1448 """Attempts to apply a gradient to the accumulator. 1449 1450 The attempt is silently dropped if the gradient is stale, i.e., `local_step` 1451 is less than the accumulator's global time step. 1452 1453 Args: 1454 grad: The gradient `IndexedSlices` to be applied. 1455 local_step: Time step at which the gradient was computed. 1456 name: Optional name for the operation. 1457 1458 Returns: 1459 The operation that (conditionally) applies a gradient to the accumulator. 1460 1461 Raises: 1462 InvalidArgumentError: If grad is of the wrong shape 1463 """ 1464 return self.apply_grad( 1465 grad_indices=grad.indices, 1466 grad_values=grad.values, 1467 grad_shape=grad.dense_shape, 1468 local_step=local_step, 1469 name=name) 1470 1471 def apply_grad(self, 1472 grad_indices, 1473 grad_values, 1474 grad_shape=None, 1475 local_step=0, 1476 name=None): 1477 """Attempts to apply a sparse gradient to the accumulator. 1478 1479 The attempt is silently dropped if the gradient is stale, i.e., `local_step` 1480 is less than the accumulator's global time step. 1481 1482 A sparse gradient is represented by its indices, values and possibly empty 1483 or None shape. Indices must be a vector representing the locations of 1484 non-zero entries in the tensor. Values are the non-zero slices of the 1485 gradient, and must have the same first dimension as indices, i.e., the nnz 1486 represented by indices and values must be consistent. Shape, if not empty or 1487 None, must be consistent with the accumulator's shape (if also provided). 1488 1489 Example: 1490 A tensor [[0, 0], [0, 1], [2, 3]] can be represented 1491 indices: [1,2] 1492 values: [[0,1],[2,3]] 1493 shape: [3, 2] 1494 1495 Args: 1496 grad_indices: Indices of the sparse gradient to be applied. 1497 grad_values: Values of the sparse gradient to be applied. 1498 grad_shape: Shape of the sparse gradient to be applied. 1499 local_step: Time step at which the gradient was computed. 1500 name: Optional name for the operation. 1501 1502 Returns: 1503 The operation that (conditionally) applies a gradient to the accumulator. 1504 1505 Raises: 1506 InvalidArgumentError: If grad is of the wrong shape 1507 """ 1508 local_step = math_ops.cast(ops.convert_to_tensor(local_step), _dtypes.int64) 1509 return gen_data_flow_ops.sparse_accumulator_apply_gradient( 1510 self._accumulator_ref, 1511 local_step=local_step, 1512 gradient_indices=math_ops.cast(grad_indices, _dtypes.int64), 1513 gradient_values=grad_values, 1514 gradient_shape=math_ops.cast( 1515 [] if grad_shape is None else grad_shape, _dtypes.int64), 1516 has_known_shape=(grad_shape is not None), 1517 name=name) 1518 1519 def take_grad(self, num_required, name=None): 1520 """Attempts to extract the average gradient from the accumulator. 1521 1522 The operation blocks until sufficient number of gradients have been 1523 successfully applied to the accumulator. 1524 1525 Once successful, the following actions are also triggered: 1526 - Counter of accumulated gradients is reset to 0. 1527 - Aggregated gradient is reset to 0 tensor. 1528 - Accumulator's internal time step is incremented by 1. 1529 1530 Args: 1531 num_required: Number of gradients that needs to have been aggregated 1532 name: Optional name for the operation 1533 1534 Returns: 1535 A tuple of indices, values, and shape representing the average gradient. 1536 1537 Raises: 1538 InvalidArgumentError: If `num_required` < 1 1539 """ 1540 return gen_data_flow_ops.sparse_accumulator_take_gradient( 1541 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1542 1543 def take_indexed_slices_grad(self, num_required, name=None): 1544 """Attempts to extract the average gradient from the accumulator. 1545 1546 The operation blocks until sufficient number of gradients have been 1547 successfully applied to the accumulator. 1548 1549 Once successful, the following actions are also triggered: 1550 - Counter of accumulated gradients is reset to 0. 1551 - Aggregated gradient is reset to 0 tensor. 1552 - Accumulator's internal time step is incremented by 1. 1553 1554 Args: 1555 num_required: Number of gradients that needs to have been aggregated 1556 name: Optional name for the operation 1557 1558 Returns: 1559 An `IndexedSlices` holding the value of the average gradient. 1560 1561 Raises: 1562 InvalidArgumentError: If `num_required` < 1 1563 """ 1564 return_val = gen_data_flow_ops.sparse_accumulator_take_gradient( 1565 self._accumulator_ref, num_required, dtype=self._dtype, name=name) 1566 return ops.IndexedSlices( 1567 indices=return_val.indices, 1568 values=return_val.values, 1569 dense_shape=return_val.shape) 1570 1571 # SparseConditionalAccumulator is not switched to resource. Use old kernels. 1572 def num_accumulated(self, name=None): 1573 """Number of gradients that have currently been aggregated in accumulator. 1574 1575 Args: 1576 name: Optional name for the operation. 1577 1578 Returns: 1579 Number of accumulated gradients currently in accumulator. 1580 """ 1581 if name is None: 1582 name = "%s_NumAccumulated" % self._name 1583 1584 return gen_data_flow_ops.accumulator_num_accumulated( 1585 self._accumulator_ref, name=name) 1586 1587 def set_global_step(self, new_global_step, name=None): 1588 """Sets the global time step of the accumulator. 1589 1590 The operation logs a warning if we attempt to set to a time step that is 1591 lower than the accumulator's own time step. 1592 1593 Args: 1594 new_global_step: Value of new time step. Can be a variable or a constant 1595 name: Optional name for the operation. 1596 1597 Returns: 1598 Operation that sets the accumulator's time step. 1599 """ 1600 return gen_data_flow_ops.accumulator_set_global_step( 1601 self._accumulator_ref, 1602 math_ops.cast(ops.convert_to_tensor(new_global_step), _dtypes.int64), 1603 name=name) 1604 1605 1606class BaseStagingArea(object): 1607 """Base class for Staging Areas.""" 1608 _identifier = 0 1609 _lock = threading.Lock() 1610 1611 def __init__(self, 1612 dtypes, 1613 shapes=None, 1614 names=None, 1615 shared_name=None, 1616 capacity=0, 1617 memory_limit=0): 1618 if shared_name is None: 1619 self._name = ( 1620 ops.get_default_graph().unique_name(self.__class__.__name__)) 1621 elif isinstance(shared_name, six.string_types): 1622 self._name = shared_name 1623 else: 1624 raise ValueError("shared_name must be a string") 1625 1626 self._dtypes = dtypes 1627 1628 if shapes is not None: 1629 if len(shapes) != len(dtypes): 1630 raise ValueError("StagingArea shapes must be the same length as dtypes") 1631 self._shapes = [tensor_shape.TensorShape(s) for s in shapes] 1632 else: 1633 self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes] 1634 1635 if names is not None: 1636 if len(names) != len(dtypes): 1637 raise ValueError("StagingArea names must be the same length as dtypes") 1638 self._names = names 1639 else: 1640 self._names = None 1641 1642 self._capacity = capacity 1643 self._memory_limit = memory_limit 1644 1645 # all get and put ops must colocate with this op 1646 with ops.name_scope("%s_root" % self._name): 1647 self._coloc_op = control_flow_ops.no_op() 1648 1649 @property 1650 def name(self): 1651 """The name of the staging area.""" 1652 return self._name 1653 1654 @property 1655 def dtypes(self): 1656 """The list of dtypes for each component of a staging area element.""" 1657 return self._dtypes 1658 1659 @property 1660 def shapes(self): 1661 """The list of shapes for each component of a staging area element.""" 1662 return self._shapes 1663 1664 @property 1665 def names(self): 1666 """The list of names for each component of a staging area element.""" 1667 return self._names 1668 1669 @property 1670 def capacity(self): 1671 """The maximum number of elements of this staging area.""" 1672 return self._capacity 1673 1674 @property 1675 def memory_limit(self): 1676 """The maximum number of bytes of this staging area.""" 1677 return self._memory_limit 1678 1679 def _check_put_dtypes(self, vals, indices=None): 1680 """Validate and convert `vals` to a list of `Tensor`s. 1681 1682 The `vals` argument can be a Tensor, a list or tuple of tensors, or a 1683 dictionary with tensor values. 1684 1685 If `vals` is a list, then the appropriate indices associated with the 1686 values must be provided. 1687 1688 If it is a dictionary, the staging area must have been constructed with a 1689 `names` attribute and the dictionary keys must match the staging area names. 1690 `indices` will be inferred from the dictionary keys. 1691 If the staging area was constructed with a `names` attribute, `vals` must 1692 be a dictionary. 1693 1694 Checks that the dtype and shape of each value matches that 1695 of the staging area. 1696 1697 Args: 1698 vals: A tensor, a list or tuple of tensors, or a dictionary. 1699 1700 Returns: 1701 A (tensors, indices) tuple where `tensors` is a list of `Tensor` objects 1702 and `indices` is a list of indices associated with the tensors. 1703 1704 Raises: 1705 ValueError: If `vals` or `indices` is invalid. 1706 """ 1707 if isinstance(vals, dict): 1708 if not self._names: 1709 raise ValueError( 1710 "Staging areas must have names to enqueue a dictionary") 1711 if not set(vals.keys()).issubset(self._names): 1712 raise ValueError("Keys in dictionary to put do not match names " 1713 "of staging area. Dictionary: (%s), Queue: (%s)" % 1714 (sorted(vals.keys()), sorted(self._names))) 1715 # The order of values in `self._names` indicates the order in which the 1716 # tensors in the dictionary `vals` must be listed. 1717 vals, indices, _ = zip(*[(vals[k], i, k) 1718 for i, k in enumerate(self._names) 1719 if k in vals]) 1720 else: 1721 if self._names: 1722 raise ValueError("You must enqueue a dictionary in a staging area " 1723 "with names") 1724 1725 if indices is None: 1726 raise ValueError("Indices must be supplied when inserting a list " 1727 "of tensors") 1728 1729 if len(indices) != len(vals): 1730 raise ValueError("Number of indices '%s' doesn't match " 1731 "number of values '%s'") 1732 1733 if not isinstance(vals, (list, tuple)): 1734 vals = [vals] 1735 indices = [0] 1736 1737 # Sanity check number of values 1738 if not len(vals) <= len(self._dtypes): 1739 raise ValueError("Unexpected number of inputs '%s' vs '%s'" % 1740 (len(vals), len(self._dtypes))) 1741 1742 tensors = [] 1743 1744 for val, i in zip(vals, indices): 1745 dtype, shape = self._dtypes[i], self._shapes[i] 1746 # Check dtype 1747 if val.dtype != dtype: 1748 raise ValueError("Datatypes do not match. '%s' != '%s'" % 1749 (str(val.dtype), str(dtype))) 1750 1751 # Check shape 1752 val.get_shape().assert_is_compatible_with(shape) 1753 1754 tensors.append( 1755 ops.convert_to_tensor(val, dtype=dtype, name="component_%d" % i)) 1756 1757 return tensors, indices 1758 1759 def _create_device_transfers(self, tensors): 1760 """Encode inter-device transfers if the current device 1761 is not the same as the Staging Area's device. 1762 """ 1763 1764 if not isinstance(tensors, (tuple, list)): 1765 tensors = [tensors] 1766 1767 curr_device_scope = control_flow_ops.no_op().device 1768 1769 if curr_device_scope != self._coloc_op.device: 1770 tensors = [array_ops.identity(t) for t in tensors] 1771 1772 return tensors 1773 1774 def _get_return_value(self, tensors, indices): 1775 """Return the value to return from a get op. 1776 1777 If the staging area has names, return a dictionary with the 1778 names as keys. Otherwise return either a single tensor 1779 or a list of tensors depending on the length of `tensors`. 1780 1781 Args: 1782 tensors: List of tensors from the get op. 1783 indices: Indices of associated names and shapes 1784 1785 Returns: 1786 A single tensor, a list of tensors, or a dictionary 1787 of tensors. 1788 """ 1789 1790 tensors = self._create_device_transfers(tensors) 1791 1792 # Sets shape 1793 for output, i in zip(tensors, indices): 1794 output.set_shape(self._shapes[i]) 1795 1796 if self._names: 1797 # The returned values in `tensors` are in the same order as 1798 # the names in `self._names`. 1799 return {self._names[i]: t for t, i in zip(tensors, indices)} 1800 return tensors 1801 1802 def _scope_vals(self, vals): 1803 """Return a list of values to pass to `name_scope()`. 1804 1805 Args: 1806 vals: A tensor, a list or tuple of tensors, or a dictionary. 1807 1808 Returns: 1809 The values in vals as a list. 1810 """ 1811 if isinstance(vals, (list, tuple)): 1812 return vals 1813 elif isinstance(vals, dict): 1814 return vals.values() 1815 else: 1816 return [vals] 1817 1818 1819class StagingArea(BaseStagingArea): 1820 """Class for staging inputs. No ordering guarantees. 1821 1822 A `StagingArea` is a TensorFlow data structure that stores tensors across 1823 multiple steps, and exposes operations that can put and get tensors. 1824 1825 Each `StagingArea` element is a tuple of one or more tensors, where each 1826 tuple component has a static dtype, and may have a static shape. 1827 1828 The capacity of a `StagingArea` may be bounded or unbounded. 1829 It supports multiple concurrent producers and consumers; and 1830 provides exactly-once delivery. 1831 1832 Each element of a `StagingArea` is a fixed-length tuple of tensors whose 1833 dtypes are described by `dtypes`, and whose shapes are optionally described 1834 by the `shapes` argument. 1835 1836 If the `shapes` argument is specified, each component of a staging area 1837 element must have the respective fixed shape. If it is 1838 unspecified, different elements may have different shapes, 1839 1840 It can be configured with a capacity in which case 1841 put(values) will block until space becomes available. 1842 1843 Similarly, it can be configured with a memory limit which 1844 will block put(values) until space is available. 1845 This is mostly useful for limiting the number of tensors on 1846 devices such as GPUs. 1847 1848 All get() and peek() commands block if the requested data 1849 is not present in the Staging Area. 1850 1851 """ 1852 1853 def __init__(self, 1854 dtypes, 1855 shapes=None, 1856 names=None, 1857 shared_name=None, 1858 capacity=0, 1859 memory_limit=0): 1860 """Constructs a staging area object. 1861 1862 The two optional lists, `shapes` and `names`, must be of the same length 1863 as `dtypes` if provided. The values at a given index `i` indicate the 1864 shape and name to use for the corresponding queue component in `dtypes`. 1865 1866 The device scope at the time of object creation determines where the 1867 storage for the `StagingArea` will reside. Calls to `put` will incur a copy 1868 to this memory space, if necessary. Tensors returned by `get` will be 1869 placed according to the device scope when `get` is called. 1870 1871 Args: 1872 dtypes: A list of types. The length of dtypes must equal the number 1873 of tensors in each element. 1874 shapes: (Optional.) Constraints on the shapes of tensors in an element. 1875 A list of shape tuples or None. This list is the same length 1876 as dtypes. If the shape of any tensors in the element are constrained, 1877 all must be; shapes can be None if the shapes should not be constrained. 1878 names: (Optional.) If provided, the `get()` and 1879 `put()` methods will use dictionaries with these names as keys. 1880 Must be None or a list or tuple of the same length as `dtypes`. 1881 shared_name: (Optional.) A name to be used for the shared object. By 1882 passing the same name to two different python objects they will share 1883 the underlying staging area. Must be a string. 1884 capacity: (Optional.) Maximum number of elements. 1885 An integer. If zero, the Staging Area is unbounded 1886 memory_limit: (Optional.) Maximum number of bytes of all tensors 1887 in the Staging Area. 1888 An integer. If zero, the Staging Area is unbounded 1889 1890 Raises: 1891 ValueError: If one of the arguments is invalid. 1892 """ 1893 1894 super(StagingArea, self).__init__(dtypes, shapes, names, shared_name, 1895 capacity, memory_limit) 1896 1897 def put(self, values, name=None): 1898 """Create an op that places a value into the staging area. 1899 1900 This operation will block if the `StagingArea` has reached 1901 its capacity. 1902 1903 Args: 1904 values: A single tensor, a list or tuple of tensors, or a dictionary with 1905 tensor values. The number of elements must match the length of the 1906 list provided to the dtypes argument when creating the StagingArea. 1907 name: A name for the operation (optional). 1908 1909 Returns: 1910 The created op. 1911 1912 Raises: 1913 ValueError: If the number or type of inputs don't match the staging area. 1914 """ 1915 with ops.name_scope(name, "%s_put" % self._name, 1916 self._scope_vals(values)) as scope: 1917 1918 if not isinstance(values, (list, tuple, dict)): 1919 values = [values] 1920 1921 # Hard-code indices for this staging area 1922 indices = list(six.moves.range(len(values))) 1923 vals, _ = self._check_put_dtypes(values, indices) 1924 1925 with ops.colocate_with(self._coloc_op): 1926 op = gen_data_flow_ops.stage( 1927 values=vals, 1928 shared_name=self._name, 1929 name=scope, 1930 capacity=self._capacity, 1931 memory_limit=self._memory_limit) 1932 1933 return op 1934 1935 def __internal_get(self, get_fn, name): 1936 with ops.colocate_with(self._coloc_op): 1937 ret = get_fn() 1938 1939 indices = list(six.moves.range(len(self._dtypes))) # Hard coded 1940 return self._get_return_value(ret, indices) 1941 1942 def get(self, name=None): 1943 """Gets one element from this staging area. 1944 1945 If the staging area is empty when this operation executes, it will block 1946 until there is an element to dequeue. 1947 1948 Note that unlike others ops that can block, like the queue Dequeue 1949 operations, this can stop other work from happening. To avoid this, the 1950 intended use is for this to be called only when there will be an element 1951 already available. One method for doing this in a training loop would be to 1952 run a `put()` call during a warmup session.run call, and then call both 1953 `get()` and `put()` in each subsequent step. 1954 1955 The placement of the returned tensor will be determined by the current 1956 device scope when this function is called. 1957 1958 Args: 1959 name: A name for the operation (optional). 1960 1961 Returns: 1962 The tuple of tensors that was gotten. 1963 """ 1964 if name is None: 1965 name = "%s_get" % self._name 1966 1967 # pylint: disable=bad-continuation 1968 fn = lambda: gen_data_flow_ops.unstage(dtypes=self._dtypes, 1969 shared_name=self._name, name=name, 1970 capacity=self._capacity, 1971 memory_limit=self._memory_limit) 1972 # pylint: enable=bad-continuation 1973 1974 return self.__internal_get(fn, name) 1975 1976 def peek(self, index, name=None): 1977 """Peeks at an element in the staging area. 1978 1979 If the staging area is too small to contain the element at 1980 the specified index, it will block until enough elements 1981 are inserted to complete the operation. 1982 1983 The placement of the returned tensor will be determined by 1984 the current device scope when this function is called. 1985 1986 Args: 1987 index: The index of the tensor within the staging area 1988 to look up. 1989 name: A name for the operation (optional). 1990 1991 Returns: 1992 The tuple of tensors that was gotten. 1993 """ 1994 if name is None: 1995 name = "%s_peek" % self._name 1996 1997 # pylint: disable=bad-continuation 1998 fn = lambda: gen_data_flow_ops.stage_peek(index, 1999 dtypes=self._dtypes, shared_name=self._name, 2000 name=name, capacity=self._capacity, 2001 memory_limit=self._memory_limit) 2002 # pylint: enable=bad-continuation 2003 2004 return self.__internal_get(fn, name) 2005 2006 def size(self, name=None): 2007 """Returns the number of elements in the staging area. 2008 2009 Args: 2010 name: A name for the operation (optional) 2011 2012 Returns: 2013 The created op 2014 """ 2015 if name is None: 2016 name = "%s_size" % self._name 2017 2018 return gen_data_flow_ops.stage_size( 2019 name=name, 2020 shared_name=self._name, 2021 dtypes=self._dtypes, 2022 capacity=self._capacity, 2023 memory_limit=self._memory_limit) 2024 2025 def clear(self, name=None): 2026 """Clears the staging area. 2027 2028 Args: 2029 name: A name for the operation (optional) 2030 2031 Returns: 2032 The created op 2033 """ 2034 if name is None: 2035 name = "%s_clear" % self._name 2036 2037 return gen_data_flow_ops.stage_clear( 2038 name=name, 2039 shared_name=self._name, 2040 dtypes=self._dtypes, 2041 capacity=self._capacity, 2042 memory_limit=self._memory_limit) 2043 2044 2045class MapStagingArea(BaseStagingArea): 2046 """A `MapStagingArea` is a TensorFlow data structure that stores tensors 2047 across multiple steps, and exposes operations that can put and get tensors. 2048 2049 Each `MapStagingArea` element is a (key, value) pair. 2050 Only int64 keys are supported, other types should be 2051 hashed to produce a key. 2052 Values are a tuple of one or more tensors. 2053 Each tuple component has a static dtype, 2054 and may have a static shape. 2055 2056 The capacity of a `MapStagingArea` may be bounded or unbounded. 2057 It supports multiple concurrent producers and consumers; and 2058 provides exactly-once delivery. 2059 2060 Each value tuple of a `MapStagingArea` is a fixed-length tuple of tensors 2061 whose 2062 dtypes are described by `dtypes`, and whose shapes are optionally described 2063 by the `shapes` argument. 2064 2065 If the `shapes` argument is specified, each component of a staging area 2066 element must have the respective fixed shape. If it is 2067 unspecified, different elements may have different shapes, 2068 2069 It behaves like an associative container with support for: 2070 2071 - put(key, values) 2072 - peek(key) like dict.get(key) 2073 - get(key) like dict.pop(key) 2074 - get(key=None) like dict.popitem() 2075 - size() 2076 - clear() 2077 2078 If ordered a tree structure ordered by key will be used and 2079 get(key=None) will remove (key, value) pairs in increasing key order. 2080 Otherwise a hashtable 2081 2082 It can be configured with a capacity in which case 2083 put(key, values) will block until space becomes available. 2084 2085 Similarly, it can be configured with a memory limit which 2086 will block put(key, values) until space is available. 2087 This is mostly useful for limiting the number of tensors on 2088 devices such as GPUs. 2089 2090 All get() and peek() commands block if the requested 2091 (key, value) pair is not present in the staging area. 2092 2093 Partial puts are supported and will be placed in an incomplete 2094 map until such time as all values associated with the key have 2095 been inserted. Once completed, this (key, value) pair will be 2096 inserted into the map. Data in the incomplete map 2097 counts towards the memory limit, but not towards capacity limit. 2098 2099 Partial gets from the map are also supported. 2100 This removes the partially requested tensors from the entry, 2101 but the entry is only removed from the map once all tensors 2102 associated with it are removed. 2103 """ 2104 2105 def __init__(self, 2106 dtypes, 2107 shapes=None, 2108 names=None, 2109 shared_name=None, 2110 ordered=False, 2111 capacity=0, 2112 memory_limit=0): 2113 """Args: 2114 2115 dtypes: A list of types. The length of dtypes must equal the number 2116 of tensors in each element. 2117 capacity: (Optional.) Maximum number of elements. 2118 An integer. If zero, the Staging Area is unbounded 2119 memory_limit: (Optional.) Maximum number of bytes of all tensors 2120 in the Staging Area (excluding keys). 2121 An integer. If zero, the Staging Area is unbounded 2122 ordered: (Optional.) If True the underlying data structure 2123 is a tree ordered on key. Otherwise assume a hashtable. 2124 shapes: (Optional.) Constraints on the shapes of tensors in an element. 2125 A list of shape tuples or None. This list is the same length 2126 as dtypes. If the shape of any tensors in the element are constrained, 2127 all must be; shapes can be None if the shapes should not be constrained. 2128 names: (Optional.) If provided, the `get()` and 2129 `put()` methods will use dictionaries with these names as keys. 2130 Must be None or a list or tuple of the same length as `dtypes`. 2131 shared_name: (Optional.) A name to be used for the shared object. By 2132 passing the same name to two different python objects they will share 2133 the underlying staging area. Must be a string. 2134 2135 Raises: 2136 ValueError: If one of the arguments is invalid. 2137 2138 """ 2139 2140 super(MapStagingArea, self).__init__(dtypes, shapes, names, shared_name, 2141 capacity, memory_limit) 2142 2143 # Defer to different methods depending if the map is ordered 2144 self._ordered = ordered 2145 2146 if ordered: 2147 self._put_fn = gen_data_flow_ops.ordered_map_stage 2148 self._pop_fn = gen_data_flow_ops.ordered_map_unstage 2149 self._popitem_fn = gen_data_flow_ops.ordered_map_unstage_no_key 2150 self._peek_fn = gen_data_flow_ops.ordered_map_peek 2151 self._size_fn = gen_data_flow_ops.ordered_map_size 2152 self._incomplete_size_fn = gen_data_flow_ops.ordered_map_incomplete_size 2153 self._clear_fn = gen_data_flow_ops.ordered_map_clear 2154 else: 2155 self._put_fn = gen_data_flow_ops.map_stage 2156 self._pop_fn = gen_data_flow_ops.map_unstage 2157 self._popitem_fn = gen_data_flow_ops.map_unstage_no_key 2158 self._peek_fn = gen_data_flow_ops.map_peek 2159 self._size_fn = gen_data_flow_ops.map_size 2160 self._incomplete_size_fn = gen_data_flow_ops.map_incomplete_size 2161 self._clear_fn = gen_data_flow_ops.map_clear 2162 2163 def put(self, key, vals, indices=None, name=None): 2164 """Create an op that stores the (key, vals) pair in the staging area. 2165 2166 Incomplete puts are possible, preferably using a dictionary for vals 2167 as the appropriate dtypes and shapes can be inferred from the value names 2168 dictionary key values. If vals is a list or tuple, indices must 2169 also be specified so that the op knows at which element position 2170 to perform the insert. 2171 2172 This operation will block if the capacity or memory limit of this 2173 container is reached. 2174 2175 Args: 2176 key: Key associated with the data 2177 vals: Tensor (or a dict/tuple of Tensors) to place 2178 into the staging area. 2179 indices: (Optional) if vals is a tuple/list, this is required. 2180 name: A name for the operation (optional) 2181 2182 Returns: 2183 The created op 2184 2185 Raises: 2186 ValueError: If the number or type of inputs don't match the staging 2187 area. 2188 """ 2189 2190 with ops.name_scope(name, "%s_put" % self._name, 2191 self._scope_vals(vals)) as scope: 2192 2193 vals, indices = self._check_put_dtypes(vals, indices) 2194 2195 with ops.colocate_with(self._coloc_op): 2196 op = self._put_fn( 2197 key, 2198 indices, 2199 vals, 2200 dtypes=self._dtypes, 2201 shared_name=self._name, 2202 name=scope, 2203 capacity=self._capacity, 2204 memory_limit=self._memory_limit) 2205 return op 2206 2207 def _get_indices_and_dtypes(self, indices=None): 2208 if indices is None: 2209 indices = list(six.moves.range(len(self._dtypes))) 2210 2211 if not isinstance(indices, (tuple, list)): 2212 raise TypeError("Invalid indices type '%s'" % type(indices)) 2213 2214 if len(indices) == 0: 2215 raise ValueError("Empty indices") 2216 2217 if all(isinstance(i, str) for i in indices): 2218 if self._names is None: 2219 raise ValueError("String indices provided '%s', but this Staging Area " 2220 "was not created with names." % indices) 2221 2222 try: 2223 indices = [self._names.index(n) for n in indices] 2224 except ValueError: 2225 raise ValueError("Named index '%s' not in " 2226 "Staging Area names '%s'" % (n, self._names)) 2227 elif all(isinstance(i, int) for i in indices): 2228 pass 2229 else: 2230 raise TypeError("Mixed types in indices '%s'. " 2231 "May only be str or int" % indices) 2232 2233 dtypes = [self._dtypes[i] for i in indices] 2234 2235 return indices, dtypes 2236 2237 def peek(self, key, indices=None, name=None): 2238 """Peeks at staging area data associated with the key. 2239 2240 If the key is not in the staging area, it will block 2241 until the associated (key, value) is inserted. 2242 2243 Args: 2244 key: Key associated with the required data 2245 indices: Partial list of tensors to retrieve (optional). 2246 A list of integer or string indices. 2247 String indices are only valid if the Staging Area 2248 has names associated with it. 2249 name: A name for the operation (optional) 2250 2251 Returns: 2252 The created op 2253 """ 2254 2255 if name is None: 2256 name = "%s_pop" % self._name 2257 2258 indices, dtypes = self._get_indices_and_dtypes(indices) 2259 2260 with ops.colocate_with(self._coloc_op): 2261 result = self._peek_fn( 2262 key, 2263 shared_name=self._name, 2264 indices=indices, 2265 dtypes=dtypes, 2266 name=name, 2267 capacity=self._capacity, 2268 memory_limit=self._memory_limit) 2269 2270 return self._get_return_value(result, indices) 2271 2272 def get(self, key=None, indices=None, name=None): 2273 """If the key is provided, the associated (key, value) is returned from the staging area. 2274 2275 If the key is not in the staging area, this method will block until 2276 the associated (key, value) is inserted. 2277 If no key is provided and the staging area is ordered, 2278 the (key, value) with the smallest key will be returned. 2279 Otherwise, a random (key, value) will be returned. 2280 2281 If the staging area is empty when this operation executes, 2282 it will block until there is an element to dequeue. 2283 2284 Args: 2285 key: Key associated with the required data (Optional) 2286 indices: Partial list of tensors to retrieve (optional). 2287 A list of integer or string indices. 2288 String indices are only valid if the Staging Area 2289 has names associated with it. 2290 name: A name for the operation (optional) 2291 2292 Returns: 2293 The created op 2294 """ 2295 if key is None: 2296 return self._popitem(indices=indices, name=name) 2297 else: 2298 return self._pop(key, indices=indices, name=name) 2299 2300 def _pop(self, key, indices=None, name=None): 2301 """Remove and return the associated (key, value) is returned from the staging area. 2302 2303 If the key is not in the staging area, this method will block until 2304 the associated (key, value) is inserted. 2305 Args: 2306 key: Key associated with the required data 2307 indices: Partial list of tensors to retrieve (optional). 2308 A list of integer or string indices. 2309 String indices are only valid if the Staging Area 2310 has names associated with it. 2311 name: A name for the operation (optional) 2312 2313 Returns: 2314 The created op 2315 """ 2316 if name is None: 2317 name = "%s_get" % self._name 2318 2319 indices, dtypes = self._get_indices_and_dtypes(indices) 2320 2321 with ops.colocate_with(self._coloc_op): 2322 result = self._pop_fn( 2323 key, 2324 shared_name=self._name, 2325 indices=indices, 2326 dtypes=dtypes, 2327 name=name, 2328 capacity=self._capacity, 2329 memory_limit=self._memory_limit) 2330 2331 return key, self._get_return_value(result, indices) 2332 2333 def _popitem(self, indices=None, name=None): 2334 """If the staging area is ordered, the (key, value) with the smallest key will be returned. 2335 2336 Otherwise, a random (key, value) will be returned. 2337 If the staging area is empty when this operation executes, 2338 it will block until there is an element to dequeue. 2339 2340 Args: 2341 key: Key associated with the required data 2342 indices: Partial list of tensors to retrieve (optional). 2343 A list of integer or string indices. 2344 String indices are only valid if the Staging Area 2345 has names associated with it. 2346 name: A name for the operation (optional) 2347 2348 Returns: 2349 The created op 2350 """ 2351 if name is None: 2352 name = "%s_get_nokey" % self._name 2353 2354 indices, dtypes = self._get_indices_and_dtypes(indices) 2355 2356 with ops.colocate_with(self._coloc_op): 2357 key, result = self._popitem_fn( 2358 shared_name=self._name, 2359 indices=indices, 2360 dtypes=dtypes, 2361 name=name, 2362 capacity=self._capacity, 2363 memory_limit=self._memory_limit) 2364 2365 # Separate keys and results out from 2366 # underlying namedtuple 2367 key = self._create_device_transfers(key)[0] 2368 result = self._get_return_value(result, indices) 2369 2370 return key, result 2371 2372 def size(self, name=None): 2373 """Returns the number of elements in the staging area. 2374 2375 Args: 2376 name: A name for the operation (optional) 2377 2378 Returns: 2379 The created op 2380 """ 2381 if name is None: 2382 name = "%s_size" % self._name 2383 2384 return self._size_fn( 2385 shared_name=self._name, 2386 name=name, 2387 dtypes=self._dtypes, 2388 capacity=self._capacity, 2389 memory_limit=self._memory_limit) 2390 2391 def incomplete_size(self, name=None): 2392 """Returns the number of incomplete elements in the staging area. 2393 2394 Args: 2395 name: A name for the operation (optional) 2396 2397 Returns: 2398 The created op 2399 """ 2400 if name is None: 2401 name = "%s_incomplete_size" % self._name 2402 2403 return self._incomplete_size_fn( 2404 shared_name=self._name, 2405 name=name, 2406 dtypes=self._dtypes, 2407 capacity=self._capacity, 2408 memory_limit=self._memory_limit) 2409 2410 def clear(self, name=None): 2411 """Clears the staging area. 2412 2413 Args: 2414 name: A name for the operation (optional) 2415 2416 Returns: 2417 The created op 2418 """ 2419 if name is None: 2420 name = "%s_clear" % self._name 2421 2422 return self._clear_fn( 2423 shared_name=self._name, 2424 name=name, 2425 dtypes=self._dtypes, 2426 capacity=self._capacity, 2427 memory_limit=self._memory_limit) 2428 2429 2430class RecordInput(object): 2431 """RecordInput asynchronously reads and randomly yields TFRecords. 2432 2433 A RecordInput Op will continuously read a batch of records asynchronously 2434 into a buffer of some fixed capacity. It can also asynchronously yield 2435 random records from this buffer. 2436 2437 It will not start yielding until at least `buffer_size / 2` elements have been 2438 placed into the buffer so that sufficient randomization can take place. 2439 2440 The order the files are read will be shifted each epoch by `shift_amount` so 2441 that the data is presented in a different order every epoch. 2442 """ 2443 2444 def __init__(self, 2445 file_pattern, 2446 batch_size=1, 2447 buffer_size=1, 2448 parallelism=1, 2449 shift_ratio=0, 2450 seed=0, 2451 name=None, 2452 batches=None, 2453 compression_type=None): 2454 """Constructs a RecordInput Op. 2455 2456 Args: 2457 file_pattern: File path to the dataset, possibly containing wildcards. 2458 All matching files will be iterated over each epoch. 2459 batch_size: How many records to return at a time. 2460 buffer_size: The maximum number of records the buffer will contain. 2461 parallelism: How many reader threads to use for reading from files. 2462 shift_ratio: What percentage of the total number files to move the start 2463 file forward by each epoch. 2464 seed: Specify the random number seed used by generator that randomizes 2465 records. 2466 name: Optional name for the operation. 2467 batches: None by default, creating a single batch op. Otherwise specifies 2468 how many batches to create, which are returned as a list when 2469 `get_yield_op()` is called. An example use case is to split processing 2470 between devices on one computer. 2471 compression_type: The type of compression for the file. Currently ZLIB and 2472 GZIP are supported. Defaults to none. 2473 2474 Raises: 2475 ValueError: If one of the arguments is invalid. 2476 """ 2477 self._batch_size = batch_size 2478 if batches is not None: 2479 self._batch_size *= batches 2480 self._batches = batches 2481 self._file_pattern = file_pattern 2482 self._buffer_size = buffer_size 2483 self._parallelism = parallelism 2484 self._shift_ratio = shift_ratio 2485 self._seed = seed 2486 self._name = name 2487 self._compression_type = python_io.TFRecordCompressionType.NONE 2488 if compression_type is not None: 2489 self._compression_type = compression_type 2490 2491 def get_yield_op(self): 2492 """Adds a node that yields a group of records every time it is executed. 2493 If RecordInput `batches` parameter is not None, it yields a list of 2494 record batches with the specified `batch_size`. 2495 """ 2496 compression_type = python_io.TFRecordOptions.get_compression_type_string( 2497 python_io.TFRecordOptions(self._compression_type)) 2498 records = gen_data_flow_ops.record_input( 2499 file_pattern=self._file_pattern, 2500 file_buffer_size=self._buffer_size, 2501 file_parallelism=self._parallelism, 2502 file_shuffle_shift_ratio=self._shift_ratio, 2503 batch_size=self._batch_size, 2504 file_random_seed=self._seed, 2505 compression_type=compression_type, 2506 name=self._name) 2507 if self._batches is None: 2508 return records 2509 else: 2510 with ops.name_scope(self._name): 2511 batch_list = [[] for _ in six.moves.range(self._batches)] 2512 records = array_ops.split(records, self._batch_size, 0) 2513 for index, protobuf in enumerate(records): 2514 batch_index = index % self._batches 2515 batch_list[batch_index].append(array_ops.reshape(protobuf, [])) 2516 return batch_list 2517