1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ 16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ 17 18 #include "absl/types/variant.h" 19 #include "tensorflow/core/common_runtime/eager/context.h" 20 #include "tensorflow/core/framework/tensor.h" 21 #include "tensorflow/core/lib/core/status.h" 22 23 namespace tensorflow { 24 25 // Local Tensor Handle: Handle to a Tensor present on the local host. 26 class LocalTensorHandleData { 27 public: LocalTensorHandleData()28 LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {} LocalTensorHandleData(tensorflow::Tensor && t)29 explicit LocalTensorHandleData(tensorflow::Tensor&& t) 30 : tensor_(std::move(t)), 31 forwarding_protection_tensor_(tensor_), 32 ctrl_(absl::in_place_type<NonBlockingControl>) {} 33 34 // A local tensor handle should be able to satisfy all of these requests. 35 Status Tensor(const tensorflow::Tensor** t) const; 36 Status TensorValue(tensorflow::TensorValue* t); 37 Status Shape(TensorShape* shape) const; 38 Status NumDims(int* num_dims) const; 39 Status Dim(int dim_index, int64* dim) const; 40 Status NumElements(int64* num_elements) const; 41 Status Unprotect(); 42 IsReady()43 bool IsReady() const { 44 return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_); 45 } 46 WaitReady(const char * caller)47 Status WaitReady(const char* caller) const { 48 return absl::visit([caller](auto& data) { return data.WaitReady(caller); }, 49 ctrl_); 50 } Poison(Status status)51 void Poison(Status status) { 52 return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_); 53 } IsPoisoned()54 Status IsPoisoned() const { 55 return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_); 56 } 57 58 Status SetTensor(tensorflow::Tensor&& t); 59 60 string DebugString() const; 61 62 private: 63 tensorflow::Tensor tensor_; 64 // TensorHandle has its own reference counting which is distinct from the 65 // backing Tensor. As a result, if the Tensor reference count is 1 while 66 // executing an op, the TensorBuffer could be reused for the output. We avoid 67 // this behavior maintaining another reference count with the 68 // forwarding_protection_tensor_ Tensor. When Unprotect() is called, we 69 // release this Tensor to allow forwarding. 70 tensorflow::Tensor forwarding_protection_tensor_; 71 72 // We distinguish between ready and empty tensors with the ctrl_ variant. 73 // which contains 2 implementations of the waiting logic. The 74 // NonBlockingControl is a simple no-op class whereas the BlockingControl 75 // actually uses a mutex. By using a variant we avoid the overhead of 76 // constructing and destructing the mutex for ready local tensors. 77 class NonBlockingControl { 78 public: IsReady()79 bool IsReady() const { return true; } WaitReady(const char * caller)80 Status WaitReady(const char* caller) const { return Status::OK(); } Poison(Status status)81 void Poison(Status status) {} IsPoisoned()82 Status IsPoisoned() const { return Status::OK(); } 83 }; 84 85 class BlockingControl { 86 public: IsReady()87 bool IsReady() const { 88 tf_shared_lock l(mu_); 89 return is_ready_; 90 } 91 void SetReady(); 92 Status WaitReady(const char* caller) const; 93 void Poison(Status status); IsPoisoned()94 Status IsPoisoned() const { 95 tf_shared_lock l(mu_); 96 return is_poisoned_; 97 } 98 99 private: 100 mutable mutex mu_; 101 bool is_ready_ TF_GUARDED_BY(mu_); 102 Status is_poisoned_ TF_GUARDED_BY(mu_); 103 }; 104 105 absl::variant<NonBlockingControl, BlockingControl> ctrl_; 106 }; 107 108 } // namespace tensorflow 109 110 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_ 111