1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_
18 
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Operation.h"  // from @llvm-project
21 
22 namespace mlir {
23 namespace TF {
24 
25 // Copies attributes that satisfy the given predicate from `from` to `to`.
26 template <typename Predicate>
CopyAttributes(Operation * from,Operation * to,Predicate P)27 void CopyAttributes(Operation *from, Operation *to, Predicate P) {
28   for (const NamedAttribute &attr : from->getAttrs())
29     if (P(attr)) to->setAttr(attr.first, attr.second);
30 }
31 
32 // Copies attributes whose name begins with an _ from `from` to `to`.
CopyUnderscoredAttributes(Operation * from,Operation * to)33 inline void CopyUnderscoredAttributes(Operation *from, Operation *to) {
34   CopyAttributes(from, to, [](const NamedAttribute &attr) {
35     return attr.first.strref().front() == '_';
36   });
37 }
38 
39 // Copies attributes that are either `device` or whose name begins with an _
40 // from `from` to `to`.
41 // TODO(b/158769932): This should be a general feature instead post some policy
42 // discussion.
CopyDeviceAndUnderscoredAttributes(Operation * from,Operation * to)43 inline void CopyDeviceAndUnderscoredAttributes(Operation *from, Operation *to) {
44   auto device = mlir::Identifier::get("device", from->getContext());
45   CopyAttributes(from, to, [&device](const NamedAttribute &attr) {
46     return attr.first.strref().front() == '_' || attr.first == device;
47   });
48 }
49 
50 // Forward declare these passthrough ops.
51 // TODO(jpienaar): Remove these and use trait instead.
52 class IdentityOp;
53 class IdentityNOp;
54 
55 // Returns if a value corresponds to a constant, returns the matched constant
56 // as an attribute.
57 template <typename AttrT>
GetValueAsConstant(Value val,AttrT & attr)58 bool GetValueAsConstant(Value val, AttrT &attr) {
59   while (auto result = val.dyn_cast<OpResult>()) {
60     Operation *op = result.getOwner();
61     if (!isa<IdentityOp>(op) && !isa<IdentityNOp>(op)) break;
62     val = op->getOperand(result.getResultNumber());
63   }
64   return matchPattern(val, m_Constant(&attr));
65 }
66 
67 }  // namespace TF
68 }  // namespace mlir
69 
70 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_ATTRIBUTE_UTILS_H_
71