1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_element_type_converter.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/literal.h"
25 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
28 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/hlo_query.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/types.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 
35 namespace xla {
36 namespace {
37 
ToElementType(HloInstruction * hlo,PrimitiveType type)38 HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) {
39   if (hlo->shape().element_type() != type) {
40     Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type);
41     hlo = hlo->parent()->AddInstruction(
42         HloInstruction::CreateConvert(shape, hlo));
43   }
44   CHECK_EQ(hlo->shape().element_type(), type);
45   return hlo;
46 }
47 
HasOperandType(HloInstruction * hlo,PrimitiveType type)48 bool HasOperandType(HloInstruction* hlo, PrimitiveType type) {
49   for (HloInstruction* operand : hlo->operands()) {
50     if (operand->shape().element_type() == type) {
51       return true;
52     }
53   }
54   return false;
55 }
56 
57 // Finds out the Tuple Shape of the new instruction after converting the element
58 // type of the operands of the original instruction from `from_type` to
59 // `to_type`.
60 //
61 // This routine assumes the resulting `shape` of the original instruction is a
62 // non-nested tuple. This assumption is currently safe as only kTuple, kInfeed,
63 // kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce
64 // results with tuple shapes, and this routine is only called to convert the
65 // result shapes of kBatchNorm* HLO instructions, which are non-nested tuples.
GetConvertedTupleShape(const Shape & shape,PrimitiveType from_type,PrimitiveType to_type)66 Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type,
67                              PrimitiveType to_type) {
68   std::vector<Shape> new_tuple_subshapes;
69   for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
70     Shape subshape = ShapeUtil::GetTupleElementShape(shape, i);
71     CHECK(!subshape.IsTuple());
72     if (subshape.element_type() == from_type) {
73       subshape = ShapeUtil::ChangeElementType(subshape, to_type);
74     }
75     new_tuple_subshapes.push_back(subshape);
76   }
77   return ShapeUtil::MakeTupleShape(new_tuple_subshapes);
78 }
79 
80 // Converts the elements of the result of `hlo` to produce a new tuple with
81 // shape `to_shape`.
82 //
83 // This routine assumes `hlo` is an instruction that produces a non-nested Tuple
84 // as a result.
ConvertTupleElements(HloInstruction * hlo,const Shape & to_shape)85 HloInstruction* ConvertTupleElements(HloInstruction* hlo,
86                                      const Shape& to_shape) {
87   const Shape& shape = hlo->shape();
88   HloComputation* computation = hlo->parent();
89   std::vector<HloInstruction*> tuple_elements;
90   for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
91     const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i);
92     HloInstruction* element = computation->AddInstruction(
93         HloInstruction::CreateGetTupleElement(ele_shape, hlo, i));
94     const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i);
95     CHECK(!ele_shape.IsTuple());
96     if (ele_shape.element_type() != to_ele_shape.element_type()) {
97       element = computation->AddInstruction(
98           HloInstruction::CreateConvert(to_ele_shape, element));
99     }
100     tuple_elements.push_back(element);
101   }
102   return computation->AddInstruction(
103       HloInstruction::CreateTuple(tuple_elements));
104 }
105 
106 }  // namespace
107 
HloElementTypeConverter(PrimitiveType eliminate_type,PrimitiveType replace_with_type)108 HloElementTypeConverter::HloElementTypeConverter(
109     PrimitiveType eliminate_type, PrimitiveType replace_with_type)
110     : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {}
111 
112 // This routine converts the arithmetic operations in the given module that use
113 // eliminate_type_ to operations that use replace_with_type_.
Run(HloModule * module)114 StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
115   XLA_VLOG_LINES(
116       3, "HloElementTypeConverter::Run(), before:\n" + module->ToString());
117 
118   if (eliminate_type_ == replace_with_type_) {
119     return false;
120   }
121 
122   HloCloneContext context(module);
123   bool changed = false;
124   for (auto* computation : module->computations()) {
125     for (auto* hlo : computation->MakeInstructionPostOrder()) {
126       const auto opcode = hlo->opcode();
127       // These are ops where it does not make sense to convert them.
128       if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant ||
129           opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert ||
130           opcode == HloOpcode::kBitcastConvert ||
131           opcode == HloOpcode::kGetTupleElement ||
132           opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) {
133         continue;
134       }
135 
136       // We cannot change a CustomCall since we have no way of adjusting the
137       // called binary to expect the updated type.
138       if (opcode == HloOpcode::kCustomCall) {
139         continue;
140       }
141 
142       // These are ops with embedded computations where it suffices to convert
143       // the embedded computations instead of converting the ops themselves.
144       if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
145           opcode == HloOpcode::kAllReduce || opcode == HloOpcode::kFusion ||
146           opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce ||
147           opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kScatter ||
148           opcode == HloOpcode::kSelectAndScatter ||
149           opcode == HloOpcode::kSort || opcode == HloOpcode::kConditional) {
150         continue;
151       }
152       TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString();
153 
154       bool nullary = hlo->operands().empty();
155       bool wrong_element_type = hlo->shape().element_type() == eliminate_type_;
156       bool should_eliminate_type = (nullary && wrong_element_type) ||
157                                    HasOperandType(hlo, eliminate_type_);
158       if (!should_eliminate_type) {
159         // If this CHECK fires, then this was an instruction that does not take
160         // the elimination type as an operand but it does return it. This pass
161         // does not have a feature to change the output type in that case, so
162         // instead of silently failing to eliminate the type, it fails loudly.
163         TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_);
164         continue;
165       }
166 
167       // Handle instructions that perform arithmetic operations and contain
168       // operands with eliminate_type_.
169       //
170       // First, convert the operands with eliminate_type_ to operands with
171       // replace_with_type_.
172       std::vector<HloInstruction*> new_operands;
173       for (HloInstruction* operand : hlo->operands()) {
174         if (operand->shape().element_type() == eliminate_type_) {
175           operand = ToElementType(operand, replace_with_type_);
176         }
177         new_operands.push_back(operand);
178       }
179 
180       // Then find out the result type of the new instruction with the same
181       // opcode but using the converted operands, create the new instruction,
182       // and convert the result of the new instruction back to match the result
183       // type of the original instruction.
184       HloInstruction* new_hlo;
185       if (hlo->shape().element_type() == eliminate_type_) {
186         Shape shape =
187             ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_);
188 
189         new_hlo = computation->AddInstruction(
190             hlo->CloneWithNewOperands(shape, new_operands, &context));
191         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
192 
193         new_hlo = ToElementType(new_hlo, eliminate_type_);
194       } else if (hlo->shape().IsTuple()) {
195         Shape old_shape = hlo->shape();
196         Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_,
197                                                  replace_with_type_);
198 
199         new_hlo = computation->AddInstruction(
200             hlo->CloneWithNewOperands(new_shape, new_operands, &context));
201         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
202 
203         // Convert the elements of the result of `new_hlo` to produce a new
204         // tuple with shape `old_shape`.
205         new_hlo = ConvertTupleElements(new_hlo, old_shape);
206       } else {
207         new_hlo = computation->AddInstruction(
208             hlo->CloneWithNewOperands(hlo->shape(), new_operands, &context));
209         TF_RETURN_IF_ERROR(new_hlo->CopyAllControlDepsFrom(hlo));
210       }
211 
212       TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo));
213       TF_RETURN_IF_ERROR(hlo->DropAllControlDeps());
214 
215       // NB!  We want to replace and remove side effecting instructions like Rng
216       // as well so we can't rely HloComputation::ReplaceInstruction to reliably
217       // remove the replaced instruction.
218       TF_RETURN_IF_ERROR(computation->RemoveInstruction(hlo));
219       changed = true;
220     }
221   }
222   XLA_VLOG_LINES(
223       2, "HloElementTypeConverter::Run(), after:\n" + module->ToString());
224   return changed;
225 }
226 
227 }  // namespace xla
228