1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/gcs_dns_cache.h"
17 #ifndef _WIN32
18 #include <arpa/inet.h>
19 #include <netdb.h>
20 #else
21 #include <Windows.h>
22 #include <winsock2.h>
23 #include <ws2tcpip.h>
24 #endif
25 #include <sys/types.h>
26 
27 namespace tensorflow {
28 
29 namespace {
30 
31 const std::vector<string>& kCachedDomainNames =
32     *new std::vector<string>{"www.googleapis.com", "storage.googleapis.com"};
33 
print_getaddrinfo_error(const string & name,int error_code)34 inline void print_getaddrinfo_error(const string& name, int error_code) {
35 #ifndef _WIN32
36   if (error_code == EAI_SYSTEM) {
37     LOG(ERROR) << "Error resolving " << name
38                << " (EAI_SYSTEM): " << strerror(errno);
39   } else {
40     LOG(ERROR) << "Error resolving " << name << ": "
41                << gai_strerror(error_code);
42   }
43 #else
44   // TODO:WSAGetLastError is better than gai_strerror
45   LOG(ERROR) << "Error resolving " << name << ": " << gai_strerror(error_code);
46 #endif
47 }
48 
49 // Selects one item at random from a vector of items, using a uniform
50 // distribution.
51 template <typename T>
SelectRandomItemUniform(std::default_random_engine * random,const std::vector<T> & items)52 const T& SelectRandomItemUniform(std::default_random_engine* random,
53                                  const std::vector<T>& items) {
54   CHECK_GT(items.size(), 0);
55   std::uniform_int_distribution<size_t> distribution(0u, items.size() - 1u);
56   size_t choice_index = distribution(*random);
57   return items[choice_index];
58 }
59 }  // namespace
60 
GcsDnsCache(Env * env,int64 refresh_rate_secs)61 GcsDnsCache::GcsDnsCache(Env* env, int64 refresh_rate_secs)
62     : env_(env), refresh_rate_secs_(refresh_rate_secs) {}
63 
AnnotateRequest(HttpRequest * request)64 void GcsDnsCache::AnnotateRequest(HttpRequest* request) {
65   // TODO(saeta): Blacklist failing IP addresses.
66   mutex_lock l(mu_);
67   if (!started_) {
68     VLOG(1) << "Starting GCS DNS cache.";
69     DCHECK(!worker_) << "Worker thread already exists!";
70     // Perform DNS resolutions to warm the cache.
71     addresses_ = ResolveNames(kCachedDomainNames);
72 
73     // Note: we opt to use a thread instead of a delayed closure.
74     worker_.reset(env_->StartThread({}, "gcs_dns_worker",
75                                     [this]() { return WorkerThread(); }));
76     started_ = true;
77   }
78 
79   CHECK_EQ(kCachedDomainNames.size(), addresses_.size());
80   for (size_t i = 0; i < kCachedDomainNames.size(); ++i) {
81     const string& name = kCachedDomainNames[i];
82     const std::vector<string>& addresses = addresses_[i];
83     if (!addresses.empty()) {
84       const string& chosen_address =
85           SelectRandomItemUniform(&random_, addresses);
86       request->AddResolveOverride(name, 443, chosen_address);
87       VLOG(1) << "Annotated DNS mapping: " << name << " --> " << chosen_address;
88     } else {
89       LOG(WARNING) << "No IP addresses available for " << name;
90     }
91   }
92 }
93 
ResolveName(const string & name)94 /* static */ std::vector<string> GcsDnsCache::ResolveName(const string& name) {
95   VLOG(1) << "Resolving DNS name: " << name;
96 
97   addrinfo hints;
98   memset(&hints, 0, sizeof(hints));
99   hints.ai_family = AF_INET;  // Only use IPv4 for now.
100   hints.ai_socktype = SOCK_STREAM;
101   addrinfo* result = nullptr;
102   int return_code = getaddrinfo(name.c_str(), nullptr, &hints, &result);
103 
104   std::vector<string> output;
105   if (return_code == 0) {
106     for (const addrinfo* i = result; i != nullptr; i = i->ai_next) {
107       if (i->ai_family != AF_INET || i->ai_addr->sa_family != AF_INET) {
108         LOG(WARNING) << "Non-IPv4 address returned. ai_family: " << i->ai_family
109                      << ". sa_family: " << i->ai_addr->sa_family << ".";
110         continue;
111       }
112       char buf[INET_ADDRSTRLEN];
113       void* address_ptr =
114           &(reinterpret_cast<sockaddr_in*>(i->ai_addr)->sin_addr);
115       const char* formatted = nullptr;
116       if ((formatted = inet_ntop(i->ai_addr->sa_family, address_ptr, buf,
117                                  INET_ADDRSTRLEN)) == nullptr) {
118         LOG(ERROR) << "Error converting response to IP address for " << name
119                    << ": " << strerror(errno);
120       } else {
121         output.emplace_back(buf);
122         VLOG(1) << "... address: " << buf;
123       }
124     }
125   } else {
126     print_getaddrinfo_error(name, return_code);
127   }
128   if (result != nullptr) {
129     freeaddrinfo(result);
130   }
131   return output;
132 }
133 
134 // Performs DNS resolution for a set of DNS names. The return vector contains
135 // one element for each element in 'names', and each element is itself a
136 // vector of IP addresses (in textual form).
137 //
138 // If DNS resolution fails for any name, then that slot in the return vector
139 // will still be present, but will be an empty vector.
140 //
141 // Ensures: names.size() == return_value.size()
142 
ResolveNames(const std::vector<string> & names)143 std::vector<std::vector<string>> GcsDnsCache::ResolveNames(
144     const std::vector<string>& names) {
145   std::vector<std::vector<string>> all_addresses;
146   all_addresses.reserve(names.size());
147   for (const string& name : names) {
148     all_addresses.push_back(ResolveName(name));
149   }
150   return all_addresses;
151 }
152 
WorkerThread()153 void GcsDnsCache::WorkerThread() {
154   while (true) {
155     {
156       // Don't immediately re-resolve the addresses.
157       mutex_lock l(mu_);
158       if (cancelled_) return;
159       cond_var_.wait_for(l, std::chrono::seconds(refresh_rate_secs_));
160       if (cancelled_) return;
161     }
162 
163     // Resolve DNS values
164     auto new_addresses = ResolveNames(kCachedDomainNames);
165 
166     {
167       mutex_lock l(mu_);
168       // Update instance variables.
169       addresses_.swap(new_addresses);
170     }
171   }
172 }
173 
174 }  // namespace tensorflow
175