1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 17 #define TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 18 19 #include <functional> 20 21 #include "absl/strings/string_view.h" 22 #include "absl/types/optional.h" 23 #include "tensorflow/core/data/dataset.pb.h" 24 #include "tensorflow/core/data/service/worker.pb.h" 25 #include "tensorflow/core/platform/status.h" 26 27 namespace tensorflow { 28 namespace data { 29 30 // Client for communicating with the tf.data service transfer server. 31 class DataTransferClient { 32 public: 33 struct Config { 34 absl::string_view protocol; 35 std::string address; 36 }; 37 using FactoryT = 38 std::function<Status(Config, std::unique_ptr<DataTransferClient>*)>; 39 virtual ~DataTransferClient() = default; 40 41 // Fetches the next element. 42 virtual Status GetElement(const GetElementRequest& req, 43 GetElementResponse& resp) = 0; 44 45 // Makes a best effort to cancel all outstanding calls in progress for the 46 // client, and causes further calls to return Cancelled status. 47 virtual void TryCancel() = 0; 48 49 // Registers a DataTransferClient factory under `name`. 50 static void Register(std::string name, FactoryT factory); 51 52 // Builds a DataTransferClient from the factory registered under `name`. 53 static Status Build(std::string name, Config config, 54 std::unique_ptr<DataTransferClient>* out); 55 }; 56 57 // Server for communicating with the tf.data service transfer client. 58 class DataTransferServer { 59 public: 60 using GetElementT = 61 std::function<Status(const GetElementRequest*, GetElementResponse*)>; 62 virtual ~DataTransferServer() = default; 63 64 // Starts DataTransferServer, it should be available for requests afterwards. 65 virtual Status Start() = 0; 66 67 // Return the port that this server is listening on. 68 virtual int get_port() = 0; 69 70 // Register a DataTransferServer factory under `name`. 71 static void Register( 72 std::string name, 73 std::function<std::shared_ptr<DataTransferServer>(GetElementT)> factory); 74 75 // Builds a DataTransferServer from the factory registered with `name`. 76 static Status Build(std::string name, GetElementT get_element, 77 std::shared_ptr<DataTransferServer>* out); 78 }; 79 80 } // namespace data 81 } // namespace tensorflow 82 83 #endif // TENSORFLOW_CORE_DATA_SERVICE_DATA_TRANSFER_H_ 84