1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
18 
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/service/hlo_computation.h"
21 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
22 #include "tensorflow/compiler/xla/service/hlo_module.h"
23 
24 namespace xla {
25 
26 // Helper interface for making queries about the HLO IR.
27 namespace hlo_query {
28 
29 // Returns whether the instruction provided is a constant rank-0 float32, and
30 // if so, places the constant value into out.
31 // Precondition: out != nullptr
32 bool IsConstantR0F32(HloInstruction* instruction, float* out);
33 
34 // Returns whether all of an instruction's operands are of the types constants
35 // and parameters.
36 bool AllOperandsAreParametersOrConstants(const HloInstruction& instruction);
37 
38 // Returns whether all of an instruction's operands are parameters.
39 bool AllOperandsAreParameters(const HloInstruction& instruction);
40 
41 // Returns whether all of an instruction's operands are constants.
42 bool AllOperandsAreConstants(const HloInstruction& instruction);
43 
44 // Returns whether the instruction is a scalar constant.
45 bool IsScalarConstant(const HloInstruction* instruction);
46 
47 // Determines whether the given computation contains an instruction with one of
48 // the given opcodes.  Checks both comp's instructions and the instructions of
49 // any computations nested within it.
50 bool ContainsInstrWithOpcode(const HloComputation* comp,
51                              const absl::flat_hash_set<HloOpcode>& opcodes);
52 
53 // Returns an operand of an instruction with the given opcode. If there are
54 // multiple matching operands, then the first matching operand is returned. If
55 // there are no matching operands then nullptr is returned.
56 HloInstruction* GetMatchingOperand(
57     const std::function<bool(const HloInstruction*)>& matcher,
58     HloInstruction* instruction);
59 
60 // Returns whether a binary instruction has a matching operand. Sets
61 // matching_operand to the matching operand and the other operand to
62 // other_operand. Note: in the case where both operands match, the first operand
63 // of the instruction is returned.
64 bool MatchBinaryInstructionOperand(
65     const std::function<bool(const HloInstruction*)>& matcher,
66     HloInstruction* instruction, HloInstruction** matching_operand,
67     HloInstruction** other_operand);
68 
69 // Returns whether a binary instruction has a operand with a given opcode.
70 // This is a special case of MatchingBinaryInstructionOperand.
71 bool MatchBinaryInstructionOperandOpcode(HloOpcode opcode,
72                                          HloInstruction* instruction,
73                                          HloInstruction** matching_operand,
74                                          HloInstruction** other_operand);
75 
76 // Returns whether the module contains all-reduce instructions with constrained
77 // layout.
78 bool ContainsLayoutConstrainedAllReduce(const HloModule& module);
79 
80 // Returns the next available channel id that can be used in the given module
81 // (for HloChannelInstructions).
82 int64 NextChannelId(const HloModule& module);
83 
84 // Returns whether the module contains host send/recv with X64 data type.
85 // This function is called after X64Rewriter, so X64 host transfers are already
86 // rewritten into tuple shaped transfers.
87 bool HasX64TransformedHostTransfer(const HloModule& module);
88 
89 }  // namespace hlo_query
90 }  // namespace xla
91 
92 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_QUERY_H_
93