1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gather_expander.h"
17 #include "tensorflow/compiler/xla/service/hlo_parser.h"
18 #include "tensorflow/compiler/xla/test.h"
19 #include "tensorflow/compiler/xla/tests/test_macros.h"
20 
21 namespace xla {
22 namespace {
TEST(GatherExpanderTest,ErrorStatusOnTooManyIndices)23 TEST(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
24   const string hlo_text = R"(
25 HloModule TensorFlowGatherMultipleBatchDims
26 
27 ENTRY main {
28   operand = s32[3,3] parameter(0)
29   indices = s32[2147483647,5] parameter(1)
30   ROOT gather = s32[2147483647,3,5] gather(operand, indices),
31       offset_dims={1},
32       collapsed_slice_dims={1},
33       start_index_map={1},
34       index_vector_dim=2,
35       slice_sizes={3, 1}
36 }
37 )";
38   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
39                           ParseHloString(hlo_text));
40 
41   Status status = GatherExpander{}.Run(module.get()).status();
42   EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
43 
44   ASSERT_THAT(
45       status.error_message(),
46       ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
47                            "indices are not supported."));
48 }
49 
TEST(GatherExpanderTest,AvoidDegenerateDims)50 TEST(GatherExpanderTest, AvoidDegenerateDims) {
51   const string hlo_text = R"(
52 HloModule TensorFlowGatherV2
53 
54 ENTRY main {
55   operand = s32[3,3] parameter(0)
56   indices = s32[2] parameter(1)
57   ROOT gather = s32[3,2] gather(operand, indices),
58       offset_dims={0},
59       collapsed_slice_dims={1},
60       start_index_map={1},
61       index_vector_dim=1,
62       slice_sizes={3, 1}
63 }
64 )";
65   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
66                           ParseHloString(hlo_text));
67   TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
68   ASSERT_TRUE(changed);
69 
70   HloInstruction* while_instr = nullptr;
71   for (auto* instr : module->entry_computation()->instructions()) {
72     if (instr->opcode() == HloOpcode::kWhile) {
73       ASSERT_EQ(while_instr, nullptr)
74           << "Expected exactly one while instruction in the entry computation "
75              "after gather expansion";
76       while_instr = instr;
77     }
78   }
79 
80   ASSERT_NE(while_instr, nullptr)
81       << "Expected exactly one while instruction in the entry computation "
82          "after gather expansion";
83 
84   // We want to avoid create while loop with shapes that have degenerate
85   // dimensions for TF gather.  In this case we expect the loop state to be of
86   // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}).  The leading
87   // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't
88   // check it here (though in theory the form of the while loop state is itself
89   // an implementation detail from WhileUtil::MakeCountedLoop).
90 
91   const Shape& while_shape = while_instr->shape();
92   ASSERT_TRUE(while_shape.IsTuple());
93   ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4);
94 
95   EXPECT_TRUE(ShapeUtil::SameDimensions(
96       ShapeUtil::MakeShape(S32, {3, 3}),
97       ShapeUtil::GetTupleElementShape(while_shape, 1)));
98 
99   EXPECT_TRUE(ShapeUtil::SameDimensions(
100       ShapeUtil::MakeShape(S32, {2}),
101       ShapeUtil::GetTupleElementShape(while_shape, 2)));
102 
103   EXPECT_TRUE(ShapeUtil::SameDimensions(
104       ShapeUtil::MakeShape(S32, {2, 3}),
105       ShapeUtil::GetTupleElementShape(while_shape, 3)));
106 }
107 
TEST(GatherExpanderTest,CheckOpMetadata)108 TEST(GatherExpanderTest, CheckOpMetadata) {
109   const string hlo_text = R"(
110 HloModule TensorFlowGatherV2
111 
112 ENTRY main {
113   operand = s32[3,3] parameter(0)
114   indices = s32[2] parameter(1)
115   ROOT gather = s32[3,2] gather(operand, indices),
116       offset_dims={0},
117       collapsed_slice_dims={1},
118       start_index_map={1},
119       index_vector_dim=1,
120       slice_sizes={3, 1}
121 }
122 )";
123   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
124                           ParseHloString(hlo_text));
125   OpMetadata metadata;
126   metadata.set_op_name("Gather");
127   module->entry_computation()->root_instruction()->set_metadata(metadata);
128   TF_ASSERT_OK_AND_ASSIGN(bool changed, GatherExpander{}.Run(module.get()));
129   ASSERT_TRUE(changed);
130 
131   HloInstruction* while_instr = nullptr;
132   for (auto* instr : module->entry_computation()->instructions()) {
133     if (instr->opcode() == HloOpcode::kWhile) {
134       ASSERT_EQ(while_instr, nullptr)
135           << "Expected exactly one while instruction in the entry computation "
136              "after gather expansion";
137       while_instr = instr;
138     }
139   }
140 
141   ASSERT_NE(while_instr, nullptr)
142       << "Expected exactly one while instruction in the entry computation "
143          "after gather expansion";
144   EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
145 }
146 }  // namespace
147 }  // namespace xla
148