1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/distributed_runtime/rpc/grpc_response_cache.h"
17 #include "tensorflow/core/platform/env.h"
18 
19 namespace tensorflow {
20 
21 struct WorkerCacheEntry {
22   enum class State {
23     PENDING = 0,
24     ACTIVE = 1,
25     FINISHED = 2,
26   };
27 
28   State state = State::PENDING;
29   int64 expires_seconds;
30 
31   ::grpc::ByteBuffer response_buf;
32   Status response_status;
33 
34   // Additional retries may arrive while a request is still executing.  The
35   // callbacks for these calls are queued in `callbacks` and evaluated after
36   // the original request is completed.
37   std::vector<std::pair<RPCResponse, StatusCallback>> callbacks;
38 };
39 
Encode(::grpc::ByteBuffer * tgt) const40 void RPCResponse::Encode(::grpc::ByteBuffer* tgt) const {
41   if (buf_ != nullptr) {
42     *tgt = *buf_;
43   } else {
44     CHECK(msg_ != nullptr);
45     ::grpc::Slice slice(msg_->ByteSizeLong());
46     msg_->SerializeWithCachedSizesToArray(
47         const_cast<uint8*>(reinterpret_cast<const uint8*>(slice.begin())));
48     ::grpc::ByteBuffer tmp(&slice, 1);
49     tgt->Swap(&tmp);
50   }
51 }
52 
CopyFrom(const::grpc::ByteBuffer & src)53 void RPCResponse::CopyFrom(const ::grpc::ByteBuffer& src) {
54   if (buf_ != nullptr) {
55     *buf_ = src;
56     return;
57   }
58 
59   CHECK(msg_ != nullptr);
60   // We create a single slice when encoding protocol messages.
61   std::vector<::grpc::Slice> slices;
62   if (src.Dump(&slices).ok()) {
63     msg_->ParseFromArray(slices[0].begin(), slices[0].size());
64   } else {
65     LOG(ERROR) << "Failed to decode cached buffer.";
66   }
67 }
68 
LookupOrCompute(const string & key,RPCResponse response,ComputeFunc compute_func,StatusCallback done_cb)69 void GrpcResponseCache::LookupOrCompute(const string& key, RPCResponse response,
70                                         ComputeFunc compute_func,
71                                         StatusCallback done_cb) {
72   VLOG(1) << "Lookup " << key;
73   std::shared_ptr<WorkerCacheEntry> req;
74   MaybeCleanup();
75   {
76     mutex_lock m(mu_);
77 
78     if (requests_.find(key) != requests_.end()) {
79       req = requests_[key];
80     } else {
81       req.reset(new WorkerCacheEntry);
82       requests_[key] = req;
83     }
84 
85     if (req->state == WorkerCacheEntry::State::FINISHED) {
86       if (req->expires_seconds > Env::Default()->NowSeconds()) {
87         VLOG(1) << "Reuse cached response for " << key;
88         response.CopyFrom(req->response_buf);
89         done_cb(req->response_status);
90         return;
91       }
92       VLOG(1) << "Found expired cache entry for " << key;
93       req->state = WorkerCacheEntry::State::PENDING;
94       req->response_buf.Clear();
95     }
96 
97     req->callbacks.push_back(std::make_pair(response, done_cb));
98 
99     if (req->state == WorkerCacheEntry::State::ACTIVE) {
100       VLOG(1) << "Found active request for " << key
101               << ".  Adding entry to response queue.";
102       return;
103     }
104 
105     VLOG(2) << "No cache entry for " << key << ", running user computation.";
106     req->state = WorkerCacheEntry::State::ACTIVE;
107     req->expires_seconds = Env::Default()->NowSeconds() + expire_time_seconds_;
108   }
109 
110   compute_func([this, key, req, response](Status status) {
111     mutex_lock m(mu_);
112     response.Encode(&req->response_buf);
113     current_bytes_ += req->response_buf.Length();
114 
115     req->response_status = status;
116     req->state = WorkerCacheEntry::State::FINISHED;
117 
118     VLOG(1) << "Operation for " << key << " finished. "
119             << "Status: " << status << ", " << req->response_buf.Length()
120             << " response bytes, " << req->callbacks.size()
121             << " pending callbacks.";
122     for (auto& cb : req->callbacks) {
123       cb.first.CopyFrom(req->response_buf);
124       cb.second(req->response_status);
125     }
126     req->callbacks.clear();
127   });
128 }
129 
130 // Remove all stale or expired cache entries if the cache is full.
MaybeCleanup()131 void GrpcResponseCache::MaybeCleanup() {
132   mutex_lock m(mu_);
133   if (current_bytes_ < max_bytes_) {
134     return;
135   }
136 
137   VLOG(1) << "Cleanup: " << current_bytes_ << " -> " << max_bytes_;
138   std::vector<std::pair<string, std::shared_ptr<WorkerCacheEntry>>>
139       ordered_entries;
140   ordered_entries.reserve(requests_.size());
141   for (const auto& p : requests_) {
142     ordered_entries.push_back(std::make_pair(p.first, p.second));
143   }
144 
145   std::sort(ordered_entries.begin(), ordered_entries.end(),
146             [](const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& a,
147                const std::pair<string, std::shared_ptr<WorkerCacheEntry>>& b) {
148               return a.second->expires_seconds > b.second->expires_seconds;
149             });
150 
151   std::unordered_map<string, std::shared_ptr<WorkerCacheEntry>> kept;
152   int64 now = Env::Default()->NowSeconds();
153   int64 bytes_used = 0;
154 
155   // Always keep active requests.
156   for (auto& pair : ordered_entries) {
157     if (pair.second->state != WorkerCacheEntry::State::FINISHED) {
158       kept.insert(pair);
159     }
160   }
161 
162   // Keep unexpired, finished requests up to half of max_bytes_.  This reduces
163   // chances of overfilling the cache when active requests complete and
164   // amortizes cache cleanup cost.
165   for (auto& pair : ordered_entries) {
166     if (pair.second->expires_seconds < now || bytes_used >= max_bytes_ / 2) {
167       break;
168     }
169 
170     if (pair.second->state == WorkerCacheEntry::State::FINISHED) {
171       kept.insert(pair);
172       bytes_used += pair.second->response_buf.Length();
173     }
174   }
175 
176   VLOG(1) << "Cleaned cache.  Bytes used: " << current_bytes_ << " -> "
177           << bytes_used << ". Cache size: " << requests_.size() << " -> "
178           << kept.size();
179   current_bytes_ = bytes_used;
180   std::swap(requests_, kept);
181 }
182 
183 }  // namespace tensorflow
184