1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20
21 namespace xla {
22
Bind(const DynamicParameter & dynamic_parameter,const DynamicDimension & dynamic_dimension)23 Status DynamicParameterBinding::Bind(
24 const DynamicParameter& dynamic_parameter,
25 const DynamicDimension& dynamic_dimension) {
26 auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter);
27 TF_RET_CHECK(result.second);
28 return Status::OK();
29 }
30
31 absl::optional<DynamicParameterBinding::DynamicParameter>
GetBinding(const DynamicDimension & dynamic_dimension) const32 DynamicParameterBinding::GetBinding(
33 const DynamicDimension& dynamic_dimension) const {
34 auto param_iter = bindings_.find(dynamic_dimension);
35 if (param_iter == bindings_.end()) {
36 return absl::nullopt;
37 }
38 return param_iter->second;
39 }
40
ToProto() const41 DynamicParameterBindingProto DynamicParameterBinding::ToProto() const {
42 DynamicParameterBindingProto result;
43 for (const auto& binding : bindings_) {
44 const DynamicDimension& dynamic_dimension = binding.first;
45 const DynamicParameter& dynamic_param = binding.second;
46 DynamicParameterBindingProto::Binding binding_proto;
47 binding_proto.set_dynamic_param_num(dynamic_param.parameter_num);
48 for (int64 i : dynamic_param.parameter_index) {
49 binding_proto.add_dynamic_param_index(i);
50 }
51
52 binding_proto.set_target_param_num(dynamic_dimension.parameter_num);
53
54 for (int64 i : dynamic_dimension.parameter_index) {
55 binding_proto.add_target_param_index(i);
56 }
57
58 binding_proto.set_target_param_dim_num(dynamic_dimension.dimension);
59 result.add_entries()->Swap(&binding_proto);
60 }
61 return result;
62 }
63
CreateFromProto(const DynamicParameterBindingProto & proto)64 StatusOr<DynamicParameterBinding> DynamicParameterBinding::CreateFromProto(
65 const DynamicParameterBindingProto& proto) {
66 DynamicParameterBinding result;
67 for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) {
68 int64 dynamic_param_num = binding.dynamic_param_num();
69 ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(),
70 binding.dynamic_param_index().end());
71 int64 target_param_num = binding.target_param_num();
72 ShapeIndex target_param_index(binding.target_param_index().begin(),
73 binding.target_param_index().end());
74 int64 target_dim_num = binding.target_param_dim_num();
75
76 TF_RETURN_IF_ERROR(
77 result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index},
78 DynamicDimension{target_param_num, target_param_index,
79 target_dim_num}));
80 }
81
82 return result;
83 }
84
ToString() const85 string DynamicParameterBinding::ToString() const {
86 std::vector<string> pieces;
87 pieces.push_back("DynamicParameterBinding: ");
88 for (const auto& binding : bindings_) {
89 const DynamicDimension& dynamic_dimension = binding.first;
90 const DynamicParameter& dynamic_param = binding.second;
91 pieces.push_back(absl::StrFormat(
92 " -- Input param number %lld at %s has dim %lld as dynamic"
93 " dimension, which is represented by param number %lld at "
94 "%s",
95 dynamic_dimension.parameter_num,
96 dynamic_dimension.parameter_index.ToString(),
97 dynamic_dimension.dimension, dynamic_param.parameter_num,
98 dynamic_param.parameter_index.ToString()));
99 }
100 return absl::StrJoin(pieces, "\n");
101 }
102
ForEachBinding(BindingFn fn) const103 Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const {
104 for (const auto& binding : bindings_) {
105 TF_RETURN_IF_ERROR(fn(binding.second, binding.first));
106 }
107 return Status::OK();
108 }
109
Verify(const HloModule & module) const110 Status DynamicParameterBinding::Verify(const HloModule& module) const {
111 const HloComputation* entry = module.entry_computation();
112 return ForEachBinding([&](const DynamicParameter& dynamic_parameter,
113 const DynamicDimension& dynamic_dimension)
114 -> Status {
115 TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 &&
116 dynamic_parameter.parameter_num < entry->num_parameters());
117 TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters());
118 TF_RET_CHECK(ShapeUtil::IndexIsValid(
119 entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(),
120 dynamic_parameter.parameter_index));
121 TF_RET_CHECK(ShapeUtil::IndexIsValid(
122 entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(),
123 dynamic_dimension.parameter_index));
124 TF_RET_CHECK(
125 dynamic_dimension.dimension <
126 ShapeUtil::GetSubshape(
127 entry->parameter_instruction(dynamic_dimension.parameter_num)
128 ->shape(),
129 dynamic_dimension.parameter_index)
130 .rank());
131 return Status::OK();
132 });
133 }
134
operator <<(std::ostream & out,const DynamicParameterBinding & binding)135 std::ostream& operator<<(std::ostream& out,
136 const DynamicParameterBinding& binding) {
137 out << binding.ToString();
138 return out;
139 }
140
141 } // namespace xla
142