1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_input_output_alias_config.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 
34 namespace xla {
35 namespace {
36 class HloInputOutputAliasConfigTest : public HloTestBase {
37  protected:
expect_aliased(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,const HloInputOutputAliasConfig & config)38   void expect_aliased(const ShapeIndex& output_index, int64 param_number,
39                       const ShapeIndex& param_index,
40                       const HloInputOutputAliasConfig& config) {
41     absl::optional<ShapeIndex> aliased_output =
42         config.GetAliasedOutput(param_number, param_index);
43 
44     EXPECT_TRUE(aliased_output);
45     EXPECT_EQ(aliased_output.value(), output_index);
46 
47     absl::optional<HloInputOutputAliasConfig::Alias> aliased_param =
48         config.GetAliasedParameter(output_index);
49 
50     EXPECT_TRUE(aliased_param);
51     EXPECT_EQ(aliased_param->parameter_number, param_number);
52     EXPECT_EQ(aliased_param->parameter_index, param_index);
53   }
54 
expect_not_aliased(const ShapeIndex & output_index,int64 param_number,const ShapeIndex & param_index,const HloInputOutputAliasConfig & config)55   void expect_not_aliased(const ShapeIndex& output_index, int64 param_number,
56                           const ShapeIndex& param_index,
57                           const HloInputOutputAliasConfig& config) {
58     absl::optional<ShapeIndex> aliased_output =
59         config.GetAliasedOutput(param_number, param_index);
60 
61     EXPECT_FALSE(aliased_output && aliased_output == output_index);
62 
63     absl::optional<HloInputOutputAliasConfig::Alias> aliased_param =
64         config.GetAliasedParameter(output_index);
65 
66     EXPECT_FALSE(aliased_param &&
67                  aliased_param->parameter_number == param_number &&
68                  aliased_param->parameter_index == param_index);
69   }
70 };
71 
TEST_F(HloInputOutputAliasConfigTest,SimpleAliasing)72 TEST_F(HloInputOutputAliasConfigTest, SimpleAliasing) {
73   const string module_str = R"(
74 HloModule TEST
75 
76 ENTRY main {
77   a = f32[] parameter(0)
78   b = f32[] parameter(1)
79   ROOT root = (f32[], f32[]) tuple(%a, %b)
80 }
81 )";
82   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
83                           ParseHloString(module_str));
84 
85   HloInputOutputAliasConfig config(
86       module->entry_computation()->root_instruction()->shape());
87 
88   TF_ASSERT_OK(config.SetUpAlias(
89       /*output_index=*/{0}, /*param_number=*/1,
90       /*param_index=*/{},
91       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
92 
93   expect_aliased(/*output_index=*/{0}, /*param_number=*/1,
94                  /*param_index=*/{}, config);
95 
96   expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
97                      /*param_index=*/{}, config);
98 
99   expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
100                      /*param_index=*/{}, config);
101 }
102 
TEST_F(HloInputOutputAliasConfigTest,SimpleAliasingWithTupleInput)103 TEST_F(HloInputOutputAliasConfigTest, SimpleAliasingWithTupleInput) {
104   const string module_str = R"(
105 HloModule TEST
106 
107 ENTRY main {
108   param = (f32[], f32[]) parameter(0)
109   gte1 = f32[] get-tuple-element(%param), index=0
110   gte2 = f32[] get-tuple-element(%param), index=1
111   ROOT root = (f32[], f32[]) tuple(%gte1, %gte2)
112 }
113 )";
114   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
115                           ParseHloString(module_str));
116 
117   HloInputOutputAliasConfig config(
118       module->entry_computation()->root_instruction()->shape());
119 
120   TF_ASSERT_OK(config.SetUpAlias(
121       /*output_index=*/{0}, /*param_number=*/0,
122       /*param_index=*/{0},
123       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
124 
125   TF_ASSERT_OK(config.SetUpAlias(
126       /*output_index=*/{1}, /*param_number=*/0,
127       /*param_index=*/{1},
128       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
129 
130   expect_aliased(/*output_index=*/{0}, /*param_number=*/0,
131                  /*param_index=*/{0}, config);
132 
133   expect_aliased(/*output_index=*/{1}, /*param_number=*/0,
134                  /*param_index=*/{1}, config);
135 
136   expect_not_aliased(/*output_index=*/{1}, /*param_number=*/1,
137                      /*param_index=*/{}, config);
138 
139   expect_not_aliased(/*output_index=*/{0}, /*param_number=*/0,
140                      /*param_index=*/{}, config);
141 }
142 
TEST_F(HloInputOutputAliasConfigTest,InputDoNotAliasTwice)143 TEST_F(HloInputOutputAliasConfigTest, InputDoNotAliasTwice) {
144   const string module_str = R"(
145 HloModule TEST
146 
147 ENTRY main {
148   a = f32[] parameter(0)
149   b = f32[] parameter(1)
150   ROOT root = (f32[], f32[]) tuple(%a, %b)
151 }
152 )";
153   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
154                           ParseHloString(module_str));
155 
156   HloInputOutputAliasConfig config(
157       module->entry_computation()->root_instruction()->shape());
158 
159   TF_ASSERT_OK(config.SetUpAlias(
160       /*output_index=*/{0}, /*param_number=*/0,
161       /*param_index=*/{},
162       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
163 
164   TF_ASSERT_OK(config.SetUpAlias(
165       /*output_index=*/{1}, /*param_number=*/0,
166       /*param_index=*/{},
167       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
168 
169   ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
170     return ShapeUtil::ByteSizeOf(shape);
171   }));
172 }
173 
TEST_F(HloInputOutputAliasConfigTest,SizesMustMatch)174 TEST_F(HloInputOutputAliasConfigTest, SizesMustMatch) {
175   const string module_str = R"(
176 HloModule TEST
177 
178 ENTRY main {
179   a = f32[] parameter(0)
180   b = f32[4096] parameter(1)
181   ROOT root = (f32[], f32[4096]) tuple(%a, %b)
182 }
183 )";
184   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
185                           ParseHloString(module_str));
186 
187   HloInputOutputAliasConfig config(
188       module->entry_computation()->root_instruction()->shape());
189 
190   TF_ASSERT_OK(config.SetUpAlias(
191       /*output_index=*/{1}, /*param_number=*/0,
192       /*param_index=*/{},
193       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
194 
195   ASSERT_IS_NOT_OK(config.Verify(*module, [](const Shape& shape) {
196     return ShapeUtil::ByteSizeOf(shape);
197   }));
198 }
199 
TEST_F(HloInputOutputAliasConfigTest,OutputDoNotAliasTwice)200 TEST_F(HloInputOutputAliasConfigTest, OutputDoNotAliasTwice) {
201   const string module_str = R"(
202 HloModule TEST
203 
204 ENTRY main {
205   a = f32[] parameter(0)
206   b = f32[] parameter(1)
207   ROOT root = (f32[], f32[]) tuple(%a, %b)
208 }
209 )";
210   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
211                           ParseHloString(module_str));
212 
213   HloInputOutputAliasConfig config(
214       module->entry_computation()->root_instruction()->shape());
215 
216   TF_ASSERT_OK(config.SetUpAlias(
217       /*output_index=*/{0}, /*param_number=*/0,
218       /*param_index=*/{},
219       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
220 
221   ASSERT_IS_NOT_OK(config.SetUpAlias(
222       /*output_index=*/{0}, /*param_number=*/1,
223       /*param_index=*/{},
224       /*kind=*/HloInputOutputAliasConfig::AliasKind::kUserAlias));
225 }
226 }  // namespace
227 }  // namespace xla
228