1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/device_util.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "tensorflow/compiler/xla/status_macros.h"
21 
22 namespace tensorflow {
23 namespace jit {
24 using xla::StatusOr;
25 
Insert(DeviceId device_id)26 void DeviceSet::Insert(DeviceId device_id) {
27   int word_index = device_id.id() / kWordSize;
28   int bit_index = device_id.id() % kWordSize;
29   const int storage_size = storage_.size();
30   if (word_index >= storage_size) {
31     storage_.resize(word_index + 1, 0);
32   }
33 
34   storage_[word_index] |= (1ull << bit_index);
35 }
36 
UnionWith(const DeviceSet & other)37 void DeviceSet::UnionWith(const DeviceSet& other) {
38   if (other.storage_.size() > storage_.size()) {
39     storage_.resize(other.storage_.size(), 0);
40   }
41 
42   for (int i = 0, end = other.storage_.size(); i < end; i++) {
43     storage_[i] |= other.storage_[i];
44   }
45 }
46 
IsEmpty() const47 bool DeviceSet::IsEmpty() const {
48   return absl::c_all_of(storage_, [&](uint64 val) { return val == 0; });
49 }
50 
GetIdFor(absl::string_view name)51 xla::StatusOr<DeviceId> DeviceInfoCache::GetIdFor(absl::string_view name) {
52   TF_RET_CHECK(!name.empty());
53 
54   auto it = name_to_id_.find(name);
55   if (it != name_to_id_.end()) {
56     return it->second;
57   }
58 
59   int new_id = names_.size();
60   names_.push_back(string(name));
61   id_to_device_type_.push_back(absl::make_unique<DeviceType>(""));
62   DeviceType* device_type = id_to_device_type_.back().get();
63   TF_RETURN_IF_ERROR(DeviceNameToDeviceType(names_.back(), device_type));
64 
65   is_cpu_.push_back(device_type->type_string() == DEVICE_CPU);
66   is_gpu_.push_back(device_type->type_string() == DEVICE_GPU);
67 
68   name_to_id_.emplace(string(name), DeviceId(new_id));
69 
70   const XlaOpRegistry::DeviceRegistration* compilation_device;
71   if (!XlaOpRegistry::GetCompilationDevice(device_type->type(),
72                                            &compilation_device)) {
73     compilation_device = nullptr;
74   }
75   id_to_compilation_device_.push_back(compilation_device);
76 
77   return DeviceId(new_id);
78 }
79 
DebugString(const DeviceSet & device_set) const80 string DeviceInfoCache::DebugString(const DeviceSet& device_set) const {
81   std::vector<string> names;
82   device_set.ForEach([&](DeviceId device_id) {
83     names.push_back(string(GetNameFor(device_id)));
84     return true;
85   });
86 
87   return absl::StrCat("[", absl::StrJoin(names, ","), "]");
88 }
89 }  // namespace jit
90 
DeviceNameToDeviceType(const string & device,DeviceType * device_type)91 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type) {
92   DeviceNameUtils::ParsedName parsed;
93   if (!DeviceNameUtils::ParseFullName(device, &parsed)) {
94     return errors::Internal("Malformed assigned device '", device, "'");
95   }
96   *device_type = DeviceType(parsed.type);
97   return Status::OK();
98 }
99 
PickDeviceForXlaImpl(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu,bool failure_to_pick_is_error)100 xla::StatusOr<absl::optional<jit::DeviceId>> PickDeviceForXlaImpl(
101     const jit::DeviceInfoCache& device_info_cache,
102     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu,
103     bool failure_to_pick_is_error) {
104 #define FAILED_TO_PICK_DEVICE(failing_status) \
105   do {                                        \
106     if (failure_to_pick_is_error) {           \
107       return failing_status;                  \
108     } else {                                  \
109       return {absl::nullopt};                 \
110     }                                         \
111   } while (false)
112 
113   absl::optional<jit::DeviceId> maybe_gpu_device;
114   absl::optional<jit::DeviceId> maybe_cpu_device;
115   absl::optional<jit::DeviceId> maybe_unknown_device;
116 
117   bool multiple_cpu_devices = false;
118   bool multiple_gpu_devices = false;
119   bool multiple_unknown_devices = false;
120 
121   // Returns 'true' if d0 and d1 are conflicting devices. If they are
122   // compatible, update d1 with a more specific one.
123   // TODO(sanjoy): Cache DeviceNameUtils::ParsedName inside device_info_cache.
124   const auto is_multiple_devices =
125       [&](const jit::DeviceId& d0, absl::optional<jit::DeviceId>* d1) -> bool {
126     const absl::string_view name0 = device_info_cache.GetNameFor(d0);
127     const absl::string_view name1 = device_info_cache.GetNameFor(d1->value());
128 
129     DeviceNameUtils::ParsedName parsed0, parsed1;
130     if (!DeviceNameUtils::ParseFullName(name0, &parsed0) ||
131         !DeviceNameUtils::ParseFullName(name1, &parsed1) ||
132         !DeviceNameUtils::AreCompatibleDevNames(parsed0, parsed1)) {
133       return true;
134     }
135 
136     if (DeviceNameUtils::IsSpecification(parsed0, parsed1)) {
137       return false;
138     }
139 
140     if (DeviceNameUtils::IsSpecification(parsed1, parsed0)) {
141       *d1 = d0;
142       return false;
143     }
144 
145     return true;
146   };
147 
148   devices.ForEach([&](jit::DeviceId device) {
149     if (device_info_cache.IsGpu(device)) {
150       if (maybe_gpu_device) {
151         multiple_gpu_devices = is_multiple_devices(device, &maybe_gpu_device);
152         if (multiple_gpu_devices) return false;
153       } else {
154         maybe_gpu_device = device;
155       }
156     } else if (device_info_cache.IsCpu(device)) {
157       if (maybe_cpu_device) {
158         multiple_cpu_devices = is_multiple_devices(device, &maybe_cpu_device);
159         if (multiple_cpu_devices) return false;
160       } else {
161         maybe_cpu_device = device;
162       }
163     } else {
164       if (maybe_unknown_device) {
165         multiple_unknown_devices = true;
166         return false;
167       }
168       maybe_unknown_device = device;
169     }
170 
171     return true;
172   });
173 
174   if (multiple_cpu_devices) {
175     FAILED_TO_PICK_DEVICE(errors::Internal(
176         "Multiple CPU devices ", device_info_cache.DebugString(devices)));
177   }
178 
179   if (multiple_gpu_devices) {
180     FAILED_TO_PICK_DEVICE(errors::Internal(
181         "Multiple GPU devices ", device_info_cache.DebugString(devices)));
182   }
183 
184   if (multiple_unknown_devices) {
185     FAILED_TO_PICK_DEVICE(errors::Internal(
186         "Multiple unknown devices ", device_info_cache.DebugString(devices)));
187   }
188 
189   if (maybe_unknown_device && maybe_gpu_device) {
190     FAILED_TO_PICK_DEVICE(errors::Internal(
191         "Found both unknown and GPU devices: ",
192         device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
193         device_info_cache.GetNameFor(*maybe_gpu_device)));
194   }
195 
196   if (!allow_mixing_unknown_and_cpu) {
197     if (maybe_unknown_device && maybe_cpu_device) {
198       FAILED_TO_PICK_DEVICE(errors::Internal(
199           "Found both unknown and CPU devices: ",
200           device_info_cache.GetNameFor(*maybe_unknown_device), ", ",
201           device_info_cache.GetNameFor(*maybe_cpu_device)));
202     }
203   }
204 
205   if (maybe_gpu_device) {
206     return {*maybe_gpu_device};
207   } else if (maybe_unknown_device) {
208     return {*maybe_unknown_device};
209   } else if (maybe_cpu_device) {
210     return {*maybe_cpu_device};
211   }
212 
213   FAILED_TO_PICK_DEVICE(errors::Internal("Empty device set!"));
214 
215 #undef FAILED_TO_PICK_DEVICE
216 }
217 
PickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)218 xla::StatusOr<jit::DeviceId> PickDeviceForXla(
219     const jit::DeviceInfoCache& device_info_cache,
220     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
221   TF_ASSIGN_OR_RETURN(absl::optional<jit::DeviceId> device_id,
222                       PickDeviceForXlaImpl(device_info_cache, devices,
223                                            allow_mixing_unknown_and_cpu,
224                                            /*failure_to_pick_is_error=*/true));
225   return *device_id;
226 }
227 
MaybePickDeviceForXla(const jit::DeviceInfoCache & device_info_cache,const jit::DeviceSet & devices,bool allow_mixing_unknown_and_cpu)228 xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
229     const jit::DeviceInfoCache& device_info_cache,
230     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu) {
231   return PickDeviceForXlaImpl(device_info_cache, devices,
232                               allow_mixing_unknown_and_cpu,
233                               /*failure_to_pick_is_error=*/false);
234 }
235 }  // namespace tensorflow
236