1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
16 #define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
17 
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21 
22 #include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/platform/protobuf.h"
26 
27 // gRPC response caching.  Most WorkerService methods cannot be retried directly
28 // as they will fail or deadlock.  To enable retrying, we can instead cache
29 // responses for a short period of time and reply to duplicate requests from the
30 // cache.
31 namespace tensorflow {
32 
33 // Union type to aid caching of either raw buffers (for RecvTensor RPCs) and
34 // protocol buffer messages (for all other RPCs).
35 class RPCResponse {
36  public:
RPCResponse()37   explicit RPCResponse() : buf_(nullptr), msg_(nullptr) {}
RPCResponse(::grpc::ByteBuffer * b)38   explicit RPCResponse(::grpc::ByteBuffer* b) : buf_(b), msg_(nullptr) {}
RPCResponse(protobuf::Message * m)39   explicit RPCResponse(protobuf::Message* m) : buf_(nullptr), msg_(m) {}
40 
41   // Encode this response into the target buffer.
42   void Encode(::grpc::ByteBuffer* tgt) const;
43 
44   // Copy from `src`: if this is a buffer, make a shallow copy.
45   // For protocol messages, parse the response from `src`.
46   void CopyFrom(const ::grpc::ByteBuffer& src);
47 
48  private:
49   ::grpc::ByteBuffer* buf_;
50   protobuf::Message* msg_;
51 };
52 
53 typedef std::function<void(StatusCallback)> ComputeFunc;
54 struct WorkerCacheEntry;
55 
56 // Track and cache the state of worker service RPCs.  An RPC can be in 3 states:
57 //
58 // * PENDING: this is the first call of the RPC, and it will transition to
59 // * ACTIVE: another thread is active processing this RPC
60 // * FINISHED: the worker has finished processing the method
61 //
62 // The response from completed RPCs are LRU cached until either `max_bytes`
63 // bytes are in use by the cache or they expire (according to `expire_time`).
64 class GrpcResponseCache {
65  public:
GrpcResponseCache(int64 max_bytes,int64 expire_time_seconds)66   GrpcResponseCache(int64 max_bytes, int64 expire_time_seconds)
67       : max_bytes_(max_bytes), expire_time_seconds_(expire_time_seconds) {}
68 
69   // Lookup the result for key.
70   // If it is finished, invoke `done_cb` immediately after filling `response`.
71   // If active, done_db will be invoked when the current call completes.
72   // Otherwise, invoke `compute_func` to fill the cache and invoke done_cb.
73   void LookupOrCompute(const string& key, RPCResponse response,
74                        ComputeFunc compute_func, StatusCallback done_cb);
75 
76   // Remove all stale or expired cache entries if the cache is full.
77   void MaybeCleanup();
78 
79  private:
80   int64 current_bytes_ GUARDED_BY(mu_) = 0;
81   const int64 max_bytes_;
82   const int64 expire_time_seconds_;
83 
84   std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> requests_
85       GUARDED_BY(mu_);
86   mutex mu_;
87 };
88 
89 }  // namespace tensorflow
90 
91 #endif  // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_RESPONSE_CACHE_H_
92