1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/rng_bit_generator_expander.h"
17 
18 #include "tensorflow/compiler/xla/client/lib/prng.h"
19 #include "tensorflow/compiler/xla/client/xla_builder.h"
20 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_computation.h"
22 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
23 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
24 #include "tensorflow/compiler/xla/service/hlo_module.h"
25 #include "tensorflow/compiler/xla/shape.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/compiler/xla/util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31 
32 namespace xla {
33 namespace {
34 
GetPhiloxStateOp(XlaOp input_state,const Shape & state_shape)35 XlaOp GetPhiloxStateOp(XlaOp input_state, const Shape& state_shape) {
36   if (state_shape.dimensions(0) >= 3) {
37     return Slice(input_state, {1}, {3}, {1});
38   }
39   return Rev(input_state, {0});
40 }
41 
GetPhiloxOutputStateOp(XlaOp output_state,const Shape & state_shape)42 XlaOp GetPhiloxOutputStateOp(XlaOp output_state, const Shape& state_shape) {
43   if (state_shape.dimensions(0) < 3) {
44     output_state = Slice(output_state, {0}, {1}, {1});
45   }
46   return output_state;
47 }
48 
49 }  // namespace
50 
InstructionMatchesPattern(HloInstruction * instruction)51 bool RngBitGeneratorExpander::InstructionMatchesPattern(
52     HloInstruction* instruction) {
53   return instruction->opcode() == HloOpcode::kRngBitGenerator;
54 }
55 
GetGeneratorComputation(const Shape & data_shape,const Shape & state_shape,RandomAlgorithm algorithm,HloModule * module)56 StatusOr<HloComputation*> RngBitGeneratorExpander::GetGeneratorComputation(
57     const Shape& data_shape, const Shape& state_shape,
58     RandomAlgorithm algorithm, HloModule* module) {
59   RngGeneratorKey cache_key{data_shape, state_shape, algorithm, module};
60   auto it = computation_cache_.find(cache_key);
61   if (it != computation_cache_.end()) {
62     return it->second;
63   }
64 
65   XlaBuilder builder("rng");
66   XlaOp state_param = Parameter(&builder, 0, state_shape, "state");
67   XlaOp key_op = Reshape(Slice(state_param, {0}, {1}, {1}), {});
68   RngOutput output;
69   switch (algorithm) {
70     case RandomAlgorithm::RNG_THREE_FRY:
71       output = ThreeFryBitGenerator(key_op, Slice(state_param, {1}, {2}, {1}),
72                                     data_shape);
73       break;
74     case RandomAlgorithm::RNG_PHILOX:
75       output = PhiloxBitGenerator(
76           key_op, GetPhiloxStateOp(state_param, state_shape), data_shape);
77       output.state = GetPhiloxOutputStateOp(output.state, state_shape);
78       break;
79     default:
80       return Unimplemented("Unsupported random algorthm: %s",
81                            RandomAlgorithm_Name(algorithm));
82   }
83 
84   XlaOp final_state =
85       ConcatInDim(&builder, {Reshape(key_op, {1}), output.state}, 0);
86   Tuple(&builder, {final_state, output.value});
87   TF_ASSIGN_OR_RETURN(XlaComputation xla_computation, builder.Build());
88 
89   TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
90                       xla_computation.GetProgramShape());
91   HloModuleConfig config(program_shape);
92   TF_ASSIGN_OR_RETURN(auto new_module, HloModule::CreateFromProto(
93                                            xla_computation.proto(), config));
94   HloCloneContext context(module);
95   HloComputation* new_computation =
96       module->DeepCloneComputation(new_module->entry_computation(), &context);
97   computation_cache_.emplace(cache_key, new_computation);
98   return new_computation;
99 }
100 
ExpandInstruction(HloInstruction * hlo)101 StatusOr<HloInstruction*> RngBitGeneratorExpander::ExpandInstruction(
102     HloInstruction* hlo) {
103   HloRngBitGeneratorInstruction* rng = Cast<HloRngBitGeneratorInstruction>(hlo);
104   RandomAlgorithm algorithm = rng->algorithm();
105   if (algorithm == RandomAlgorithm::RNG_DEFAULT) {
106     algorithm = default_algorithm_;
107   }
108 
109   HloModule* module = hlo->parent()->parent();
110   const Shape& data_shape = rng->shape().tuple_shapes(1);
111   const Shape& state_shape = rng->operand(0)->shape();
112   TF_ASSIGN_OR_RETURN(
113       HloComputation * generator_computation,
114       GetGeneratorComputation(data_shape, state_shape, algorithm, module));
115   return hlo->parent()->AddInstruction(HloInstruction::CreateCall(
116       ShapeUtil::MakeTupleShape({state_shape, data_shape}),
117       {hlo->mutable_operand(0)}, generator_computation));
118 }
119 
120 }  // namespace xla
121