1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_parameter_binding.h"
17 #include "tensorflow/compiler/xla/service/hlo_computation.h"
18 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
19 #include "tensorflow/compiler/xla/service/hlo_module.h"
20 
21 namespace xla {
22 
Bind(const DynamicParameter & dynamic_parameter,const DynamicDimension & dynamic_dimension)23 Status DynamicParameterBinding::Bind(
24     const DynamicParameter& dynamic_parameter,
25     const DynamicDimension& dynamic_dimension) {
26   auto result = bindings_.emplace(dynamic_dimension, dynamic_parameter);
27   TF_RET_CHECK(result.second);
28   return Status::OK();
29 }
30 
31 absl::optional<DynamicParameterBinding::DynamicParameter>
GetBinding(const DynamicDimension & dynamic_dimension) const32 DynamicParameterBinding::GetBinding(
33     const DynamicDimension& dynamic_dimension) const {
34   auto param_iter = bindings_.find(dynamic_dimension);
35   if (param_iter == bindings_.end()) {
36     return absl::nullopt;
37   }
38   return param_iter->second;
39 }
40 
ToProto() const41 DynamicParameterBindingProto DynamicParameterBinding::ToProto() const {
42   DynamicParameterBindingProto result;
43   for (const auto& binding : bindings_) {
44     const DynamicDimension& dynamic_dimension = binding.first;
45     const DynamicParameter& dynamic_param = binding.second;
46     DynamicParameterBindingProto::Binding binding_proto;
47     binding_proto.set_dynamic_param_num(dynamic_param.parameter_num);
48     for (int64 i : dynamic_param.parameter_index) {
49       binding_proto.add_dynamic_param_index(i);
50     }
51 
52     binding_proto.set_target_param_num(dynamic_dimension.parameter_num);
53 
54     for (int64 i : dynamic_dimension.parameter_index) {
55       binding_proto.add_target_param_index(i);
56     }
57 
58     binding_proto.set_target_param_dim_num(dynamic_dimension.dimension);
59     result.add_entries()->Swap(&binding_proto);
60   }
61   return result;
62 }
63 
CreateFromProto(const DynamicParameterBindingProto & proto)64 StatusOr<DynamicParameterBinding> DynamicParameterBinding::CreateFromProto(
65     const DynamicParameterBindingProto& proto) {
66   DynamicParameterBinding result;
67   for (const DynamicParameterBindingProto::Binding& binding : proto.entries()) {
68     int64 dynamic_param_num = binding.dynamic_param_num();
69     ShapeIndex dynamic_param_index(binding.dynamic_param_index().begin(),
70                                    binding.dynamic_param_index().end());
71     int64 target_param_num = binding.target_param_num();
72     ShapeIndex target_param_index(binding.target_param_index().begin(),
73                                   binding.target_param_index().end());
74     int64 target_dim_num = binding.target_param_dim_num();
75 
76     TF_RETURN_IF_ERROR(
77         result.Bind(DynamicParameter{dynamic_param_num, dynamic_param_index},
78                     DynamicDimension{target_param_num, target_param_index,
79                                      target_dim_num}));
80   }
81 
82   return result;
83 }
84 
ToString() const85 string DynamicParameterBinding::ToString() const {
86   std::vector<string> pieces;
87   pieces.push_back("DynamicParameterBinding: ");
88   for (const auto& binding : bindings_) {
89     const DynamicDimension& dynamic_dimension = binding.first;
90     const DynamicParameter& dynamic_param = binding.second;
91     pieces.push_back(absl::StrFormat(
92         " -- Input param number %lld at %s has dim %lld as dynamic"
93         " dimension, which is represented by param number %lld at "
94         "%s",
95         dynamic_dimension.parameter_num,
96         dynamic_dimension.parameter_index.ToString(),
97         dynamic_dimension.dimension, dynamic_param.parameter_num,
98         dynamic_param.parameter_index.ToString()));
99   }
100   return absl::StrJoin(pieces, "\n");
101 }
102 
ForEachBinding(BindingFn fn) const103 Status DynamicParameterBinding::ForEachBinding(BindingFn fn) const {
104   for (const auto& binding : bindings_) {
105     TF_RETURN_IF_ERROR(fn(binding.second, binding.first));
106   }
107   return Status::OK();
108 }
109 
Verify(const HloModule & module) const110 Status DynamicParameterBinding::Verify(const HloModule& module) const {
111   const HloComputation* entry = module.entry_computation();
112   return ForEachBinding([&](const DynamicParameter& dynamic_parameter,
113                             const DynamicDimension& dynamic_dimension)
114                             -> Status {
115     TF_RET_CHECK(dynamic_parameter.parameter_num >= 0 &&
116                  dynamic_parameter.parameter_num < entry->num_parameters());
117     TF_RET_CHECK(dynamic_dimension.parameter_num < entry->num_parameters());
118     TF_RET_CHECK(ShapeUtil::IndexIsValid(
119         entry->parameter_instruction(dynamic_parameter.parameter_num)->shape(),
120         dynamic_parameter.parameter_index));
121     TF_RET_CHECK(ShapeUtil::IndexIsValid(
122         entry->parameter_instruction(dynamic_dimension.parameter_num)->shape(),
123         dynamic_dimension.parameter_index));
124     TF_RET_CHECK(
125         dynamic_dimension.dimension <
126         ShapeUtil::GetSubshape(
127             entry->parameter_instruction(dynamic_dimension.parameter_num)
128                 ->shape(),
129             dynamic_dimension.parameter_index)
130             .rank());
131     return Status::OK();
132   });
133 }
134 
operator <<(std::ostream & out,const DynamicParameterBinding & binding)135 std::ostream& operator<<(std::ostream& out,
136                          const DynamicParameterBinding& binding) {
137   out << binding.ToString();
138   return out;
139 }
140 
141 }  // namespace xla
142