1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
16 #define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
17 
18 #include <functional>
19 #include <memory>
20 #include <set>
21 #include <string>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/framework/collective.h"
26 #include "tensorflow/core/framework/device_attributes.pb.h"
27 #include "tensorflow/core/lib/gtl/flatmap.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 
30 namespace tensorflow {
31 class CompleteGroupRequest;
32 class CompleteGroupResponse;
33 class CompleteInstanceRequest;
34 class CompleteInstanceResponse;
35 class ConfigProto;
36 class DeviceMgr;
37 
38 // Implements ParamResolverInterface for a single-task context.
39 // It also implements the functionality necessary to serve as the
40 // group leader for param resolution in a multi-task context.
41 class CollectiveParamResolverLocal : public ParamResolverInterface {
42  public:
43   CollectiveParamResolverLocal(const ConfigProto& config,
44                                const DeviceMgr* dev_mgr,
45                                DeviceResolverInterface* dev_resolver,
46                                const string& task_name);
47 
~CollectiveParamResolverLocal()48   ~CollectiveParamResolverLocal() override {}
49 
50   void CompleteParamsAsync(const DeviceAttributes& device, CollectiveParams* cp,
51                            CancellationManager* cancel_mgr,
52                            const StatusCallback& done) override;
53 
54   void CompleteGroupAsync(const CompleteGroupRequest* request,
55                           CompleteGroupResponse* response,
56                           CancellationManager* cancel_mgr,
57                           const StatusCallback& done) override;
58 
59   void CompleteInstanceAsync(const CompleteInstanceRequest* request,
60                              CompleteInstanceResponse* response,
61                              CancellationManager* cancel_mgr,
62                              const StatusCallback& done) override;
63 
64   void StartAbort(const Status& s) override;
65 
66  protected:
67   // For access to InstanceRec and CompleteDefaultRanking.
68   friend class CollectiveParamResolverLocalTest;
69 
70   // Used to complete/verify CollGroup.
71   struct GroupRec {
72     mutable mutex mu;
73     CollGroupParams group TF_GUARDED_BY(mu);
74     Status status TF_GUARDED_BY(mu);
75     std::unordered_map<string, DeviceAttributes> devices TF_GUARDED_BY(mu);
76     std::vector<StatusCallback> waiting TF_GUARDED_BY(mu);
77   };
78 
79   // Finds the GroupRec that corresponds to cp->group_key.
80   // Also populates cp->group from that group_rec.
81   // Will wait until GroupRec is fully populated or an error arises before
82   // calling done.  Callback GroupRec* arg is only valid if status is ok.
83   // Ownership of GroupRec stays with this object and does not pass to the
84   // callback.
85   typedef std::function<void(const Status& s, const GroupRec* gr)>
86       GroupRecCallback;
87   void CompleteGroupLocal(const DeviceAttributes& device, CollectiveParams* cp,
88                           const GroupRecCallback& done,
89                           CancellationManager* cancel_mgr)
90       TF_LOCKS_EXCLUDED(group_mu_);
91 
92   // Finishes the group parameters once all members of the group are there.
93   void FinishGroup(GroupRec* gr) TF_EXCLUSIVE_LOCKS_REQUIRED(gr->mu);
94 
95   // Used to complete/verify CollInstance.
96   struct InstanceRec;
97 
98   typedef std::function<void(InstanceRec*)> IRConsumer;
99   struct InstanceRec {
100     mutex mu;
101     // Values to be shared by all instances, constant after initialization.
102     CollectiveParams* shared;
103     // If an error occurs during initialization this structure stays in the
104     // table with a non-OK status. Purging the table and restarting needs to be
105     // done at a higher level.
106     Status status TF_GUARDED_BY(mu);
107 
108     // These fields are used to count the instances that have called
109     // in and become known while resolving broadcast source identity and
110     // communicator key.
111     int source_rank TF_GUARDED_BY(mu);
112     string communicator_key TF_GUARDED_BY(mu);
113     int known_count TF_GUARDED_BY(mu);
114     std::vector<bool> known TF_GUARDED_BY(mu);
115     std::vector<IRConsumer> known_waiters TF_GUARDED_BY(mu);
116 
InstanceRecInstanceRec117     InstanceRec()
118         : shared(new CollectiveParams()), source_rank(-1), known_count(0) {}
~InstanceRecInstanceRec119     ~InstanceRec() { shared->Unref(); }
120   };
121 
122   // Find the InstanceRec with the same instance_key as cp.  If it doesn't
123   // already exist, create and initialize from gr and cp.
124   // created is set to true if a new IRec is created, false otherwise.
125   //
126   // Precondition: *gr must be a complete GroupRec, i.e. the value set
127   // by CompleteGroupLocal. *cp must be populated with all the fields
128   // required by InitInstanceSharedParams.  Ownership of InstanceRec stays
129   // with this object and does not pass to the callback.
130   InstanceRec* GetOrCreateInstanceRec(const GroupRec* gr, CollectiveParams* cp,
131                                       bool* created)
132       TF_LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
133 
134   // Populate *ir with device membership from gr, then initialize to be specific
135   // to cp->instance_key, i.e. order the devices and tasks.
136   //
137   // Preconditions:
138   //  cp is populated with all DeviceLocalities
139   void InitInstanceSharedParams(const GroupRec* gr, const CollectiveParams* cp,
140                                 InstanceRec* ir) TF_LOCKS_EXCLUDED(gr->mu);
141 
142   // Establishes the final order of gp->device_names and gp->task_names by
143   // considering localities of all devices.
144   void CompleteDefaultRanking(const std::vector<DeviceAttributes>& attributes,
145                               CollGroupParams* gp);
146 
147   // Finish populating *cp.
148   // Precondition: *gr has been fully populated by CompleteGroupLocal.
149   void CompleteInstanceLocal(const string& device, const GroupRec* gr,
150                              CollectiveParams* cp, bool is_source,
151                              const StatusCallback& done)
152       TF_LOCKS_EXCLUDED(instance_mu_, gr->mu, group_mu_);
153 
154   // Finish populating *cp from fully initialized *ir.
155   // Precondition: *gr and *ir are fully populated.
156   void CompleteInstanceFromInitializedIRec(const string& device,
157                                            const GroupRec* gr,
158                                            CollectiveParams* cp,
159                                            InstanceRec* ir, bool is_source,
160                                            const StatusCallback& done)
161       TF_LOCKS_EXCLUDED(ir->mu);
162 
163   // Complete instance params after waiting for group.
164   // Precondition: *cp has complete group data and default_rank.
165   void WaitForGroup(InstanceRec* ir, CollectiveParams* cp, bool is_source,
166                     const IRConsumer& f) TF_LOCKS_EXCLUDED(ir->mu);
167 
168   // If cp.device_names contains only devices local to this process
169   // populates *localities, else returns an error.
170   Status GetLocalDeviceLocalities(const CollectiveParams& cp,
171                                   std::vector<DeviceLocality>* localities);
172 
173   // Sets CollTaskParams.is_local and CollectiveParams.default_rank.
174   // Precondition: cp->device_names is fully populated and in final order.
175   void CompleteTaskIsLocal(const string& task_name, CollectiveParams* cp);
176 
177   // Sets cp->instance_default_rank according to location of device in
178   // current ordering of cp->instance.device_names.
179   void SetDefaultRank(const string& device, CollectiveParams* cp);
180 
181   // Sets cp->instance.type based on collective op type, and attempts to assign
182   // best implementation.
183   void AssignCollectiveType(CollectiveParams* cp);
184 
185   void StartAbortLocal(const Status& s)
186       TF_LOCKS_EXCLUDED(status_mu_, group_mu_, instance_mu_);
187 
188   const bool nccl_;
189   const DeviceMgr* dev_mgr_;
190   DeviceResolverInterface* dev_resolver_;  // Not owned.
191   string task_name_;
192   mutex group_mu_;
193   gtl::FlatMap<int32, std::unique_ptr<GroupRec>> group_table_
194       TF_GUARDED_BY(group_mu_);
195   mutex instance_mu_;
196   gtl::FlatMap<int32, gtl::FlatMap<int32, std::unique_ptr<InstanceRec>>>
197       instance_table_ TF_GUARDED_BY(instance_mu_);
198   mutex status_mu_;
199   Status status_ TF_GUARDED_BY(status_mu_);
200 };
201 
202 }  // namespace tensorflow
203 
204 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
205