1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This file defines the operations used in the tf_framework dialect. 17 18 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h" 19 20 #include "mlir/IR/Builders.h" // from @llvm-project 21 #include "mlir/IR/DialectImplementation.h" // from @llvm-project 22 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc" 23 24 namespace mlir { 25 namespace kernel_gen { 26 namespace tf_framework { 27 initialize()28void TFFrameworkDialect::initialize() { 29 addOperations< 30 #define GET_OP_LIST 31 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" 32 >(); 33 addTypes<OpKernelContextType>(); 34 } 35 36 /// Parse a type registered to this dialect. parseType(DialectAsmParser & parser) const37Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const { 38 StringRef keyword; 39 if (parser.parseKeyword(&keyword)) return Type(); 40 41 if (keyword == "op_kernel_context") { 42 return OpKernelContextType::get(getContext()); 43 } 44 45 parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ") 46 << keyword; 47 return Type(); 48 } 49 50 /// Print a type registered to this dialect. printType(Type type,DialectAsmPrinter & os) const51void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const { 52 if (type.isa<OpKernelContextType>()) { 53 os << "op_kernel_context"; 54 return; 55 } 56 llvm_unreachable("unexpected TF Framework type kind"); 57 } 58 59 template <typename OpTy> Verify(OpTy op)60LogicalResult Verify(OpTy op) { 61 return success(); 62 } 63 64 //===----------------------------------------------------------------------===// 65 // TFAllocOp 66 //===----------------------------------------------------------------------===// 67 template <> Verify(TFAllocOp op)68LogicalResult Verify<TFAllocOp>(TFAllocOp op) { 69 // Check that the total number of operands matches the number of dynamic 70 // dimensions specified in the memref type. 71 unsigned result_dyn_dims = op.getType().getNumDynamicDims(); 72 unsigned dyn_sizes_count = op.dyn_sizes().size(); 73 if (dyn_sizes_count != result_dyn_dims) 74 return op.emitOpError() 75 << "`dyn_sizes` count " << dyn_sizes_count 76 << " does not match dynamic dimensions count in the result type" 77 << op.getType(); 78 return success(); 79 } 80 ConvertAttrToEnumValue(ErrorCode error_code)81::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) { 82 using ::tensorflow::error::Code; 83 switch (error_code) { 84 case ErrorCode::OK: 85 return Code::OK; 86 case ErrorCode::CANCELLED: 87 return Code::CANCELLED; 88 case ErrorCode::UNKNOWN: 89 return Code::UNKNOWN; 90 case ErrorCode::INVALID_ARGUMENT: 91 return Code::INVALID_ARGUMENT; 92 case ErrorCode::DEADLINE_EXCEEDED: 93 return Code::DEADLINE_EXCEEDED; 94 case ErrorCode::NOT_FOUND: 95 return Code::NOT_FOUND; 96 case ErrorCode::ALREADY_EXISTS: 97 return Code::ALREADY_EXISTS; 98 case ErrorCode::PERMISSION_DENIED: 99 return Code::PERMISSION_DENIED; 100 case ErrorCode::UNAUTHENTICATED: 101 return Code::UNAUTHENTICATED; 102 case ErrorCode::RESOURCE_EXHAUSTED: 103 return Code::RESOURCE_EXHAUSTED; 104 case ErrorCode::FAILED_PRECONDITION: 105 return Code::FAILED_PRECONDITION; 106 case ErrorCode::ABORTED: 107 return Code::ABORTED; 108 case ErrorCode::OUT_OF_RANGE: 109 return Code::OUT_OF_RANGE; 110 case ErrorCode::UNIMPLEMENTED: 111 return Code::UNIMPLEMENTED; 112 case ErrorCode::INTERNAL: 113 return Code::INTERNAL; 114 case ErrorCode::UNAVAILABLE: 115 return Code::UNAVAILABLE; 116 case ErrorCode::DATA_LOSS: 117 return Code::DATA_LOSS; 118 } 119 } 120 121 //===----------------------------------------------------------------------===// 122 // MinimumBroadcastShapesOp 123 //===----------------------------------------------------------------------===// 124 template <> Verify(MinimumBroadcastShapesOp op)125LogicalResult Verify<MinimumBroadcastShapesOp>(MinimumBroadcastShapesOp op) { 126 // Check that the number of operands matches the number of outputs. 127 unsigned result_shapes_count = op.results().size(); 128 unsigned operand_shapes_count = op.shapes().size(); 129 if (operand_shapes_count != result_shapes_count) { 130 return op.emitOpError() 131 << "number of operand shapes " << operand_shapes_count 132 << " does not match number of result shapes " << result_shapes_count; 133 } 134 return success(); 135 } 136 137 } // namespace tf_framework 138 } // namespace kernel_gen 139 } // namespace mlir 140 141 #define GET_OP_CLASSES 142 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc" 143