1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
17 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
18 
19 #include <memory>
20 #include <unordered_map>
21 #include "tensorflow/core/distributed_runtime/recent_request_ids.h"
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
23 #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.h"
24 #include "tensorflow/core/distributed_runtime/worker.h"
25 
26 namespace grpc {
27 class ByteBuffer;
28 class ServerBuilder;
29 }  // namespace grpc
30 
31 namespace tensorflow {
32 
33 class AsyncServiceInterface;
34 class ConfigProto;
35 struct WorkerEnv;
36 struct WorkerSession;
37 
38 class GrpcWorker : public Worker {
39  public:
40   GrpcWorker(WorkerEnv* env, const ConfigProto& config);
41 
42   // Specialized version of RecvTensor for gRPC, which avoids a copy.
43   virtual void GrpcRecvTensorAsync(CallOptions* opts,
44                                    const RecvTensorRequest* request,
45                                    ::grpc::ByteBuffer* response,
46                                    StatusCallback done);
47 
48   virtual void LoggingAsync(const LoggingRequest* request,
49                             LoggingResponse* response, StatusCallback done);
50 
51   virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
52                             RecvBufResponse* response, StatusCallback done);
53 
54   WorkerEnv* env();
55 
56  private:
57   RecentRequestIds recent_request_ids_;
58   const int32 recv_buf_max_chunk_;
59 };
60 
61 std::unique_ptr<GrpcWorker> NewGrpcWorker(WorkerEnv* worker_env,
62                                           const ConfigProto& config);
63 
64 struct GrpcWorkerServiceOptions {
65   // Map from GrpcWorkerMethod id to queue depth.  If set this overrides the
66   // default queue depth for a method.
67   std::unordered_map<int, int> queue_depth;
68   int num_serving_threads = 8;
69   int64 response_cache_bytes = 0;
70   int64 response_cache_expires_seconds = 0;
71 };
72 
73 // Returns an implementation of WorkerService rpc service.
74 std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
75     GrpcWorker* worker, ::grpc::ServerBuilder* builder,
76     GrpcWorkerServiceOptions opts = GrpcWorkerServiceOptions());
77 
78 }  // namespace tensorflow
79 
80 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_WORKER_SERVICE_H_
81