1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gather_expander.h"
17 #include "tensorflow/compiler/xla/service/hlo_parser.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/tests/test_macros.h"
20
21 namespace xla {
22 namespace {
TEST(GatherExpanderTest,ErrorStatusOnTooManyIndices)23 TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
24 const string hlo_text = R"(
25 HloModule TensorFlowGatherMultipleBatchDims
26
27 ENTRY main {
28 operand = s32[3,3] parameter(0)
29 indices = s32[2147483647,5] parameter(1)
30 ROOT gather = s32[2147483647,3,5] gather(operand, indices),
31 offset_dims={1},
32 collapsed_slice_dims={1},
33 start_index_map={1},
34 index_vector_dim=2,
35 slice_sizes={3, 1}
36 }
37 )";
38 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
39 ParseHloString(hlo_text));
40
41 Status status = GatherExpander{}.Run(module.get()).status();
42 EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
43
44 ASSERT_THAT(
45 status.error_message(),
46 ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
47 "indices are not supported."));
48 }
49
TEST(GatherExpanderTest,AvoidDegenerateDims)50 TEST(GatherExpanderTest, AvoidDegenerateDims) {
51 const string hlo_text = R"(
52 HloModule TensorFlowGatherV2
53
54 ENTRY main {
55 operand = s32[3,3] parameter(0)
56 indices = s32[2] parameter(1)
57 ROOT gather = s32[3,2] gather(operand, indices),
58 offset_dims={0},
59 collapsed_slice_dims={1},
60 start_index_map={1},
61 index_vector_dim=1,
62 slice_sizes={3, 1}
63 }
64 )";
65 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
66 ParseHloString(hlo_text));
67 TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
68 ASSERT_TRUE(changed);
69
70 HloInstruction* while_instr = nullptr;
71 for (auto* instr : module->entry_computation()->instructions()) {
72 if (instr->opcode() == HloOpcode::kWhile) {
73 ASSERT_EQ(while_instr, nullptr)
74 << "Expected exactly one while instruction in the entry computation "
75 "after gather expansion";
76 while_instr = instr;
77 }
78 }
79
80 ASSERT_NE(while_instr, nullptr)
81 << "Expected exactly one while instruction in the entry computation "
82 "after gather expansion";
83
84 // We want to avoid create while loop with shapes that have degenerate
85 // dimensions for TF gather. In this case we expect the loop state to be of
86 // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}). The leading
87 // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't
88 // check it here (though in theory the form of the while loop state is itself
89 // an implementation detail from WhileUtil::MakeCountedLoop).
90
91 const Shape& while_shape = while_instr->shape();
92 ASSERT_TRUE(while_shape.IsTuple());
93 ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4);
94
95 EXPECT_TRUE(ShapeUtil::SameDimensions(
96 ShapeUtil::MakeShape(S32, {3, 3}),
97 ShapeUtil::GetTupleElementShape(while_shape, 1)));
98
99 EXPECT_TRUE(ShapeUtil::SameDimensions(
100 ShapeUtil::MakeShape(S32, {2}),
101 ShapeUtil::GetTupleElementShape(while_shape, 2)));
102
103 EXPECT_TRUE(ShapeUtil::SameDimensions(
104 ShapeUtil::MakeShape(S32, {2, 3}),
105 ShapeUtil::GetTupleElementShape(while_shape, 3)));
106 }
107
TEST(GatherExpanderTest,CheckOpMetadata)108 TEST(GatherExpanderTest, CheckOpMetadata) {
109 const string hlo_text = R"(
110 HloModule TensorFlowGatherV2
111
112 ENTRY main {
113 operand = s32[3,3] parameter(0)
114 indices = s32[2] parameter(1)
115 ROOT gather = s32[3,2] gather(operand, indices),
116 offset_dims={0},
117 collapsed_slice_dims={1},
118 start_index_map={1},
119 index_vector_dim=1,
120 slice_sizes={3, 1}
121 }
122 )";
123 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
124 ParseHloString(hlo_text));
125 OpMetadata metadata;
126 metadata.set_op_name("Gather");
127 module->entry_computation()->root_instruction()->set_metadata(metadata);
128 TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
129 ASSERT_TRUE(changed);
130
131 HloInstruction* while_instr = nullptr;
132 for (auto* instr : module->entry_computation()->instructions()) {
133 if (instr->opcode() == HloOpcode::kWhile) {
134 ASSERT_EQ(while_instr, nullptr)
135 << "Expected exactly one while instruction in the entry computation "
136 "after gather expansion";
137 while_instr = instr;
138 }
139 }
140
141 ASSERT_NE(while_instr, nullptr)
142 << "Expected exactly one while instruction in the entry computation "
143 "after gather expansion";
144 EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
145 }
146 } // namespace
147 } // namespace xla
148