1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
17 #define TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
18 
19 #include "grpcpp/impl/codegen/client_context.h"
20 #include "absl/container/flat_hash_set.h"
21 #include "tensorflow/core/data/service/data_transfer.h"
22 #include "tensorflow/core/data/service/dispatcher.grpc.pb.h"
23 #include "tensorflow/core/data/service/worker.grpc.pb.h"
24 #include "tensorflow/core/data/service/worker.pb.h"
25 #include "tensorflow/core/framework/dataset.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 
28 namespace tensorflow {
29 namespace data {
30 
31 // Increment this when making backwards-incompatible changes to communication
32 // between tf.data servers.
33 constexpr int kDataServiceVersion = 1;
34 
35 // Modes for how a tf.data service job should process a dataset.
36 enum class ProcessingMode : int64 {
37   UNSET = 0,
38   // Each tf.data worker processes an entire epoch. If a dataset contains 2
39   // elements and there are 3 workers, the job will produce 6 elements.
40   PARALLEL_EPOCHS = 1,
41   // Processing of a single epoch is distributed across all tf.data workers.
42   DISTRIBUTED_EPOCH = 2,
43 };
44 
45 // Parses a string representing a processing mode and stores the result in
46 // `mode`. Returns an InvalidArgument status if the string is not recognized.
47 Status ParseProcessingMode(const std::string& s, ProcessingMode& mode);
48 
49 // Converts a processing mode to its corresponding string.
50 std::string ProcessingModeToString(ProcessingMode mode);
51 
52 // Base class for data service clients. Data service clients are
53 // threadsafe.
54 class DataServiceClientBase {
55  public:
DataServiceClientBase(const std::string & address,const std::string & protocol)56   DataServiceClientBase(const std::string& address, const std::string& protocol)
57       : address_(address), protocol_(protocol) {}
58 
59   virtual ~DataServiceClientBase() = default;
60   // Not copyable or movable.
61   DataServiceClientBase(const DataServiceClientBase&) = delete;
62   DataServiceClientBase& operator=(const DataServiceClientBase&) = delete;
63 
64   // Initializes the client. Calling `Initialize()` is not required since the
65   // first RPC will perform any necessary initialization. However, it can be
66   // useful to call `Initialize()` proactively so that any errors that happen
67   // during initialization can be surfaced earlier.
Initialize()68   Status Initialize() { return EnsureInitialized(); }
69 
70  protected:
71   // Initializes the client if it isn't already initialized.
72   virtual Status EnsureInitialized() = 0;
73 
74   const std::string address_;
75   const std::string protocol_;
76 };
77 
78 // Client for communicating with the tf.data service dispatcher.
79 class DataServiceDispatcherClient : public DataServiceClientBase {
80  public:
DataServiceDispatcherClient(const std::string & address,const std::string & protocol)81   DataServiceDispatcherClient(const std::string& address,
82                               const std::string& protocol)
83       : DataServiceClientBase(address, protocol) {}
84 
85   // Sends a heartbeat to the dispatcher. If the worker wasn't already
86   // registered with the dispatcher, this will register the worker. The
87   // dispatcher will report which new tasks the worker should run, and which
88   // tasks it should delete. This is stored into `new_tasks` and
89   // `tasks_to_delete`.
90   Status WorkerHeartbeat(const std::string& worker_address,
91                          const std::string& transfer_address,
92                          const std::vector<int64>& current_tasks,
93                          std::vector<TaskDef>& new_tasks,
94                          std::vector<int64>& tasks_to_delete);
95 
96   // Updates the dispatcher with information about the worker's state.
97   Status WorkerUpdate(const std::string& worker_address,
98                       std::vector<TaskProgress>& task_progress);
99 
100   // Gets a dataset definition for the given dataset id, and stores the
101   // definition in `dataset_def`.
102   Status GetDatasetDef(int64 dataset_id, DatasetDef& dataset_def);
103 
104   // Gets the next split for the specified job id and repetition.
105   Status GetSplit(int64 job_id, int64 repetition, Tensor& split,
106                   bool& end_of_splits);
107 
108   // Registers a dataset with the tf.data service, and stores the generated
109   // dataset id in `dataset_id`.
110   Status RegisterDataset(GraphDef dataset, int64& dataset_id);
111 
112   // If `job_key` is set, looks up a job matching `job_key`. If `job_key` is
113   // absent or no matching job is found, creates a new job. The resulting job
114   // id is stored in `job_client_id`.
115   Status GetOrCreateJob(int64 dataset_id, ProcessingMode processing_mode,
116                         const absl::optional<JobKey>& job_key,
117                         absl::optional<int64> num_consumers,
118                         int64& job_client_id);
119 
120   // Releases a job client id, indicating that the id will no longer be used to
121   // read from the job.
122   Status ReleaseJobClient(int64 job_client_id);
123 
124   // Heartbeats to the dispatcher, getting back the tasks that should be
125   // running, and whether the job is finished.
126   Status ClientHeartbeat(ClientHeartbeatRequest& req,
127                          ClientHeartbeatResponse& resp);
128 
129   // Queries the dispatcher for its registered workers. The worker info will be
130   // stored in `workers`.
131   Status GetWorkers(std::vector<WorkerInfo>& workers);
132 
133  protected:
134   Status EnsureInitialized() override;
135 
136  private:
137   mutex mu_;
138   // Initialization is guarded by `mu_`, but using the stub does not require
139   // holding `mu_`
140   std::unique_ptr<DispatcherService::Stub> stub_;
141 };
142 
143 // Client for communicating with the tf.data service worker.
144 class DataServiceWorkerClient : public DataServiceClientBase {
145  public:
DataServiceWorkerClient(const std::string & address,const std::string & protocol,const std::string & transfer_protocol)146   DataServiceWorkerClient(const std::string& address,
147                           const std::string& protocol,
148                           const std::string& transfer_protocol)
149       : DataServiceClientBase(address, protocol),
150         transfer_protocol_(transfer_protocol) {}
151 
152   // Fetches an element from the worker.
153   Status GetElement(const GetElementRequest& req, GetElementResponse& resp);
154 
155   // Makes a best effort to cancel all outstanding calls in progress for the
156   // client, and causes further calls to return Cancelled status.
157   void TryCancel();
158 
159  protected:
160   Status EnsureInitialized() override;
161 
162  private:
163   const std::string transfer_protocol_;
164   mutex mu_;
165   // Initialization is guarded by `mu_`, but using the stub does not require
166   // holding `mu_`
167   std::unique_ptr<DataTransferClient> client_;
168 };
169 
170 // Creates and initializes a new tf.data service dispatcher client.
171 Status CreateDataServiceDispatcherClient(
172     const std::string& address, const std::string& protocol,
173     std::unique_ptr<DataServiceDispatcherClient>& out);
174 
175 // Creates and initializes a new tf.data service worker client.
176 Status CreateDataServiceWorkerClient(
177     const std::string& address, const std::string& protocol,
178     const std::string& transfer_protocol,
179     std::unique_ptr<DataServiceWorkerClient>& out);
180 
181 }  // namespace data
182 }  // namespace tensorflow
183 
184 #endif  // TENSORFLOW_CORE_DATA_SERVICE_DATA_SERVICE_H_
185