1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/bfloat16_support.h"
17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
20 
21 namespace xla {
22 
SupportsBF16Operand(const HloInstruction & hlo,int64 operand_index) const23 bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo,
24                                           int64 operand_index) const {
25   switch (hlo.opcode()) {
26     case HloOpcode::kCall:
27     case HloOpcode::kConditional:
28     case HloOpcode::kCustomCall:
29     case HloOpcode::kDomain:
30     case HloOpcode::kGetTupleElement:
31     case HloOpcode::kTuple:
32     case HloOpcode::kWhile:
33       return true;
34     case HloOpcode::kConvert:
35       CHECK_EQ(operand_index, 0);
36       return hlo.operand(0)->shape().element_type() == BF16;
37     default:
38       break;
39   }
40   return false;
41 }
42 
SupportsBF16Output(const HloInstruction & hlo) const43 bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const {
44   switch (hlo.opcode()) {
45     case HloOpcode::kCall:
46     case HloOpcode::kConditional:
47     case HloOpcode::kCustomCall:
48     case HloOpcode::kDomain:
49     case HloOpcode::kGetTupleElement:
50     case HloOpcode::kTuple:
51     case HloOpcode::kWhile:
52       return true;
53     case HloOpcode::kConvert:
54       return hlo.shape().element_type() == BF16;
55     default:
56       break;
57   }
58   return false;
59 }
60 
SupportsMixedPrecisions(const HloInstruction & hlo) const61 bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const {
62   switch (hlo.opcode()) {
63     case HloOpcode::kCall:
64     case HloOpcode::kConditional:
65     case HloOpcode::kConvert:
66     case HloOpcode::kCustomCall:
67     case HloOpcode::kGetTupleElement:
68     case HloOpcode::kTuple:
69     case HloOpcode::kWhile:
70       return true;
71     default:
72       break;
73   }
74   return false;
75 }
76 
77 /* static */
EffectiveOperandPrecisionIsOutputPrecision(const HloInstruction & hlo,int64 operand_index)78 bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
79     const HloInstruction& hlo, int64 operand_index) {
80   switch (hlo.opcode()) {
81     case HloOpcode::kAbs:
82     case HloOpcode::kAllGather:
83     case HloOpcode::kAllToAll:
84     case HloOpcode::kBroadcast:
85     case HloOpcode::kClamp:
86     case HloOpcode::kCollectivePermute:
87     case HloOpcode::kConcatenate:
88     case HloOpcode::kConvert:
89     case HloOpcode::kCopy:
90     case HloOpcode::kDomain:
91     case HloOpcode::kGetTupleElement:
92     case HloOpcode::kMaximum:
93     case HloOpcode::kMinimum:
94     case HloOpcode::kPad:
95     case HloOpcode::kReshape:
96     case HloOpcode::kReverse:
97     case HloOpcode::kSlice:
98     case HloOpcode::kSort:
99     case HloOpcode::kTranspose:
100     case HloOpcode::kTuple:
101       return true;
102     case HloOpcode::kBitcast:
103       return hlo.shape().element_type() ==
104              hlo.operand(0)->shape().element_type();
105     case HloOpcode::kDynamicSlice:
106       return operand_index == 0;
107     case HloOpcode::kDynamicUpdateSlice:
108       return operand_index == 0 || operand_index == 1;
109     case HloOpcode::kGather:
110       return operand_index == 0;
111     case HloOpcode::kSelect:
112     case HloOpcode::kTupleSelect:
113       return operand_index == 1 || operand_index == 2;
114     case HloOpcode::kReduce:
115     case HloOpcode::kReduceWindow: {
116       HloComputation* reduce_comp = hlo.called_computations()[0];
117       for (HloInstruction* inst : reduce_comp->instructions()) {
118         if (inst->opcode() == HloOpcode::kParameter) {
119           continue;
120         }
121         for (int64 i = 0; i < inst->operand_count(); ++i) {
122           if (!EffectiveOperandPrecisionIsOutputPrecision(*inst, i)) {
123             return false;
124           }
125         }
126       }
127       return true;
128     }
129     default:
130       break;
131   }
132   return false;
133 }
134 
EffectiveOperandPrecisionIsBF16(const HloInstruction & hlo,int64 operand_index) const135 bool BFloat16Support::EffectiveOperandPrecisionIsBF16(
136     const HloInstruction& hlo, int64 operand_index) const {
137   return false;
138 }
139 
140 }  // namespace xla
141