1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ 16 #define TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ 17 18 #include <memory> 19 20 #include "absl/types/span.h" 21 #include "tensorflow/c/eager/abstract_tensor_handle.h" 22 #include "tensorflow/c/tensor_interface.h" 23 #include "tensorflow/core/framework/types.pb.h" 24 #include "tensorflow/core/platform/status.h" 25 26 namespace tensorflow { 27 28 // Abstract interface to an operation. 29 // This interface allows building and executing an operation in either 30 // tracing or immediate execution mode. 31 class AbstractOperation { 32 protected: 33 enum AbstractOperationKind { 34 kGraph, 35 kMlir, 36 kEager, 37 kTfrt, 38 kTape, 39 kOpHandler 40 }; AbstractOperation(AbstractOperationKind kind)41 explicit AbstractOperation(AbstractOperationKind kind) : kind_(kind) {} ~AbstractOperation()42 virtual ~AbstractOperation() {} 43 44 public: getKind()45 AbstractOperationKind getKind() const { return kind_; } 46 47 // Release any underlying resources, including the interface object. 48 // 49 // WARNING: The destructor of this class is marked as protected to disallow 50 // clients from directly destroying this object since it may manage it's own 51 // lifetime through ref counting. Thus this must be allocated on the heap and 52 // clients MUST call Release() in order to destroy an instance of this class. 53 virtual void Release() = 0; 54 55 virtual Status Reset(const char* op, const char* raw_device_name) = 0; 56 57 virtual const string& Name() const = 0; 58 59 // Returns the operation's device name. 60 // 61 // The value returned may be different from the one set by SetDeviceName, but 62 // it will be compatible with it: the name will be updated by device placement 63 // logic to refer to the specific device chosen. 64 // 65 // Example: If one calls `op->SetDeviceName("/device:GPU")`, the value 66 // returned by DeviceName should be "/device:GPU:*" until a particular GPU is 67 // chosen for the operation by the device placement logic in the 68 // executor. After that, the value returned by DeviceName will be a full 69 // device name such as "/job:localhost/replica:0/task:0/device:GPU:1". 70 virtual const string& DeviceName() const = 0; 71 72 // Sets the operation device name. 73 // 74 // The given `name` must be parseable by DeviceNameUtils::ParseFullName, and 75 // the result will be used as a constraint for device placement. See the 76 // documentation for DeviceName for more details. 77 // 78 // The value will override the previous value - that is, no "merging" of 79 // existing and given constraints will be performed. 80 virtual Status SetDeviceName(const char* name) = 0; 81 82 virtual Status AddInput(AbstractTensorHandle* input) = 0; 83 virtual Status AddInputList( 84 absl::Span<AbstractTensorHandle* const> inputs) = 0; 85 virtual Status Execute(absl::Span<AbstractTensorHandle*> retvals, 86 int* num_retvals) = 0; 87 88 virtual Status SetAttrString(const char* attr_name, const char* data, 89 size_t length) = 0; 90 virtual Status SetAttrInt(const char* attr_name, int64_t value) = 0; 91 virtual Status SetAttrFloat(const char* attr_name, float value) = 0; 92 virtual Status SetAttrBool(const char* attr_name, bool value) = 0; 93 virtual Status SetAttrType(const char* attr_name, DataType value) = 0; 94 virtual Status SetAttrShape(const char* attr_name, const int64_t* dims, 95 const int num_dims) = 0; 96 virtual Status SetAttrFunction(const char* attr_name, 97 const AbstractOperation* value) = 0; 98 virtual Status SetAttrFunctionName(const char* attr_name, const char* value, 99 size_t length) = 0; 100 virtual Status SetAttrTensor(const char* attr_name, 101 AbstractTensorInterface* tensor) = 0; 102 virtual Status SetAttrStringList(const char* attr_name, 103 const void* const* values, 104 const size_t* lengths, int num_values) = 0; 105 virtual Status SetAttrFloatList(const char* attr_name, const float* values, 106 int num_values) = 0; 107 virtual Status SetAttrIntList(const char* attr_name, const int64_t* values, 108 int num_values) = 0; 109 virtual Status SetAttrTypeList(const char* attr_name, const DataType* values, 110 int num_values) = 0; 111 virtual Status SetAttrBoolList(const char* attr_name, 112 const unsigned char* values, 113 int num_values) = 0; 114 virtual Status SetAttrShapeList(const char* attr_name, const int64_t** dims, 115 const int* num_dims, int num_values) = 0; 116 virtual Status SetAttrFunctionList( 117 const char* attr_name, absl::Span<const AbstractOperation*> values) = 0; 118 119 private: 120 const AbstractOperationKind kind_; 121 }; 122 123 namespace internal { 124 struct AbstractOperationDeleter { operatorAbstractOperationDeleter125 void operator()(AbstractOperation* p) const { 126 if (p != nullptr) { 127 p->Release(); 128 } 129 } 130 }; 131 } // namespace internal 132 133 using AbstractOperationPtr = 134 std::unique_ptr<AbstractOperation, internal::AbstractOperationDeleter>; 135 136 } // namespace tensorflow 137 138 #endif // TENSORFLOW_C_EAGER_ABSTRACT_OPERATION_H_ 139