1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
17 
18 #include "absl/types/variant.h"
19 #include "tensorflow/core/common_runtime/eager/context.h"
20 #include "tensorflow/core/framework/tensor.h"
21 #include "tensorflow/core/lib/core/status.h"
22 
23 namespace tensorflow {
24 
25 // Local Tensor Handle: Handle to a Tensor present on the local host.
26 class LocalTensorHandleData {
27  public:
LocalTensorHandleData()28   LocalTensorHandleData() : ctrl_(absl::in_place_type<BlockingControl>) {}
LocalTensorHandleData(tensorflow::Tensor && t)29   explicit LocalTensorHandleData(tensorflow::Tensor&& t)
30       : tensor_(std::move(t)),
31         forwarding_protection_tensor_(tensor_),
32         ctrl_(absl::in_place_type<NonBlockingControl>) {}
33 
34   // A local tensor handle should be able to satisfy all of these requests.
35   Status Tensor(const tensorflow::Tensor** t) const;
36   Status TensorValue(tensorflow::TensorValue* t);
37   Status Shape(TensorShape* shape) const;
38   Status NumDims(int* num_dims) const;
39   Status Dim(int dim_index, int64* dim) const;
40   Status NumElements(int64* num_elements) const;
41   Status Unprotect();
42 
IsReady()43   bool IsReady() const {
44     return absl::visit([](auto& data) { return data.IsReady(); }, ctrl_);
45   }
46 
WaitReady(const char * caller)47   Status WaitReady(const char* caller) const {
48     return absl::visit([caller](auto& data) { return data.WaitReady(caller); },
49                        ctrl_);
50   }
Poison(Status status)51   void Poison(Status status) {
52     return absl::visit([status](auto& data) { data.Poison(status); }, ctrl_);
53   }
IsPoisoned()54   Status IsPoisoned() const {
55     return absl::visit([](auto& data) { return data.IsPoisoned(); }, ctrl_);
56   }
57 
58   Status SetTensor(tensorflow::Tensor&& t);
59 
60   string DebugString() const;
61 
62  private:
63   tensorflow::Tensor tensor_;
64   // TensorHandle has its own reference counting which is distinct from the
65   // backing Tensor. As a result, if the Tensor reference count is 1 while
66   // executing an op, the TensorBuffer could be reused for the output. We avoid
67   // this behavior maintaining another reference count with the
68   // forwarding_protection_tensor_ Tensor. When Unprotect() is called, we
69   // release this Tensor to allow forwarding.
70   tensorflow::Tensor forwarding_protection_tensor_;
71 
72   // We distinguish between ready and empty tensors with the ctrl_ variant.
73   // which contains 2 implementations of the waiting logic. The
74   // NonBlockingControl is a simple no-op class whereas the BlockingControl
75   // actually uses a mutex. By using a variant we avoid the overhead of
76   // constructing and destructing the mutex for ready local tensors.
77   class NonBlockingControl {
78    public:
IsReady()79     bool IsReady() const { return true; }
WaitReady(const char * caller)80     Status WaitReady(const char* caller) const { return Status::OK(); }
Poison(Status status)81     void Poison(Status status) {}
IsPoisoned()82     Status IsPoisoned() const { return Status::OK(); }
83   };
84 
85   class BlockingControl {
86    public:
IsReady()87     bool IsReady() const {
88       tf_shared_lock l(mu_);
89       return is_ready_;
90     }
91     void SetReady();
92     Status WaitReady(const char* caller) const;
93     void Poison(Status status);
IsPoisoned()94     Status IsPoisoned() const {
95       tf_shared_lock l(mu_);
96       return is_poisoned_;
97     }
98 
99    private:
100     mutable mutex mu_;
101     bool is_ready_ TF_GUARDED_BY(mu_);
102     Status is_poisoned_ TF_GUARDED_BY(mu_);
103   };
104 
105   absl::variant<NonBlockingControl, BlockingControl> ctrl_;
106 };
107 
108 }  // namespace tensorflow
109 
110 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_TENSOR_HANDLE_DATA_H_
111