1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/xla/service/tuple_simplifier.h" 17 18 #include <memory> 19 #include <utility> 20 21 #include "tensorflow/compiler/xla/literal_util.h" 22 #include "tensorflow/compiler/xla/service/hlo_computation.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/compiler/xla/service/hlo_matchers.h" 25 #include "tensorflow/compiler/xla/service/hlo_opcode.h" 26 #include "tensorflow/compiler/xla/shape_util.h" 27 #include "tensorflow/compiler/xla/test.h" 28 #include "tensorflow/compiler/xla/tests/hlo_test_base.h" 29 #include "tensorflow/compiler/xla/types.h" 30 #include "tensorflow/core/lib/core/status_test_util.h" 31 32 namespace op = xla::testing::opcode_matchers; 33 34 namespace xla { 35 namespace { 36 37 class TupleSimplifierTest : public HloTestBase { 38 protected: 39 void Run(HloModule* module, bool change_expected) { 40 TupleSimplifier simplifier; 41 auto changed_status = simplifier.Run(module); 42 TF_ASSERT_OK(changed_status.status()); 43 EXPECT_EQ(change_expected, changed_status.ValueOrDie()); 44 } 45 46 const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); 47 const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( 48 {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}), 49 ShapeUtil::MakeShape(F32, {})}); 50 }; 51 52 TEST_F(TupleSimplifierTest, TupleOfParameters) { 53 // A Tuple constructed of a bunch of parameters should not be changed. 54 HloComputation::Builder builder(TestName()); 55 HloInstruction* param0 = builder.AddInstruction( 56 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 57 HloInstruction* param1 = builder.AddInstruction( 58 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 59 HloInstruction* param2 = builder.AddInstruction( 60 HloInstruction::CreateParameter(2, scalar_shape_, "param2")); 61 builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); 62 auto module = CreateNewModule(); 63 module->AddEntryComputation(builder.Build()); 64 65 Run(module.get(), /*change_expected=*/false); 66 } 67 68 TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { 69 // A GTE of a tuple parameter should not be changed. 70 HloComputation::Builder builder(TestName()); 71 HloInstruction* param = builder.AddInstruction( 72 HloInstruction::CreateParameter(0, tuple_shape_, "param")); 73 builder.AddInstruction( 74 HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); 75 auto module = CreateNewModule(); 76 module->AddEntryComputation(builder.Build()); 77 78 Run(module.get(), /*change_expected=*/false); 79 } 80 81 TEST_F(TupleSimplifierTest, GteOfTuple) { 82 // A GTE of a Tuple should be short-circuited. 83 HloComputation::Builder builder(TestName()); 84 HloInstruction* param0 = builder.AddInstruction( 85 HloInstruction::CreateParameter(0, scalar_shape_, "param0")); 86 HloInstruction* param1 = builder.AddInstruction( 87 HloInstruction::CreateParameter(1, scalar_shape_, "param1")); 88 HloInstruction* param2 = builder.AddInstruction( 89 HloInstruction::CreateParameter(2, scalar_shape_, "param2")); 90 HloInstruction* tuple = builder.AddInstruction( 91 HloInstruction::CreateTuple({param0, param1, param2})); 92 HloInstruction* gte = builder.AddInstruction( 93 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); 94 95 auto module = CreateNewModule(); 96 auto computation = module->AddEntryComputation(builder.Build()); 97 98 EXPECT_THAT(computation->root_instruction(), gte); 99 100 Run(module.get(), /*change_expected=*/true); 101 102 EXPECT_THAT(computation->root_instruction(), param1); 103 } 104 105 TEST_F(TupleSimplifierTest, GteOfTupleChain) { 106 // Verify a chain of GTE/Tuple instructions is collapsed. 107 HloComputation::Builder builder(TestName()); 108 HloInstruction* param = builder.AddInstruction( 109 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 110 111 const int kChainLength = 10; 112 HloInstruction* element = param; 113 for (int i = 0; i < kChainLength; ++i) { 114 HloInstruction* tuple = builder.AddInstruction( 115 HloInstruction::CreateTuple({element, element, element})); 116 element = builder.AddInstruction( 117 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); 118 } 119 builder.AddInstruction( 120 HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element)); 121 122 auto module = CreateNewModule(); 123 auto computation = module->AddEntryComputation(builder.Build()); 124 125 EXPECT_THAT(computation->root_instruction(), 126 op::Negate(op::GetTupleElement(op::Tuple()))); 127 128 Run(module.get(), /*change_expected=*/true); 129 130 EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); 131 } 132 133 TEST_F(TupleSimplifierTest, NestedGteOfTuples) { 134 // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested 135 // to some depth with a chain of Tuple instructions, then extracted with a 136 // chain of GTE instructions. 137 HloComputation::Builder builder(TestName()); 138 HloInstruction* param = builder.AddInstruction( 139 HloInstruction::CreateParameter(0, scalar_shape_, "param")); 140 141 const int kNestingDepth = 5; 142 HloInstruction* nested_tuple = param; 143 for (int i = 0; i < kNestingDepth; ++i) { 144 nested_tuple = builder.AddInstruction( 145 HloInstruction::CreateTuple({nested_tuple, nested_tuple})); 146 } 147 148 HloInstruction* element = nested_tuple; 149 for (int i = 0; i < kNestingDepth; ++i) { 150 element = builder.AddInstruction(HloInstruction::CreateGetTupleElement( 151 ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0)); 152 } 153 154 auto module = CreateNewModule(); 155 auto computation = module->AddEntryComputation(builder.Build()); 156 157 EXPECT_THAT(computation->root_instruction(), element); 158 159 Run(module.get(), /*change_expected=*/true); 160 161 EXPECT_THAT(computation->root_instruction(), param); 162 } 163 164 TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { 165 // Verify that a tuple constructed of GTE instructions operating on the same 166 // tuple are collapsed. 167 HloComputation::Builder builder(TestName()); 168 HloInstruction* tuple_param = builder.AddInstruction( 169 HloInstruction::CreateParameter(0, tuple_shape_, "param")); 170 HloInstruction* gte0 = builder.AddInstruction( 171 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0)); 172 HloInstruction* gte1 = builder.AddInstruction( 173 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1)); 174 HloInstruction* gte2 = builder.AddInstruction( 175 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2)); 176 HloInstruction* tuple = 177 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); 178 179 auto module = CreateNewModule(); 180 auto computation = module->AddEntryComputation(builder.Build()); 181 182 EXPECT_THAT(computation->root_instruction(), tuple); 183 184 Run(module.get(), /*change_expected=*/true); 185 186 EXPECT_THAT(computation->root_instruction(), tuple_param); 187 } 188 189 TEST_F(TupleSimplifierTest, IncompatibleTuples) { 190 // Verify that a tuple->GTE->tuple construct is not simplified if the input 191 // and output tuple are not compatible shapes. 192 HloComputation::Builder builder(TestName()); 193 HloInstruction* tuple_param = builder.AddInstruction( 194 HloInstruction::CreateParameter(0, tuple_shape_, "param")); 195 HloInstruction* gte0 = builder.AddInstruction( 196 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0)); 197 HloInstruction* gte1 = builder.AddInstruction( 198 HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1)); 199 // Output tuple has only two elements. Parameter tuple has three elements so 200 // simplification is not possible. 201 HloInstruction* tuple = 202 builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); 203 204 auto module = CreateNewModule(); 205 auto computation = module->AddEntryComputation(builder.Build()); 206 207 EXPECT_THAT(computation->root_instruction(), tuple); 208 209 Run(module.get(), /*change_expected=*/false); 210 211 EXPECT_THAT(computation->root_instruction(), tuple); 212 } 213 214 } // namespace 215 } // namespace xla 216