1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16 
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 // A RegistrationInfo object stores a collective implementation registration
28 // details.  `factory` is used to create instances of the collective
29 // implementation.
30 struct RegistrationInfo {
31   // This constructor also creates, and stores in `param_resolver_instance`,
32   // what is effectively a static instance of the collective implementation.
33   // During param resolution of collective ops we return this static instance.
34   // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anon07f350080111::RegistrationInfo35   RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36       : name(n),
37         factory(std::move(f)),
38         param_resolver_instance(this->factory()) {}
39   string name;
40   CollectiveRegistry::Factory factory;
41   CollectiveImplementationInterface* param_resolver_instance;
42 };
43 
MutableCollectiveRegistry()44 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45   static std::vector<RegistrationInfo>* registry =
46       new std::vector<RegistrationInfo>;
47   return registry;
48 }
49 }  // namespace
50 
ToString() const51 string CollGroupRuntimeDetails::ToString() const {
52   return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53                          absl::CEscape(communicator_key), "}");
54 }
55 
ToString() const56 string CollGroupParams::ToString() const {
57   string v = strings::StrCat(
58       "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59       " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60       " runtime_details=", runtime_details.ToString(), " devices {");
61   for (const auto& d : device_names) {
62     strings::StrAppend(&v, d, ",");
63   }
64   strings::StrAppend(&v, "} task_names={");
65   for (const auto& n : task_names) {
66     strings::StrAppend(&v, n, ", ");
67   }
68   strings::StrAppend(&v, "} num_devices_per_task={");
69   for (const auto& dpt : num_devices_per_task) {
70     strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
71   }
72   strings::StrAppend(&v, "}");
73   return v;
74 }
75 
operator =(const CollInstanceParams & other)76 CollInstanceParams& CollInstanceParams::operator=(
77     const CollInstanceParams& other) {
78   if (this != &other) {
79     instance_key = other.instance_key;
80     type = other.type;
81     data_type = other.data_type;
82     shape = other.shape;
83     impl_details.subdiv_offsets.assign(
84         other.impl_details.subdiv_offsets.begin(),
85         other.impl_details.subdiv_offsets.end());
86     impl_details.subdiv_permutations.clear();
87     for (auto p : other.impl_details.subdiv_permutations) {
88       impl_details.subdiv_permutations.push_back(
89           std::vector<int>(p.begin(), p.end()));
90     }
91     impl_details.subdiv_source_rank.assign(
92         other.impl_details.subdiv_source_rank.begin(),
93         other.impl_details.subdiv_source_rank.end());
94     impl_details.dependencies = other.impl_details.dependencies;
95     devices.assign(other.devices.begin(), other.devices.end());
96     permutation.assign(other.permutation.begin(), other.permutation.end());
97   }
98   return *this;
99 }
100 
ToString() const101 string CollInstanceParams::ToString() const {
102   string v =
103       strings::StrCat("CollInstanceParams { instance_key=", instance_key,
104                       " type=", type, " data_type=", DataTypeString(data_type),
105                       " shape=", shape.DebugString(), " devices {");
106   strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
107                      ", subdiv_offsets={");
108   strings::StrAppend(&v, "}, subdiv_offsets={");
109   for (const auto& d : impl_details.subdiv_offsets) {
110     strings::StrAppend(&v, d, ",");
111   }
112   strings::StrAppend(&v, "}, subdiv_perms={");
113   for (const auto& p : impl_details.subdiv_permutations) {
114     strings::StrAppend(&v, "{");
115     for (const auto& i : p) {
116       strings::StrAppend(&v, i, ",");
117     }
118     strings::StrAppend(&v, "}");  // one subdiv
119   }
120   if (!impl_details.subdiv_source_rank.empty()) {
121     strings::StrAppend(&v, " subdiv_source_rank={");
122     for (const auto& r : impl_details.subdiv_source_rank) {
123       strings::StrAppend(&v, r, ",");
124     }
125     strings::StrAppend(&v, "}");
126   }  // all subdivs
127   if (type == PERMUTE_COLLECTIVE) {
128     strings::StrAppend(&v, "}, permute_devices {");
129     for (const auto& d : devices) {
130       strings::StrAppend(&v, d, ",");
131     }
132     strings::StrAppend(&v, "}, permute_permutation {");
133     for (const auto& p : permutation) {
134       strings::StrAppend(&v, p, ",");
135     }
136     strings::StrAppend(&v, "}");
137   }
138   return v;
139 }
140 
ToString() const141 string CollTaskParams::ToString() const {
142   string v = strings::StrCat("CollTaskParams {is_local={");
143   for (const auto& b : is_local) {
144     strings::StrAppend(&v, static_cast<int>(b), ",");
145   }
146   strings::StrAppend(&v, "}}");
147   return v;
148 }
149 
ToString() const150 string CollectiveParams::ToString() const {
151   string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
152   strings::StrAppend(&v, " ", instance.ToString());
153   strings::StrAppend(&v, " ", task.ToString());
154   strings::StrAppend(&v, " default_rank=", default_rank,
155                      " is_source=", is_source, " source_rank=", source_rank,
156                      " subdiv_rank={");
157   for (const auto& r : subdiv_rank) {
158     strings::StrAppend(&v, r, ",");
159   }
160   strings::StrAppend(&v, "}}");
161   return v;
162 }
163 
CtxParams(OpKernelContext * ctx)164 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
165     OpKernelContext* ctx) {
166   return ctx->params_;
167 }
168 
CollectiveContext(CollectiveExecutor * col_exec,NcclCommunicatorInterface * nccl_communicator,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams * col_params,const string & exec_key,int64 step_id,const Tensor * input,Tensor * output)169 CollectiveContext::CollectiveContext(
170     CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
171     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
172     OpKernelContext::Params* op_params, const CollectiveParams* col_params,
173     const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
174     : col_exec(col_exec),
175       nccl_communicator(nccl_communicator),
176       dev_mgr(dev_mgr),
177       op_ctx(ctx),
178       op_params(op_params),
179       col_params(col_params),
180       exec_key(exec_key),
181       step_id(step_id),
182       input(input),
183       output(output),
184       device(nullptr),
185       device_name(col_params->group.device_names[col_params->default_rank]) {}
186 
187 /*static*/
188 int64 CollectiveExecutor::kInvalidId = -1;
189 
190 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)191 Status CollectiveRegistry::Lookup(
192     const string& collective_name,
193     CollectiveImplementationInterface** implementation) {
194   return LookupHelper(collective_name, implementation, false);
195 }
196 
197 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)198 Status CollectiveRegistry::LookupParamResolverInstance(
199     const string& collective_name,
200     CollectiveImplementationInterface** implementation) {
201   return LookupHelper(collective_name, implementation, true);
202 }
203 
204 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)205 void CollectiveRegistry::GetAll(
206     std::vector<CollectiveImplementationInterface*>* implementations) {
207   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
208   for (const RegistrationInfo& reg_info : *registry)
209     implementations->emplace_back(reg_info.factory());
210 }
211 
212 /*static*/
Register(const string & collective_name,Factory factory)213 Status CollectiveRegistry::Register(const string& collective_name,
214                                     Factory factory) {
215   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
216   for (const RegistrationInfo& reg_info : *registry) {
217     if (reg_info.name == collective_name)
218       return errors::Internal("Already registered collective ",
219                               collective_name);
220   }
221   registry->emplace_back(collective_name, std::move(factory));
222   return Status::OK();
223 }
224 
225 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)226 Status CollectiveRegistry::LookupHelper(
227     const string& collective_name,
228     CollectiveImplementationInterface** implementation, bool param_resolver) {
229   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
230   for (const RegistrationInfo& reg_info : *registry) {
231     if (reg_info.name == collective_name) {
232       if (param_resolver) {
233         *implementation = reg_info.param_resolver_instance;
234       } else {
235         *implementation = reg_info.factory();
236       }
237       return Status::OK();
238     }
239   }
240   return errors::Internal(
241       "CollectiveRegistry::Lookup did not find collective implementation ",
242       collective_name);
243 }
244 
245 }  // namespace tensorflow
246