1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/client_library.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/service/backend.h"
20 #include "tensorflow/compiler/xla/service/platform_util.h"
21 #include "tensorflow/compiler/xla/status_macros.h"
22 #include "tensorflow/compiler/xla/util.h"
23 #include "tensorflow/core/platform/logging.h"
24 
25 namespace xla {
26 
LocalClientOptions(se::Platform * platform,int number_of_replicas,int intra_op_parallelism_threads,const absl::optional<std::set<int>> & allowed_devices)27 LocalClientOptions::LocalClientOptions(
28     se::Platform* platform, int number_of_replicas,
29     int intra_op_parallelism_threads,
30     const absl::optional<std::set<int>>& allowed_devices)
31     : platform_(platform),
32       number_of_replicas_(number_of_replicas),
33       intra_op_parallelism_threads_(intra_op_parallelism_threads),
34       allowed_devices_(allowed_devices) {}
35 
set_platform(se::Platform * platform)36 LocalClientOptions& LocalClientOptions::set_platform(se::Platform* platform) {
37   platform_ = platform;
38   return *this;
39 }
40 
platform() const41 se::Platform* LocalClientOptions::platform() const { return platform_; }
42 
set_number_of_replicas(int number_of_replicas)43 LocalClientOptions& LocalClientOptions::set_number_of_replicas(
44     int number_of_replicas) {
45   number_of_replicas_ = number_of_replicas;
46   return *this;
47 }
48 
number_of_replicas() const49 int LocalClientOptions::number_of_replicas() const {
50   return number_of_replicas_;
51 }
52 
set_intra_op_parallelism_threads(int num_threads)53 LocalClientOptions& LocalClientOptions::set_intra_op_parallelism_threads(
54     int num_threads) {
55   intra_op_parallelism_threads_ = num_threads;
56   return *this;
57 }
58 
intra_op_parallelism_threads() const59 int LocalClientOptions::intra_op_parallelism_threads() const {
60   return intra_op_parallelism_threads_;
61 }
62 
set_allowed_devices(const absl::optional<std::set<int>> & allowed_devices)63 LocalClientOptions& LocalClientOptions::set_allowed_devices(
64     const absl::optional<std::set<int>>& allowed_devices) {
65   allowed_devices_ = allowed_devices;
66   return *this;
67 }
68 
allowed_devices() const69 const absl::optional<std::set<int>>& LocalClientOptions::allowed_devices()
70     const {
71   return allowed_devices_;
72 }
73 
Singleton()74 /* static */ ClientLibrary& ClientLibrary::Singleton() {
75   static ClientLibrary* c = new ClientLibrary;
76   return *c;
77 }
78 
79 ClientLibrary::ClientLibrary() = default;
80 ClientLibrary::~ClientLibrary() = default;
81 
GetOrCreateLocalClient(se::Platform * platform,const absl::optional<std::set<int>> & device_set)82 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
83     se::Platform* platform, const absl::optional<std::set<int>>& device_set) {
84   LocalClientOptions default_options;
85   default_options.set_platform(platform);
86   default_options.set_allowed_devices(device_set);
87   return GetOrCreateLocalClient(default_options);
88 }
89 
GetOrCreateLocalClient(const LocalClientOptions & options)90 /* static */ StatusOr<LocalClient*> ClientLibrary::GetOrCreateLocalClient(
91     const LocalClientOptions& options) {
92   se::Platform* platform = options.platform();
93   int replica_count = options.number_of_replicas();
94   ClientLibrary& client_library = Singleton();
95   tensorflow::mutex_lock lock(client_library.service_mutex_);
96 
97   if (platform == nullptr) {
98     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
99   }
100 
101   auto it = client_library.local_instances_.find(platform->id());
102   if (it != client_library.local_instances_.end()) {
103     return it->second->client.get();
104   }
105 
106   ServiceOptions service_options;
107   service_options.set_platform(platform);
108   service_options.set_number_of_replicas(replica_count);
109   service_options.set_intra_op_parallelism_threads(
110       options.intra_op_parallelism_threads());
111   service_options.set_allowed_devices(options.allowed_devices());
112   auto instance = absl::make_unique<LocalInstance>();
113   TF_ASSIGN_OR_RETURN(instance->service,
114                       LocalService::NewService(service_options));
115   instance->client = absl::make_unique<LocalClient>(instance->service.get());
116   LocalClient* cl = instance->client.get();
117 
118   client_library.local_instances_.insert(
119       std::make_pair(platform->id(), std::move(instance)));
120   return cl;
121 }
122 
LocalClientOrDie()123 /* static */ LocalClient* ClientLibrary::LocalClientOrDie() {
124   auto client_status = GetOrCreateLocalClient();
125   TF_CHECK_OK(client_status.status());
126   return client_status.ValueOrDie();
127 }
128 
GetXlaService(se::Platform * platform)129 /* static */ LocalService* ClientLibrary::GetXlaService(
130     se::Platform* platform) {
131   ClientLibrary& client_library = Singleton();
132   tensorflow::mutex_lock lock(client_library.service_mutex_);
133   auto it = client_library.local_instances_.find(platform->id());
134   CHECK(it != client_library.local_instances_.end());
135   return it->second->service.get();
136 }
137 
138 /* static */ StatusOr<CompileOnlyClient*>
GetOrCreateCompileOnlyClient(se::Platform * platform)139 ClientLibrary::GetOrCreateCompileOnlyClient(se::Platform* platform) {
140   ClientLibrary& client_library = Singleton();
141   tensorflow::mutex_lock lock(client_library.service_mutex_);
142 
143   if (platform == nullptr) {
144     TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
145   }
146 
147   auto it = client_library.compile_only_instances_.find(platform->id());
148   if (it != client_library.compile_only_instances_.end()) {
149     return it->second->client.get();
150   }
151 
152   auto instance = absl::make_unique<CompileOnlyInstance>();
153   TF_ASSIGN_OR_RETURN(instance->service,
154                       CompileOnlyService::NewService(platform));
155   instance->client =
156       absl::make_unique<CompileOnlyClient>(instance->service.get());
157   CompileOnlyClient* cl = instance->client.get();
158 
159   client_library.compile_only_instances_.insert(
160       std::make_pair(platform->id(), std::move(instance)));
161   return cl;
162 }
163 
DestroyLocalInstances()164 /* static */ void ClientLibrary::DestroyLocalInstances() {
165   ClientLibrary& client_library = Singleton();
166   tensorflow::mutex_lock lock(client_library.service_mutex_);
167 
168   client_library.local_instances_.clear();
169   client_library.compile_only_instances_.clear();
170 }
171 
172 }  // namespace xla
173