1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the attributes used in the TensorFlow dialect. 17 18 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ 19 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ 20 21 #include "llvm/ADT/StringRef.h" 22 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project 23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 24 #include "mlir/IR/MLIRContext.h" // from @llvm-project 25 26 namespace mlir { 27 namespace TF { 28 29 namespace detail { 30 31 struct ShapeAttrStorage; 32 struct FuncAttrStorage; 33 34 } // namespace detail 35 36 class ShapeAttr : public Attribute::AttrBase<ShapeAttr, Attribute, 37 detail::ShapeAttrStorage> { 38 public: 39 using Base::Base; 40 41 // Get or create a shape attribute. If shape is llvm::None, then it is 42 // unranked. Otherwise it is ranked. And for ranked shapes, the value of the 43 // dimension size must be >= -1. The value of -1 means the dimension is 44 // dynamic. Otherwise, the dimension is static. 45 static ShapeAttr get(mlir::MLIRContext* context, 46 llvm::Optional<ArrayRef<int64_t>> shape); 47 48 // Get or create a shape attribute from a ShapedType type. 49 static ShapeAttr get(mlir::MLIRContext* context, ShapedType shaped_type); 50 51 llvm::Optional<ArrayRef<int64_t>> getValue() const; 52 53 bool hasRank() const; 54 55 // If this is ranked, return the rank. Otherwise, abort. 56 int64_t getRank() const; 57 58 // If this is ranked, return the shape. Otherwise, abort. 59 ArrayRef<int64_t> getShape() const; 60 61 // If this is unranked type or any dimension has unknown size (<0), it doesn't 62 // have static shape. If all dimensions have known size (>= 0), it has static 63 // shape. 64 bool hasStaticShape() const; 65 }; 66 67 // Custom attribute to model AttrValue.value.func (NameAttrList type attribute). 68 // This attribute holds a SymbolRefAttr, for the NameAttrList.name string and a 69 // DictionaryAttr for the NameAttrList.attr map<string, AttrValue>. It is 70 // currently printed and parsed for the following format: 71 // 72 // #tf.func<@symbol, {attr = "value"}> 73 // 74 // where the first element is the SymbolRefAttr and the second element is the 75 // DictionaryAttr. 76 class FuncAttr 77 : public Attribute::AttrBase<FuncAttr, Attribute, detail::FuncAttrStorage> { 78 public: 79 using Base::Base; 80 81 static FuncAttr get(mlir::MLIRContext* context, llvm::StringRef name, 82 DictionaryAttr attr); 83 84 static FuncAttr get(mlir::MLIRContext* context, SymbolRefAttr symbol, 85 DictionaryAttr attr); 86 87 SymbolRefAttr GetName() const; 88 89 DictionaryAttr GetAttrs() const; 90 }; 91 92 } // namespace TF 93 } // namespace mlir 94 95 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_ATTRIBUTES_H_ 96