1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
17 #ifndef _WIN32
18 #include <arpa/inet.h>
19 #include <netdb.h>
20 #else
21 #include <Windows.h>
22 #include <winsock2.h>
23 #include <ws2tcpip.h>
24 #endif
25 #include <sys/types.h>
26
27 namespace tensorflow {
28
29 namespace {
30
31 const std::vector<string>& kCachedDomainNames =
32 *new std::vector<string>{"www.googleapis.com", "storage.googleapis.com"};
33
print_getaddrinfo_error(const string & name,int error_code)34 inline void print_getaddrinfo_error(const string& name, int error_code) {
35 #ifndef _WIN32
36 if (error_code == EAI_SYSTEM) {
37 LOG(ERROR) << "Error resolving " << name
38 << " (EAI_SYSTEM): " << strerror(errno);
39 } else {
40 LOG(ERROR) << "Error resolving " << name << ": "
41 << gai_strerror(error_code);
42 }
43 #else
44 // TODO:WSAGetLastError is better than gai_strerror
45 LOG(ERROR) << "Error resolving " << name << ": " << gai_strerror(error_code);
46 #endif
47 }
48
49 // Selects one item at random from a vector of items, using a uniform
50 // distribution.
51 template <typename T>
SelectRandomItemUniform(std::default_random_engine * random,const std::vector<T> & items)52 const T& SelectRandomItemUniform(std::default_random_engine* random,
53 const std::vector<T>& items) {
54 CHECK_GT(items.size(), 0);
55 std::uniform_int_distribution<size_t> distribution(0u, items.size() - 1u);
56 size_t choice_index = distribution(*random);
57 return items[choice_index];
58 }
59 } // namespace
60
GcsDnsCache(Env * env,int64 refresh_rate_secs)61 GcsDnsCache::GcsDnsCache(Env* env, int64 refresh_rate_secs)
62 : env_(env), refresh_rate_secs_(refresh_rate_secs) {}
63
AnnotateRequest(HttpRequest * request)64 void GcsDnsCache::AnnotateRequest(HttpRequest* request) {
65 // TODO(saeta): Blacklist failing IP addresses.
66 mutex_lock l(mu_);
67 if (!started_) {
68 VLOG(1) << "Starting GCS DNS cache.";
69 DCHECK(!worker_) << "Worker thread already exists!";
70 // Perform DNS resolutions to warm the cache.
71 addresses_ = ResolveNames(kCachedDomainNames);
72
73 // Note: we opt to use a thread instead of a delayed closure.
74 worker_.reset(env_->StartThread({}, "gcs_dns_worker",
75 [this]() { return WorkerThread(); }));
76 started_ = true;
77 }
78
79 CHECK_EQ(kCachedDomainNames.size(), addresses_.size());
80 for (size_t i = 0; i < kCachedDomainNames.size(); ++i) {
81 const string& name = kCachedDomainNames[i];
82 const std::vector<string>& addresses = addresses_[i];
83 if (!addresses.empty()) {
84 const string& chosen_address =
85 SelectRandomItemUniform(&random_, addresses);
86 request->AddResolveOverride(name, 443, chosen_address);
87 VLOG(1) << "Annotated DNS mapping: " << name << " --> " << chosen_address;
88 } else {
89 LOG(WARNING) << "No IP addresses available for " << name;
90 }
91 }
92 }
93
ResolveName(const string & name)94 /* static */ std::vector<string> GcsDnsCache::ResolveName(const string& name) {
95 VLOG(1) << "Resolving DNS name: " << name;
96
97 addrinfo hints;
98 memset(&hints, 0, sizeof(hints));
99 hints.ai_family = AF_INET; // Only use IPv4 for now.
100 hints.ai_socktype = SOCK_STREAM;
101 addrinfo* result = nullptr;
102 int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result);
103
104 std::vector<string> output;
105 if (return_code == 0) {
106 for (const addrinfo* i = result; i != nullptr; i = i->ai_next) {
107 if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) {
108 LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family
109 << ". sa_family: " << i->ai_addr->sa_family << ".";
110 continue;
111 }
112 char buf[INET_ADDRSTRLEN];
113 void* address_ptr =
114 &(reinterpret_cast<sockaddr_in*>(i->ai_addr)->sin_addr);
115 const char* formatted = nullptr;
116 if ((formatted = inet_ntop(i->ai_addr->sa_family, address_ptr, buf,
117 INET_ADDRSTRLEN)) == nullptr) {
118 LOG(ERROR) << "Error converting response to IP address for " << name
119 << ": " << strerror(errno);
120 } else {
121 output.emplace_back(buf);
122 VLOG(1) << "... address: " << buf;
123 }
124 }
125 } else {
126 print_getaddrinfo_error(name, return_code);
127 }
128 if (result != nullptr) {
129 freeaddrinfo(result);
130 }
131 return output;
132 }
133
134 // Performs DNS resolution for a set of DNS names. The return vector contains
135 // one element for each element in 'names', and each element is itself a
136 // vector of IP addresses (in textual form).
137 //
138 // If DNS resolution fails for any name, then that slot in the return vector
139 // will still be present, but will be an empty vector.
140 //
141 // Ensures: names.size() == return_value.size()
142
ResolveNames(const std::vector<string> & names)143 std::vector<std::vector<string>> GcsDnsCache::ResolveNames(
144 const std::vector<string>& names) {
145 std::vector<std::vector<string>> all_addresses;
146 all_addresses.reserve(names.size());
147 for (const string& name : names) {
148 all_addresses.push_back(ResolveName(name));
149 }
150 return all_addresses;
151 }
152
WorkerThread()153 void GcsDnsCache::WorkerThread() {
154 while (true) {
155 {
156 // Don't immediately re-resolve the addresses.
157 mutex_lock l(mu_);
158 if (cancelled_) return;
159 cond_var_.wait_for(l, std::chrono::seconds(refresh_rate_secs_));
160 if (cancelled_) return;
161 }
162
163 // Resolve DNS values
164 auto new_addresses = ResolveNames(kCachedDomainNames);
165
166 {
167 mutex_lock l(mu_);
168 // Update instance variables.
169 addresses_.swap(new_addresses);
170 }
171 }
172 }
173
174 } // namespace tensorflow
175