1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
17 #include "tensorflow/core/platform/env.h"
18
19 namespace tensorflow {
20
21 struct WorkerCacheEntry {
22 enum class State {
23 PENDING = 0,
24 ACTIVE = 1,
25 FINISHED = 2,
26 };
27
28 State state = State::PENDING;
29 int64 expires_seconds;
30
31 ::grpc::ByteBuffer response_buf;
32 Status response_status;
33
34 // Additional retries may arrive while a request is still executing. The
35 // callbacks for these calls are queued in `callbacks` and evaluated after
36 // the original request is completed.
37 std::vector<std::pair<RPCResponse, StatusCallback>> callbacks;
38 };
39
Encode(::grpc::ByteBuffer * tgt) const40 void RPCResponse::Encode(::grpc::ByteBuffer* tgt) const {
41 if (buf_ != nullptr) {
42 *tgt = *buf_;
43 } else {
44 CHECK(msg_ != nullptr);
45 ::grpc::Slice slice(msg_->ByteSizeLong());
46 msg_->SerializeWithCachedSizesToArray(
47 const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
48 ::grpc::ByteBuffer tmp(&slice, 1);
49 tgt->Swap(&tmp);
50 }
51 }
52
CopyFrom(const::grpc::ByteBuffer & src)53 void RPCResponse::CopyFrom(const ::grpc::ByteBuffer& src) {
54 if (buf_ != nullptr) {
55 *buf_ = src;
56 return;
57 }
58
59 CHECK(msg_ != nullptr);
60 // We create a single slice when encoding protocol messages.
61 std::vector<::grpc::Slice> slices;
62 if (src.Dump(&slices).ok()) {
63 msg_->ParseFromArray(slices[0].begin(), slices[0].size());
64 } else {
65 LOG(ERROR) << "Failed to decode cached buffer.";
66 }
67 }
68
LookupOrCompute(const string & key,RPCResponse response,ComputeFunc compute_func,StatusCallback done_cb)69 void GrpcResponseCache::LookupOrCompute(const string& key, RPCResponse response,
70 ComputeFunc compute_func,
71 StatusCallback done_cb) {
72 VLOG(1) << "Lookup " << key;
73 std::shared_ptr<WorkerCacheEntry> req;
74 MaybeCleanup();
75 {
76 mutex_lock m(mu_);
77
78 if (requests_.find(key) != requests_.end()) {
79 req = requests_[key];
80 } else {
81 req.reset(new WorkerCacheEntry);
82 requests_[key] = req;
83 }
84
85 if (req->state == WorkerCacheEntry::State::FINISHED) {
86 if (req->expires_seconds > Env::Default()->NowSeconds()) {
87 VLOG(1) << "Reuse cached response for " << key;
88 response.CopyFrom(req->response_buf);
89 done_cb(req->response_status);
90 return;
91 }
92 VLOG(1) << "Found expired cache entry for " << key;
93 req->state = WorkerCacheEntry::State::PENDING;
94 req->response_buf.Clear();
95 }
96
97 req->callbacks.push_back(std::make_pair(response, done_cb));
98
99 if (req->state == WorkerCacheEntry::State::ACTIVE) {
100 VLOG(1) << "Found active request for " << key
101 << ". Adding entry to response queue.";
102 return;
103 }
104
105 VLOG(2) << "No cache entry for " << key << ", running user computation.";
106 req->state = WorkerCacheEntry::State::ACTIVE;
107 req->expires_seconds = Env::Default()->NowSeconds() + expire_time_seconds_;
108 }
109
110 compute_func([this, key, req, response](Status status) {
111 mutex_lock m(mu_);
112 response.Encode(&req->response_buf);
113 current_bytes_ += req->response_buf.Length();
114
115 req->response_status = status;
116 req->state = WorkerCacheEntry::State::FINISHED;
117
118 VLOG(1) << "Operation for " << key << " finished. "
119 << "Status: " << status << ", " << req->response_buf.Length()
120 << " response bytes, " << req->callbacks.size()
121 << " pending callbacks.";
122 for (auto& cb : req->callbacks) {
123 cb.first.CopyFrom(req->response_buf);
124 cb.second(req->response_status);
125 }
126 req->callbacks.clear();
127 });
128 }
129
130 // Remove all stale or expired cache entries if the cache is full.
MaybeCleanup()131 void GrpcResponseCache::MaybeCleanup() {
132 mutex_lock m(mu_);
133 if (current_bytes_ < max_bytes_) {
134 return;
135 }
136
137 VLOG(1) << "Cleanup: " << current_bytes_ << " -> " << max_bytes_;
138 std::vector<std::pair<string, std::shared_ptr<WorkerCacheEntry>>>
139 ordered_entries;
140 ordered_entries.reserve(requests_.size());
141 for (const auto& p : requests_) {
142 ordered_entries.push_back(std::make_pair(p.first, p.second));
143 }
144
145 std::sort(ordered_entries.begin(), ordered_entries.end(),
146 [](const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& a,
147 const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& b) {
148 return a.second->expires_seconds > b.second->expires_seconds;
149 });
150
151 std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> kept;
152 int64 now = Env::Default()->NowSeconds();
153 int64 bytes_used = 0;
154
155 // Always keep active requests.
156 for (auto& pair : ordered_entries) {
157 if (pair.second->state != WorkerCacheEntry::State::FINISHED) {
158 kept.insert(pair);
159 }
160 }
161
162 // Keep unexpired, finished requests up to half of max_bytes_. This reduces
163 // chances of overfilling the cache when active requests complete and
164 // amortizes cache cleanup cost.
165 for (auto& pair : ordered_entries) {
166 if (pair.second->expires_seconds < now || bytes_used >= max_bytes_ / 2) {
167 break;
168 }
169
170 if (pair.second->state == WorkerCacheEntry::State::FINISHED) {
171 kept.insert(pair);
172 bytes_used += pair.second->response_buf.Length();
173 }
174 }
175
176 VLOG(1) << "Cleaned cache. Bytes used: " << current_bytes_ << " -> "
177 << bytes_used << ". Cache size: " << requests_.size() << " -> "
178 << kept.size();
179 current_bytes_ = bytes_used;
180 std::swap(requests_, kept);
181 }
182
183 } // namespace tensorflow
184