1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the attributes used in the TensorFlow dialect.
17 
18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
20 
21 #include "llvm/ADT/StringRef.h"
22 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
25 
26 namespace mlir {
27 namespace TF {
28 
29 namespace detail {
30 
31 struct ShapeAttrStorage;
32 struct FuncAttrStorage;
33 
34 }  // namespace detail
35 
36 class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute,
37                                              detail::ShapeAttrStorage> {
38  public:
39   using Base::Base;
40 
41   // Get or create a shape attribute. If shape is llvm::None, then it is
42   // unranked. Otherwise it is ranked. And for ranked shapes, the value of the
43   // dimension size must be >= -1. The value of -1 means the dimension is
44   // dynamic. Otherwise, the dimension is static.
45   static ShapeAttr get(mlir::MLIRContext* context,
46                        llvm::Optional<ArrayRef<int64_t>> shape);
47 
48   // Get or create a shape attribute from a ShapedType type.
49   static ShapeAttr get(mlir::MLIRContext* context, ShapedType shaped_type);
50 
51   llvm::Optional<ArrayRef<int64_t>> getValue() const;
52 
53   bool hasRank() const;
54 
55   // If this is ranked, return the rank. Otherwise, abort.
56   int64_t getRank() const;
57 
58   // If this is ranked, return the shape. Otherwise, abort.
59   ArrayRef<int64_t> getShape() const;
60 
61   // If this is unranked type or any dimension has unknown size (<0), it doesn't
62   // have static shape. If all dimensions have known size (>= 0), it has static
63   // shape.
64   bool hasStaticShape() const;
65 };
66 
67 // Custom attribute to model AttrValue.value.func (NameAttrList type attribute).
68 // This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a
69 // DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is
70 // currently printed and parsed for the following format:
71 //
72 //   #tf.func<@symbol, {attr = "value"}>
73 //
74 // where the first element is the SymbolRefAttr and the second element is the
75 // DictionaryAttr.
76 class FuncAttr
77     : public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage> {
78  public:
79   using Base::Base;
80 
81   static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name,
82                       DictionaryAttr attr);
83 
84   static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol,
85                       DictionaryAttr attr);
86 
87   SymbolRefAttr GetName() const;
88 
89   DictionaryAttr GetAttrs() const;
90 };
91 
92 }  // namespace TF
93 }  // namespace mlir
94 
95 #endif  // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_
96