1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
17 #define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
18 
19 #include <limits.h>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/partial_tensor_shape.h"
24 #include "tensorflow/core/framework/register_types.h"
25 #include "tensorflow/core/framework/resource_mgr.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/aggregate_ops.h"
30 #include "tensorflow/core/kernels/fill_functor.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/platform/logging.h"
33 #include "tensorflow/core/platform/types.h"
34 
35 namespace tensorflow {
36 
37 typedef Eigen::ThreadPoolDevice CPUDevice;
38 typedef Eigen::GpuDevice GPUDevice;
39 
40 namespace tensor_array {
41 
42 // Full implementations are in tensor_array.cc
43 template <typename Device, typename T>
AddToTensor(OpKernelContext * ctx,Tensor * sum,const Tensor * current,const Tensor * add)44 Status AddToTensor(OpKernelContext* ctx, Tensor* sum, const Tensor* current,
45                    const Tensor* add) {
46   return errors::InvalidArgument(
47       "tensor_array::AddToTensor type not supported: ",
48       DataTypeString(DataTypeToEnum<T>::value));
49 };
50 
51 #define TENSOR_ARRAY_WRITE_OR_ADD(Device, T)                         \
52   template <>                                                        \
53   Status AddToTensor<Device, T>(OpKernelContext * ctx, Tensor * sum, \
54                                 const Tensor* current, const Tensor* add);
55 
56 #define TENSOR_ARRAY_WRITE_OR_ADD_CPU(T) TENSOR_ARRAY_WRITE_OR_ADD(CPUDevice, T)
57 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_CPU)
58 #undef TENSOR_ARRAY_WRITE_OR_ADD_CPU
59 
60 #if GOOGLE_CUDA
61 
62 #define TENSOR_ARRAY_WRITE_OR_ADD_GPU(T) TENSOR_ARRAY_WRITE_OR_ADD(GPUDevice, T)
63 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
64 TF_CALL_complex64(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
65 TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
66 #undef TENSOR_ARRAY_WRITE_OR_ADD_GPU
67 
68 #endif  // GOOGLE_CUDA
69 
70 #undef TENSOR_ARRAY_WRITE_OR_ADD
71 
72 template <typename Device, typename T>
TensorSetZero(OpKernelContext * ctx,Tensor * value)73 Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
74   return errors::InvalidArgument(
75       "tensor_array::TensorSetZero type not supported: ",
76       DataTypeString(DataTypeToEnum<T>::value));
77 };
78 
79 #define TENSOR_ARRAY_SET_ZERO(Device, T) \
80   template <>                            \
81   Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
82 
83 #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
84 TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
85 TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
86 #undef TENSOR_ARRAY_SET_ZERO_CPU
87 
88 #if GOOGLE_CUDA
89 
90 #define TENSOR_ARRAY_SET_ZERO_GPU(T) TENSOR_ARRAY_SET_ZERO(GPUDevice, T)
91 TF_CALL_GPU_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_GPU);
92 TF_CALL_complex64(TENSOR_ARRAY_SET_ZERO_GPU);
93 TF_CALL_complex128(TENSOR_ARRAY_SET_ZERO_GPU);
94 #undef TENSOR_ARRAY_SET_ZERO_GPU
95 
96 #endif  // GOOGLE_CUDA
97 
98 #undef TENSOR_ARRAY_SET_ZERO
99 
100 }  // namespace tensor_array
101 
102 // The TensorArray object keeps an array of PersistentTensors.  It
103 // allows reading from the array and writing to the array.
104 //
105 // Important properties:
106 //   * Usually, writing to a particular index in the TensorArray is allowed at
107 //     most once per index.  In a special case, writes with the flag
108 //     multiple_writes_aggregate allow multiple writes to the same
109 //     index.  In this case, the writes are summed.
110 //   * Multiple reads are supported.
111 //   * Deep copies of PersistentTensors are rarely made.  The only
112 //     time they are made is when WriteOrAggregate is called at least twice
113 //     on the same index with the flag multiple_writes_aggregate = True.
114 //   * Reading and Writing to the array is protected by a mutex.
115 //     All operations on a TensorArray are thread-safe.
116 //   * A TensorArray may be preemptively closed, which releases all
117 //     memory associated with it.
118 //
119 // These properties together allow the TensorArray to work as a
120 // functional object and makes gradient computation easy.  For
121 // example:
122 //   * Write-Once semantics mean the gradient of a TensorArray Read never has to
123 //     worry which of multiple writes to that index the gradient value
124 //     is meant for.
125 //   * Read-Many semantics (when using clear_after_read=false) allow the
126 //     TensorArray to be read, packed, or concatenated multiple times;
127 //     and the gradient operations use the multiple_writes_aggregate
128 //     flag to aggregate the backprop writes.  Multiple backprop writes to
129 //     the same index are partial gradients corresponding to the
130 //     multiple reads of that index in the forward phase.
131 //
132 class TensorArray : public ResourceBase {
133  public:
134   static std::atomic<int64> tensor_array_counter;
135 
136   // Construct a TensorArray for holding Tensors of type 'dtype' with
137   // 'N' elements.  While the underlying storage is a std::vector and
138   // can hold more than MAX_INT entries, in practice we do not expect
139   // users to construct this many Tensors for storage in a TensorArray.
TensorArray(const string & key,const DataType & dtype,const Tensor & handle,int32 N,const PartialTensorShape & element_shape,bool identical_element_shapes,bool dynamic_size,bool multiple_writes_aggregate,bool is_grad,int32 marked_size,bool clear_after_read)140   TensorArray(const string& key, const DataType& dtype, const Tensor& handle,
141               int32 N, const PartialTensorShape& element_shape,
142               bool identical_element_shapes, bool dynamic_size,
143               bool multiple_writes_aggregate, bool is_grad, int32 marked_size,
144               bool clear_after_read)
145       : key_(key),
146         dtype_(dtype),
147         handle_(handle),
148         closed_(false),
149         dynamic_size_(dynamic_size),
150         multiple_writes_aggregate_(multiple_writes_aggregate),
151         gradients_disallowed_(false),
152         clear_after_read_(clear_after_read),
153         is_grad_(is_grad),
154         marked_size_(marked_size),
155         element_shape_(element_shape),
156         identical_element_shapes_(identical_element_shapes),
157         tensors_(N) {}
158 
159   // Write PersistentTensor 'value' to index 'index'.
160   //
161   // Preconditions:
162   //  * The TensorArray is not closed
163   //  * If the array has dynamic size:
164   //      The index is >= 0
165   //    Otherwise:
166   //      The index is in [0, N) where N == Size()
167   //  * The dtype of the Tensor in 'value' matches the TensorArray's dtype.
168   //  * If multiple_writes_aggregate is false:
169   //    The Tensor at 'index' has not yet been written to.
170   //  * If multiple_writes_aggregate is true:
171   //    The Tensor at 'index' has the same shape as value.
172   //
173   // Side effects:
174   //  * On the first write to 'index':
175   //    - The underlying Tensor in 'value' has a new reference to it.
176   //    - The index 'index' is marked as written.
177   //  * If multiple_writes_aggregate is false, subsequent writes to 'index'
178   //    raise an InvalidArgument error.
179   //  * If multiple_writes_aggregate is true, subsequent writes to 'index':
180   //    - The underlying Tensors in 'value' and from the first write
181   //      are released and a local PersistentTensor is created.
182   //    - Index 'index' is also marked as local_copy.
183   //    - The gradients_disallowed flag is set true (GradientsAllowed()
184   //      will now return false).
185   //
186   // Note, value is passed as a pointer because we its underlying
187   // Tensor's shape is accessed.  Otherwise it is not modified.
188   template <typename Device, typename T>
WriteOrAggregate(OpKernelContext * ctx,const int32 index,PersistentTensor * value)189   Status WriteOrAggregate(OpKernelContext* ctx, const int32 index,
190                           PersistentTensor* value) {
191     mutex_lock l(mu_);
192     return LockedWriteOrAggregate<Device, T>(ctx, index, value);
193   }
194 
195   template <typename Device, typename T>
WriteOrAggregateMany(OpKernelContext * ctx,const std::vector<int32> & indices,std::vector<PersistentTensor> * values)196   Status WriteOrAggregateMany(OpKernelContext* ctx,
197                               const std::vector<int32>& indices,
198                               std::vector<PersistentTensor>* values) {
199     mutex_lock l(mu_);
200     int32 i = 0;
201     for (const int32 ix : indices) {
202       Status s = LockedWriteOrAggregate<Device, T>(ctx, ix, &(*values)[i]);
203       ++i;
204       TF_RETURN_IF_ERROR(s);
205     }
206     return Status::OK();
207   }
208 
209   // Read from index 'index' into PersistentTensor 'value'.
210   //
211   // Preconditions:
212   //  * The TensorArray is not closed
213   //  * The index is in [0, N)
214   //  * The Tensor at 'index' has been written to.
215   //  * The Tensor at 'index' has not been read from with flag
216   //    clear_after_read = true.
217   //
218   // Side effects:
219   //  * If clear_after_read is true, the reference to the underlying
220   //    Tensor is deleted.
221   //  * The reference to the underlying Tensor at 'index' is copied to
222   //    the returned '*value'.
223   //  * The index is marked as read (it cannot be rewritten to).
224   template <typename Device, typename T>
Read(OpKernelContext * ctx,const int32 index,PersistentTensor * value)225   Status Read(OpKernelContext* ctx, const int32 index,
226               PersistentTensor* value) {
227     mutex_lock l(mu_);
228     return LockedRead<Device, T>(ctx, index, value);
229   }
230 
231   template <typename Device, typename T>
ReadMany(OpKernelContext * ctx,const std::vector<int32> & indices,std::vector<PersistentTensor> * values)232   Status ReadMany(OpKernelContext* ctx, const std::vector<int32>& indices,
233                   std::vector<PersistentTensor>* values) {
234     mutex_lock l(mu_);
235     values->clear();
236     values->resize(indices.size());
237     int32 i = 0;
238     for (const int32 ix : indices) {
239       Status s = LockedRead<Device, T>(ctx, ix, &(*values)[i]);
240       ++i;
241       if (!s.ok()) return s;
242     }
243     return Status::OK();
244   }
245 
ElemType()246   DataType ElemType() const { return dtype_; }
247 
ElemShape()248   PartialTensorShape ElemShape() {
249     mutex_lock l(mu_);
250     return element_shape_;
251   }
252 
SetElemShape(const PartialTensorShape & candidate)253   Status SetElemShape(const PartialTensorShape& candidate) {
254     mutex_lock l(mu_);
255     PartialTensorShape new_element_shape_;
256     Status s = element_shape_.MergeWith(candidate, &new_element_shape_);
257     if (!s.ok()) {
258       return s;
259     }
260     element_shape_ = new_element_shape_;
261     return Status::OK();
262   }
263 
DebugString()264   string DebugString() const override {
265     mutex_lock l(mu_);
266     CHECK(!closed_);
267     return strings::StrCat("TensorArray[", tensors_.size(), "]");
268   }
269 
IsClosed()270   bool IsClosed() {
271     mutex_lock l(mu_);
272     return closed_;
273   }
274 
275   // Return the size of the TensorArray.
Size(int32 * size)276   Status Size(int32* size) {
277     mutex_lock l(mu_);
278     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
279     *size = tensors_.size();
280     return Status::OK();
281   }
282 
283   // Record the size of the TensorArray after an unpack or split.
SetMarkedSize(int32 size)284   Status SetMarkedSize(int32 size) {
285     mutex_lock l(mu_);
286     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
287     if (!is_grad_) {
288       marked_size_ = size;
289     }
290     return Status::OK();
291   }
292 
293   // Return the marked size of the TensorArray.
MarkedSize(int32 * size)294   Status MarkedSize(int32* size) {
295     mutex_lock l(mu_);
296     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
297     *size = marked_size_;
298     return Status::OK();
299   }
300 
301   // Return the size that should be used by pack or concat op.
PackOrConcatSize(int32 * size)302   Status PackOrConcatSize(int32* size) {
303     mutex_lock l(mu_);
304     TF_RETURN_IF_ERROR(LockedReturnIfClosed());
305     *size = is_grad_ ? marked_size_ : tensors_.size();
306     return Status::OK();
307   }
308 
309   // Once a TensorArray is being used for gradient calculations, it
310   // should be marked as no longer resizeable.
DisableDynamicSize()311   void DisableDynamicSize() {
312     mutex_lock l(mu_);
313     dynamic_size_ = false;
314   }
315 
HasDynamicSize()316   bool HasDynamicSize() {
317     mutex_lock l(mu_);
318     return dynamic_size_;
319   }
320 
GradientsAllowed()321   bool GradientsAllowed() {
322     mutex_lock l(mu_);
323     return !gradients_disallowed_;
324   }
325 
HasIdenticalElementShapes()326   bool HasIdenticalElementShapes() const { return identical_element_shapes_; }
327 
328   // Copy the TensorShapes from another TensorArray into this one.
329   // If `shapes_to_prepend` is set, expands the rank of the copied shape by
330   // prepending the passed in shape prefix to the shape values in `rhs`.
331   // The sizes of the two TensorArrays must match and this one
332   // may not have any entries filled in.  This performs a "soft copy",
333   // essentially filling the current TensorArray with virtual
334   // zero-tensors, which will be replaced by future aggregate writes,
335   // or instantiated by future reads.  Requires a non-const pointer
336   // to the rhs to access its mutex.
337   Status CopyShapesFrom(TensorArray* rhs, const TensorShape* shape_to_prepend);
338 
339   // Clear the TensorArray, including any Tensor references, and mark as closed.
ClearAndMarkClosed()340   void ClearAndMarkClosed() {
341     mutex_lock l(mu_);
342     tensors_.clear();
343     closed_ = true;
344   }
345 
mu()346   mutex* mu() { return &mu_; }
handle()347   Tensor* handle() { return &handle_; }
348 
resource_handle(OpKernelContext * ctx)349   ResourceHandle resource_handle(OpKernelContext* ctx) {
350     return MakePerStepResourceHandle<TensorArray>(ctx, key_);
351   }
352 
353  private:
354   Status LockedWrite(OpKernelContext* ctx, const int32 index,
355                      PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_);
356 
357   template <typename Device, typename T>
358   Status LockedWriteOrAggregate(OpKernelContext* ctx, const int32 index,
359                                 PersistentTensor* value)
360       EXCLUSIVE_LOCKS_REQUIRED(mu_);
361 
362   template <typename Device, typename T>
363   Status LockedRead(OpKernelContext* ctx, const int32 index,
364                     PersistentTensor* value) EXCLUSIVE_LOCKS_REQUIRED(mu_);
365 
LockedReturnIfClosed()366   Status LockedReturnIfClosed() const EXCLUSIVE_LOCKS_REQUIRED(mu_) {
367     if (closed_) {
368       return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
369                                      " has already been closed.");
370     }
371     return Status::OK();
372   }
373 
374   const string key_;
375 
376   const DataType dtype_;
377   Tensor handle_;
378 
379   mutable mutex mu_;
380 
381   // Marks that the tensor_array_ has been cleared.
382   bool closed_ GUARDED_BY(mu_);
383 
384   // Writes are allowed to grow the array.
385   bool dynamic_size_;
386 
387   // Multiple writes to the same index will result in summation of the
388   // values (used by backprop)
389   const bool multiple_writes_aggregate_;
390 
391   // If multiple Writes were attempted (e.g. via attribute
392   // multiple_writes_aggregate), then gradients are disallowed.
393   bool gradients_disallowed_ GUARDED_BY(mu_);
394 
395   // After a read at an index, clear away its PersistentTensor to
396   // release memory.
397   const bool clear_after_read_;
398 
399   // True iff this is a gradient tensor array.
400   const bool is_grad_;
401 
402   // The size of the TensorArray after a (legacy) unpack or split is performed.
403   // -1 if there has been no unpack or split performed on the TensorArray.
404   int32 marked_size_;
405 
406   // The shape of each element in the TensorArray, may be partially known or not
407   // known at all.
408   PartialTensorShape element_shape_ GUARDED_BY(mu_);
409 
410   // Whether all elements in the TensorArray have identical shapes.
411   // This allows certain behaviors, like dynamically checking for
412   // consistent shapes on write, and being able to fill in properly
413   // shaped zero tensors on stack -- even if the initial element_shape
414   // was not fully defined.
415   const bool identical_element_shapes_;
416 
417   // TensorAndState is used to keep track of the PersistentTensors
418   // stored in the TensorArray, along with their shapes, and a boolean
419   // that determines whether they have already been read or not.
420   struct TensorAndState {
TensorAndStateTensorAndState421     TensorAndState()
422         : written(false), read(false), cleared(false), local_copy(false) {}
423     PersistentTensor tensor;
424     TensorShape shape;
425     bool written;  // True if a Tensor has been written to the index.
426     bool read;  // True if a Tensor has been written to and read from the index.
427     bool cleared;  // True if a tensor has been read with
428                    // clear_after_read = true;
429 
430     // Used by writes when multiple_writes_aggregate is true.  In this
431     // case, the first time a value is written, it is a shallow copy.
432     // The second time a value is written, it is aggregated.  However,
433     // in this case a new Tensor must be constructed to hold the
434     // aggregated value.  This flag marks that such a Tensor is being
435     // used.  All future writes will aggregate to the existing local Tensor.
436     bool local_copy;
437   };
438   // The list of underlying PersistentTensors and states.
439   std::vector<TensorAndState> tensors_ GUARDED_BY(mu_);
440 };
441 
442 template <typename Device, typename T>
LockedWriteOrAggregate(OpKernelContext * ctx,const int32 index,PersistentTensor * value)443 Status TensorArray::LockedWriteOrAggregate(OpKernelContext* ctx,
444                                            const int32 index,
445                                            PersistentTensor* value) {
446   TF_RETURN_IF_ERROR(LockedReturnIfClosed());
447   size_t index_size = static_cast<size_t>(index);
448   if (index < 0 || (!dynamic_size_ && index_size >= tensors_.size())) {
449     return errors::InvalidArgument(
450         "TensorArray ", handle_.vec<string>()(1), ": Tried to write to index ",
451         index, " but array is not resizeable and size is: ", tensors_.size());
452   }
453   if (dynamic_size_) {
454     // We must grow the internal TensorArray
455     if (index_size >= tensors_.capacity()) {
456       tensors_.reserve(2 * (index_size + 1));
457     }
458     if (index_size >= tensors_.size()) {
459       tensors_.resize(index_size + 1);
460     }
461   }
462   TensorAndState& t = tensors_[index];
463 
464   Tensor* value_t = value->AccessTensor(ctx);
465   if (value_t->dtype() != dtype_) {
466     return errors::InvalidArgument(
467         "TensorArray ", handle_.vec<string>()(1),
468         ": Could not write to TensorArray index ", index,
469         " because the value dtype is ", DataTypeString(value_t->dtype()),
470         " but TensorArray dtype is ", DataTypeString(dtype_), ".");
471   }
472   if (!element_shape_.IsCompatibleWith(value_t->shape())) {
473     return errors::InvalidArgument(
474         "TensorArray ", handle_.vec<string>()(1),
475         ": Could not write to TensorArray index ", index,
476         " because the value shape is ", value_t->shape().DebugString(),
477         " which is incompatible with the TensorArray's inferred element "
478         "shape: ",
479         element_shape_.DebugString(), " (consider setting infer_shape=False).");
480   } else if (identical_element_shapes_ && !element_shape_.IsFullyDefined()) {
481     element_shape_ = PartialTensorShape(value_t->shape().dim_sizes());
482   }
483 
484   if (t.read) {
485     return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
486                                    ": Could not write to TensorArray index ",
487                                    index, " because it has already been read.");
488   }
489 
490   if (!multiple_writes_aggregate_ && t.written) {
491     return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
492                                    ": Could not write to TensorArray index ",
493                                    index,
494                                    " because it has already been written to.");
495   }
496 
497   if (t.written) {
498     DCHECK(multiple_writes_aggregate_);
499 
500     // Check that value_t shape matches t.shape
501     if (value_t->shape() != t.shape) {
502       return errors::InvalidArgument(
503           "TensorArray ", handle_.vec<string>()(1),
504           ": Could not aggregate to TensorArray index ", index,
505           " because the existing shape is ", t.shape.DebugString(),
506           " but the new input shape is ", value_t->shape().DebugString(), ".");
507     }
508 
509     if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
510       // If existing_t == nullptr but written == true, then what was stored
511       // was just a shape, which just means zeros.  So all we must do in this
512       // case is copy the reference over and return early.
513       t.tensor = *value;
514       return Status::OK();
515     }
516 
517     Tensor* existing_t = t.tensor.AccessTensor(ctx);
518 
519     if (t.local_copy) {
520       Status s = tensor_array::AddToTensor<Device, T>(ctx, existing_t,
521                                                       existing_t, value_t);
522       TF_RETURN_IF_ERROR(s);
523     } else {
524       PersistentTensor local_tensor;
525       Tensor* local_tensor_t;
526       TF_RETURN_IF_ERROR(ctx->allocate_persistent(
527           dtype_, existing_t->shape(), &local_tensor, &local_tensor_t));
528       Status s = tensor_array::AddToTensor<Device, T>(ctx, local_tensor_t,
529                                                       existing_t, value_t);
530       TF_RETURN_IF_ERROR(s);
531       t.tensor = local_tensor;
532       t.local_copy = true;
533     }
534 
535     // We've aggregated the values, so disallow backprop on this
536     // TensorArray.
537     gradients_disallowed_ = true;
538   } else {
539     t.tensor = *value;
540     t.shape = value_t->shape();
541     t.written = true;
542   }
543   return Status::OK();
544 }
545 
546 template <typename Device, typename T>
LockedRead(OpKernelContext * ctx,const int32 index,PersistentTensor * value)547 Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
548                                PersistentTensor* value) {
549   TF_RETURN_IF_ERROR(LockedReturnIfClosed());
550   if ((index < 0) ||
551       (!is_grad_ && (static_cast<size_t>(index) >= tensors_.size()))) {
552     return errors::InvalidArgument("Tried to read from index ", index,
553                                    " but array size is: ", tensors_.size());
554   }
555   size_t index_t = static_cast<size_t>(index);
556   if ((is_grad_ && (index_t >= tensors_.size() || !tensors_[index].written)) ||
557       (!is_grad_ && (index_t < tensors_.size() && !tensors_[index].written))) {
558     // Special case returning zeros if this is a gradient read that happens
559     // after a stop_gradients call with dynamic forward TensorArrays.
560     // There is sometimes a race condition where the gradient is not
561     // written due to stop_gradients, but is later read.
562     TensorShape element_shape;
563     if (is_grad_ && index_t < tensors_.size() &&
564         tensors_[index].shape.dims() > 0) {
565       // A gradient TensorArray has more specific gradient information
566       // available for each entry.  A forward TensorArray must rely on
567       // the global element_shape_ to fill in zeros on read.
568       element_shape = tensors_[index].shape;
569     } else if (!element_shape_.IsFullyDefined()) {
570       return errors::InvalidArgument(
571           "TensorArray ", handle_.vec<string>()(1),
572           ": Could not read from TensorArray index ", index,
573           ".  Furthermore, the element shape is not fully defined: ",
574           element_shape_.DebugString(),
575           ".  It is possible you are working with a resizeable TensorArray and "
576           "stop_gradients is not allowing the gradients to be written.  If you "
577           "set the full "
578           "element_shape property on the forward TensorArray, the proper "
579           "all-zeros tensor "
580           "will be returned instead of incurring this error.");
581     } else {
582       element_shape_.AsTensorShape(&element_shape);  // Always succeeds.
583     }
584     if (index_t >= tensors_.size()) {
585       // Fill in tensors_ up to index to have known shape.
586       size_t old_tensors_size = tensors_.size();
587       tensors_.resize(index + 1);
588       for (size_t i = old_tensors_size; i < index + 1; ++i) {
589         tensors_[i].shape = element_shape;
590         tensors_[i].written = true;
591       }
592     } else {
593       tensors_[index].shape = element_shape;
594       tensors_[index].written = true;
595     }
596   }
597 
598   TensorAndState& t = tensors_[index];
599 
600   if (t.cleared) {
601     return errors::InvalidArgument("TensorArray ", handle_.vec<string>()(1),
602                                    ": Could not read index ", index,
603                                    " twice because it was cleared after a "
604                                    "previous read (perhaps try setting "
605                                    "clear_after_read = false?).");
606   }
607 
608   if (!t.tensor.IsInitialized() || t.tensor.NumElements() == 0) {
609     // We stored just a shape, but no value.  This means create and
610     // return zeros of the appropriate shape.
611     Tensor* tensor_t;
612     TF_RETURN_IF_ERROR(
613         ctx->allocate_persistent(dtype_, t.shape, &t.tensor, &tensor_t));
614     if (t.shape.num_elements() > 0) {
615       Status s = tensor_array::TensorSetZero<Device, T>(ctx, tensor_t);
616       if (!s.ok()) return s;
617     }
618   }
619 
620   // Data is available inside the tensor, copy the reference over.
621   *value = t.tensor;
622 
623   if (clear_after_read_) {
624     t.tensor = PersistentTensor();
625     t.cleared = true;
626   }
627   t.read = true;
628   return Status::OK();
629 }
630 
631 }  // namespace tensorflow
632 
633 #endif  // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
634