1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20
21 namespace xla {
22
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const23 bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
24 int64 operand_index) const {
25 switch (hlo.opcode()) {
26 case HloOpcode::kCall:
27 case HloOpcode::kConditional:
28 case HloOpcode::kCustomCall:
29 case HloOpcode::kDomain:
30 case HloOpcode::kGetTupleElement:
31 case HloOpcode::kTuple:
32 case HloOpcode::kWhile:
33 return true;
34 case HloOpcode::kConvert:
35 CHECK_EQ(operand_index, 0);
36 return hlo.operand(0)->shape().element_type() == BF16;
37 default:
38 break;
39 }
40 return false;
41 }
42
SupportsBF16Output(const HloInstruction & hlo) const43 bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
44 switch (hlo.opcode()) {
45 case HloOpcode::kCall:
46 case HloOpcode::kConditional:
47 case HloOpcode::kCustomCall:
48 case HloOpcode::kDomain:
49 case HloOpcode::kGetTupleElement:
50 case HloOpcode::kTuple:
51 case HloOpcode::kWhile:
52 return true;
53 case HloOpcode::kConvert:
54 return hlo.shape().element_type() == BF16;
55 default:
56 break;
57 }
58 return false;
59 }
60
SupportsMixedPrecisions(const HloInstruction & hlo) const61 bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
62 switch (hlo.opcode()) {
63 case HloOpcode::kCall:
64 case HloOpcode::kConditional:
65 case HloOpcode::kConvert:
66 case HloOpcode::kCustomCall:
67 case HloOpcode::kGetTupleElement:
68 case HloOpcode::kTuple:
69 case HloOpcode::kWhile:
70 return true;
71 default:
72 break;
73 }
74 return false;
75 }
76
77 /* static */
EffectiveOperandPrecisionIsOutputPrecision(const HloInstruction & hlo,int64 operand_index)78 bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
79 const HloInstruction& hlo, int64 operand_index) {
80 switch (hlo.opcode()) {
81 case HloOpcode::kAbs:
82 case HloOpcode::kAllGather:
83 case HloOpcode::kAllToAll:
84 case HloOpcode::kBroadcast:
85 case HloOpcode::kClamp:
86 case HloOpcode::kCollectivePermute:
87 case HloOpcode::kConcatenate:
88 case HloOpcode::kConvert:
89 case HloOpcode::kCopy:
90 case HloOpcode::kDomain:
91 case HloOpcode::kGetTupleElement:
92 case HloOpcode::kMaximum:
93 case HloOpcode::kMinimum:
94 case HloOpcode::kPad:
95 case HloOpcode::kReshape:
96 case HloOpcode::kReverse:
97 case HloOpcode::kSlice:
98 case HloOpcode::kSort:
99 case HloOpcode::kTranspose:
100 case HloOpcode::kTuple:
101 return true;
102 case HloOpcode::kBitcast:
103 return hlo.shape().element_type() ==
104 hlo.operand(0)->shape().element_type();
105 case HloOpcode::kDynamicSlice:
106 return operand_index == 0;
107 case HloOpcode::kDynamicUpdateSlice:
108 return operand_index == 0 || operand_index == 1;
109 case HloOpcode::kGather:
110 return operand_index == 0;
111 case HloOpcode::kSelect:
112 case HloOpcode::kTupleSelect:
113 return operand_index == 1 || operand_index == 2;
114 case HloOpcode::kReduce:
115 case HloOpcode::kReduceWindow: {
116 HloComputation* reduce_comp = hlo.called_computations()[0];
117 for (HloInstruction* inst : reduce_comp->instructions()) {
118 if (inst->opcode() == HloOpcode::kParameter) {
119 continue;
120 }
121 for (int64 i = 0; i < inst->operand_count(); ++i) {
122 if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) {
123 return false;
124 }
125 }
126 }
127 return true;
128 }
129 default:
130 break;
131 }
132 return false;
133 }
134
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64 operand_index) const135 bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
136 const HloInstruction& hlo, int64 operand_index) const {
137 return false;
138 }
139
140 } // namespace xla
141