1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo.pb.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 
21 namespace xla {
22 
OutputHasAlias(const ShapeIndex & output_index) const23 bool HloInputOutputAliasConfig::OutputHasAlias(
24     const ShapeIndex& output_index) const {
25   return alias_.element(output_index).has_value();
26 }
27 
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,HloInputOutputAliasConfig::AliasKind must_alias)28 Status HloInputOutputAliasConfig::SetUpAlias(
29     const ShapeIndex& output_index, int64 param_number,
30     const ShapeIndex& param_index,
31     HloInputOutputAliasConfig::AliasKind must_alias) {
32   TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
33       << "Trying to set up alias at " << output_index.ToString()
34       << " which is an invalid index for shape "
35       << ShapeUtil::HumanString(alias_.shape());
36   TF_RET_CHECK(param_number >= 0) << param_number;
37   TF_RET_CHECK(!OutputHasAlias(output_index))
38       << "Output index " << output_index << " already has an alias setup";
39   // Output can't be aliased with multiple parameters.
40   TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
41       "Trying to set up output alias for param %lld at %s but failed: output "
42       "index %s is already aliased with param %lld at %s",
43       param_number, param_index.ToString(), output_index.ToString(),
44       alias_.element(output_index)->parameter_number,
45       alias_.element(output_index)->parameter_index.ToString());
46   (*alias_.mutable_element(output_index)) =
47       Alias(param_number, param_index, must_alias);
48   VLOG(4) << "Set up alias between output index " << output_index.ToString()
49           << " and parameter " << param_index << " at index "
50           << param_index.ToString();
51   return Status::OK();
52 }
53 
ToProto() const54 HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
55   HloInputOutputAliasProto result;
56   alias_.ForEachElement(
57       [&](const ShapeIndex& index, const absl::optional<Alias>& data) {
58         if (data) {
59           HloInputOutputAliasProto::AliasEntryProto entry;
60           for (int64 i : index) {
61             entry.add_output_shape_index(i);
62           }
63           entry.set_parameter_number(data->parameter_number);
64           for (int64 i : data->parameter_index) {
65             entry.add_parameter_shape_index(i);
66           }
67           if (data->must_alias()) {
68             entry.set_kind(Kind::MUST_ALIAS);
69           } else {
70             entry.set_kind(Kind::MAY_ALIAS);
71           }
72           result.add_entries()->Swap(&entry);
73         }
74       });
75   return result;
76 }
77 
CreateFromProto(Shape output_shape,const HloInputOutputAliasProto & proto)78 StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
79     Shape output_shape, const HloInputOutputAliasProto& proto) {
80   HloInputOutputAliasConfig result(std::move(output_shape));
81   for (const HloInputOutputAliasProto::AliasEntryProto& entry :
82        proto.entries()) {
83     ShapeIndex output_index(entry.output_shape_index().begin(),
84                             entry.output_shape_index().end());
85     int64 param_number = entry.parameter_number();
86     ShapeIndex param_index(entry.parameter_shape_index().begin(),
87                            entry.parameter_shape_index().end());
88     AliasKind kind = entry.kind() == Kind::MAY_ALIAS ? kMayAlias : kMustAlias;
89     TF_RETURN_IF_ERROR(
90         result.SetUpAlias(output_index, param_number, param_index, kind));
91   }
92   return result;
93 }
94 
shape() const95 const Shape& HloInputOutputAliasConfig::shape() const { return alias_.shape(); }
96 
ToString() const97 string HloInputOutputAliasConfig::ToString() const {
98   std::vector<string> pieces;
99   pieces.push_back("HloInputOutputAliasConfig");
100   pieces.push_back(
101       absl::StrFormat("  Output shape: %s", alias_.shape().ToString()));
102 
103   ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
104     pieces.push_back(absl::StrFormat(
105         "  OutputIndex %s is %saliased with parameter %lld at %s:",
106         output_index.ToString(), alias.kind == kMustAlias ? "must-" : "may-",
107         alias.parameter_number, alias.parameter_index.ToString()));
108   });
109   return absl::StrJoin(pieces, "\n");
110 }
111 
ToShortString() const112 string HloInputOutputAliasConfig::ToShortString() const {
113   std::vector<string> pieces;
114   for (const auto& p : alias_) {
115     const ShapeIndex& index = p.first;
116     if (absl::optional<Alias> alias = p.second) {
117       pieces.push_back(
118           absl::StrFormat("%s: %s", index.ToString(), alias->ToString()));
119     }
120   }
121   return absl::StrJoin(pieces, ", ");
122 }
123 
ParameterMustAlias(int64 param_number,const ShapeIndex & param_index) const124 bool HloInputOutputAliasConfig::ParameterMustAlias(
125     int64 param_number, const ShapeIndex& param_index) const {
126   bool result = false;
127   alias_.ForEachElement(
128       [&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
129         if (alias && alias->parameter_number == param_number &&
130             alias->parameter_index == param_index && alias->must_alias()) {
131           result = true;
132         }
133       });
134   return result;
135 }
136 
GetAliasedOutput(int64 param_number,const ShapeIndex & param_index) const137 absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
138     int64 param_number, const ShapeIndex& param_index) const {
139   absl::optional<ShapeIndex> output;
140   alias_.ForEachElement(
141       [&](const xla::ShapeIndex& output_index, absl::optional<Alias> alias) {
142         if (alias && alias->parameter_number == param_number &&
143             alias->parameter_index == param_index) {
144           output = output_index;
145         }
146       });
147   return output;
148 }
149 
150 absl::optional<HloInputOutputAliasConfig::Alias>
GetAliasedParameter(const ShapeIndex & output_index) const151 HloInputOutputAliasConfig::GetAliasedParameter(
152     const ShapeIndex& output_index) const {
153   CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
154       << ToString() << " " << alias_.shape().ToString() << " " << output_index;
155   return alias_.element(output_index);
156 }
157 
ForEachAlias(AliasFn fn) const158 void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
159   alias_.ForEachElement(
160       [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
161         if (aliased) {
162           fn(output_index, *aliased);
163         }
164       });
165 }
166 
ForEachAliasWithStatus(AliasFnWithStatus fn) const167 Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
168     AliasFnWithStatus fn) const {
169   return alias_.ForEachElementWithStatus(
170       [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
171         if (aliased) {
172           TF_RETURN_IF_ERROR(fn(output_index, *aliased));
173         }
174         return Status::OK();
175       });
176 }
177 
Verify(const HloModule & module,std::function<int64 (const Shape &)> size_func) const178 Status HloInputOutputAliasConfig::Verify(
179     const HloModule& module,
180     std::function<int64(const Shape&)> size_func) const {
181   std::vector<ShapeTree<bool>> param_has_seen;
182   const HloComputation* entry = module.entry_computation();
183   for (int64 i = 0; i < entry->num_parameters(); ++i) {
184     HloInstruction* param = entry->parameter_instruction(i);
185     param_has_seen.emplace_back(param->shape());
186   }
187   return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
188                                     const Alias& alias) -> Status {
189     const HloInstruction* root = entry->root_instruction();
190 
191     TF_RET_CHECK(0 <= alias.parameter_number);
192     TF_RET_CHECK(entry->num_parameters() > alias.parameter_number);
193     const Shape& param_shape =
194         entry->parameter_instruction(alias.parameter_number)->shape();
195     const Shape& output_shape = root->shape();
196     TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index));
197     TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
198 
199     const Shape& param_subshape =
200         ShapeUtil::GetSubshape(param_shape, alias.parameter_index);
201     const Shape& output_subshape =
202         ShapeUtil::GetSubshape(output_shape, output_index);
203     TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape));
204     TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape));
205 
206     if (size_func(param_subshape) != size_func(output_subshape)) {
207       return InternalError(
208           "Expected aliased input %lld at index %s and output at index %s to "
209           "have the same size. Input sub-shape is %s with size %lld, output "
210           "sub-shape is %s with size %lld",
211           alias.parameter_number, alias.parameter_index.ToString(),
212           output_index.ToString(),
213           ShapeUtil::HumanStringWithLayout(param_subshape),
214           size_func(param_subshape),
215           ShapeUtil::HumanStringWithLayout(output_subshape),
216           size_func(output_subshape));
217     }
218 
219     // Check each alias.parameter_number and alias.parameter_index pair only
220     // show up once. No input can be aliased with output buffers.
221     TF_RET_CHECK(param_has_seen[alias.parameter_number].element(
222                      alias.parameter_index) == false);
223     *(param_has_seen[alias.parameter_number].mutable_element(
224         alias.parameter_index)) = true;
225     return Status::OK();
226   });
227 }
228 
operator <<(std::ostream & out,const HloInputOutputAliasConfig & config)229 std::ostream& operator<<(std::ostream& out,
230                          const HloInputOutputAliasConfig& config) {
231   out << config.ToString();
232   return out;
233 }
234 }  // namespace xla
235