1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
17 #define TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
18 
19 #include <memory>
20 #include "tensorflow/core/platform/cloud/auth_provider.h"
21 #include "tensorflow/core/platform/cloud/compute_engine_metadata_client.h"
22 #include "tensorflow/core/platform/cloud/oauth_client.h"
23 #include "tensorflow/core/platform/mutex.h"
24 #include "tensorflow/core/platform/thread_annotations.h"
25 
26 namespace tensorflow {
27 
28 /// Implementation based on Google Application Default Credentials.
29 class GoogleAuthProvider : public AuthProvider {
30  public:
31   GoogleAuthProvider(std::shared_ptr<ComputeEngineMetadataClient>
32                          compute_engine_metadata_client);
33   explicit GoogleAuthProvider(std::unique_ptr<OAuthClient> oauth_client,
34                               std::shared_ptr<ComputeEngineMetadataClient>
35                                   compute_engine_metadata_client,
36                               Env* env);
~GoogleAuthProvider()37   virtual ~GoogleAuthProvider() {}
38 
39   /// \brief Returns the short-term authentication bearer token.
40   ///
41   /// Safe for concurrent use by multiple threads.
42   Status GetToken(string* token) override;
43 
44  private:
45   /// \brief Gets the bearer token from files.
46   ///
47   /// Tries the file from $GOOGLE_APPLICATION_CREDENTIALS and the
48   /// standard gcloud tool's location.
49   Status GetTokenFromFiles() EXCLUSIVE_LOCKS_REQUIRED(mu_);
50 
51   /// Gets the bearer token from Google Compute Engine environment.
52   Status GetTokenFromGce() EXCLUSIVE_LOCKS_REQUIRED(mu_);
53 
54   /// Gets the bearer token from the system env variable, for testing purposes.
55   Status GetTokenForTesting() EXCLUSIVE_LOCKS_REQUIRED(mu_);
56 
57   std::unique_ptr<OAuthClient> oauth_client_;
58   std::shared_ptr<ComputeEngineMetadataClient> compute_engine_metadata_client_;
59   Env* env_;
60   mutex mu_;
61   string current_token_ GUARDED_BY(mu_);
62   uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
63   TF_DISALLOW_COPY_AND_ASSIGN(GoogleAuthProvider);
64 };
65 
66 }  // namespace tensorflow
67 
68 #endif  // TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
69