1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/data/service/worker_impl.h"
17
18 #include "grpcpp/create_channel.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/core/data/dataset.pb.h"
23 #include "tensorflow/core/data/service/common.pb.h"
24 #include "tensorflow/core/data/service/credentials_factory.h"
25 #include "tensorflow/core/data/service/data_service.h"
26 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
27 #include "tensorflow/core/data/service/dispatcher.pb.h"
28 #include "tensorflow/core/data/service/grpc_util.h"
29 #include "tensorflow/core/data/service/split_provider.h"
30 #include "tensorflow/core/data/service/task_runner.h"
31 #include "tensorflow/core/data/service/utils.h"
32 #include "tensorflow/core/data/standalone.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/refcount.h"
39 #include "tensorflow/core/platform/snappy.h"
40 #include "tensorflow/core/public/session_options.h"
41
42 namespace tensorflow {
43 namespace data {
44
45 const constexpr uint64 kRetryIntervalMicros = 5ull * 1000 * 1000;
46
47 namespace {
48 auto* tf_data_service_created =
49 monitoring::Gauge<bool, 0>::New("/tensorflow/data/service/created",
50 "Whether a tf.data service server "
51 "has been created.");
52 } // namespace
53
DataServiceWorkerImpl(const experimental::WorkerConfig & config)54 DataServiceWorkerImpl::DataServiceWorkerImpl(
55 const experimental::WorkerConfig& config)
56 : config_(config) {
57 tf_data_service_created->GetCell()->Set(true);
58 }
59
~DataServiceWorkerImpl()60 DataServiceWorkerImpl::~DataServiceWorkerImpl() {
61 mutex_lock l(mu_);
62 cancelled_ = true;
63 task_completion_cv_.notify_one();
64 heartbeat_cv_.notify_one();
65 }
66
Start(const std::string & worker_address,const std::string & transfer_address)67 Status DataServiceWorkerImpl::Start(const std::string& worker_address,
68 const std::string& transfer_address) {
69 VLOG(3) << "Starting tf.data service worker at address " << worker_address;
70 worker_address_ = worker_address;
71 transfer_address_ = transfer_address;
72
73 dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
74 config_.dispatcher_address(), config_.protocol());
75 TF_RETURN_IF_ERROR(dispatcher_->Initialize());
76
77 Status s = Heartbeat();
78 while (!s.ok()) {
79 if (!errors::IsUnavailable(s) && !errors::IsAborted(s) &&
80 !errors::IsCancelled(s)) {
81 return s;
82 }
83 LOG(WARNING) << "Failed to register with dispatcher at "
84 << config_.dispatcher_address() << ": " << s;
85 Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
86 s = Heartbeat();
87 }
88 LOG(INFO) << "Worker registered with dispatcher running at "
89 << config_.dispatcher_address();
90 task_completion_thread_ = absl::WrapUnique(
91 Env::Default()->StartThread({}, "data-service-worker-task-completion",
92 [this]() { TaskCompletionThread(); }));
93 heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
94 {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
95 mutex_lock l(mu_);
96 registered_ = true;
97 return Status::OK();
98 }
99
ProcessTask(const ProcessTaskRequest * request,ProcessTaskResponse * response)100 Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
101 ProcessTaskResponse* response) {
102 mutex_lock l(mu_);
103 const TaskDef& task = request->task();
104 VLOG(3) << "Received request to process task " << task.task_id();
105 return ProcessTaskInternal(task);
106 }
107
ProcessTaskInternal(const TaskDef & task_def)108 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
109 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
110 std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
111 if (task) {
112 VLOG(1) << "Received request to process already-processed task "
113 << task->task_def.task_id();
114 return Status::OK();
115 }
116 task = absl::make_unique<Task>(task_def);
117 VLOG(3) << "Began processing for task " << task_def.task_id()
118 << " with processing mode " << task_def.processing_mode();
119 return Status::OK();
120 }
121
EnsureTaskInitialized(DataServiceWorkerImpl::Task & task)122 Status DataServiceWorkerImpl::EnsureTaskInitialized(
123 DataServiceWorkerImpl::Task& task) {
124 mutex_lock l(task.mu);
125 if (task.initialized) {
126 return Status::OK();
127 }
128 standalone::Dataset::Params params;
129 std::unique_ptr<standalone::Dataset> dataset;
130 std::unique_ptr<standalone::Iterator> iterator;
131
132 switch (task.task_def.dataset_case()) {
133 case TaskDef::kDatasetDef:
134 TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
135 params, task.task_def.dataset_def().graph(), &dataset));
136 break;
137 case TaskDef::kPath: {
138 DatasetDef def;
139 Status s = ReadDatasetDef(task.task_def.path(), def);
140 if (!s.ok()) {
141 LOG(INFO) << "Failed to read dataset from " << task.task_def.path()
142 << ": " << s << ". Falling back to reading from dispatcher.";
143 TF_RETURN_IF_ERROR(
144 dispatcher_->GetDatasetDef(task.task_def.dataset_id(), def));
145 }
146 TF_RETURN_IF_ERROR(
147 standalone::Dataset::FromGraph(params, def.graph(), &dataset));
148 break;
149 }
150 case TaskDef::DATASET_NOT_SET:
151 return errors::Internal("Unrecognized dataset case: ",
152 task.task_def.dataset_case());
153 }
154 switch (task.task_def.processing_mode()) {
155 case DISTRIBUTED_EPOCH: {
156 auto split_provider = absl::make_unique<DataServiceSplitProvider>(
157 config_.dispatcher_address(), config_.protocol(),
158 task.task_def.job_id(), config_.dispatcher_timeout_ms());
159 TF_RETURN_IF_ERROR(
160 dataset->MakeIterator(std::move(split_provider), &iterator));
161 break;
162 }
163 case PARALLEL_EPOCHS:
164 TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator));
165 break;
166 default:
167 return errors::InvalidArgument("Unrecognized processing mode: ",
168 task.task_def.processing_mode());
169 }
170 auto task_iterator = absl::make_unique<StandaloneTaskIterator>(
171 std::move(dataset), std::move(iterator));
172 TF_RETURN_IF_ERROR(TaskRunner::Create(task.task_def, std::move(task_iterator),
173 task.task_runner));
174
175 task.initialized = true;
176 VLOG(3) << "Created iterator for task " << task.task_def.task_id();
177 return Status::OK();
178 }
179
GetElement(const GetElementRequest * request,GetElementResponse * response)180 Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
181 GetElementResponse* response) {
182 VLOG(3) << "Received GetElement request for task " << request->task_id();
183 Task* task;
184 {
185 mutex_lock l(mu_);
186 if (!registered_) {
187 // We need to reject requests until the worker has registered with the
188 // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
189 // had before preemption.
190 return errors::Unavailable(
191 "Worker has not yet registered with dispatcher.");
192 }
193 auto it = tasks_.find(request->task_id());
194 if (it == tasks_.end()) {
195 if (finished_tasks_.contains(request->task_id())) {
196 VLOG(3) << "Task is already finished";
197 response->set_end_of_sequence(true);
198 return Status::OK();
199 } else {
200 // Perhaps the workers hasn't gotten the task from the dispatcher yet.
201 // Return Unavailable so that the client knows to continue retrying.
202 return errors::Unavailable("Task ", request->task_id(), " not found");
203 }
204 }
205 task = it->second.get();
206 TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
207 }
208 TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *response));
209 if (response->end_of_sequence()) {
210 mutex_lock l(mu_);
211 VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
212 pending_completed_tasks_.insert(request->task_id());
213 task_completion_cv_.notify_one();
214 } else if (!response->skip_task()) {
215 VLOG(3) << "Producing an element for task " << request->task_id();
216 }
217
218 return Status::OK();
219 }
220
GetWorkerTasks(const GetWorkerTasksRequest * request,GetWorkerTasksResponse * response)221 Status DataServiceWorkerImpl::GetWorkerTasks(
222 const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
223 mutex_lock l(mu_);
224 for (const auto& it : tasks_) {
225 Task* task = it.second.get();
226 TaskInfo* task_info = response->add_tasks();
227 task_info->set_worker_address(worker_address_);
228 task_info->set_task_id(task->task_def.task_id());
229 task_info->set_job_id(task->task_def.job_id());
230 }
231 return Status::OK();
232 }
233
TaskCompletionThread()234 void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) {
235 while (true) {
236 {
237 mutex_lock l(mu_);
238 while (!cancelled_ && pending_completed_tasks_.empty()) {
239 task_completion_cv_.wait(l);
240 }
241 if (cancelled_) {
242 VLOG(3) << "Task completion thread shutting down";
243 return;
244 }
245 }
246 Status s = SendTaskUpdates();
247 if (!s.ok()) {
248 LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
249 mutex_lock l(mu_);
250 if (!cancelled_) {
251 task_completion_cv_.wait_for(
252 l, std::chrono::microseconds(kRetryIntervalMicros));
253 }
254 }
255 }
256 }
257
SendTaskUpdates()258 Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) {
259 std::vector<TaskProgress> task_progress;
260 {
261 mutex_lock l(mu_);
262 VLOG(3) << "Sending " << pending_completed_tasks_.size()
263 << " task updates to dispatcher";
264 task_progress.reserve(pending_completed_tasks_.size());
265 for (int task_id : pending_completed_tasks_) {
266 task_progress.emplace_back();
267 task_progress.back().set_task_id(task_id);
268 task_progress.back().set_completed(true);
269 }
270 }
271
272 TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
273 mutex_lock l(mu_);
274 for (const auto& update : task_progress) {
275 pending_completed_tasks_.erase(update.task_id());
276 }
277 VLOG(3) << "Sent " << task_progress.size() << " task updates ";
278 return Status::OK();
279 }
280
HeartbeatThread()281 void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) {
282 while (true) {
283 int64 next_heartbeat_micros =
284 Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
285 {
286 mutex_lock l(mu_);
287 while (!cancelled_ &&
288 Env::Default()->NowMicros() < next_heartbeat_micros) {
289 int64 time_to_wait_micros =
290 next_heartbeat_micros - Env::Default()->NowMicros();
291 heartbeat_cv_.wait_for(l,
292 std::chrono::microseconds(time_to_wait_micros));
293 }
294 if (cancelled_) {
295 VLOG(3) << "Heartbeat thread shutting down";
296 return;
297 }
298 if (!registered_) {
299 VLOG(1) << "Not performing heartbeat; worker is not yet registered";
300 continue;
301 }
302 }
303 Status s = Heartbeat();
304 if (!s.ok()) {
305 LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
306 }
307 }
308 }
309
Heartbeat()310 Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
311 std::vector<int64> current_tasks;
312 {
313 mutex_lock l(mu_);
314 for (const auto& task : tasks_) {
315 current_tasks.push_back(task.first);
316 }
317 }
318 std::vector<TaskDef> new_tasks;
319 std::vector<int64> tasks_to_delete;
320 TF_RETURN_IF_ERROR(
321 dispatcher_->WorkerHeartbeat(worker_address_, transfer_address_,
322 current_tasks, new_tasks, tasks_to_delete));
323 mutex_lock l(mu_);
324 for (const auto& task : new_tasks) {
325 Status s = ProcessTaskInternal(task);
326 if (!s.ok() && !errors::IsAlreadyExists(s)) {
327 LOG(WARNING) << "Failed to start processing task " << task.task_id()
328 << ": " << s;
329 }
330 }
331 for (int64 task_id : tasks_to_delete) {
332 VLOG(3) << "Deleting task " << task_id
333 << " at the request of the dispatcher";
334 tasks_.erase(task_id);
335 finished_tasks_.insert(task_id);
336 }
337 return Status::OK();
338 }
339
340 } // namespace data
341 } // namespace tensorflow
342