1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17 #include "tensorflow/compiler/xla/service/hlo_module.h"
18 
19 namespace xla {
20 
OutputHasAlias(const ShapeIndex & output_index) const21 bool HloInputOutputAliasConfig::OutputHasAlias(
22     const ShapeIndex& output_index) const {
23   return alias_.element(output_index).has_value();
24 }
25 
SetUpAlias(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,AliasKind kind)26 Status HloInputOutputAliasConfig::SetUpAlias(const ShapeIndex& output_index,
27                                              int64 param_number,
28                                              const ShapeIndex& param_index,
29                                              AliasKind kind) {
30   TF_RET_CHECK(kind == AliasKind::kUserAlias || kind == AliasKind::kSystemAlias)
31       << kind;
32   TF_RET_CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index))
33       << absl::StrCat("Tring to set up alias at ", output_index.ToString(),
34                       " which is an invalid index for shape ",
35                       ShapeUtil::HumanString(alias_.shape()));
36   TF_RET_CHECK(param_number >= 0) << param_number;
37   TF_RET_CHECK(!OutputHasAlias(output_index))
38       << "Output index " << output_index << " already has an alias setup";
39   // Output can't be aliased with multiple parameters.
40   TF_RET_CHECK(!alias_.element(output_index)) << absl::StrFormat(
41       "Trying to set up output alias for param %lld at %s but failed: output "
42       "index %s is already aliased with param %lld at %s",
43       param_number, param_index.ToString(), output_index.ToString(),
44       alias_.element(output_index)->parameter_number,
45       alias_.element(output_index)->parameter_index.ToString());
46   (*alias_.mutable_element(output_index)) =
47       Alias(kind, param_number, param_index);
48   VLOG(4) << "Set up alias between output index " << output_index.ToString()
49           << " and parameter " << param_index << " at index "
50           << param_index.ToString();
51   return Status::OK();
52 }
53 
ToProto() const54 HloInputOutputAliasProto HloInputOutputAliasConfig::ToProto() const {
55   HloInputOutputAliasProto result;
56   alias_.ForEachElement(
57       [&](const ShapeIndex& index, const absl::optional<Alias>& data) {
58         if (data) {
59           HloInputOutputAliasProto::AliasEntryProto entry;
60           switch (data->kind) {
61             case AliasKind::kUserAlias:
62               entry.set_kind(HloInputOutputAliasProto::USER_ALIAS);
63               break;
64             case AliasKind::kSystemAlias:
65               entry.set_kind(HloInputOutputAliasProto::SYSTEM_ALIAS);
66               break;
67             default:
68               LOG(FATAL) << "Unknown alias kind " << data->kind;
69           }
70           for (int64 i : index) {
71             entry.add_output_shape_index(i);
72           }
73           entry.set_parameter_number(data->parameter_number);
74           for (int64 i : data->parameter_index) {
75             entry.add_parameter_shape_index(i);
76           }
77           result.add_entries()->Swap(&entry);
78         }
79       });
80   return result;
81 }
82 
CreateFromProto(const Shape & output_shape,const HloInputOutputAliasProto & proto)83 StatusOr<HloInputOutputAliasConfig> HloInputOutputAliasConfig::CreateFromProto(
84     const Shape& output_shape, const HloInputOutputAliasProto& proto) {
85   HloInputOutputAliasConfig result(output_shape);
86   for (const HloInputOutputAliasProto::AliasEntryProto& entry :
87        proto.entries()) {
88     ShapeIndex output_index(entry.output_shape_index().begin(),
89                             entry.output_shape_index().end());
90     int64 param_number = entry.parameter_number();
91     ShapeIndex param_index(entry.parameter_shape_index().begin(),
92                            entry.parameter_shape_index().end());
93     // Handle backward compatibility with existing protos, which only knew of
94     // system aliases.
95     AliasKind kind = AliasKind::kSystemAlias;
96     if (entry.kind() == HloInputOutputAliasProto::USER_ALIAS) {
97       kind = AliasKind::kUserAlias;
98     }
99     TF_RETURN_IF_ERROR(
100         result.SetUpAlias(output_index, param_number, param_index, kind));
101   }
102   return result;
103 }
104 
ToString() const105 string HloInputOutputAliasConfig::ToString() const {
106   std::vector<string> pieces;
107   pieces.push_back("HloInputOutputAliasConfig");
108 
109   ForEachAlias([&](const ShapeIndex& output_index, const Alias& alias) {
110     const char* kind = alias.kind == AliasKind::kUserAlias ? "USER" : "SYSTEM";
111     pieces.push_back(absl::StrFormat(
112         "  OutputIndex %s is aliased (kind=%s) with parameter %lld at %s:",
113         output_index.ToString(), kind, alias.parameter_number,
114         alias.parameter_index.ToString()));
115   });
116   return absl::StrJoin(pieces, "\n");
117 }
118 
119 HloInputOutputAliasConfig::AliasKind
ParameterAliasKind(int64 param_number,const ShapeIndex & param_index) const120 HloInputOutputAliasConfig::ParameterAliasKind(
121     int64 param_number, const ShapeIndex& param_index) const {
122   AliasKind kind = AliasKind::kNoAlias;
123   alias_.ForEachElement(
124       [&](const xla::ShapeIndex&, absl::optional<Alias> alias) {
125         if (alias && alias->parameter_number == param_number &&
126             alias->parameter_index == param_index) {
127           kind = alias->kind;
128         }
129       });
130   return kind;
131 }
132 
GetAliasedOutput(int64 param_number,const ShapeIndex & param_index) const133 absl::optional<ShapeIndex> HloInputOutputAliasConfig::GetAliasedOutput(
134     int64 param_number, const ShapeIndex& param_index) const {
135   absl::optional<ShapeIndex> output;
136   alias_.ForEachElement(
137       [&](const xla::ShapeIndex& output_index, absl::optional<Alias> alias) {
138         if (alias && alias->parameter_number == param_number &&
139             alias->parameter_index == param_index) {
140           output = output_index;
141         }
142       });
143   return output;
144 }
145 
146 absl::optional<HloInputOutputAliasConfig::Alias>
GetAliasedParameter(const ShapeIndex & output_index) const147 HloInputOutputAliasConfig::GetAliasedParameter(
148     const ShapeIndex& output_index) const {
149   CHECK(ShapeUtil::IndexIsValid(alias_.shape(), output_index));
150   return alias_.element(output_index);
151 }
152 
ForEachAlias(AliasFn fn) const153 void HloInputOutputAliasConfig::ForEachAlias(AliasFn fn) const {
154   alias_.ForEachElement(
155       [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
156         if (aliased) {
157           fn(output_index, *aliased);
158         }
159       });
160 }
161 
ForEachAliasWithStatus(AliasFnWithStatus fn) const162 Status HloInputOutputAliasConfig::ForEachAliasWithStatus(
163     AliasFnWithStatus fn) const {
164   return alias_.ForEachElementWithStatus(
165       [&](const ShapeIndex& output_index, absl::optional<Alias> aliased) {
166         if (aliased) {
167           TF_RETURN_IF_ERROR(fn(output_index, *aliased));
168         }
169         return Status::OK();
170       });
171 }
172 
Verify(const HloModule & module,std::function<int64 (const Shape &)> size_func) const173 Status HloInputOutputAliasConfig::Verify(
174     const HloModule& module,
175     std::function<int64(const Shape&)> size_func) const {
176   std::vector<ShapeTree<bool>> param_has_seen;
177   const HloComputation* entry = module.entry_computation();
178   for (int64 i = 0; i < entry->num_parameters(); ++i) {
179     HloInstruction* param = entry->parameter_instruction(i);
180     param_has_seen.emplace_back(param->shape());
181   }
182   return ForEachAliasWithStatus([&](const ShapeIndex& output_index,
183                                     const Alias& alias) -> Status {
184     const HloInstruction* root = entry->root_instruction();
185 
186     TF_RET_CHECK(0 <= alias.parameter_number);
187     TF_RET_CHECK(entry->num_parameters() > alias.parameter_number);
188     const Shape& param_shape =
189         entry->parameter_instruction(alias.parameter_number)->shape();
190     const Shape& output_shape = root->shape();
191     TF_RET_CHECK(ShapeUtil::IndexIsValid(param_shape, alias.parameter_index));
192     TF_RET_CHECK(ShapeUtil::IndexIsValid(output_shape, output_index));
193 
194     const Shape& param_subshape =
195         ShapeUtil::GetSubshape(param_shape, alias.parameter_index);
196     const Shape& output_subshape =
197         ShapeUtil::GetSubshape(output_shape, output_index);
198     TF_RET_CHECK(LayoutUtil::IsDenseArray(param_subshape));
199     TF_RET_CHECK(LayoutUtil::IsDenseArray(output_subshape));
200 
201     if (size_func(param_subshape) != size_func(output_subshape)) {
202       return InternalError(
203           "Expected aliased input %lld at index %s and output at index %s to "
204           "have the same size. Input sub-shape is %s with size %lld, output "
205           "sub-shape is %s with size %lld",
206           alias.parameter_number, alias.parameter_index.ToString(),
207           output_index.ToString(),
208           ShapeUtil::HumanStringWithLayout(param_subshape),
209           size_func(param_subshape),
210           ShapeUtil::HumanStringWithLayout(output_subshape),
211           size_func(output_subshape));
212     }
213 
214     // Check each alias.parameter_number and alias.parameter_index pair only
215     // show up once. No input can be aliased with output buffers.
216     TF_RET_CHECK(param_has_seen[alias.parameter_number].element(
217                      alias.parameter_index) == false);
218     *(param_has_seen[alias.parameter_number].mutable_element(
219         alias.parameter_index)) = true;
220     return Status::OK();
221   });
222 }
223 
operator <<(std::ostream & out,const HloInputOutputAliasConfig & config)224 std::ostream& operator<<(std::ostream& out,
225                          const HloInputOutputAliasConfig& config) {
226   out << config.ToString();
227   return out;
228 }
229 }  // namespace xla
230