1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/device_util.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21
22 namespace tensorflow {
23 namespace jit {
24 using xla::StatusOr;
25
Insert(DeviceId device_id)26 void DeviceSet::Insert(DeviceId device_id) {
27 int word_index = device_id.id() / kWordSize;
28 int bit_index = device_id.id() % kWordSize;
29 const int storage_size = storage_.size();
30 if (word_index >= storage_size) {
31 storage_.resize(word_index + 1, 0);
32 }
33
34 storage_[word_index] |= (1ull << bit_index);
35 }
36
UnionWith(const DeviceSet & other)37 void DeviceSet::UnionWith(const DeviceSet& other) {
38 if (other.storage_.size() > storage_.size()) {
39 storage_.resize(other.storage_.size(), 0);
40 }
41
42 for (int i = 0, end = other.storage_.size(); i < end; i++) {
43 storage_[i] |= other.storage_[i];
44 }
45 }
46
IsEmpty() const47 bool DeviceSet::IsEmpty() const {
48 return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
49 }
50
GetIdFor(absl::string_view name)51 xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
52 TF_RET_CHECK(!name.empty());
53
54 auto it = name_to_id_.find(name);
55 if (it != name_to_id_.end()) {
56 return it->second;
57 }
58
59 int new_id = names_.size();
60 names_.push_back(string(name));
61 id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
62 DeviceType* device_type = id_to_device_type_.back().get();
63 TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
64
65 is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
66 is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
67
68 name_to_id_.emplace(string(name), DeviceId(new_id));
69
70 const XlaOpRegistry::DeviceRegistration* compilation_device;
71 if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
72 &compilation_device)) {
73 compilation_device = nullptr;
74 }
75 id_to_compilation_device_.push_back(compilation_device);
76
77 return DeviceId(new_id);
78 }
79
DebugString(const DeviceSet & device_set) const80 string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
81 std::vector<string> names;
82 device_set.ForEach([&](DeviceId device_id) {
83 names.push_back(string(GetNameFor(device_id)));
84 return true;
85 });
86
87 return absl::StrCat("[", absl::StrJoin(names, ","), "]");
88 }
89 } // namespace jit
90
DeviceNameToDeviceType(const string & device,DeviceType * device_type)91 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
92 DeviceNameUtils::ParsedName parsed;
93 if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
94 return errors::Internal("Malformed assigned device '", device, "'");
95 }
96 *device_type = DeviceType(parsed.type);
97 return Status::OK();
98 }
99
PickDeviceForXlaImpl(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu,bool failure_to_pick_is_error)100 xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
101 const jit::DeviceInfoCache& device_info_cache,
102 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
103 bool failure_to_pick_is_error) {
104 #define FAILED_TO_PICK_DEVICE(failing_status) \
105 do { \
106 if (failure_to_pick_is_error) { \
107 return failing_status; \
108 } else { \
109 return {absl::nullopt}; \
110 } \
111 } while (false)
112
113 absl::optional<jit::DeviceId> maybe_gpu_device;
114 absl::optional<jit::DeviceId> maybe_cpu_device;
115 absl::optional<jit::DeviceId> maybe_unknown_device;
116
117 bool multiple_cpu_devices = false;
118 bool multiple_gpu_devices = false;
119 bool multiple_unknown_devices = false;
120
121 // Returns 'true' if d0 and d1 are conflicting devices. If they are
122 // compatible, update d1 with a more specific one.
123 // TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
124 const auto is_multiple_devices =
125 [&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
126 const absl::string_view name0 = device_info_cache.GetNameFor(d0);
127 const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
128
129 DeviceNameUtils::ParsedName parsed0, parsed1;
130 if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
131 !DeviceNameUtils::ParseFullName(name1, &parsed1) ||
132 !DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
133 return true;
134 }
135
136 if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
137 return false;
138 }
139
140 if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
141 *d1 = d0;
142 return false;
143 }
144
145 return true;
146 };
147
148 devices.ForEach([&](jit::DeviceId device) {
149 if (device_info_cache.IsGpu(device)) {
150 if (maybe_gpu_device) {
151 multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
152 if (multiple_gpu_devices) return false;
153 } else {
154 maybe_gpu_device = device;
155 }
156 } else if (device_info_cache.IsCpu(device)) {
157 if (maybe_cpu_device) {
158 multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
159 if (multiple_cpu_devices) return false;
160 } else {
161 maybe_cpu_device = device;
162 }
163 } else {
164 if (maybe_unknown_device) {
165 multiple_unknown_devices = true;
166 return false;
167 }
168 maybe_unknown_device = device;
169 }
170
171 return true;
172 });
173
174 if (multiple_cpu_devices) {
175 FAILED_TO_PICK_DEVICE(errors::Internal(
176 "Multiple CPU devices ", device_info_cache.DebugString(devices)));
177 }
178
179 if (multiple_gpu_devices) {
180 FAILED_TO_PICK_DEVICE(errors::Internal(
181 "Multiple GPU devices ", device_info_cache.DebugString(devices)));
182 }
183
184 if (multiple_unknown_devices) {
185 FAILED_TO_PICK_DEVICE(errors::Internal(
186 "Multiple unknown devices ", device_info_cache.DebugString(devices)));
187 }
188
189 if (maybe_unknown_device && maybe_gpu_device) {
190 FAILED_TO_PICK_DEVICE(errors::Internal(
191 "Found both unknown and GPU devices: ",
192 device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
193 device_info_cache.GetNameFor(*maybe_gpu_device)));
194 }
195
196 if (!allow_mixing_unknown_and_cpu) {
197 if (maybe_unknown_device && maybe_cpu_device) {
198 FAILED_TO_PICK_DEVICE(errors::Internal(
199 "Found both unknown and CPU devices: ",
200 device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
201 device_info_cache.GetNameFor(*maybe_cpu_device)));
202 }
203 }
204
205 if (maybe_gpu_device) {
206 return {*maybe_gpu_device};
207 } else if (maybe_unknown_device) {
208 return {*maybe_unknown_device};
209 } else if (maybe_cpu_device) {
210 return {*maybe_cpu_device};
211 }
212
213 FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
214
215 #undef FAILED_TO_PICK_DEVICE
216 }
217
PickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)218 xla::StatusOr<jit::DeviceId> PickDeviceForXla(
219 const jit::DeviceInfoCache& device_info_cache,
220 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
221 TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
222 PickDeviceForXlaImpl(device_info_cache, devices,
223 allow_mixing_unknown_and_cpu,
224 /*failure_to_pick_is_error=*/true));
225 return *device_id;
226 }
227
MaybePickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)228 xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
229 const jit::DeviceInfoCache& device_info_cache,
230 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
231 return PickDeviceForXlaImpl(device_info_cache, devices,
232 allow_mixing_unknown_and_cpu,
233 /*failure_to_pick_is_error=*/false);
234 }
235 } // namespace tensorflow
236