1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_query.h"
17 
18 #include "tensorflow/compiler/xla/literal.h"
19 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
20 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 
24 namespace xla {
25 namespace hlo_query {
26 
IsConstantR0F32(HloInstruction * instruction,float * out)27 bool IsConstantR0F32(HloInstruction* instruction, float* out) {
28   if (instruction->opcode() == HloOpcode::kConstant &&
29       ShapeUtil::IsScalarWithElementType(instruction->shape(), F32)) {
30     *out = instruction->literal().Get<float>({});
31     return true;
32   }
33 
34   return false;
35 }
36 
AllOperandsAreParametersOrConstants(const HloInstruction & instruction)37 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction) {
38   for (const auto& operand : instruction.operands()) {
39     if (operand->opcode() != HloOpcode::kParameter &&
40         operand->opcode() != HloOpcode::kConstant) {
41       return false;
42     }
43   }
44   return true;
45 }
46 
AllOperandsAreParameters(const HloInstruction & instruction)47 bool AllOperandsAreParameters(const HloInstruction& instruction) {
48   for (const auto& operand : instruction.operands()) {
49     if (operand->opcode() != HloOpcode::kParameter) {
50       return false;
51     }
52   }
53   return true;
54 }
55 
AllOperandsAreConstants(const HloInstruction & instruction)56 bool AllOperandsAreConstants(const HloInstruction& instruction) {
57   for (const auto& operand : instruction.operands()) {
58     if (operand->opcode() != HloOpcode::kConstant) {
59       return false;
60     }
61   }
62   return true;
63 }
64 
GetMatchingOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction)65 HloInstruction* GetMatchingOperand(
66     const std::function<bool(const HloInstruction*)>& matcher,
67     HloInstruction* instruction) {
68   for (HloInstruction* op : instruction->operands()) {
69     if (matcher(op)) {
70       return op;
71     }
72   }
73   return nullptr;
74 }
75 
MatchBinaryInstructionOperand(const std::function<bool (const HloInstruction *)> & matcher,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)76 bool MatchBinaryInstructionOperand(
77     const std::function<bool(const HloInstruction*)>& matcher,
78     HloInstruction* instruction, HloInstruction** matching_operand,
79     HloInstruction** other_operand) {
80   CHECK_EQ(instruction->operand_count(), 2);
81   if (matcher(instruction->operand(0))) {
82     *matching_operand = instruction->mutable_operand(0);
83     *other_operand = instruction->mutable_operand(1);
84     return true;
85   }
86   if (matcher(instruction->operand(1))) {
87     *matching_operand = instruction->mutable_operand(1);
88     *other_operand = instruction->mutable_operand(0);
89     return true;
90   }
91   return false;
92 }
93 
MatchBinaryInstructionOperandOpcode(HloOpcode opcode,HloInstruction * instruction,HloInstruction ** matching_operand,HloInstruction ** other_operand)94 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
95                                          HloInstruction* instruction,
96                                          HloInstruction** matching_operand,
97                                          HloInstruction** other_operand) {
98   return MatchBinaryInstructionOperand(
99       [opcode](const HloInstruction* instruction) {
100         return instruction->opcode() == opcode;
101       },
102       instruction, matching_operand, other_operand);
103 }
104 
IsScalarConstant(const HloInstruction * instruction)105 bool IsScalarConstant(const HloInstruction* instruction) {
106   return instruction->IsConstant() && ShapeUtil::IsScalar(instruction->shape());
107 }
108 
ContainsInstrWithOpcode(const HloComputation * comp,const absl::flat_hash_set<HloOpcode> & opcodes)109 bool ContainsInstrWithOpcode(const HloComputation* comp,
110                              const absl::flat_hash_set<HloOpcode>& opcodes) {
111   for (const auto* instr : comp->instructions()) {
112     if (opcodes.count(instr->opcode())) {
113       return true;
114     }
115     for (const HloComputation* subcomp : instr->called_computations()) {
116       if (ContainsInstrWithOpcode(subcomp, opcodes)) {
117         return true;
118       }
119     }
120   }
121   return false;
122 }
123 
ContainsLayoutConstrainedAllReduce(const HloModule & module)124 bool ContainsLayoutConstrainedAllReduce(const HloModule& module) {
125   for (auto computation : module.computations()) {
126     for (auto hlo : computation->instructions()) {
127       if (hlo->opcode() == HloOpcode::kAllReduce &&
128           DynCast<HloAllReduceInstruction>(hlo)->constrain_layout()) {
129         return true;
130       }
131     }
132   }
133   return false;
134 }
135 
NextChannelId(const HloModule & module)136 int64 NextChannelId(const HloModule& module) {
137   int64 next_channel_id = 1;
138   for (const HloComputation* comp : module.computations()) {
139     for (const HloInstruction* hlo : comp->instructions()) {
140       const HloChannelInstruction* channel_instr =
141           DynCast<HloChannelInstruction>(hlo);
142       if (channel_instr && channel_instr->channel_id()) {
143         next_channel_id =
144             std::max(next_channel_id, *channel_instr->channel_id() + 1);
145       }
146     }
147   }
148   return next_channel_id;
149 }
150 
HasX64TransformedHostTransfer(const HloModule & module)151 bool HasX64TransformedHostTransfer(const HloModule& module) {
152   for (auto computation : module.computations()) {
153     for (auto hlo : computation->instructions()) {
154       if (hlo->opcode() == HloOpcode::kSend) {
155         auto send = DynCast<HloSendInstruction>(hlo);
156         if (send->is_host_transfer() && send->operand(0)->shape().IsTuple()) {
157           return true;
158         }
159       } else if (hlo->opcode() == HloOpcode::kRecv) {
160         auto recv = DynCast<HloRecvInstruction>(hlo);
161         if (recv->is_host_transfer() &&
162             recv->shape().tuple_shapes(0).IsTuple()) {
163           return true;
164         }
165       }
166     }
167   }
168   return false;
169 }
170 
171 }  // namespace hlo_query
172 }  // namespace xla
173