1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/optimize_input_output_buffer_alias.h"
17 
18 #include <memory>
19 
20 #include "absl/memory/memory.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 #include "tensorflow/compiler/xla/tests/test_utils.h"
26 #include "tensorflow/core/platform/test.h"
27 
28 namespace xla {
29 
30 // Tests that UserBufferAlias properly maps input and output buffer indices of
31 // various shapes for aliasing.
32 class OptimizeInputOutputBufferAliasTest : public HloTestBase {
33  protected:
OptimizeInputOutputBufferAliasTest()34   OptimizeInputOutputBufferAliasTest() {
35     r1f32_ = ShapeUtil::MakeShape(F32, {4});
36     r2f32_ = ShapeUtil::MakeShape(F32, {4, 5});
37     r3f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6});
38     r4f32_ = ShapeUtil::MakeShape(F32, {4, 5, 6, 7});
39 
40     auto size_func = [](const Shape& shape) {
41       return ShapeUtil::ByteSizeOf(shape);
42     };
43 
44     optimize_pass_ =
45         absl::make_unique<OptimizeInputOutputBufferAlias>(size_func);
46   }
47 
48   // Returns the number of output indices that aliases with the input.
AliasCount()49   int64 AliasCount() {
50     int64 count = 0;
51 
52     config_.ForEachAlias(
53         [&](const ShapeIndex&, const HloInputOutputAliasConfig::Alias&) {
54           count++;
55         });
56     return count;
57   }
58 
BuildAliasConfig(const Shape & input_shape,const Shape & output_shape)59   bool BuildAliasConfig(const Shape& input_shape, const Shape& output_shape) {
60     config_ = HloInputOutputAliasConfig(output_shape);
61     auto changed = optimize_pass_->Build(input_shape, output_shape, &config_);
62     TF_CHECK_OK(changed.status());
63 
64     return changed.ValueOrDie();
65   }
66 
67   std::unique_ptr<OptimizeInputOutputBufferAlias> optimize_pass_;
68 
69   HloInputOutputAliasConfig config_;
70 
71   Shape r1f32_;
72   Shape r2f32_;
73   Shape r3f32_;
74   Shape r4f32_;
75 };
76 
77 // All shapes are different, so no aliasing is available.
TEST_F(OptimizeInputOutputBufferAliasTest,AllDifferentBufferSizes)78 TEST_F(OptimizeInputOutputBufferAliasTest, AllDifferentBufferSizes) {
79   Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_});
80   Shape output = ShapeUtil::MakeTupleShape({r3f32_, r4f32_});
81   bool changed = BuildAliasConfig(input, output);
82   EXPECT_FALSE(changed);
83   EXPECT_EQ(AliasCount(), 0);
84 }
85 
86 // Input and output shapes are equal, so buffers can alias at the same index.
TEST_F(OptimizeInputOutputBufferAliasTest,OrderedNonNestedTuple)87 TEST_F(OptimizeInputOutputBufferAliasTest, OrderedNonNestedTuple) {
88   Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_});
89   Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_});
90   bool changed = BuildAliasConfig(input, output);
91   EXPECT_TRUE(changed);
92   EXPECT_EQ(AliasCount(), 4);
93 
94   EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0});
95   EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{1});
96   EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{2});
97   EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{3});
98 }
99 
100 // Only a subset of the tuple element shapes match between the input and the
101 // output.
TEST_F(OptimizeInputOutputBufferAliasTest,PartialReuseNonNestedTuple)102 TEST_F(OptimizeInputOutputBufferAliasTest, PartialReuseNonNestedTuple) {
103   Shape input = ShapeUtil::MakeTupleShape({r1f32_, r1f32_, r2f32_, r2f32_});
104   Shape output = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_});
105   bool changed = BuildAliasConfig(input, output);
106   EXPECT_TRUE(changed);
107 
108   EXPECT_EQ(AliasCount(), 2);
109 
110   EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{0});
111   EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1});
112 }
113 
114 // The output shape is reverse of the input shape, but we can still reuse all
115 // the buffers.
TEST_F(OptimizeInputOutputBufferAliasTest,UnorderedNonNestedTuple)116 TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNonNestedTuple) {
117   Shape input = ShapeUtil::MakeTupleShape({r1f32_, r2f32_, r3f32_, r4f32_});
118   Shape output = ShapeUtil::MakeTupleShape({r4f32_, r3f32_, r2f32_, r1f32_});
119   bool changed = BuildAliasConfig(input, output);
120   EXPECT_TRUE(changed);
121 
122   EXPECT_EQ(AliasCount(), 4);
123 
124   EXPECT_EQ(config_.GetAliasedOutput(0, {0}), ShapeIndex{3});
125   EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex{2});
126   EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex{1});
127   EXPECT_EQ(config_.GetAliasedOutput(0, {3}), ShapeIndex{0});
128 }
129 
TEST_F(OptimizeInputOutputBufferAliasTest,UnorderedNestedTuple)130 TEST_F(OptimizeInputOutputBufferAliasTest, UnorderedNestedTuple) {
131   Shape input = ShapeUtil::MakeTupleShape(
132       {ShapeUtil::MakeTupleShape({r1f32_}), r2f32_, r3f32_, r4f32_});
133   Shape output = ShapeUtil::MakeTupleShape(
134       {r1f32_, ShapeUtil::MakeTupleShape({r3f32_, r2f32_}), r2f32_});
135   bool changed = BuildAliasConfig(input, output);
136   EXPECT_TRUE(changed);
137 
138   EXPECT_EQ(AliasCount(), 3);
139 
140   EXPECT_EQ(config_.GetAliasedOutput(0, {0, 0}), ShapeIndex{0});
141   EXPECT_EQ(config_.GetAliasedOutput(0, {1}), ShapeIndex({1, 1}));
142   EXPECT_EQ(config_.GetAliasedOutput(0, {2}), ShapeIndex({1, 0}));
143 }
144 
145 }  // namespace xla
146