1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/distributed_runtime/rpc_collective_executor_mgr.h"
16 
17 #include "tensorflow/core/common_runtime/base_collective_executor.h"
18 #include "tensorflow/core/common_runtime/collective_executor_mgr.h"
19 #include "tensorflow/core/common_runtime/collective_rma_local.h"
20 #include "tensorflow/core/distributed_runtime/collective_param_resolver_distributed.h"
21 #include "tensorflow/core/distributed_runtime/collective_rma_distributed.h"
22 #include "tensorflow/core/distributed_runtime/device_resolver_distributed.h"
23 #include "tensorflow/core/distributed_runtime/worker_cache.h"
24 #include "tensorflow/core/lib/random/random.h"
25 
26 namespace tensorflow {
27 
RpcCollectiveExecutorMgr(const ConfigProto & config,const DeviceMgr * dev_mgr,std::unique_ptr<DeviceResolverDistributed> dev_resolver,std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,WorkerCacheInterface * worker_cache,const string & task_name)28 RpcCollectiveExecutorMgr::RpcCollectiveExecutorMgr(
29     const ConfigProto& config, const DeviceMgr* dev_mgr,
30     std::unique_ptr<DeviceResolverDistributed> dev_resolver,
31     std::unique_ptr<CollectiveParamResolverDistributed> param_resolver,
32     std::unique_ptr<NcclCommunicatorInterface> nccl_communicator,
33     WorkerCacheInterface* worker_cache, const string& task_name)
34     : CollectiveExecutorMgr(config, dev_mgr, std::move(dev_resolver),
35                             std::move(param_resolver),
36                             std::move(nccl_communicator)),
37       worker_cache_(worker_cache),
38       task_name_(task_name) {
39   group_leader_ = (task_name == config.experimental().collective_group_leader())
40                       ? ""
41                       : config.experimental().collective_group_leader();
42 }
43 
~RpcCollectiveExecutorMgr()44 RpcCollectiveExecutorMgr::~RpcCollectiveExecutorMgr() {
45   for (auto it : sequence_table_) {
46     delete it.second;
47   }
48 }
49 
Create(int64 step_id)50 CollectiveExecutor* RpcCollectiveExecutorMgr::Create(int64 step_id) {
51   CollectiveRemoteAccessDistributed* rma =
52       new CollectiveRemoteAccessDistributed(dev_mgr_, dev_resolver_.get(),
53                                             work_queue_, worker_cache_, step_id,
54                                             task_name_);
55   return new BaseCollectiveExecutor(this, rma, step_id, dev_mgr_,
56                                     &gpu_ring_order_, work_queue_);
57 }
58 
59 namespace {
60 // StepId must leave the most-significant 7 bits empty for future use.
61 static const int64 kStepIdMask = (((1uLL << 56) - 1) | (1uLL << 56));
62 
NewRandomStepId()63 int64 NewRandomStepId() {
64   int64 step_id = random::New64();
65   // Leave MS 8 bits clear for future use.
66   step_id &= kStepIdMask;
67   return step_id;
68 }
69 }  // namespace
70 
RefreshStepIdSequenceAsync(int64 graph_key,const StatusCallback & done)71 void RpcCollectiveExecutorMgr::RefreshStepIdSequenceAsync(
72     int64 graph_key, const StatusCallback& done) {
73   if (group_leader_.empty()) {
74     mutex_lock l(sequence_mu_);
75     GraphKeySequence* gks = nullptr;
76     auto it = sequence_table_.find(graph_key);
77     if (it == sequence_table_.end()) {
78       gks = new GraphKeySequence(graph_key);
79       sequence_table_[graph_key] = gks;
80     } else {
81       gks = it->second;
82     }
83     gks->next_step_id_ = NewRandomStepId();
84     done(Status::OK());
85   } else {
86     WorkerInterface* wi = worker_cache_->GetOrCreateWorker(group_leader_);
87     GetStepSequenceRequest* req = new GetStepSequenceRequest;
88     GetStepSequenceResponse* resp = new GetStepSequenceResponse;
89     req->add_graph_key(graph_key);
90     wi->GetStepSequenceAsync(
91         req, resp, [this, req, resp, done](const Status& s) {
92           if (!s.ok()) {
93             LOG(ERROR) << "Bad response [" << s
94                        << "] from GetStepSequenceAsync call to "
95                        << group_leader_;
96             done(s);
97           } else {
98             done(UpdateStepSequences(*resp));
99           }
100           delete req;
101           delete resp;
102         });
103   }
104 }
105 
GetStepSequenceAsync(const GetStepSequenceRequest * request,GetStepSequenceResponse * response,const StatusCallback & done)106 void RpcCollectiveExecutorMgr::GetStepSequenceAsync(
107     const GetStepSequenceRequest* request, GetStepSequenceResponse* response,
108     const StatusCallback& done) {
109   if (!group_leader_.empty()) {
110     LOG(ERROR) << "GetStepSequence called at non-group-leader";
111     done(errors::Internal("GetStepSequenceAsync called at non-group-leader"));
112   } else {
113     mutex_lock l(sequence_mu_);
114     for (int64 graph_key : request->graph_key()) {
115       auto it = sequence_table_.find(graph_key);
116       GraphKeySequence* gks = nullptr;
117       if (it == sequence_table_.end()) {
118         gks = new GraphKeySequence(graph_key);
119         gks->next_step_id_ = NewRandomStepId();
120         sequence_table_[graph_key] = gks;
121       } else {
122         gks = it->second;
123       }
124       StepSequence* ss = response->add_step_sequence();
125       ss->set_graph_key(graph_key);
126       ss->set_next_step_id(gks->next_step_id_);
127     }
128     done(Status::OK());
129   }
130 }
131 
UpdateStepSequences(const GetStepSequenceResponse & resp)132 Status RpcCollectiveExecutorMgr::UpdateStepSequences(
133     const GetStepSequenceResponse& resp) {
134   mutex_lock l(sequence_mu_);
135   for (const StepSequence& ss : resp.step_sequence()) {
136     GraphKeySequence* gks = nullptr;
137     auto it = sequence_table_.find(ss.graph_key());
138     if (it == sequence_table_.end()) {
139       gks = new GraphKeySequence(ss.graph_key());
140       sequence_table_[ss.graph_key()] = gks;
141     } else {
142       gks = it->second;
143     }
144     gks->next_step_id_ = ss.next_step_id();
145   }
146   return Status::OK();
147 }
148 
NextStepId(int64 graph_key)149 int64 RpcCollectiveExecutorMgr::NextStepId(int64 graph_key) {
150   mutex_lock l(sequence_mu_);
151   auto it = sequence_table_.find(graph_key);
152   if (it != sequence_table_.end()) {
153     return it->second->next_step_id_;
154   }
155   return CollectiveExecutor::kInvalidId;
156 }
157 
RetireStepId(int64 graph_key,int64 step_id)158 void RpcCollectiveExecutorMgr::RetireStepId(int64 graph_key, int64 step_id) {
159   mutex_lock l(sequence_mu_);
160   auto it = sequence_table_.find(graph_key);
161   if (it != sequence_table_.end()) {
162     if (step_id == it->second->next_step_id_) {
163       it->second->next_step_id_ = (it->second->next_step_id_ + 1) & kStepIdMask;
164     } else {
165       it->second->next_step_id_ = CollectiveExecutor::kInvalidId;
166     }
167   } else {
168     LOG(ERROR) << "Failed to find graph_key " << graph_key << " to retire.";
169   }
170 }
171 
172 }  // namespace tensorflow
173