1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
18 
19 #include <string>
20 
21 #include "llvm/ADT/DenseMapInfo.h"
22 #include "llvm/ADT/Hashing.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "mlir/IR/Attributes.h"  // from @llvm-project
25 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
26 #include "mlir/IR/Operation.h"  // from @llvm-project
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_verifiers.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 
31 namespace mlir {
32 namespace TF {
33 
34 //===----------------------------------------------------------------------===//
35 // TensorFlow Contraction Fusion.
36 //===----------------------------------------------------------------------===//
37 
38 struct ContractionFusion {
39   explicit ContractionFusion(
40       StringRef output_kernel, ArrayRef<int> additional_arguments = {},
41       ArrayRef<NamedAttribute> additional_attributes = {})
42       : output_kernel(output_kernel.str()),
43         additional_arguments(additional_arguments.begin(),
44                              additional_arguments.end()),
45         additional_attributes(additional_attributes.begin(),
46                               additional_attributes.end()) {}
47 
48   // Name of the output kernel implementing the contraction fusion.
49   std::string output_kernel;
50 
51   // Indices of additional arguments that will be forwarded to the fused
52   // operation (e.g. forward bias vector if fusing BiasAdd operation).
53   SmallVector<int, 4> additional_arguments;
54 
55   // Add additional attributes to the fused node.
56   SmallVector<NamedAttribute, 4> additional_attributes;
57 };
58 
59 //===----------------------------------------------------------------------===//
60 // TensorFlow Resource Handles.
61 //===----------------------------------------------------------------------===//
62 
IsResourceHandleAnonymous(StringRef name)63 inline bool IsResourceHandleAnonymous(StringRef name) {
64   return name == ::tensorflow::ResourceHandle::ANONYMOUS_NAME;
65 }
66 
67 // Helper struct representing an identifier for a resource handle. For resource
68 // handles created explicitly and shared across resource allocator ops,
69 // `container`, `name`, and `device` can be set. If an resource handle is tied
70 // to an instance of an operation (e.g. TensorFlow runtime operation caching),
71 // `op` can be set instead.
72 struct ResourceHandle {
ResourceHandleResourceHandle73   ResourceHandle(StringRef container, StringRef name, StringRef device,
74                  Operation* op)
75       : container(container), name(name), device(device), op(op) {}
76 
77   bool operator==(const ResourceHandle& rhs) const {
78     return container == rhs.container && name == rhs.name &&
79            device == rhs.device && op == rhs.op;
80   }
81 
82   // Make ResourceHandle hashable.
83   friend ::llvm::hash_code hash_value(const ResourceHandle& resource_handle);
84 
85   StringRef container;
86   StringRef name;
87   StringRef device;
88   Operation* op = nullptr;
89 };
90 
91 // Make ResourceHandle hashable.
hash_value(const ResourceHandle & resource_handle)92 inline ::llvm::hash_code hash_value(const ResourceHandle& resource_handle) {
93   return ::llvm::hash_combine(resource_handle.container, resource_handle.name,
94                               resource_handle.device, resource_handle.op);
95 }
96 
97 // Helper struct holding a resource handle value and unique id associated to the
98 // resource handle.
99 struct ResourceHandleValueAndId {
ResourceHandleValueAndIdResourceHandleValueAndId100   ResourceHandleValueAndId(Value value, int64_t id) : value(value), id(id) {}
101 
102   Value value;
103   int64_t id = -1;
104 };
105 
106 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h.inc"
107 }  // namespace TF
108 }  // namespace mlir
109 
110 namespace llvm {
111 template <>
112 struct DenseMapInfo<mlir::TF::ResourceHandle> {
113   static mlir::TF::ResourceHandle getEmptyKey() {
114     return {/*container=*/"", /*name=*/"", /*device=*/"",
115             /*op=*/DenseMapInfo<mlir::Operation*>::getEmptyKey()};
116   }
117 
118   static mlir::TF::ResourceHandle getTombstoneKey() {
119     return {/*container=*/"", /*name=*/"", /*device=*/"",
120             /*op=*/DenseMapInfo<mlir::Operation*>::getTombstoneKey()};
121   }
122 
123   static unsigned getHashValue(
124       const mlir::TF::ResourceHandle& resource_handle) {
125     return mlir::TF::hash_value(resource_handle);
126   }
127 
128   static bool isEqual(const mlir::TF::ResourceHandle& lhs,
129                       const mlir::TF::ResourceHandle& rhs) {
130     return lhs == rhs;
131   }
132 };
133 }  // namespace llvm
134 
135 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_OP_INTERFACES_H_
136