1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/hash/hash.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22
23 namespace tensorflow {
24
25 namespace {
26 // A RegistrationInfo object stores a collective implementation registration
27 // details. `factory` is used to create instances of the collective
28 // implementation.
29 struct RegistrationInfo {
30 // This constructor also creates, and stores in `param_resolver_instance`,
31 // what is effectively a static instance of the collective implementation.
32 // During param resolution of collective ops we return this static instance.
33 // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anon07f350080111::RegistrationInfo34 RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
35 : name(n),
36 factory(std::move(f)),
37 param_resolver_instance(this->factory()) {}
38 string name;
39 CollectiveRegistry::Factory factory;
40 CollectiveImplementationInterface* param_resolver_instance;
41 };
42
MutableCollectiveRegistry()43 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
44 static std::vector<RegistrationInfo>* registry =
45 new std::vector<RegistrationInfo>;
46 return registry;
47 }
48 } // namespace
49
ToString() const50 string CollGroupParams::ToString() const {
51 return strings::StrCat("CollGroupParams {group_key=", group_key,
52 " group_size=", group_size,
53 " device_type=", device_type.type_string(),
54 " num_tasks=", num_tasks, "}");
55 }
56
operator =(const CollInstanceParams & other)57 CollInstanceParams& CollInstanceParams::operator=(
58 const CollInstanceParams& other) {
59 if (this != &other) {
60 instance_key = other.instance_key;
61 type = other.type;
62 data_type = other.data_type;
63 shape = other.shape;
64 device_names.clear();
65 device_names.assign(other.device_names.begin(), other.device_names.end());
66 task_names.assign(other.task_names.begin(), other.task_names.end());
67 same_num_devices_per_task = other.same_num_devices_per_task;
68 num_devices_per_task = other.num_devices_per_task;
69 gpu_ring_order = other.gpu_ring_order;
70 communicator_key = other.communicator_key;
71 impl_details.subdiv_offsets.assign(
72 other.impl_details.subdiv_offsets.begin(),
73 other.impl_details.subdiv_offsets.end());
74 impl_details.subdiv_permutations.clear();
75 for (auto p : other.impl_details.subdiv_permutations) {
76 impl_details.subdiv_permutations.push_back(
77 std::vector<int>(p.begin(), p.end()));
78 }
79 impl_details.subdiv_source_rank.assign(
80 other.impl_details.subdiv_source_rank.begin(),
81 other.impl_details.subdiv_source_rank.end());
82 impl_details.dependencies = other.impl_details.dependencies;
83 }
84 return *this;
85 }
86
ToString() const87 string CollInstanceParams::ToString() const {
88 string v = strings::StrCat("CollInstanceParams { instance_key=", instance_key,
89 " type=", type, " data_type=", data_type,
90 " shape=", shape.DebugString(), " devices {");
91 for (const auto& d : device_names) {
92 strings::StrAppend(&v, d, ",");
93 }
94 strings::StrAppend(&v, "} task_names={");
95 for (const auto& n : task_names) {
96 strings::StrAppend(&v, n, ", ");
97 }
98 strings::StrAppend(&v, "} num_devices_per_task={");
99 for (const auto dpt : num_devices_per_task) {
100 strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
101 }
102 strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
103 ", communicator_key=", str_util::CEscape(communicator_key),
104 ", subdiv_offsets={");
105 strings::StrAppend(&v, "}, subdiv_offsets={");
106 for (const auto& d : impl_details.subdiv_offsets) {
107 strings::StrAppend(&v, d, ",");
108 }
109 strings::StrAppend(&v, "}, subdiv_perms={");
110 for (const auto& p : impl_details.subdiv_permutations) {
111 strings::StrAppend(&v, "{");
112 for (const auto& i : p) {
113 strings::StrAppend(&v, i, ",");
114 }
115 strings::StrAppend(&v, "}"); // one subdiv
116 }
117 if (!impl_details.subdiv_source_rank.empty()) {
118 strings::StrAppend(&v, " subdiv_source_rank={");
119 for (const auto& r : impl_details.subdiv_source_rank) {
120 strings::StrAppend(&v, r, ",");
121 }
122 strings::StrAppend(&v, "}");
123 }
124 strings::StrAppend(&v, "}"); // all subdivs
125 return v;
126 }
127
ToString() const128 string CollTaskParams::ToString() const {
129 string v = strings::StrCat("CollTaskParams {is_local={");
130 for (const auto& b : is_local) {
131 strings::StrAppend(&v, static_cast<int>(b), ",");
132 }
133 strings::StrAppend(&v, "}}");
134 return v;
135 }
136
ToString() const137 string CollectiveParams::ToString() const {
138 string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
139 strings::StrAppend(&v, " ", instance.ToString());
140 strings::StrAppend(&v, " ", task.ToString());
141 strings::StrAppend(&v, " default_rank=", default_rank,
142 " is_source=", is_source, " source_rank=", source_rank,
143 " subdiv_rank={");
144 for (const auto& r : subdiv_rank) {
145 strings::StrAppend(&v, r, ",");
146 }
147 strings::StrAppend(&v, "}}");
148 return v;
149 }
150
CtxParams(OpKernelContext * ctx)151 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
152 OpKernelContext* ctx) {
153 return ctx->params_;
154 }
155
CollectiveContext(CollectiveExecutor * col_exec,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams & col_params,const string & exec_key,int64 step_id,const Tensor * input,Tensor * output)156 CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
157 const DeviceMgr* dev_mgr,
158 OpKernelContext* ctx,
159 OpKernelContext::Params* op_params,
160 const CollectiveParams& col_params,
161 const string& exec_key, int64 step_id,
162 const Tensor* input, Tensor* output)
163 : col_exec(col_exec),
164 dev_mgr(dev_mgr),
165 op_ctx(ctx),
166 op_params(op_params),
167 col_params(col_params),
168 exec_key(exec_key),
169 step_id(step_id),
170 input(input),
171 output(output),
172 device(nullptr),
173 device_name(col_params.instance.device_names[col_params.default_rank]) {}
174
175 /*static*/
176 int64 CollectiveExecutor::kInvalidId = -1;
177
178 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)179 Status CollectiveRegistry::Lookup(
180 const string& collective_name,
181 CollectiveImplementationInterface** implementation) {
182 return LookupHelper(collective_name, implementation, false);
183 }
184
185 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)186 Status CollectiveRegistry::LookupParamResolverInstance(
187 const string& collective_name,
188 CollectiveImplementationInterface** implementation) {
189 return LookupHelper(collective_name, implementation, true);
190 }
191
192 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)193 void CollectiveRegistry::GetAll(
194 std::vector<CollectiveImplementationInterface*>* implementations) {
195 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
196 for (const RegistrationInfo& reg_info : *registry)
197 implementations->emplace_back(reg_info.factory());
198 }
199
200 /*static*/
Register(const string & collective_name,Factory factory)201 Status CollectiveRegistry::Register(const string& collective_name,
202 Factory factory) {
203 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
204 for (const RegistrationInfo& reg_info : *registry) {
205 if (reg_info.name == collective_name)
206 return errors::Internal("Already registered collective ",
207 collective_name);
208 }
209 registry->emplace_back(collective_name, std::move(factory));
210 return Status::OK();
211 }
212
213 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)214 Status CollectiveRegistry::LookupHelper(
215 const string& collective_name,
216 CollectiveImplementationInterface** implementation, bool param_resolver) {
217 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
218 for (const RegistrationInfo& reg_info : *registry) {
219 if (reg_info.name == collective_name) {
220 if (param_resolver) {
221 *implementation = reg_info.param_resolver_instance;
222 } else {
223 *implementation = reg_info.factory();
224 }
225 return Status::OK();
226 }
227 }
228 return errors::Internal(
229 "CollectiveRegistry::Lookup did not find collective implementation ",
230 collective_name);
231 }
232
233 } // namespace tensorflow
234