1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
17
18 #include <memory>
19 #include <string>
20
21 #include "absl/algorithm/container.h"
22 #include "tensorflow/compiler/xla/service/hlo_computation.h"
23 #include "tensorflow/compiler/xla/service/hlo_dce.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/compiler/xla/service/hlo_memory_scheduler.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/hlo_ordering.h"
28 #include "tensorflow/compiler/xla/service/hlo_parser.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/lib/core/status_test_util.h"
33
34 namespace xla {
35 namespace {
36 class DynamicParameterBindingTest : public HloTestBase {
37 protected:
38 // Serialize and then deserialize a binding.
SerializeAndDeserialize(DynamicParameterBinding * binding)39 void SerializeAndDeserialize(DynamicParameterBinding* binding) {
40 DynamicParameterBindingProto proto = binding->ToProto();
41 TF_ASSERT_OK_AND_ASSIGN(*binding,
42 DynamicParameterBinding::CreateFromProto(proto));
43 }
44 };
45
TEST_F(DynamicParameterBindingTest,SimpleBinding)46 TEST_F(DynamicParameterBindingTest, SimpleBinding) {
47 // 'b' is a dynamic shape; 'a' represents the real size of b's first
48 // dimension.
49 const string module_str = R"(
50 HloModule TEST
51
52 ENTRY main {
53 a = f32[] parameter(0)
54 b = f32[10] parameter(1)
55 ROOT root = (f32[], f32[10]) tuple(%a, %b)
56 }
57 )";
58 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
59 ParseHloString(module_str));
60
61 DynamicParameterBinding binding;
62
63 TF_EXPECT_OK(
64 binding.Bind(DynamicParameterBinding::DynamicParameter{0, {}},
65 DynamicParameterBinding::DynamicDimension{1, {}, 0}));
66
67 auto test = [&](const DynamicParameterBinding& binding) {
68 absl::optional<DynamicParameterBinding::DynamicParameter> param =
69 binding.GetBinding(
70 DynamicParameterBinding::DynamicDimension{/*parameter_num=*/1,
71 /*parameter_index=*/{},
72 /*dimension=*/0});
73 EXPECT_TRUE(param);
74 EXPECT_EQ(param->parameter_num, 0);
75 EXPECT_EQ(param->parameter_index, ShapeIndex({}));
76 TF_EXPECT_OK(binding.Verify(*module));
77 };
78 test(binding);
79 SerializeAndDeserialize(&binding);
80 test(binding);
81 }
82
TEST_F(DynamicParameterBindingTest,TupleBinding)83 TEST_F(DynamicParameterBindingTest, TupleBinding) {
84 // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's first
85 // dimension.
86 const string module_str = R"(
87 HloModule TEST
88
89 ENTRY main {
90 param = (f32[], f32[10]) parameter(0)
91 gte1 = f32[] get-tuple-element(%param), index=0
92 gte2 = f32[10] get-tuple-element(%param), index=1
93 ROOT root = (f32[], f32[10]) tuple(%gte1, %gte2)
94 }
95 )";
96 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
97 ParseHloString(module_str));
98
99 DynamicParameterBinding binding;
100
101 TF_EXPECT_OK(
102 binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
103 DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
104
105 auto test = [&](const DynamicParameterBinding& binding) {
106 absl::optional<DynamicParameterBinding::DynamicParameter> param =
107 binding.GetBinding(
108 DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
109 /*parameter_index=*/{1},
110 /*dimension=*/0});
111
112 EXPECT_TRUE(param);
113 EXPECT_EQ(param->parameter_num, 0);
114 EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
115 TF_EXPECT_OK(binding.Verify(*module));
116 };
117 test(binding);
118 SerializeAndDeserialize(&binding);
119 test(binding);
120 }
121
TEST_F(DynamicParameterBindingTest,TupleBindingWithMultiDimension)122 TEST_F(DynamicParameterBindingTest, TupleBindingWithMultiDimension) {
123 // 'gte2' is a dynamic shape; 'gte1' represents the real size of gte2's both
124 // dimensions.
125 const string module_str = R"(
126 HloModule TEST
127
128 ENTRY main {
129 param = (f32[], f32[10, 10]) parameter(0)
130 gte1 = f32[] get-tuple-element(%param), index=0
131 gte2 = f32[10, 10] get-tuple-element(%param), index=1
132 ROOT root = (f32[], f32[10, 10]) tuple(%gte1, %gte2)
133 }
134 )";
135 TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
136 ParseHloString(module_str));
137
138 DynamicParameterBinding binding;
139
140 TF_EXPECT_OK(
141 binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
142 DynamicParameterBinding::DynamicDimension{0, {1}, 0}));
143
144 TF_EXPECT_OK(
145 binding.Bind(DynamicParameterBinding::DynamicParameter{0, {0}},
146 DynamicParameterBinding::DynamicDimension{0, {1}, 1}));
147
148 auto test = [&](const DynamicParameterBinding& binding) {
149 absl::optional<DynamicParameterBinding::DynamicParameter> param =
150 binding.GetBinding(
151 DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
152 /*parameter_index=*/{1},
153 /*dimension=*/0});
154
155 EXPECT_TRUE(param);
156 EXPECT_EQ(param->parameter_num, 0);
157 EXPECT_EQ(param->parameter_index, ShapeIndex({0}));
158
159 absl::optional<DynamicParameterBinding::DynamicParameter> param2 =
160
161 binding.GetBinding(
162 DynamicParameterBinding::DynamicDimension{/*parameter_num=*/0,
163 /*parameter_index=*/{1},
164 /*dimension=*/0});
165 EXPECT_TRUE(param2);
166 EXPECT_EQ(param2->parameter_num, 0);
167 EXPECT_EQ(param2->parameter_index, ShapeIndex({0}));
168 TF_EXPECT_OK(binding.Verify(*module));
169 };
170
171 test(binding);
172
173 SerializeAndDeserialize(&binding);
174
175 // Test the binding again after deserialization.
176 test(binding);
177 }
178
179 } // namespace
180 } // namespace xla
181