1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h"
17 
18 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
21 #include "tensorflow/core/framework/rng_alg.h"
22 
23 namespace mlir {
24 namespace TF {
25 
26 namespace {
27 // Returns int or float DenseElementsAttr with scalar shape with the given
28 // element type and the integer value.
GetScalarOfType(Type ty,int64_t raw_value)29 static DenseElementsAttr GetScalarOfType(Type ty, int64_t raw_value) {
30   RankedTensorType scalar_ty = RankedTensorType::get({}, ty);
31   if (auto float_ty = ty.dyn_cast_or_null<FloatType>()) {
32     FloatAttr attr = FloatAttr::get(float_ty, raw_value);
33     return DenseElementsAttr::get(scalar_ty, attr);
34   }
35 
36   auto int_ty = ty.cast<IntegerType>();
37   IntegerAttr attr = IntegerAttr::get(int_ty, raw_value);
38   return DenseElementsAttr::get(scalar_ty, attr);
39 }
40 
41 // Returns subtype of `resource` if present. Otherwise an unranked tensor type
42 // of `element_type` is returned.
GetResourceSubtypeOrDefault(Value resource,Type element_type)43 static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) {
44   auto resource_type = resource.getType()
45                            .cast<TensorType>()
46                            .getElementType()
47                            .cast<ResourceType>();
48   if (resource_type.getSubtypes().size() == 1)
49     return resource_type.getSubtypes().front();
50 
51   return UnrankedTensorType::get(element_type);
52 }
53 
HasResourceSubtype(Value resource)54 static bool HasResourceSubtype(Value resource) {
55   return resource.getType()
56              .cast<TensorType>()
57              .getElementType()
58              .cast<ResourceType>()
59              .getSubtypes()
60              .size() == 1;
61 }
62 
GetResourceSubtype(Value resource)63 static Type GetResourceSubtype(Value resource) {
64   return resource.getType()
65       .cast<TensorType>()
66       .getElementType()
67       .cast<ResourceType>()
68       .getSubtypes()
69       .front();
70 }
71 
72 // Decompose tf.RngReadAndSkip.
73 //
74 // For Philox, the resource variable holds a tensor<3xi64> with the state:
75 //   [counter_lo, counter_hi, key]
76 //
77 //   RngReadAndSkip increments the 128 bit counter value by 256 * delta and
78 //   returns the original state value.
79 //
80 // For Threefry, the resource variable holds a tensor<2xi64> with the state:
81 //   [counter, key]
82 //
83 //   RngReadAndSkip increments the 64 bit counter value by 256 * delta and
84 //   returns a tensor<3xi64> value [counter, key, 0].
85 class DecomposeRngReadAndSkipOp : public RewritePattern {
86  public:
DecomposeRngReadAndSkipOp(MLIRContext * context)87   explicit DecomposeRngReadAndSkipOp(MLIRContext *context)
88       : RewritePattern(RngReadAndSkipOp::getOperationName(),
89                        {
90                            AddV2Op::getOperationName(),
91                            AssignVariableOp::getOperationName(),
92                            CastOp::getOperationName(),
93                            ConstOp::getOperationName(),
94                            LessOp::getOperationName(),
95                            MulOp::getOperationName(),
96                            PadOp::getOperationName(),
97                            PackOp::getOperationName(),
98                            ReadVariableOp::getOperationName(),
99                            SelectV2Op::getOperationName(),
100                            UnpackOp::getOperationName(),
101                        },
102                        1, context) {}
103 
matchAndRewrite(Operation * op,PatternRewriter & rewriter) const104   LogicalResult matchAndRewrite(Operation *op,
105                                 PatternRewriter &rewriter) const override {
106     auto rng_op = cast<RngReadAndSkipOp>(op);
107 
108     DenseIntElementsAttr alg_constant;
109     if (!matchPattern(rng_op.alg(), m_Constant(&alg_constant))) {
110       return rewriter.notifyMatchFailure(
111           op, "unable to determine algorithm statically");
112     }
113 
114     if (alg_constant.getNumElements() != 1) {
115       return rewriter.notifyMatchFailure(op, "expected alg to be a scalar");
116     }
117 
118     uint64_t alg_value = ((*alg_constant.int_value_begin()).getZExtValue());
119     tensorflow::Algorithm alg;
120     if (tensorflow::RNG_ALG_PHILOX == alg_value) {
121       alg = tensorflow::RNG_ALG_PHILOX;
122     } else if (tensorflow::RNG_ALG_THREEFRY == alg_value) {
123       alg = tensorflow::RNG_ALG_THREEFRY;
124     } else {
125       return rewriter.notifyMatchFailure(op, "unsupported alg");
126     }
127 
128     Type state_element_type = rewriter.getI64Type();
129     RankedTensorType op_type = RankedTensorType::get(
130         {tensorflow::RNG_MAX_COUNTER_SIZE + tensorflow::RNG_KEY_SIZE},
131         state_element_type);
132     if (op_type != rng_op.getType()) {
133       return rewriter.notifyMatchFailure(op, "unexpected op type");
134     }
135 
136     if (!HasResourceSubtype(rng_op.resource())) {
137       return rewriter.notifyMatchFailure(op, "missing resource subtype");
138     }
139 
140     int counter_size = tensorflow::GetCounterSize(alg);
141     int state_size = counter_size + tensorflow::RNG_KEY_SIZE;
142     RankedTensorType res_type =
143         RankedTensorType::get({state_size}, state_element_type);
144     if (res_type != GetResourceSubtype(rng_op.resource())) {
145       return rewriter.notifyMatchFailure(op, "unexpected resource subtype");
146     }
147 
148     Location loc = op->getLoc();
149 
150     // Read the state value from the resource.
151     Value state =
152         rewriter.create<ReadVariableOp>(loc, res_type, rng_op.resource());
153 
154     // Extract the key and counter from the state.
155     RankedTensorType word_type = RankedTensorType::get({}, state_element_type);
156     auto unpacked = rewriter.create<UnpackOp>(
157         loc, SmallVector<Type, 4>(state_size, word_type), state, 0);
158     Value key = unpacked.getResult(counter_size);
159 
160     SmallVector<Value, 4> counter;
161     for (int i = 0; i < counter_size; ++i) {
162       counter.push_back(unpacked.getResult(i));
163     }
164 
165     // Set the increment to 256 * delta.
166     Type u64 = rewriter.getIntegerType(64, /*isSigned=*/false);
167     RankedTensorType u64_scalar = RankedTensorType::get({}, u64);
168     Value step_size = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 256));
169     Value increment =
170         rewriter.create<MulOp>(loc, u64_scalar, step_size, rng_op.delta());
171 
172     // Increment the counter.
173     SmallVector<Value, 4> pack_args;
174     RankedTensorType word_u64_type = RankedTensorType::get({}, u64);
175     Value zero_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 0));
176     Value one_u64 = rewriter.create<ConstOp>(loc, GetScalarOfType(u64, 1));
177     for (int i = 0; i < counter_size; ++i) {
178       Value word = counter[i];
179       Value word_u64 = rewriter.create<CastOp>(loc, word_u64_type, word);
180       Value new_word_u64 = rewriter.create<AddV2Op>(loc, word_u64, increment);
181       Value new_word = rewriter.create<CastOp>(loc, word_type, new_word_u64);
182       pack_args.push_back(new_word);
183 
184       Value overflow = rewriter.create<LessOp>(loc, new_word_u64, word_u64);
185       increment = rewriter.create<SelectV2Op>(loc, overflow, one_u64, zero_u64);
186     }
187 
188     // Save the new state value to the resource.
189     pack_args.push_back(key);
190     Value new_state = rewriter.create<PackOp>(loc, res_type, pack_args);
191     rewriter.create<AssignVariableOp>(loc, rng_op.resource(), new_state);
192 
193     // Pad the original state as necessary to fill the output shape.
194     int pad = tensorflow::RNG_MAX_COUNTER_SIZE - counter_size;
195     Type i64 = rewriter.getI64Type();
196     RankedTensorType paddings_ty = RankedTensorType::get({1, 2}, i64);
197     std::vector<int64_t> paddings_values = {0, pad};
198     Value paddings = rewriter.create<ConstOp>(
199         loc, DenseIntElementsAttr::get(paddings_ty, paddings_values));
200     Value output = rewriter.create<PadOp>(loc, op_type, state, paddings);
201 
202     rewriter.replaceOp(op, output);
203     return success();
204   }
205 };
206 
207 #include "tensorflow/compiler/mlir/tensorflow/transforms/generated_decompose_resource_ops.inc"
208 }  // namespace
209 
PopulateDecomposeResourceOpsPatterns(MLIRContext * context,OwningRewritePatternList * patterns)210 void PopulateDecomposeResourceOpsPatterns(MLIRContext *context,
211                                           OwningRewritePatternList *patterns) {
212   patterns->insert<DecomposeRngReadAndSkipOp>(context);
213   populateWithGenerated(context, *patterns);
214 }
215 
216 }  // namespace TF
217 }  // namespace mlir
218