1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 #include "absl/memory/memory.h"
18 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/platform/test.h"
24
25 namespace xla {
26 namespace {
27
28 class HloCreationUtilsTest : public HloTestBase {
29 protected:
CreateModuleWithProgramShape(PrimitiveType primitive_type,absl::Span<const int64> input_shape_dims,absl::Span<const int64> output_shape_dims,HloInstruction ** param,HloComputation ** entry_computation)30 std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
31 PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
32 absl::Span<const int64> output_shape_dims, HloInstruction** param,
33 HloComputation** entry_computation) {
34 Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
35 Shape output_shape =
36 ShapeUtil::MakeShape(primitive_type, output_shape_dims);
37 auto module = CreateNewVerifiedModule("test");
38 *entry_computation = module->AddEntryComputation(
39 CreateComputationWithSignature({&input_shape}, output_shape, "entry")
40 .ValueOrDie());
41 *param = (*entry_computation)->parameter_instruction(0);
42 return module;
43 }
44 };
45
TEST_F(HloCreationUtilsTest,CollapseFirst1Dim)46 TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
47 HloInstruction* param;
48 HloComputation* entry_computation;
49
50 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
51 /*output_shape_dims=*/{2}, ¶m,
52 &entry_computation);
53
54 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
55 CollapseFirstNDims(param, 1));
56 entry_computation->set_root_instruction(first_1_dims_collapsed);
57
58 HloEvaluator evaluator;
59 TF_ASSERT_OK_AND_ASSIGN(
60 Literal result_literal,
61 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
62 CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
63 }
64
TEST_F(HloCreationUtilsTest,CollapseFirst2Dims)65 TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
66 HloInstruction* param;
67 HloComputation* entry_computation;
68
69 auto module = CreateModuleWithProgramShape(
70 S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, ¶m,
71 &entry_computation);
72
73 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed,
74 CollapseFirstNDims(param, 2));
75 entry_computation->set_root_instruction(first_2_dims_collapsed);
76
77 HloEvaluator evaluator;
78 TF_ASSERT_OK_AND_ASSIGN(
79 Literal result_literal,
80 evaluator.Evaluate(*module, {LiteralUtil::CreateR3<int32>(
81 {{{1, 2}, {3, 4}, {5, 6}},
82 {{-1, -2}, {-3, -4}, {-5, -6}}})}));
83 CHECK_EQ(result_literal,
84 LiteralUtil::CreateR2<int32>(
85 {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
86 }
87
TEST_F(HloCreationUtilsTest,Prepend1DegenerateDim)88 TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
89 HloInstruction* param;
90 HloComputation* entry_computation;
91
92 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
93 /*output_shape_dims=*/{1, 2},
94 ¶m, &entry_computation);
95
96 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
97 PrependDegenerateDims(param, 1));
98 entry_computation->set_root_instruction(with_1_degenerate_dim_prepended);
99
100 HloEvaluator evaluator;
101 TF_ASSERT_OK_AND_ASSIGN(
102 Literal result_literal,
103 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
104 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
105 }
106
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDims)107 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
108 HloInstruction* param;
109 HloComputation* entry_computation;
110
111 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
112 /*output_shape_dims=*/{1, 1, 2},
113 ¶m, &entry_computation);
114
115 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
116 PrependDegenerateDims(param, 2));
117 entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
118
119 HloEvaluator evaluator;
120 TF_ASSERT_OK_AND_ASSIGN(
121 Literal result_literal,
122 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
123 CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
124 }
125
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDimsToScalar)126 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
127 HloInstruction* param;
128 HloComputation* entry_computation;
129
130 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
131 /*output_shape_dims=*/{1, 1},
132 ¶m, &entry_computation);
133
134 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
135 PrependDegenerateDims(param, 2));
136 entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
137
138 HloEvaluator evaluator;
139 TF_ASSERT_OK_AND_ASSIGN(
140 Literal result_literal,
141 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(9)}));
142 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
143 }
144
TEST_F(HloCreationUtilsTest,ExpandFirstDimInto3Dims)145 TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
146 HloInstruction* param;
147 HloComputation* entry_computation;
148
149 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6},
150 /*output_shape_dims=*/{3, 1, 2},
151 ¶m, &entry_computation);
152
153 TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded,
154 ExpandFirstDimIntoNDims(param, {3, 1, 2}));
155 entry_computation->set_root_instruction(first_dim_expanded);
156
157 HloEvaluator evaluator;
158 TF_ASSERT_OK_AND_ASSIGN(
159 Literal result_literal,
160 evaluator.Evaluate(*module,
161 {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
162 CHECK_EQ(result_literal,
163 LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
164 }
165
TEST_F(HloCreationUtilsTest,PadVectorWithZeros)166 TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
167 HloInstruction* param;
168 HloComputation* entry_computation;
169
170 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
171 /*output_shape_dims=*/{6}, ¶m,
172 &entry_computation);
173
174 TF_ASSERT_OK_AND_ASSIGN(
175 HloInstruction * zero_padded_param,
176 PadVectorWithZeros(param, /*zeros_to_prepend=*/3, /*zeros_to_append=*/1));
177 entry_computation->set_root_instruction(zero_padded_param);
178
179 HloEvaluator evaluator;
180 TF_ASSERT_OK_AND_ASSIGN(
181 Literal result_literal,
182 evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
183 CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
184 }
185
TEST_F(HloCreationUtilsTest,BroadcastZeros_S32)186 TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
187 HloInstruction* param;
188 HloComputation* entry_computation;
189
190 auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
191 /*output_shape_dims=*/{2, 2},
192 ¶m, &entry_computation);
193
194 HloInstruction* zeros =
195 BroadcastZeros(module->entry_computation(), S32, {2, 2});
196 entry_computation->set_root_instruction(zeros);
197
198 HloEvaluator evaluator;
199 TF_ASSERT_OK_AND_ASSIGN(
200 Literal result_literal,
201 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0)}));
202 CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
203 }
204
TEST_F(HloCreationUtilsTest,BroadcastZeros_F32)205 TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
206 HloInstruction* param;
207 HloComputation* entry_computation;
208
209 auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
210 /*output_shape_dims=*/{2, 2},
211 ¶m, &entry_computation);
212
213 HloInstruction* zeros =
214 BroadcastZeros(module->entry_computation(), F32, {2, 2});
215 entry_computation->set_root_instruction(zeros);
216
217 HloEvaluator evaluator;
218 TF_ASSERT_OK_AND_ASSIGN(
219 Literal result_literal,
220 evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
221 CHECK_EQ(result_literal,
222 LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
223 }
224
225 } // namespace
226 } // namespace xla
227