1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16 
17 #include "tensorflow/core/framework/op_kernel.h"
18 #include "tensorflow/core/lib/core/errors.h"
19 #include "tensorflow/core/lib/hash/hash.h"
20 #include "tensorflow/core/lib/strings/str_util.h"
21 #include "tensorflow/core/lib/strings/strcat.h"
22 
23 namespace tensorflow {
24 
25 namespace {
26 // A RegistrationInfo object stores a collective implementation registration
27 // details.  `factory` is used to create instances of the collective
28 // implementation.
29 struct RegistrationInfo {
30   // This constructor also creates, and stores in `param_resolver_instance`,
31   // what is effectively a static instance of the collective implementation.
32   // During param resolution of collective ops we return this static instance.
33   // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anon07f350080111::RegistrationInfo34   RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
35       : name(n),
36         factory(std::move(f)),
37         param_resolver_instance(this->factory()) {}
38   string name;
39   CollectiveRegistry::Factory factory;
40   CollectiveImplementationInterface* param_resolver_instance;
41 };
42 
MutableCollectiveRegistry()43 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
44   static std::vector<RegistrationInfo>* registry =
45       new std::vector<RegistrationInfo>;
46   return registry;
47 }
48 }  // namespace
49 
ToString() const50 string CollGroupParams::ToString() const {
51   return strings::StrCat("CollGroupParams {group_key=", group_key,
52                          " group_size=", group_size,
53                          " device_type=", device_type.type_string(),
54                          " num_tasks=", num_tasks, "}");
55 }
56 
operator =(const CollInstanceParams & other)57 CollInstanceParams& CollInstanceParams::operator=(
58     const CollInstanceParams& other) {
59   if (this != &other) {
60     instance_key = other.instance_key;
61     type = other.type;
62     data_type = other.data_type;
63     shape = other.shape;
64     device_names.clear();
65     device_names.assign(other.device_names.begin(), other.device_names.end());
66     task_names.assign(other.task_names.begin(), other.task_names.end());
67     same_num_devices_per_task = other.same_num_devices_per_task;
68     num_devices_per_task = other.num_devices_per_task;
69     gpu_ring_order = other.gpu_ring_order;
70     communicator_key = other.communicator_key;
71     impl_details.subdiv_offsets.assign(
72         other.impl_details.subdiv_offsets.begin(),
73         other.impl_details.subdiv_offsets.end());
74     impl_details.subdiv_permutations.clear();
75     for (auto p : other.impl_details.subdiv_permutations) {
76       impl_details.subdiv_permutations.push_back(
77           std::vector<int>(p.begin(), p.end()));
78     }
79     impl_details.subdiv_source_rank.assign(
80         other.impl_details.subdiv_source_rank.begin(),
81         other.impl_details.subdiv_source_rank.end());
82     impl_details.dependencies = other.impl_details.dependencies;
83   }
84   return *this;
85 }
86 
ToString() const87 string CollInstanceParams::ToString() const {
88   string v = strings::StrCat("CollInstanceParams { instance_key=", instance_key,
89                              " type=", type, " data_type=", data_type,
90                              " shape=", shape.DebugString(), " devices {");
91   for (const auto& d : device_names) {
92     strings::StrAppend(&v, d, ",");
93   }
94   strings::StrAppend(&v, "} task_names={");
95   for (const auto& n : task_names) {
96     strings::StrAppend(&v, n, ", ");
97   }
98   strings::StrAppend(&v, "} num_devices_per_task={");
99   for (const auto dpt : num_devices_per_task) {
100     strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
101   }
102   strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
103                      ", communicator_key=", str_util::CEscape(communicator_key),
104                      ", subdiv_offsets={");
105   strings::StrAppend(&v, "}, subdiv_offsets={");
106   for (const auto& d : impl_details.subdiv_offsets) {
107     strings::StrAppend(&v, d, ",");
108   }
109   strings::StrAppend(&v, "}, subdiv_perms={");
110   for (const auto& p : impl_details.subdiv_permutations) {
111     strings::StrAppend(&v, "{");
112     for (const auto& i : p) {
113       strings::StrAppend(&v, i, ",");
114     }
115     strings::StrAppend(&v, "}");  // one subdiv
116   }
117   if (!impl_details.subdiv_source_rank.empty()) {
118     strings::StrAppend(&v, " subdiv_source_rank={");
119     for (const auto& r : impl_details.subdiv_source_rank) {
120       strings::StrAppend(&v, r, ",");
121     }
122     strings::StrAppend(&v, "}");
123   }
124   strings::StrAppend(&v, "}");  // all subdivs
125   return v;
126 }
127 
ToString() const128 string CollTaskParams::ToString() const {
129   string v = strings::StrCat("CollTaskParams {is_local={");
130   for (const auto& b : is_local) {
131     strings::StrAppend(&v, static_cast<int>(b), ",");
132   }
133   strings::StrAppend(&v, "}}");
134   return v;
135 }
136 
ToString() const137 string CollectiveParams::ToString() const {
138   string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
139   strings::StrAppend(&v, " ", instance.ToString());
140   strings::StrAppend(&v, " ", task.ToString());
141   strings::StrAppend(&v, " default_rank=", default_rank,
142                      " is_source=", is_source, " source_rank=", source_rank,
143                      " subdiv_rank={");
144   for (const auto& r : subdiv_rank) {
145     strings::StrAppend(&v, r, ",");
146   }
147   strings::StrAppend(&v, "}}");
148   return v;
149 }
150 
CtxParams(OpKernelContext * ctx)151 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
152     OpKernelContext* ctx) {
153   return ctx->params_;
154 }
155 
CollectiveContext(CollectiveExecutor * col_exec,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams & col_params,const string & exec_key,int64 step_id,const Tensor * input,Tensor * output)156 CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
157                                      const DeviceMgr* dev_mgr,
158                                      OpKernelContext* ctx,
159                                      OpKernelContext::Params* op_params,
160                                      const CollectiveParams& col_params,
161                                      const string& exec_key, int64 step_id,
162                                      const Tensor* input, Tensor* output)
163     : col_exec(col_exec),
164       dev_mgr(dev_mgr),
165       op_ctx(ctx),
166       op_params(op_params),
167       col_params(col_params),
168       exec_key(exec_key),
169       step_id(step_id),
170       input(input),
171       output(output),
172       device(nullptr),
173       device_name(col_params.instance.device_names[col_params.default_rank]) {}
174 
175 /*static*/
176 int64 CollectiveExecutor::kInvalidId = -1;
177 
178 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)179 Status CollectiveRegistry::Lookup(
180     const string& collective_name,
181     CollectiveImplementationInterface** implementation) {
182   return LookupHelper(collective_name, implementation, false);
183 }
184 
185 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)186 Status CollectiveRegistry::LookupParamResolverInstance(
187     const string& collective_name,
188     CollectiveImplementationInterface** implementation) {
189   return LookupHelper(collective_name, implementation, true);
190 }
191 
192 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)193 void CollectiveRegistry::GetAll(
194     std::vector<CollectiveImplementationInterface*>* implementations) {
195   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
196   for (const RegistrationInfo& reg_info : *registry)
197     implementations->emplace_back(reg_info.factory());
198 }
199 
200 /*static*/
Register(const string & collective_name,Factory factory)201 Status CollectiveRegistry::Register(const string& collective_name,
202                                     Factory factory) {
203   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
204   for (const RegistrationInfo& reg_info : *registry) {
205     if (reg_info.name == collective_name)
206       return errors::Internal("Already registered collective ",
207                               collective_name);
208   }
209   registry->emplace_back(collective_name, std::move(factory));
210   return Status::OK();
211 }
212 
213 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)214 Status CollectiveRegistry::LookupHelper(
215     const string& collective_name,
216     CollectiveImplementationInterface** implementation, bool param_resolver) {
217   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
218   for (const RegistrationInfo& reg_info : *registry) {
219     if (reg_info.name == collective_name) {
220       if (param_resolver) {
221         *implementation = reg_info.param_resolver_instance;
222       } else {
223         *implementation = reg_info.factory();
224       }
225       return Status::OK();
226     }
227   }
228   return errors::Internal(
229       "CollectiveRegistry::Lookup did not find collective implementation ",
230       collective_name);
231 }
232 
233 }  // namespace tensorflow
234