• Home
  • History
  • Annotate
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines the operations used in the tf_framework dialect.
17 
18 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.h"
19 
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/DialectImplementation.h"  // from @llvm-project
22 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_status.cc.inc"
23 
24 namespace mlir {
25 namespace kernel_gen {
26 namespace tf_framework {
27 
initialize()28 void TFFrameworkDialect::initialize() {
29   addOperations<
30 #define GET_OP_LIST
31 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
32       >();
33   addTypes<OpKernelContextType>();
34 }
35 
36 /// Parse a type registered to this dialect.
parseType(DialectAsmParser & parser) const37 Type TFFrameworkDialect::parseType(DialectAsmParser &parser) const {
38   StringRef keyword;
39   if (parser.parseKeyword(&keyword)) return Type();
40 
41   if (keyword == "op_kernel_context") {
42     return OpKernelContextType::get(getContext());
43   }
44 
45   parser.emitError(parser.getNameLoc(), "unknown TF Framework type: ")
46       << keyword;
47   return Type();
48 }
49 
50 /// Print a type registered to this dialect.
printType(Type type,DialectAsmPrinter & os) const51 void TFFrameworkDialect::printType(Type type, DialectAsmPrinter &os) const {
52   if (type.isa<OpKernelContextType>()) {
53     os << "op_kernel_context";
54     return;
55   }
56   llvm_unreachable("unexpected TF Framework type kind");
57 }
58 
59 template <typename OpTy>
Verify(OpTy op)60 LogicalResult Verify(OpTy op) {
61   return success();
62 }
63 
64 //===----------------------------------------------------------------------===//
65 // TFAllocOp
66 //===----------------------------------------------------------------------===//
67 template <>
Verify(TFAllocOp op)68 LogicalResult Verify<TFAllocOp>(TFAllocOp op) {
69   // Check that the total number of operands matches the number of dynamic
70   // dimensions specified in the memref type.
71   unsigned result_dyn_dims = op.getType().getNumDynamicDims();
72   unsigned dyn_sizes_count = op.dyn_sizes().size();
73   if (dyn_sizes_count != result_dyn_dims)
74     return op.emitOpError()
75            << "`dyn_sizes` count " << dyn_sizes_count
76            << " does not match dynamic dimensions count in the result type"
77            << op.getType();
78   return success();
79 }
80 
ConvertAttrToEnumValue(ErrorCode error_code)81 ::tensorflow::error::Code ConvertAttrToEnumValue(ErrorCode error_code) {
82   using ::tensorflow::error::Code;
83   switch (error_code) {
84     case ErrorCode::OK:
85       return Code::OK;
86     case ErrorCode::CANCELLED:
87       return Code::CANCELLED;
88     case ErrorCode::UNKNOWN:
89       return Code::UNKNOWN;
90     case ErrorCode::INVALID_ARGUMENT:
91       return Code::INVALID_ARGUMENT;
92     case ErrorCode::DEADLINE_EXCEEDED:
93       return Code::DEADLINE_EXCEEDED;
94     case ErrorCode::NOT_FOUND:
95       return Code::NOT_FOUND;
96     case ErrorCode::ALREADY_EXISTS:
97       return Code::ALREADY_EXISTS;
98     case ErrorCode::PERMISSION_DENIED:
99       return Code::PERMISSION_DENIED;
100     case ErrorCode::UNAUTHENTICATED:
101       return Code::UNAUTHENTICATED;
102     case ErrorCode::RESOURCE_EXHAUSTED:
103       return Code::RESOURCE_EXHAUSTED;
104     case ErrorCode::FAILED_PRECONDITION:
105       return Code::FAILED_PRECONDITION;
106     case ErrorCode::ABORTED:
107       return Code::ABORTED;
108     case ErrorCode::OUT_OF_RANGE:
109       return Code::OUT_OF_RANGE;
110     case ErrorCode::UNIMPLEMENTED:
111       return Code::UNIMPLEMENTED;
112     case ErrorCode::INTERNAL:
113       return Code::INTERNAL;
114     case ErrorCode::UNAVAILABLE:
115       return Code::UNAVAILABLE;
116     case ErrorCode::DATA_LOSS:
117       return Code::DATA_LOSS;
118   }
119 }
120 
121 //===----------------------------------------------------------------------===//
122 // MinimumBroadcastShapesOp
123 //===----------------------------------------------------------------------===//
124 template <>
Verify(MinimumBroadcastShapesOp op)125 LogicalResult Verify<MinimumBroadcastShapesOp>(MinimumBroadcastShapesOp op) {
126   // Check that the number of operands matches the number of outputs.
127   unsigned result_shapes_count = op.results().size();
128   unsigned operand_shapes_count = op.shapes().size();
129   if (operand_shapes_count != result_shapes_count) {
130     return op.emitOpError()
131            << "number of operand shapes " << operand_shapes_count
132            << " does not match number of result shapes " << result_shapes_count;
133   }
134   return success();
135 }
136 
137 }  // namespace tf_framework
138 }  // namespace kernel_gen
139 }  // namespace mlir
140 
141 #define GET_OP_CLASSES
142 #include "tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.cc.inc"
143