1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
18
19 #include <string>
20
21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/IR/Attributes.h" // from @llvm-project
25 #include "mlir/IR/OpImplementation.h" // from @llvm-project
26 #include "mlir/IR/Operation.h" // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30
31 namespace mlir {
32 namespace TF {
33
34 //===----------------------------------------------------------------------===//
35 // TensorFlow Contraction Fusion.
36 //===----------------------------------------------------------------------===//
37
38 struct ContractionFusion {
39 explicit ContractionFusion(
40 StringRef output_kernel, ArrayRef<int> additional_arguments = {},
41 ArrayRef<NamedAttribute> additional_attributes = {})
42 : output_kernel(output_kernel.str()),
43 additional_arguments(additional_arguments.begin(),
44 additional_arguments.end()),
45 additional_attributes(additional_attributes.begin(),
46 additional_attributes.end()) {}
47
48 // Name of the output kernel implementing the contraction fusion.
49 std::string output_kernel;
50
51 // Indices of additional arguments that will be forwarded to the fused
52 // operation (e.g. forward bias vector if fusing BiasAdd operation).
53 SmallVector<int, 4> additional_arguments;
54
55 // Add additional attributes to the fused node.
56 SmallVector<NamedAttribute, 4> additional_attributes;
57 };
58
59 //===----------------------------------------------------------------------===//
60 // TensorFlow Resource Handles.
61 //===----------------------------------------------------------------------===//
62
IsResourceHandleAnonymous(StringRef name)63 inline bool IsResourceHandleAnonymous(StringRef name) {
64 return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME;
65 }
66
67 // Helper struct representing an identifier for a resource handle. For resource
68 // handles created explicitly and shared across resource allocator ops,
69 // `container`, `name`, and `device` can be set. If an resource handle is tied
70 // to an instance of an operation (e.g. TensorFlow runtime operation caching),
71 // `op` can be set instead.
72 struct ResourceHandle {
ResourceHandleResourceHandle73 ResourceHandle(StringRef container, StringRef name, StringRef device,
74 Operation* op)
75 : container(container), name(name), device(device), op(op) {}
76
77 bool operator==(const ResourceHandle& rhs) const {
78 return container == rhs.container && name == rhs.name &&
79 device == rhs.device && op == rhs.op;
80 }
81
82 // Make ResourceHandle hashable.
83 friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
84
85 StringRef container;
86 StringRef name;
87 StringRef device;
88 Operation* op = nullptr;
89 };
90
91 // Make ResourceHandle hashable.
hash_value(const ResourceHandle & resource_handle)92 inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
93 return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
94 resource_handle.device, resource_handle.op);
95 }
96
97 // Helper struct holding a resource handle value and unique id associated to the
98 // resource handle.
99 struct ResourceHandleValueAndId {
ResourceHandleValueAndIdResourceHandleValueAndId100 ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
101
102 Value value;
103 int64_t id = -1;
104 };
105
106 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
107 } // namespace TF
108 } // namespace mlir
109
110 namespace llvm {
111 template <>
112 struct DenseMapInfo<mlir::TF::ResourceHandle> {
113 static mlir::TF::ResourceHandle getEmptyKey() {
114 return {/*container=*/"", /*name=*/"", /*device=*/"",
115 /*op=*/DenseMapInfo<mlir::Operation*>::getEmptyKey()};
116 }
117
118 static mlir::TF::ResourceHandle getTombstoneKey() {
119 return {/*container=*/"", /*name=*/"", /*device=*/"",
120 /*op=*/DenseMapInfo<mlir::Operation*>::getTombstoneKey()};
121 }
122
123 static unsigned getHashValue(
124 const mlir::TF::ResourceHandle& resource_handle) {
125 return mlir::TF::hash_value(resource_handle);
126 }
127
128 static bool isEqual(const mlir::TF::ResourceHandle& lhs,
129 const mlir::TF::ResourceHandle& rhs) {
130 return lhs == rhs;
131 }
132 };
133 } // namespace llvm
134
135 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
136