1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/platform/cloud/google_auth_provider.h"
17 #ifndef _WIN32
18 #include <pwd.h>
19 #include <unistd.h>
20 #else
21 #include <sys/types.h>
22 #endif
23 #include <fstream>
24 #include <utility>
25 #include "absl/strings/match.h"
26 #include "include/json/json.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/io/path.h"
29 #include "tensorflow/core/lib/strings/base64.h"
30 #include "tensorflow/core/platform/cloud/retrying_utils.h"
31 #include "tensorflow/core/platform/env.h"
32 
33 namespace tensorflow {
34 
35 namespace {
36 
37 // The environment variable pointing to the file with local
38 // Application Default Credentials.
39 constexpr char kGoogleApplicationCredentials[] =
40     "GOOGLE_APPLICATION_CREDENTIALS";
41 
42 // The environment variable to override token generation for testing.
43 constexpr char kGoogleAuthTokenForTesting[] = "GOOGLE_AUTH_TOKEN_FOR_TESTING";
44 
45 // The environment variable which can override '~/.config/gcloud' if set.
46 constexpr char kCloudSdkConfig[] = "CLOUDSDK_CONFIG";
47 
48 // The environment variable used to skip attempting to fetch GCE credentials:
49 // setting this to 'true' (case insensitive) will skip attempting to contact
50 // the GCE metadata service.
51 constexpr char kNoGceCheck[] = "NO_GCE_CHECK";
52 
53 // The default path to the gcloud config folder, relative to the home folder.
54 constexpr char kGCloudConfigFolder[] = ".config/gcloud/";
55 
56 // The name of the well-known credentials JSON file in the gcloud config folder.
57 constexpr char kWellKnownCredentialsFile[] =
58     "application_default_credentials.json";
59 
60 // The minimum time delta between now and the token expiration time
61 // for the token to be re-used.
62 constexpr int kExpirationTimeMarginSec = 60;
63 
64 // The URL to retrieve the auth bearer token via OAuth with a refresh token.
65 constexpr char kOAuthV3Url[] = "https://www.googleapis.com/oauth2/v3/token";
66 
67 // The URL to retrieve the auth bearer token via OAuth with a private key.
68 constexpr char kOAuthV4Url[] = "https://www.googleapis.com/oauth2/v4/token";
69 
70 // The URL to retrieve the auth bearer token when running in Google Compute
71 // Engine.
72 constexpr char kGceTokenPath[] = "instance/service-accounts/default/token";
73 
74 // The authentication token scope to request.
75 constexpr char kOAuthScope[] = "https://www.googleapis.com/auth/cloud-platform";
76 
77 /// Returns whether the given path points to a readable file.
IsFile(const string & filename)78 bool IsFile(const string& filename) {
79   std::ifstream fstream(filename.c_str());
80   return fstream.good();
81 }
82 
83 /// Returns the credentials file name from the env variable.
GetEnvironmentVariableFileName(string * filename)84 Status GetEnvironmentVariableFileName(string* filename) {
85   if (!filename) {
86     return errors::FailedPrecondition("'filename' cannot be nullptr.");
87   }
88   const char* result = std::getenv(kGoogleApplicationCredentials);
89   if (!result || !IsFile(result)) {
90     return errors::NotFound(strings::StrCat("$", kGoogleApplicationCredentials,
91                                             " is not set or corrupt."));
92   }
93   *filename = result;
94   return Status::OK();
95 }
96 
97 /// Returns the well known file produced by command 'gcloud auth login'.
GetWellKnownFileName(string * filename)98 Status GetWellKnownFileName(string* filename) {
99   if (!filename) {
100     return errors::FailedPrecondition("'filename' cannot be nullptr.");
101   }
102   string config_dir;
103   const char* config_dir_override = std::getenv(kCloudSdkConfig);
104   if (config_dir_override) {
105     config_dir = config_dir_override;
106   } else {
107     // Determine the home dir path.
108     const char* home_dir = std::getenv("HOME");
109     if (!home_dir) {
110       return errors::FailedPrecondition("Could not read $HOME.");
111     }
112     config_dir = io::JoinPath(home_dir, kGCloudConfigFolder);
113   }
114   auto result = io::JoinPath(config_dir, kWellKnownCredentialsFile);
115   if (!IsFile(result)) {
116     return errors::NotFound(
117         "Could not find the credentials file in the standard gcloud location.");
118   }
119   *filename = result;
120   return Status::OK();
121 }
122 
123 }  // namespace
124 
GoogleAuthProvider(std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client)125 GoogleAuthProvider::GoogleAuthProvider(
126     std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client)
127     : GoogleAuthProvider(std::unique_ptr<OAuthClient>(new OAuthClient()),
128                          std::move(compute_engine_metadata_client),
129                          Env::Default()) {}
130 
GoogleAuthProvider(std::unique_ptr<OAuthClient> oauth_client,std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client,Env * env)131 GoogleAuthProvider::GoogleAuthProvider(
132     std::unique_ptr<OAuthClient> oauth_client,
133     std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client,
134     Env* env)
135     : oauth_client_(std::move(oauth_client)),
136       compute_engine_metadata_client_(
137           std::move(compute_engine_metadata_client)),
138       env_(env) {}
139 
GetToken(string * t)140 Status GoogleAuthProvider::GetToken(string* t) {
141   mutex_lock lock(mu_);
142   const uint64 now_sec = env_->NowSeconds();
143 
144   if (now_sec + kExpirationTimeMarginSec < expiration_timestamp_sec_) {
145     *t = current_token_;
146     return Status::OK();
147   }
148 
149   if (GetTokenForTesting().ok()) {
150     *t = current_token_;
151     return Status::OK();
152   }
153 
154   auto token_from_files_status = GetTokenFromFiles();
155   if (token_from_files_status.ok()) {
156     *t = current_token_;
157     return Status::OK();
158   }
159 
160   char* no_gce_check_var = std::getenv(kNoGceCheck);
161   bool skip_gce_check = no_gce_check_var != nullptr &&
162                         absl::EqualsIgnoreCase(no_gce_check_var, "true");
163   Status token_from_gce_status;
164   if (skip_gce_check) {
165     token_from_gce_status =
166         Status(error::CANCELLED,
167                strings::StrCat("GCE check skipped due to presence of $",
168                                kNoGceCheck, " environment variable."));
169   } else {
170     token_from_gce_status = GetTokenFromGce();
171   }
172 
173   if (token_from_gce_status.ok()) {
174     *t = current_token_;
175     return Status::OK();
176   }
177 
178   LOG(WARNING)
179       << "All attempts to get a Google authentication bearer token failed, "
180       << "returning an empty token. Retrieving token from files failed with \""
181       << token_from_files_status.ToString() << "\"."
182       << " Retrieving token from GCE failed with \""
183       << token_from_gce_status.ToString() << "\".";
184 
185   // Public objects can still be accessed with an empty bearer token,
186   // so return an empty token instead of failing.
187   *t = "";
188 
189   // We only want to keep returning our empty token if we've tried and failed
190   // the (potentially slow) task of detecting GCE.
191   if (skip_gce_check) {
192     expiration_timestamp_sec_ = 0;
193   } else {
194     expiration_timestamp_sec_ = UINT64_MAX;
195   }
196   current_token_ = "";
197 
198   return Status::OK();
199 }
200 
GetTokenFromFiles()201 Status GoogleAuthProvider::GetTokenFromFiles() {
202   string credentials_filename;
203   if (!GetEnvironmentVariableFileName(&credentials_filename).ok() &&
204       !GetWellKnownFileName(&credentials_filename).ok()) {
205     return errors::NotFound("Could not locate the credentials file.");
206   }
207 
208   Json::Value json;
209   Json::Reader reader;
210   std::ifstream credentials_fstream(credentials_filename);
211   if (!reader.parse(credentials_fstream, json)) {
212     return errors::FailedPrecondition(
213         "Couldn't parse the JSON credentials file.");
214   }
215   if (json.isMember("refresh_token")) {
216     TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromRefreshTokenJson(
217         json, kOAuthV3Url, &current_token_, &expiration_timestamp_sec_));
218   } else if (json.isMember("private_key")) {
219     TF_RETURN_IF_ERROR(oauth_client_->GetTokenFromServiceAccountJson(
220         json, kOAuthV4Url, kOAuthScope, &current_token_,
221         &expiration_timestamp_sec_));
222   } else {
223     return errors::FailedPrecondition(
224         "Unexpected content of the JSON credentials file.");
225   }
226   return Status::OK();
227 }
228 
GetTokenFromGce()229 Status GoogleAuthProvider::GetTokenFromGce() {
230   std::vector<char> response_buffer;
231   const uint64 request_timestamp_sec = env_->NowSeconds();
232 
233   TF_RETURN_IF_ERROR(compute_engine_metadata_client_->GetMetadata(
234       kGceTokenPath, &response_buffer));
235   StringPiece response =
236       StringPiece(&response_buffer[0], response_buffer.size());
237 
238   TF_RETURN_IF_ERROR(oauth_client_->ParseOAuthResponse(
239       response, request_timestamp_sec, &current_token_,
240       &expiration_timestamp_sec_));
241 
242   return Status::OK();
243 }
244 
GetTokenForTesting()245 Status GoogleAuthProvider::GetTokenForTesting() {
246   const char* token = std::getenv(kGoogleAuthTokenForTesting);
247   if (!token) {
248     return errors::NotFound("The env variable for testing was not set.");
249   }
250   expiration_timestamp_sec_ = UINT64_MAX;
251   current_token_ = token;
252   return Status::OK();
253 }
254 
255 }  // namespace tensorflow
256