1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33
34 namespace xla {
35 namespace {
36 class HloInputOutputAliasConfigTest : public HloTestBase {
37 protected:
expect_aliased(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,const HloInputOutputAliasConfig & config)38 void expect_aliased(const ShapeIndex& output_index, int64 param_number,
39 const ShapeIndex& param_index,
40 const HloInputOutputAliasConfig& config) {
41 absl::optional<ShapeIndex> aliased_output =
42 config.GetAliasedOutput(param_number, param_index);
43
44 EXPECT_TRUE(aliased_output);
45 EXPECT_EQ(aliased_output.value(), output_index);
46
47 absl::optional<HloInputOutputAliasConfig::Alias> aliased_param =
48 config.GetAliasedParameter(output_index);
49
50 EXPECT_TRUE(aliased_param);
51 EXPECT_EQ(aliased_param->parameter_number, param_number);
52 EXPECT_EQ(aliased_param->parameter_index, param_index);
53 }
54
expect_not_aliased(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,const HloInputOutputAliasConfig & config)55 void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
56 const ShapeIndex& param_index,
57 const HloInputOutputAliasConfig& config) {
58 absl::optional<ShapeIndex> aliased_output =
59 config.GetAliasedOutput(param_number, param_index);
60
61 EXPECT_FALSE(aliased_output && aliased_output == output_index);
62
63 absl::optional<HloInputOutputAliasConfig::Alias> aliased_param =
64 config.GetAliasedParameter(output_index);
65
66 EXPECT_FALSE(aliased_param &&
67 aliased_param->parameter_number == param_number &&
68 aliased_param->parameter_index == param_index);
69 }
70 };
71
TEST_F(HloInputOutputAliasConfigTest,SimpleAliasing)72 TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
73 const string module_str = R"(
74 HloModule TEST
75
76 ENTRY main {
77 a = f32[] parameter(0)
78 b = f32[] parameter(1)
79 ROOT root = (f32[], f32[]) tuple(%a, %b)
80 }
81 )";
82 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
83 ParseHloString(module_str));
84
85 HloInputOutputAliasConfig config(
86 module->entry_computation()->root_instruction()->shape());
87
88 TF_ASSERT_OK(config.SetUpAlias(
89 /*output_index=*/{0}, /*param_number=*/1,
90 /*param_index=*/{},
91 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
92
93 expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
94 /*param_index=*/{}, config);
95
96 expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
97 /*param_index=*/{}, config);
98
99 expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
100 /*param_index=*/{}, config);
101 }
102
TEST_F(HloInputOutputAliasConfigTest,SimpleAliasingWithTupleInput)103 TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
104 const string module_str = R"(
105 HloModule TEST
106
107 ENTRY main {
108 param = (f32[], f32[]) parameter(0)
109 gte1 = f32[] get-tuple-element(%param), index=0
110 gte2 = f32[] get-tuple-element(%param), index=1
111 ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
112 }
113 )";
114 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
115 ParseHloString(module_str));
116
117 HloInputOutputAliasConfig config(
118 module->entry_computation()->root_instruction()->shape());
119
120 TF_ASSERT_OK(config.SetUpAlias(
121 /*output_index=*/{0}, /*param_number=*/0,
122 /*param_index=*/{0},
123 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
124
125 TF_ASSERT_OK(config.SetUpAlias(
126 /*output_index=*/{1}, /*param_number=*/0,
127 /*param_index=*/{1},
128 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
129
130 expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
131 /*param_index=*/{0}, config);
132
133 expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
134 /*param_index=*/{1}, config);
135
136 expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
137 /*param_index=*/{}, config);
138
139 expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
140 /*param_index=*/{}, config);
141 }
142
TEST_F(HloInputOutputAliasConfigTest,InputDoNotAliasTwice)143 TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
144 const string module_str = R"(
145 HloModule TEST
146
147 ENTRY main {
148 a = f32[] parameter(0)
149 b = f32[] parameter(1)
150 ROOT root = (f32[], f32[]) tuple(%a, %b)
151 }
152 )";
153 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
154 ParseHloString(module_str));
155
156 HloInputOutputAliasConfig config(
157 module->entry_computation()->root_instruction()->shape());
158
159 TF_ASSERT_OK(config.SetUpAlias(
160 /*output_index=*/{0}, /*param_number=*/0,
161 /*param_index=*/{},
162 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
163
164 TF_ASSERT_OK(config.SetUpAlias(
165 /*output_index=*/{1}, /*param_number=*/0,
166 /*param_index=*/{},
167 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
168
169 ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
170 return ShapeUtil::ByteSizeOf(shape);
171 }));
172 }
173
TEST_F(HloInputOutputAliasConfigTest,SizesMustMatch)174 TEST_F(HloInputOutputAliasConfigTest, SizesMustMatch) {
175 const string module_str = R"(
176 HloModule TEST
177
178 ENTRY main {
179 a = f32[] parameter(0)
180 b = f32[4096] parameter(1)
181 ROOT root = (f32[], f32[4096]) tuple(%a, %b)
182 }
183 )";
184 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
185 ParseHloString(module_str));
186
187 HloInputOutputAliasConfig config(
188 module->entry_computation()->root_instruction()->shape());
189
190 TF_ASSERT_OK(config.SetUpAlias(
191 /*output_index=*/{1}, /*param_number=*/0,
192 /*param_index=*/{},
193 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
194
195 ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
196 return ShapeUtil::ByteSizeOf(shape);
197 }));
198 }
199
TEST_F(HloInputOutputAliasConfigTest,OutputDoNotAliasTwice)200 TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
201 const string module_str = R"(
202 HloModule TEST
203
204 ENTRY main {
205 a = f32[] parameter(0)
206 b = f32[] parameter(1)
207 ROOT root = (f32[], f32[]) tuple(%a, %b)
208 }
209 )";
210 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
211 ParseHloString(module_str));
212
213 HloInputOutputAliasConfig config(
214 module->entry_computation()->root_instruction()->shape());
215
216 TF_ASSERT_OK(config.SetUpAlias(
217 /*output_index=*/{0}, /*param_number=*/0,
218 /*param_index=*/{},
219 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
220
221 ASSERT_IS_NOT_OK(config.SetUpAlias(
222 /*output_index=*/{0}, /*param_number=*/1,
223 /*param_index=*/{},
224 /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
225 }
226 } // namespace
227 } // namespace xla
228