1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/client_library.h"
17 
18 #include "tensorflow/compiler/xla/service/backend.h"
19 #include "tensorflow/compiler/xla/service/platform_util.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace xla {
25 
LocalClientOptions(perftools::gputools::Platform * platform,int number_of_replicas,int intra_op_parallelism_threads)26 LocalClientOptions::LocalClientOptions(perftools::gputools::Platform* platform,
27                                        int number_of_replicas,
28                                        int intra_op_parallelism_threads)
29     : platform_(platform),
30       number_of_replicas_(number_of_replicas),
31       intra_op_parallelism_threads_(intra_op_parallelism_threads) {}
32 
set_platform(perftools::gputools::Platform * platform)33 LocalClientOptions& LocalClientOptions::set_platform(
34     perftools::gputools::Platform* platform) {
35   platform_ = platform;
36   return *this;
37 }
38 
platform() const39 perftools::gputools::Platform* LocalClientOptions::platform() const {
40   return platform_;
41 }
42 
set_number_of_replicas(int number_of_replicas)43 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
44     int number_of_replicas) {
45   number_of_replicas_ = number_of_replicas;
46   return *this;
47 }
48 
number_of_replicas() const49 int LocalClientOptions::number_of_replicas() const {
50   return number_of_replicas_;
51 }
52 
set_intra_op_parallelism_threads(int num_threads)53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
54     int num_threads) {
55   intra_op_parallelism_threads_ = num_threads;
56   return *this;
57 }
58 
intra_op_parallelism_threads() const59 int LocalClientOptions::intra_op_parallelism_threads() const {
60   return intra_op_parallelism_threads_;
61 }
62 
Singleton()63 /* static */ ClientLibrary& ClientLibrary::Singleton() {
64   static ClientLibrary* c = new ClientLibrary;
65   return *c;
66 }
67 
68 ClientLibrary::ClientLibrary() = default;
69 ClientLibrary::~ClientLibrary() = default;
70 
GetOrCreateLocalClient(perftools::gputools::Platform * platform)71 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
72     perftools::gputools::Platform* platform) {
73   LocalClientOptions default_options;
74   default_options.set_platform(platform);
75   return GetOrCreateLocalClient(default_options);
76 }
77 
GetOrCreateLocalClient(const LocalClientOptions & options)78 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
79     const LocalClientOptions& options) {
80   perftools::gputools::Platform* platform = options.platform();
81   int replica_count = options.number_of_replicas();
82   ClientLibrary& client_library = Singleton();
83   tensorflow::mutex_lock lock(client_library.service_mutex_);
84 
85   if (platform == nullptr) {
86     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
87   }
88 
89   auto it = client_library.local_instances_.find(platform->id());
90   if (it != client_library.local_instances_.end()) {
91     return it->second->client.get();
92   }
93 
94   ServiceOptions service_options;
95   service_options.set_platform(platform);
96   service_options.set_number_of_replicas(replica_count);
97   service_options.set_intra_op_parallelism_threads(
98       options.intra_op_parallelism_threads());
99 
100   auto instance = MakeUnique<LocalInstance>();
101   TF_ASSIGN_OR_RETURN(instance->service,
102                       LocalService::NewService(service_options));
103   instance->client = MakeUnique<LocalClient>(instance->service.get());
104   LocalClient* cl = instance->client.get();
105 
106   client_library.local_instances_.insert(
107       std::make_pair(platform->id(), std::move(instance)));
108   return cl;
109 }
110 
LocalClientOrDie()111 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
112   auto client_status = GetOrCreateLocalClient();
113   TF_CHECK_OK(client_status.status());
114   return client_status.ValueOrDie();
115 }
116 
GetXlaService(perftools::gputools::Platform * platform)117 /* static */ LocalService* ClientLibrary::GetXlaService(
118     perftools::gputools::Platform* platform) {
119   ClientLibrary& client_library = Singleton();
120   tensorflow::mutex_lock lock(client_library.service_mutex_);
121   auto it = client_library.local_instances_.find(platform->id());
122   CHECK(it != client_library.local_instances_.end());
123   return it->second->service.get();
124 }
125 
126 /* static */ StatusOr<CompileOnlyClient*>
GetOrCreateCompileOnlyClient(perftools::gputools::Platform * platform)127 ClientLibrary::GetOrCreateCompileOnlyClient(
128     perftools::gputools::Platform* platform) {
129   ClientLibrary& client_library = Singleton();
130   tensorflow::mutex_lock lock(client_library.service_mutex_);
131 
132   if (platform == nullptr) {
133     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
134   }
135 
136   auto it = client_library.compile_only_instances_.find(platform->id());
137   if (it != client_library.compile_only_instances_.end()) {
138     return it->second->client.get();
139   }
140 
141   auto instance = MakeUnique<CompileOnlyInstance>();
142   TF_ASSIGN_OR_RETURN(instance->service,
143                       CompileOnlyService::NewService(platform));
144   instance->client = MakeUnique<CompileOnlyClient>(instance->service.get());
145   CompileOnlyClient* cl = instance->client.get();
146 
147   client_library.compile_only_instances_.insert(
148       std::make_pair(platform->id(), std::move(instance)));
149   return cl;
150 }
151 
DestroyLocalInstances()152 /* static */ void ClientLibrary::DestroyLocalInstances() {
153   ClientLibrary& client_library = Singleton();
154   tensorflow::mutex_lock lock(client_library.service_mutex_);
155 
156   client_library.local_instances_.clear();
157   client_library.compile_only_instances_.clear();
158 }
159 
160 }  // namespace xla
161