1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
17 
18 #include <memory>
19 #include <string>
20 
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33 
34 namespace xla {
35 namespace {
36 class DynamicParameterBindingTest : public HloTestBase {
37  protected:
38   // Serialize and then deserialize a binding.
SerializeAndDeserialize(DynamicParameterBinding * binding)39   void SerializeAndDeserialize(DynamicParameterBinding* binding) {
40     DynamicParameterBindingProto proto = binding->ToProto();
41     TF_ASSERT_OK_AND_ASSIGN(*binding,
42                             DynamicParameterBinding::CreateFromProto(proto));
43   }
44 };
45 
TEST_F(DynamicParameterBindingTest,SimpleBinding)46 TEST_F(DynamicParameterBindingTest, SimpleBinding) {
47   // 'b' is a dynamic shape; 'a' represents the real size of b's first
48   // dimension.
49   const string module_str = R"(
50 HloModule TEST
51 
52 ENTRY main {
53   a = f32[] parameter(0)
54   b = f32[10] parameter(1)
55   ROOT root = (f32[], f32[10]) tuple(%a, %b)
56 }
57 )";
58   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
59                           ParseHloString(module_str));
60 
61   DynamicParameterBinding binding;
62 
63   TF_EXPECT_OK(
64       binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}},
65                    DynamicParameterBinding::DynamicDimension{1, {}, 0}));
66 
67   auto test = [&](const DynamicParameterBinding& binding) {
68     absl::optional<DynamicParameterBinding::DynamicParameter> param =
69         binding.GetBinding(
70             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
71                                                       /*parameter_index=*/{},
72                                                       /*dimension=*/0});
73     EXPECT_TRUE(param);
74     EXPECT_EQ(param->parameter_num, 0);
75     EXPECT_EQ(param->parameter_index, ShapeIndex({}));
76     TF_EXPECT_OK(binding.Verify(*module));
77   };
78   test(binding);
79   SerializeAndDeserialize(&binding);
80   test(binding);
81 }
82 
TEST_F(DynamicParameterBindingTest,TupleBinding)83 TEST_F(DynamicParameterBindingTest, TupleBinding) {
84   // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's first
85   // dimension.
86   const string module_str = R"(
87 HloModule TEST
88 
89 ENTRY main {
90   param = (f32[], f32[10]) parameter(0)
91   gte1 = f32[] get-tuple-element(%param), index=0
92   gte2 = f32[10] get-tuple-element(%param), index=1
93   ROOT root = (f32[], f32[10]) tuple(%gte1, %gte2)
94 }
95 )";
96   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
97                           ParseHloString(module_str));
98 
99   DynamicParameterBinding binding;
100 
101   TF_EXPECT_OK(
102       binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
103                    DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
104 
105   auto test = [&](const DynamicParameterBinding& binding) {
106     absl::optional<DynamicParameterBinding::DynamicParameter> param =
107         binding.GetBinding(
108             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
109                                                       /*parameter_index=*/{1},
110                                                       /*dimension=*/0});
111 
112     EXPECT_TRUE(param);
113     EXPECT_EQ(param->parameter_num, 0);
114     EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
115     TF_EXPECT_OK(binding.Verify(*module));
116   };
117   test(binding);
118   SerializeAndDeserialize(&binding);
119   test(binding);
120 }
121 
TEST_F(DynamicParameterBindingTest,TupleBindingWithMultiDimension)122 TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) {
123   // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's both
124   // dimensions.
125   const string module_str = R"(
126 HloModule TEST
127 
128 ENTRY main {
129   param = (f32[], f32[10, 10]) parameter(0)
130   gte1 = f32[] get-tuple-element(%param), index=0
131   gte2 = f32[10, 10] get-tuple-element(%param), index=1
132   ROOT root = (f32[], f32[10, 10]) tuple(%gte1, %gte2)
133 }
134 )";
135   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
136                           ParseHloString(module_str));
137 
138   DynamicParameterBinding binding;
139 
140   TF_EXPECT_OK(
141       binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
142                    DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
143 
144   TF_EXPECT_OK(
145       binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
146                    DynamicParameterBinding::DynamicDimension{0, {1}, 1}));
147 
148   auto test = [&](const DynamicParameterBinding& binding) {
149     absl::optional<DynamicParameterBinding::DynamicParameter> param =
150         binding.GetBinding(
151             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
152                                                       /*parameter_index=*/{1},
153                                                       /*dimension=*/0});
154 
155     EXPECT_TRUE(param);
156     EXPECT_EQ(param->parameter_num, 0);
157     EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
158 
159     absl::optional<DynamicParameterBinding::DynamicParameter> param2 =
160 
161         binding.GetBinding(
162             DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
163                                                       /*parameter_index=*/{1},
164                                                       /*dimension=*/0});
165     EXPECT_TRUE(param2);
166     EXPECT_EQ(param2->parameter_num, 0);
167     EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
168     TF_EXPECT_OK(binding.Verify(*module));
169   };
170 
171   test(binding);
172 
173   SerializeAndDeserialize(&binding);
174 
175   // Test the binding again after deserialization.
176   test(binding);
177 }
178 
179 }  // namespace
180 }  // namespace xla
181