1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 16 #define TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 17 18 #include "absl/container/flat_hash_map.h" 19 #include "absl/container/flat_hash_set.h" 20 #include "tensorflow/core/data/service/common.pb.h" 21 #include "tensorflow/core/data/service/data_service.h" 22 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h" 23 #include "tensorflow/core/data/service/task_runner.h" 24 #include "tensorflow/core/data/service/worker.pb.h" 25 #include "tensorflow/core/data/standalone.h" 26 #include "tensorflow/core/lib/core/status.h" 27 #include "tensorflow/core/protobuf/service_config.pb.h" 28 #include "tensorflow/core/public/session.h" 29 30 namespace tensorflow { 31 namespace data { 32 33 // A TensorFlow DataService serves dataset elements over RPC. 34 class DataServiceWorkerImpl { 35 public: 36 explicit DataServiceWorkerImpl(const experimental::WorkerConfig& config); 37 ~DataServiceWorkerImpl(); 38 39 // Starts the worker. The worker needs to know its own address so that it can 40 // register with the dispatcher. This is set in `Start` instead of in the 41 // constructor because the worker may be binding to port `0`, in which case 42 // the address isn't known until the worker has started and decided which port 43 // to bind to. 44 Status Start(const std::string& worker_address, 45 const std::string& transfer_address); 46 47 // See worker.proto for API documentation. 48 49 /// Dispatcher-facing API. 50 Status ProcessTask(const ProcessTaskRequest* request, 51 ProcessTaskResponse* response); 52 53 /// Client-facing API. 54 Status GetElement(const GetElementRequest* request, 55 GetElementResponse* response); 56 Status GetWorkerTasks(const GetWorkerTasksRequest* request, 57 GetWorkerTasksResponse* response); 58 59 private: 60 struct Task { TaskTask61 explicit Task(TaskDef task_def) : task_def(std::move(task_def)) {} 62 63 TaskDef task_def; 64 mutex mu; 65 bool initialized TF_GUARDED_BY(mu) = false; 66 std::unique_ptr<TaskRunner> task_runner; 67 }; 68 69 // Sends task status to the dispatcher and checks for dispatcher commands. 70 Status SendTaskUpdates() TF_LOCKS_EXCLUDED(mu_); 71 // Creates an iterator to process a task. 72 Status ProcessTaskInternal(const TaskDef& task) 73 TF_EXCLUSIVE_LOCKS_REQUIRED(mu_); 74 Status EnsureTaskInitialized(Task& task); 75 // A thread for notifying the dispatcher when tasks complete. 76 void TaskCompletionThread() TF_LOCKS_EXCLUDED(mu_); 77 // A thread for doing periodic heartbeats to the dispatcher. 78 void HeartbeatThread() TF_LOCKS_EXCLUDED(mu_); 79 // Performs a heartbeat to the dispatcher. 80 Status Heartbeat() TF_LOCKS_EXCLUDED(mu_); 81 82 const experimental::WorkerConfig config_; 83 // The worker's own address. 84 std::string worker_address_; 85 std::string transfer_address_; 86 std::unique_ptr<DataServiceDispatcherClient> dispatcher_; 87 88 mutex mu_; 89 // Information about tasks, keyed by task ids. 90 absl::flat_hash_map<int64, std::unique_ptr<Task>> tasks_ TF_GUARDED_BY(mu_); 91 // Ids of tasks that have finished. 92 absl::flat_hash_set<int64> finished_tasks_ TF_GUARDED_BY(mu_); 93 // Completed tasks which haven't yet been communicated to the dispatcher. 94 absl::flat_hash_set<int64> pending_completed_tasks_ TF_GUARDED_BY(mu_); 95 bool cancelled_ TF_GUARDED_BY(mu_) = false; 96 // Whether the worker has registered with the dispatcher yet. 97 bool registered_ TF_GUARDED_BY(mu_) = false; 98 // A thread for notifying the dispatcher when tasks complete. 99 std::unique_ptr<Thread> task_completion_thread_; 100 condition_variable task_completion_cv_ TF_GUARDED_BY(mu_); 101 // A thread for performing regular heartbeats to the dispatcher. 102 std::unique_ptr<Thread> heartbeat_thread_; 103 condition_variable heartbeat_cv_ TF_GUARDED_BY(mu_); 104 105 TF_DISALLOW_COPY_AND_ASSIGN(DataServiceWorkerImpl); 106 }; 107 108 } // namespace data 109 } // namespace tensorflow 110 111 #endif // TENSORFLOW_CORE_DATA_SERVICE_WORKER_IMPL_H_ 112