1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/data/service/worker_impl.h"
17 
18 #include "grpcpp/create_channel.h"
19 #include "absl/memory/memory.h"
20 #include "tensorflow/c/c_api_internal.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/core/data/dataset.pb.h"
23 #include "tensorflow/core/data/service/common.pb.h"
24 #include "tensorflow/core/data/service/credentials_factory.h"
25 #include "tensorflow/core/data/service/data_service.h"
26 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
27 #include "tensorflow/core/data/service/dispatcher.pb.h"
28 #include "tensorflow/core/data/service/grpc_util.h"
29 #include "tensorflow/core/data/service/split_provider.h"
30 #include "tensorflow/core/data/service/task_runner.h"
31 #include "tensorflow/core/data/service/utils.h"
32 #include "tensorflow/core/data/standalone.h"
33 #include "tensorflow/core/framework/tensor.pb.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 #include "tensorflow/core/lib/io/zlib_outputbuffer.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37 #include "tensorflow/core/platform/errors.h"
38 #include "tensorflow/core/platform/refcount.h"
39 #include "tensorflow/core/platform/snappy.h"
40 #include "tensorflow/core/public/session_options.h"
41 
42 namespace tensorflow {
43 namespace data {
44 
45 const constexpr uint64 kRetryIntervalMicros = 5ull * 1000 * 1000;
46 
47 namespace {
48 auto* tf_data_service_created =
49     monitoring::Gauge<bool, 0>::New("/tensorflow/data/service/created",
50                                     "Whether a tf.data service server "
51                                     "has been created.");
52 }  // namespace
53 
DataServiceWorkerImpl(const experimental::WorkerConfig & config)54 DataServiceWorkerImpl::DataServiceWorkerImpl(
55     const experimental::WorkerConfig& config)
56     : config_(config) {
57   tf_data_service_created->GetCell()->Set(true);
58 }
59 
~DataServiceWorkerImpl()60 DataServiceWorkerImpl::~DataServiceWorkerImpl() {
61   mutex_lock l(mu_);
62   cancelled_ = true;
63   task_completion_cv_.notify_one();
64   heartbeat_cv_.notify_one();
65 }
66 
Start(const std::string & worker_address,const std::string & transfer_address)67 Status DataServiceWorkerImpl::Start(const std::string& worker_address,
68                                     const std::string& transfer_address) {
69   VLOG(3) << "Starting tf.data service worker at address " << worker_address;
70   worker_address_ = worker_address;
71   transfer_address_ = transfer_address;
72 
73   dispatcher_ = absl::make_unique<DataServiceDispatcherClient>(
74       config_.dispatcher_address(), config_.protocol());
75   TF_RETURN_IF_ERROR(dispatcher_->Initialize());
76 
77   Status s = Heartbeat();
78   while (!s.ok()) {
79     if (!errors::IsUnavailable(s) && !errors::IsAborted(s) &&
80         !errors::IsCancelled(s)) {
81       return s;
82     }
83     LOG(WARNING) << "Failed to register with dispatcher at "
84                  << config_.dispatcher_address() << ": " << s;
85     Env::Default()->SleepForMicroseconds(kRetryIntervalMicros);
86     s = Heartbeat();
87   }
88   LOG(INFO) << "Worker registered with dispatcher running at "
89             << config_.dispatcher_address();
90   task_completion_thread_ = absl::WrapUnique(
91       Env::Default()->StartThread({}, "data-service-worker-task-completion",
92                                   [this]() { TaskCompletionThread(); }));
93   heartbeat_thread_ = absl::WrapUnique(Env::Default()->StartThread(
94       {}, "data-service-worker-heartbeat", [this]() { HeartbeatThread(); }));
95   mutex_lock l(mu_);
96   registered_ = true;
97   return Status::OK();
98 }
99 
ProcessTask(const ProcessTaskRequest * request,ProcessTaskResponse * response)100 Status DataServiceWorkerImpl::ProcessTask(const ProcessTaskRequest* request,
101                                           ProcessTaskResponse* response) {
102   mutex_lock l(mu_);
103   const TaskDef& task = request->task();
104   VLOG(3) << "Received request to process task " << task.task_id();
105   return ProcessTaskInternal(task);
106 }
107 
ProcessTaskInternal(const TaskDef & task_def)108 Status DataServiceWorkerImpl::ProcessTaskInternal(const TaskDef& task_def)
109     TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
110   std::unique_ptr<Task>& task = tasks_[task_def.task_id()];
111   if (task) {
112     VLOG(1) << "Received request to process already-processed task "
113             << task->task_def.task_id();
114     return Status::OK();
115   }
116   task = absl::make_unique<Task>(task_def);
117   VLOG(3) << "Began processing for task " << task_def.task_id()
118           << " with processing mode " << task_def.processing_mode();
119   return Status::OK();
120 }
121 
EnsureTaskInitialized(DataServiceWorkerImpl::Task & task)122 Status DataServiceWorkerImpl::EnsureTaskInitialized(
123     DataServiceWorkerImpl::Task& task) {
124   mutex_lock l(task.mu);
125   if (task.initialized) {
126     return Status::OK();
127   }
128   standalone::Dataset::Params params;
129   std::unique_ptr<standalone::Dataset> dataset;
130   std::unique_ptr<standalone::Iterator> iterator;
131 
132   switch (task.task_def.dataset_case()) {
133     case TaskDef::kDatasetDef:
134       TF_RETURN_IF_ERROR(standalone::Dataset::FromGraph(
135           params, task.task_def.dataset_def().graph(), &dataset));
136       break;
137     case TaskDef::kPath: {
138       DatasetDef def;
139       Status s = ReadDatasetDef(task.task_def.path(), def);
140       if (!s.ok()) {
141         LOG(INFO) << "Failed to read dataset from " << task.task_def.path()
142                   << ": " << s << ". Falling back to reading from dispatcher.";
143         TF_RETURN_IF_ERROR(
144             dispatcher_->GetDatasetDef(task.task_def.dataset_id(), def));
145       }
146       TF_RETURN_IF_ERROR(
147           standalone::Dataset::FromGraph(params, def.graph(), &dataset));
148       break;
149     }
150     case TaskDef::DATASET_NOT_SET:
151       return errors::Internal("Unrecognized dataset case: ",
152                               task.task_def.dataset_case());
153   }
154   switch (task.task_def.processing_mode()) {
155     case DISTRIBUTED_EPOCH: {
156       auto split_provider = absl::make_unique<DataServiceSplitProvider>(
157           config_.dispatcher_address(), config_.protocol(),
158           task.task_def.job_id(), config_.dispatcher_timeout_ms());
159       TF_RETURN_IF_ERROR(
160           dataset->MakeIterator(std::move(split_provider), &iterator));
161       break;
162     }
163     case PARALLEL_EPOCHS:
164       TF_RETURN_IF_ERROR(dataset->MakeIterator(&iterator));
165       break;
166     default:
167       return errors::InvalidArgument("Unrecognized processing mode: ",
168                                      task.task_def.processing_mode());
169   }
170   auto task_iterator = absl::make_unique<StandaloneTaskIterator>(
171       std::move(dataset), std::move(iterator));
172   TF_RETURN_IF_ERROR(TaskRunner::Create(task.task_def, std::move(task_iterator),
173                                         task.task_runner));
174 
175   task.initialized = true;
176   VLOG(3) << "Created iterator for task " << task.task_def.task_id();
177   return Status::OK();
178 }
179 
GetElement(const GetElementRequest * request,GetElementResponse * response)180 Status DataServiceWorkerImpl::GetElement(const GetElementRequest* request,
181                                          GetElementResponse* response) {
182   VLOG(3) << "Received GetElement request for task " << request->task_id();
183   Task* task;
184   {
185     mutex_lock l(mu_);
186     if (!registered_) {
187       // We need to reject requests until the worker has registered with the
188       // dispatcher, so that we don't return NOT_FOUND for tasks that the worker
189       // had before preemption.
190       return errors::Unavailable(
191           "Worker has not yet registered with dispatcher.");
192     }
193     auto it = tasks_.find(request->task_id());
194     if (it == tasks_.end()) {
195       if (finished_tasks_.contains(request->task_id())) {
196         VLOG(3) << "Task is already finished";
197         response->set_end_of_sequence(true);
198         return Status::OK();
199       } else {
200         // Perhaps the workers hasn't gotten the task from the dispatcher yet.
201         // Return Unavailable so that the client knows to continue retrying.
202         return errors::Unavailable("Task ", request->task_id(), " not found");
203       }
204     }
205     task = it->second.get();
206     TF_RETURN_IF_ERROR(EnsureTaskInitialized(*task));
207   }
208   TF_RETURN_IF_ERROR(task->task_runner->GetNext(*request, *response));
209   if (response->end_of_sequence()) {
210     mutex_lock l(mu_);
211     VLOG(3) << "Reached end_of_sequence for task " << request->task_id();
212     pending_completed_tasks_.insert(request->task_id());
213     task_completion_cv_.notify_one();
214   } else if (!response->skip_task()) {
215     VLOG(3) << "Producing an element for task " << request->task_id();
216   }
217 
218   return Status::OK();
219 }
220 
GetWorkerTasks(const GetWorkerTasksRequest * request,GetWorkerTasksResponse * response)221 Status DataServiceWorkerImpl::GetWorkerTasks(
222     const GetWorkerTasksRequest* request, GetWorkerTasksResponse* response) {
223   mutex_lock l(mu_);
224   for (const auto& it : tasks_) {
225     Task* task = it.second.get();
226     TaskInfo* task_info = response->add_tasks();
227     task_info->set_worker_address(worker_address_);
228     task_info->set_task_id(task->task_def.task_id());
229     task_info->set_job_id(task->task_def.job_id());
230   }
231   return Status::OK();
232 }
233 
TaskCompletionThread()234 void DataServiceWorkerImpl::TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_) {
235   while (true) {
236     {
237       mutex_lock l(mu_);
238       while (!cancelled_ && pending_completed_tasks_.empty()) {
239         task_completion_cv_.wait(l);
240       }
241       if (cancelled_) {
242         VLOG(3) << "Task completion thread shutting down";
243         return;
244       }
245     }
246     Status s = SendTaskUpdates();
247     if (!s.ok()) {
248       LOG(WARNING) << "Failed to send task updates to dispatcher: " << s;
249       mutex_lock l(mu_);
250       if (!cancelled_) {
251         task_completion_cv_.wait_for(
252             l, std::chrono::microseconds(kRetryIntervalMicros));
253       }
254     }
255   }
256 }
257 
SendTaskUpdates()258 Status DataServiceWorkerImpl::SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_) {
259   std::vector<TaskProgress> task_progress;
260   {
261     mutex_lock l(mu_);
262     VLOG(3) << "Sending " << pending_completed_tasks_.size()
263             << " task updates to dispatcher";
264     task_progress.reserve(pending_completed_tasks_.size());
265     for (int task_id : pending_completed_tasks_) {
266       task_progress.emplace_back();
267       task_progress.back().set_task_id(task_id);
268       task_progress.back().set_completed(true);
269     }
270   }
271 
272   TF_RETURN_IF_ERROR(dispatcher_->WorkerUpdate(worker_address_, task_progress));
273   mutex_lock l(mu_);
274   for (const auto& update : task_progress) {
275     pending_completed_tasks_.erase(update.task_id());
276   }
277   VLOG(3) << "Sent " << task_progress.size() << " task updates ";
278   return Status::OK();
279 }
280 
HeartbeatThread()281 void DataServiceWorkerImpl::HeartbeatThread() TF_LOCKS_EXCLUDED(mu_) {
282   while (true) {
283     int64 next_heartbeat_micros =
284         Env::Default()->NowMicros() + (config_.heartbeat_interval_ms() * 1000);
285     {
286       mutex_lock l(mu_);
287       while (!cancelled_ &&
288              Env::Default()->NowMicros() < next_heartbeat_micros) {
289         int64 time_to_wait_micros =
290             next_heartbeat_micros - Env::Default()->NowMicros();
291         heartbeat_cv_.wait_for(l,
292                                std::chrono::microseconds(time_to_wait_micros));
293       }
294       if (cancelled_) {
295         VLOG(3) << "Heartbeat thread shutting down";
296         return;
297       }
298       if (!registered_) {
299         VLOG(1) << "Not performing heartbeat; worker is not yet registered";
300         continue;
301       }
302     }
303     Status s = Heartbeat();
304     if (!s.ok()) {
305       LOG(WARNING) << "Failed to send heartbeat to dispatcher: " << s;
306     }
307   }
308 }
309 
Heartbeat()310 Status DataServiceWorkerImpl::Heartbeat() TF_LOCKS_EXCLUDED(mu_) {
311   std::vector<int64> current_tasks;
312   {
313     mutex_lock l(mu_);
314     for (const auto& task : tasks_) {
315       current_tasks.push_back(task.first);
316     }
317   }
318   std::vector<TaskDef> new_tasks;
319   std::vector<int64> tasks_to_delete;
320   TF_RETURN_IF_ERROR(
321       dispatcher_->WorkerHeartbeat(worker_address_, transfer_address_,
322                                    current_tasks, new_tasks, tasks_to_delete));
323   mutex_lock l(mu_);
324   for (const auto& task : new_tasks) {
325     Status s = ProcessTaskInternal(task);
326     if (!s.ok() && !errors::IsAlreadyExists(s)) {
327       LOG(WARNING) << "Failed to start processing task " << task.task_id()
328                    << ": " << s;
329     }
330   }
331   for (int64 task_id : tasks_to_delete) {
332     VLOG(3) << "Deleting task " << task_id
333             << " at the request of the dispatcher";
334     tasks_.erase(task_id);
335     finished_tasks_.insert(task_id);
336   }
337   return Status::OK();
338 }
339 
340 }  // namespace data
341 }  // namespace tensorflow
342