1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
17 #include "absl/memory/memory.h"
18 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 #include "tensorflow/compiler/xla/shape_util.h"
21 #include "tensorflow/compiler/xla/test.h"
22 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
23 #include "tensorflow/core/platform/test.h"
24 
25 namespace xla {
26 namespace {
27 
28 class HloCreationUtilsTest : public HloTestBase {
29  protected:
CreateModuleWithProgramShape(PrimitiveType primitive_type,absl::Span<const int64> input_shape_dims,absl::Span<const int64> output_shape_dims,HloInstruction ** param,HloComputation ** entry_computation)30   std::unique_ptr<VerifiedHloModule> CreateModuleWithProgramShape(
31       PrimitiveType primitive_type, absl::Span<const int64> input_shape_dims,
32       absl::Span<const int64> output_shape_dims, HloInstruction** param,
33       HloComputation** entry_computation) {
34     Shape input_shape = ShapeUtil::MakeShape(primitive_type, input_shape_dims);
35     Shape output_shape =
36         ShapeUtil::MakeShape(primitive_type, output_shape_dims);
37     auto module = CreateNewVerifiedModule("test");
38     *entry_computation = module->AddEntryComputation(
39         CreateComputationWithSignature({&input_shape}, output_shape, "entry")
40             .ValueOrDie());
41     *param = (*entry_computation)->parameter_instruction(0);
42     return module;
43   }
44 };
45 
TEST_F(HloCreationUtilsTest,CollapseFirst1Dim)46 TEST_F(HloCreationUtilsTest, CollapseFirst1Dim) {
47   HloInstruction* param;
48   HloComputation* entry_computation;
49 
50   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
51                                              /*output_shape_dims=*/{2}, &param,
52                                              &entry_computation);
53 
54   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_1_dims_collapsed,
55                           CollapseFirstNDims(param, 1));
56   entry_computation->set_root_instruction(first_1_dims_collapsed);
57 
58   HloEvaluator evaluator;
59   TF_ASSERT_OK_AND_ASSIGN(
60       Literal result_literal,
61       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
62   CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({3, 4}));
63 }
64 
TEST_F(HloCreationUtilsTest,CollapseFirst2Dims)65 TEST_F(HloCreationUtilsTest, CollapseFirst2Dims) {
66   HloInstruction* param;
67   HloComputation* entry_computation;
68 
69   auto module = CreateModuleWithProgramShape(
70       S32, /*input_shape_dims=*/{2, 3, 2}, /*output_shape_dims=*/{6, 2}, &param,
71       &entry_computation);
72 
73   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_2_dims_collapsed,
74                           CollapseFirstNDims(param, 2));
75   entry_computation->set_root_instruction(first_2_dims_collapsed);
76 
77   HloEvaluator evaluator;
78   TF_ASSERT_OK_AND_ASSIGN(
79       Literal result_literal,
80       evaluator.Evaluate(*module, {LiteralUtil::CreateR3<int32>(
81                                       {{{1, 2}, {3, 4}, {5, 6}},
82                                        {{-1, -2}, {-3, -4}, {-5, -6}}})}));
83   CHECK_EQ(result_literal,
84            LiteralUtil::CreateR2<int32>(
85                {{1, 2}, {3, 4}, {5, 6}, {-1, -2}, {-3, -4}, {-5, -6}}));
86 }
87 
TEST_F(HloCreationUtilsTest,Prepend1DegenerateDim)88 TEST_F(HloCreationUtilsTest, Prepend1DegenerateDim) {
89   HloInstruction* param;
90   HloComputation* entry_computation;
91 
92   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
93                                              /*output_shape_dims=*/{1, 2},
94                                              &param, &entry_computation);
95 
96   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_1_degenerate_dim_prepended,
97                           PrependDegenerateDims(param, 1));
98   entry_computation->set_root_instruction(with_1_degenerate_dim_prepended);
99 
100   HloEvaluator evaluator;
101   TF_ASSERT_OK_AND_ASSIGN(
102       Literal result_literal,
103       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
104   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9, 10}}));
105 }
106 
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDims)107 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDims) {
108   HloInstruction* param;
109   HloComputation* entry_computation;
110 
111   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
112                                              /*output_shape_dims=*/{1, 1, 2},
113                                              &param, &entry_computation);
114 
115   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
116                           PrependDegenerateDims(param, 2));
117   entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
118 
119   HloEvaluator evaluator;
120   TF_ASSERT_OK_AND_ASSIGN(
121       Literal result_literal,
122       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({9, 10})}));
123   CHECK_EQ(result_literal, LiteralUtil::CreateR3<int32>({{{9, 10}}}));
124 }
125 
TEST_F(HloCreationUtilsTest,Prepend2DegenerateDimsToScalar)126 TEST_F(HloCreationUtilsTest, Prepend2DegenerateDimsToScalar) {
127   HloInstruction* param;
128   HloComputation* entry_computation;
129 
130   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
131                                              /*output_shape_dims=*/{1, 1},
132                                              &param, &entry_computation);
133 
134   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * with_2_degenerate_dims_prepended,
135                           PrependDegenerateDims(param, 2));
136   entry_computation->set_root_instruction(with_2_degenerate_dims_prepended);
137 
138   HloEvaluator evaluator;
139   TF_ASSERT_OK_AND_ASSIGN(
140       Literal result_literal,
141       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(9)}));
142   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{9}}));
143 }
144 
TEST_F(HloCreationUtilsTest,ExpandFirstDimInto3Dims)145 TEST_F(HloCreationUtilsTest, ExpandFirstDimInto3Dims) {
146   HloInstruction* param;
147   HloComputation* entry_computation;
148 
149   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{6},
150                                              /*output_shape_dims=*/{3, 1, 2},
151                                              &param, &entry_computation);
152 
153   TF_ASSERT_OK_AND_ASSIGN(HloInstruction * first_dim_expanded,
154                           ExpandFirstDimIntoNDims(param, {3, 1, 2}));
155   entry_computation->set_root_instruction(first_dim_expanded);
156 
157   HloEvaluator evaluator;
158   TF_ASSERT_OK_AND_ASSIGN(
159       Literal result_literal,
160       evaluator.Evaluate(*module,
161                          {LiteralUtil::CreateR1<int32>({1, 2, 3, 4, 5, 6})}));
162   CHECK_EQ(result_literal,
163            LiteralUtil::CreateR3<int32>({{{1, 2}}, {{3, 4}}, {{5, 6}}}));
164 }
165 
TEST_F(HloCreationUtilsTest,PadVectorWithZeros)166 TEST_F(HloCreationUtilsTest, PadVectorWithZeros) {
167   HloInstruction* param;
168   HloComputation* entry_computation;
169 
170   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{2},
171                                              /*output_shape_dims=*/{6}, &param,
172                                              &entry_computation);
173 
174   TF_ASSERT_OK_AND_ASSIGN(
175       HloInstruction * zero_padded_param,
176       PadVectorWithZeros(param, /*zeros_to_prepend=*/3, /*zeros_to_append=*/1));
177   entry_computation->set_root_instruction(zero_padded_param);
178 
179   HloEvaluator evaluator;
180   TF_ASSERT_OK_AND_ASSIGN(
181       Literal result_literal,
182       evaluator.Evaluate(*module, {LiteralUtil::CreateR1<int32>({3, 4})}));
183   CHECK_EQ(result_literal, LiteralUtil::CreateR1<int32>({0, 0, 0, 3, 4, 0}));
184 }
185 
TEST_F(HloCreationUtilsTest,BroadcastZeros_S32)186 TEST_F(HloCreationUtilsTest, BroadcastZeros_S32) {
187   HloInstruction* param;
188   HloComputation* entry_computation;
189 
190   auto module = CreateModuleWithProgramShape(S32, /*input_shape_dims=*/{},
191                                              /*output_shape_dims=*/{2, 2},
192                                              &param, &entry_computation);
193 
194   HloInstruction* zeros =
195       BroadcastZeros(module->entry_computation(), S32, {2, 2});
196   entry_computation->set_root_instruction(zeros);
197 
198   HloEvaluator evaluator;
199   TF_ASSERT_OK_AND_ASSIGN(
200       Literal result_literal,
201       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<int32>(0)}));
202   CHECK_EQ(result_literal, LiteralUtil::CreateR2<int32>({{0, 0}, {0, 0}}));
203 }
204 
TEST_F(HloCreationUtilsTest,BroadcastZeros_F32)205 TEST_F(HloCreationUtilsTest, BroadcastZeros_F32) {
206   HloInstruction* param;
207   HloComputation* entry_computation;
208 
209   auto module = CreateModuleWithProgramShape(F32, /*input_shape_dims=*/{},
210                                              /*output_shape_dims=*/{2, 2},
211                                              &param, &entry_computation);
212 
213   HloInstruction* zeros =
214       BroadcastZeros(module->entry_computation(), F32, {2, 2});
215   entry_computation->set_root_instruction(zeros);
216 
217   HloEvaluator evaluator;
218   TF_ASSERT_OK_AND_ASSIGN(
219       Literal result_literal,
220       evaluator.Evaluate(*module, {LiteralUtil::CreateR0<float>(0.0f)}));
221   CHECK_EQ(result_literal,
222            LiteralUtil::CreateR2<float>({{0.0f, 0.0f}, {0.0f, 0.0f}}));
223 }
224 
225 }  // namespace
226 }  // namespace xla
227