1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
17
18 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
21 #include "tensorflow/core/framework/rng_alg.h"
22
23 namespace mlir {
24 namespace TF {
25
26 namespace {
27 // Returns int or float DenseElementsAttr with scalar shape with the given
28 // element type and the integer value.
GetScalarOfType(Type ty,int64_t raw_value)29 static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
30 RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
31 if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
32 FloatAttr attr = FloatAttr::get(float_ty, raw_value);
33 return DenseElementsAttr::get(scalar_ty, attr);
34 }
35
36 auto int_ty = ty.cast<IntegerType>();
37 IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
38 return DenseElementsAttr::get(scalar_ty, attr);
39 }
40
41 // Returns subtype of `resource` if present. Otherwise an unranked tensor type
42 // of `element_type` is returned.
GetResourceSubtypeOrDefault(Value resource,Type element_type)43 static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
44 auto resource_type = resource.getType()
45 .cast<TensorType>()
46 .getElementType()
47 .cast<ResourceType>();
48 if (resource_type.getSubtypes().size() == 1)
49 return resource_type.getSubtypes().front();
50
51 return UnrankedTensorType::get(element_type);
52 }
53
HasResourceSubtype(Value resource)54 static bool HasResourceSubtype(Value resource) {
55 return resource.getType()
56 .cast<TensorType>()
57 .getElementType()
58 .cast<ResourceType>()
59 .getSubtypes()
60 .size() == 1;
61 }
62
GetResourceSubtype(Value resource)63 static Type GetResourceSubtype(Value resource) {
64 return resource.getType()
65 .cast<TensorType>()
66 .getElementType()
67 .cast<ResourceType>()
68 .getSubtypes()
69 .front();
70 }
71
72 // Decompose tf.RngReadAndSkip.
73 //
74 // For Philox, the resource variable holds a tensor<3xi64> with the state:
75 // [counter_lo, counter_hi, key]
76 //
77 // RngReadAndSkip increments the 128 bit counter value by 256 * delta and
78 // returns the original state value.
79 //
80 // For Threefry, the resource variable holds a tensor<2xi64> with the state:
81 // [counter, key]
82 //
83 // RngReadAndSkip increments the 64 bit counter value by 256 * delta and
84 // returns a tensor<3xi64> value [counter, key, 0].
85 class DecomposeRngReadAndSkipOp : public RewritePattern {
86 public:
DecomposeRngReadAndSkipOp(MLIRContext * context)87 explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
88 : RewritePattern(RngReadAndSkipOp::getOperationName(),
89 {
90 AddV2Op::getOperationName(),
91 AssignVariableOp::getOperationName(),
92 CastOp::getOperationName(),
93 ConstOp::getOperationName(),
94 LessOp::getOperationName(),
95 MulOp::getOperationName(),
96 PadOp::getOperationName(),
97 PackOp::getOperationName(),
98 ReadVariableOp::getOperationName(),
99 SelectV2Op::getOperationName(),
100 UnpackOp::getOperationName(),
101 },
102 1, context) {}
103
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const104 LogicalResult matchAndRewrite(Operation *op,
105 PatternRewriter &rewriter) const override {
106 auto rng_op = cast<RngReadAndSkipOp>(op);
107
108 DenseIntElementsAttr alg_constant;
109 if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
110 return rewriter.notifyMatchFailure(
111 op, "unable to determine algorithm statically");
112 }
113
114 if (alg_constant.getNumElements() != 1) {
115 return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
116 }
117
118 uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
119 tensorflow::Algorithm alg;
120 if (tensorflow::RNG_ALG_PHILOX == alg_value) {
121 alg = tensorflow::RNG_ALG_PHILOX;
122 } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
123 alg = tensorflow::RNG_ALG_THREEFRY;
124 } else {
125 return rewriter.notifyMatchFailure(op, "unsupported alg");
126 }
127
128 Type state_element_type = rewriter.getI64Type();
129 RankedTensorType op_type = RankedTensorType::get(
130 {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
131 state_element_type);
132 if (op_type != rng_op.getType()) {
133 return rewriter.notifyMatchFailure(op, "unexpected op type");
134 }
135
136 if (!HasResourceSubtype(rng_op.resource())) {
137 return rewriter.notifyMatchFailure(op, "missing resource subtype");
138 }
139
140 int counter_size = tensorflow::GetCounterSize(alg);
141 int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
142 RankedTensorType res_type =
143 RankedTensorType::get({state_size}, state_element_type);
144 if (res_type != GetResourceSubtype(rng_op.resource())) {
145 return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
146 }
147
148 Location loc = op->getLoc();
149
150 // Read the state value from the resource.
151 Value state =
152 rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
153
154 // Extract the key and counter from the state.
155 RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
156 auto unpacked = rewriter.create<UnpackOp>(
157 loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
158 Value key = unpacked.getResult(counter_size);
159
160 SmallVector<Value, 4> counter;
161 for (int i = 0; i < counter_size; ++i) {
162 counter.push_back(unpacked.getResult(i));
163 }
164
165 // Set the increment to 256 * delta.
166 Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
167 RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
168 Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
169 Value increment =
170 rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
171
172 // Increment the counter.
173 SmallVector<Value, 4> pack_args;
174 RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
175 Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
176 Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
177 for (int i = 0; i < counter_size; ++i) {
178 Value word = counter[i];
179 Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
180 Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
181 Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
182 pack_args.push_back(new_word);
183
184 Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
185 increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
186 }
187
188 // Save the new state value to the resource.
189 pack_args.push_back(key);
190 Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
191 rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
192
193 // Pad the original state as necessary to fill the output shape.
194 int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
195 Type i64 = rewriter.getI64Type();
196 RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
197 std::vector<int64_t> paddings_values = {0, pad};
198 Value paddings = rewriter.create<ConstOp>(
199 loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
200 Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
201
202 rewriter.replaceOp(op, output);
203 return success();
204 }
205 };
206
207 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
208 } // namespace
209
PopulateDecomposeResourceOpsPatterns(MLIRContext * context,OwningRewritePatternList * patterns)210 void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
211 OwningRewritePatternList *patterns) {
212 patterns->insert<DecomposeRngReadAndSkipOp>(context);
213 populateWithGenerated(context, *patterns);
214 }
215
216 } // namespace TF
217 } // namespace mlir
218