1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
16
17 #include "tensorflow/core/common_runtime/base_collective_executor.h"
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_rma_local.h"
20 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
21 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache.h"
24 #include "tensorflow/core/lib/random/random.h"
25
26 namespace tensorflow {
27
RpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * dev_mgr,std::unique_ptr<DeviceResolverDistributed> dev_resolver,std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)28 RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
29 const ConfigProto& config, const DeviceMgr* dev_mgr,
30 std::unique_ptr<DeviceResolverDistributed> dev_resolver,
31 std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
32 std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
33 WorkerCacheInterface* worker_cache, const string& task_name)
34 : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
35 std::move(param_resolver),
36 std::move(nccl_communicator)),
37 worker_cache_(worker_cache),
38 task_name_(task_name) {
39 group_leader_ = (task_name == config.experimental().collective_group_leader())
40 ? ""
41 : config.experimental().collective_group_leader();
42 }
43
~RpcCollectiveExecutorMgr()44 RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
45 for (auto it : sequence_table_) {
46 delete it.second;
47 }
48 }
49
Create(int64 step_id)50 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
51 CollectiveRemoteAccessDistributed* rma =
52 new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
53 work_queue_, worker_cache_, step_id,
54 task_name_);
55 return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
56 &gpu_ring_order_, work_queue_);
57 }
58
59 namespace {
60 // StepId must leave the most-significant 7 bits empty for future use.
61 static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
62
NewRandomStepId()63 int64 NewRandomStepId() {
64 int64 step_id = random::New64();
65 // Leave MS 8 bits clear for future use.
66 step_id &= kStepIdMask;
67 return step_id;
68 }
69 } // namespace
70
RefreshStepIdSequenceAsync(int64 graph_key,const StatusCallback & done)71 void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
72 int64 graph_key, const StatusCallback& done) {
73 if (group_leader_.empty()) {
74 mutex_lock l(sequence_mu_);
75 GraphKeySequence* gks = nullptr;
76 auto it = sequence_table_.find(graph_key);
77 if (it == sequence_table_.end()) {
78 gks = new GraphKeySequence(graph_key);
79 sequence_table_[graph_key] = gks;
80 } else {
81 gks = it->second;
82 }
83 gks->next_step_id_ = NewRandomStepId();
84 done(Status::OK());
85 } else {
86 WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_);
87 GetStepSequenceRequest* req = new GetStepSequenceRequest;
88 GetStepSequenceResponse* resp = new GetStepSequenceResponse;
89 req->add_graph_key(graph_key);
90 wi->GetStepSequenceAsync(
91 req, resp, [this, req, resp, done](const Status& s) {
92 if (!s.ok()) {
93 LOG(ERROR) << "Bad response [" << s
94 << "] from GetStepSequenceAsync call to "
95 << group_leader_;
96 done(s);
97 } else {
98 done(UpdateStepSequences(*resp));
99 }
100 delete req;
101 delete resp;
102 });
103 }
104 }
105
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)106 void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
107 const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
108 const StatusCallback& done) {
109 if (!group_leader_.empty()) {
110 LOG(ERROR) << "GetStepSequence called at non-group-leader";
111 done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
112 } else {
113 mutex_lock l(sequence_mu_);
114 for (int64 graph_key : request->graph_key()) {
115 auto it = sequence_table_.find(graph_key);
116 GraphKeySequence* gks = nullptr;
117 if (it == sequence_table_.end()) {
118 gks = new GraphKeySequence(graph_key);
119 gks->next_step_id_ = NewRandomStepId();
120 sequence_table_[graph_key] = gks;
121 } else {
122 gks = it->second;
123 }
124 StepSequence* ss = response->add_step_sequence();
125 ss->set_graph_key(graph_key);
126 ss->set_next_step_id(gks->next_step_id_);
127 }
128 done(Status::OK());
129 }
130 }
131
UpdateStepSequences(const GetStepSequenceResponse & resp)132 Status RpcCollectiveExecutorMgr::UpdateStepSequences(
133 const GetStepSequenceResponse& resp) {
134 mutex_lock l(sequence_mu_);
135 for (const StepSequence& ss : resp.step_sequence()) {
136 GraphKeySequence* gks = nullptr;
137 auto it = sequence_table_.find(ss.graph_key());
138 if (it == sequence_table_.end()) {
139 gks = new GraphKeySequence(ss.graph_key());
140 sequence_table_[ss.graph_key()] = gks;
141 } else {
142 gks = it->second;
143 }
144 gks->next_step_id_ = ss.next_step_id();
145 }
146 return Status::OK();
147 }
148
NextStepId(int64 graph_key)149 int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
150 mutex_lock l(sequence_mu_);
151 auto it = sequence_table_.find(graph_key);
152 if (it != sequence_table_.end()) {
153 return it->second->next_step_id_;
154 }
155 return CollectiveExecutor::kInvalidId;
156 }
157
RetireStepId(int64 graph_key,int64 step_id)158 void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
159 mutex_lock l(sequence_mu_);
160 auto it = sequence_table_.find(graph_key);
161 if (it != sequence_table_.end()) {
162 if (step_id == it->second->next_step_id_) {
163 it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
164 } else {
165 it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
166 }
167 } else {
168 LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
169 }
170 }
171
172 } // namespace tensorflow
173