1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
16 #define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
17 
18 #include <memory>
19 
20 #include "absl/types/span.h"
21 #include "tensorflow/c/eager/abstract_tensor_handle.h"
22 #include "tensorflow/c/tensor_interface.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/platform/status.h"
25 
26 namespace tensorflow {
27 
28 // Abstract interface to an operation.
29 // This interface allows building and executing an operation in either
30 // tracing or immediate execution mode.
31 class AbstractOperation {
32  protected:
33   enum AbstractOperationKind {
34     kGraph,
35     kMlir,
36     kEager,
37     kTfrt,
38     kTape,
39     kOpHandler
40   };
AbstractOperation(AbstractOperationKind kind)41   explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {}
~AbstractOperation()42   virtual ~AbstractOperation() {}
43 
44  public:
getKind()45   AbstractOperationKind getKind() const { return kind_; }
46 
47   // Release any underlying resources, including the interface object.
48   //
49   // WARNING: The destructor of this class is marked as protected to disallow
50   // clients from directly destroying this object since it may manage it's own
51   // lifetime through ref counting. Thus this must be allocated on the heap and
52   // clients MUST call Release() in order to destroy an instance of this class.
53   virtual void Release() = 0;
54 
55   virtual Status Reset(const char* op, const char* raw_device_name) = 0;
56 
57   virtual const string& Name() const = 0;
58 
59   // Returns the operation's device name.
60   //
61   // The value returned may be different from the one set by SetDeviceName, but
62   // it will be compatible with it: the name will be updated by device placement
63   // logic to refer to the specific device chosen.
64   //
65   // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value
66   // returned by DeviceName should be "/device:GPU:*" until a particular GPU is
67   // chosen for the operation by the device placement logic in the
68   // executor. After that, the value returned by DeviceName will be a full
69   // device name such as "/job:localhost/replica:0/task:0/device:GPU:1".
70   virtual const string& DeviceName() const = 0;
71 
72   // Sets the operation device name.
73   //
74   // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and
75   // the result will be used as a constraint for device placement. See the
76   // documentation for DeviceName for more details.
77   //
78   // The value will override the previous value - that is, no "merging" of
79   // existing and given constraints will be performed.
80   virtual Status SetDeviceName(const char* name) = 0;
81 
82   virtual Status AddInput(AbstractTensorHandle* input) = 0;
83   virtual Status AddInputList(
84       absl::Span<AbstractTensorHandle* const> inputs) = 0;
85   virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals,
86                          int* num_retvals) = 0;
87 
88   virtual Status SetAttrString(const char* attr_name, const char* data,
89                                size_t length) = 0;
90   virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0;
91   virtual Status SetAttrFloat(const char* attr_name, float value) = 0;
92   virtual Status SetAttrBool(const char* attr_name, bool value) = 0;
93   virtual Status SetAttrType(const char* attr_name, DataType value) = 0;
94   virtual Status SetAttrShape(const char* attr_name, const int64_t* dims,
95                               const int num_dims) = 0;
96   virtual Status SetAttrFunction(const char* attr_name,
97                                  const AbstractOperation* value) = 0;
98   virtual Status SetAttrFunctionName(const char* attr_name, const char* value,
99                                      size_t length) = 0;
100   virtual Status SetAttrTensor(const char* attr_name,
101                                AbstractTensorInterface* tensor) = 0;
102   virtual Status SetAttrStringList(const char* attr_name,
103                                    const void* const* values,
104                                    const size_t* lengths, int num_values) = 0;
105   virtual Status SetAttrFloatList(const char* attr_name, const float* values,
106                                   int num_values) = 0;
107   virtual Status SetAttrIntList(const char* attr_name, const int64_t* values,
108                                 int num_values) = 0;
109   virtual Status SetAttrTypeList(const char* attr_name, const DataType* values,
110                                  int num_values) = 0;
111   virtual Status SetAttrBoolList(const char* attr_name,
112                                  const unsigned char* values,
113                                  int num_values) = 0;
114   virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims,
115                                   const int* num_dims, int num_values) = 0;
116   virtual Status SetAttrFunctionList(
117       const char* attr_name, absl::Span<const AbstractOperation*> values) = 0;
118 
119  private:
120   const AbstractOperationKind kind_;
121 };
122 
123 namespace internal {
124 struct AbstractOperationDeleter {
operatorAbstractOperationDeleter125   void operator()(AbstractOperation* p) const {
126     if (p != nullptr) {
127       p->Release();
128     }
129   }
130 };
131 }  // namespace internal
132 
133 using AbstractOperationPtr =
134     std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>;
135 
136 }  // namespace tensorflow
137 
138 #endif  // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_
139