1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/common_runtime/device_set.h"
17 
18 #include <set>
19 #include <utility>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/lib/core/stringpiece.h"
25 #include "tensorflow/core/lib/gtl/map_util.h"
26 
27 namespace tensorflow {
28 
DeviceSet()29 DeviceSet::DeviceSet() {}
30 
~DeviceSet()31 DeviceSet::~DeviceSet() {}
32 
AddDevice(Device * device)33 void DeviceSet::AddDevice(Device* device) {
34   mutex_lock l(devices_mu_);
35   devices_.push_back(device);
36   prioritized_devices_.clear();
37   prioritized_device_types_.clear();
38   for (const string& name :
39        DeviceNameUtils::GetNamesForDeviceMappings(device->parsed_name())) {
40     device_by_name_.insert({name, device});
41   }
42 }
43 
FindMatchingDevices(const DeviceNameUtils::ParsedName & spec,std::vector<Device * > * devices) const44 void DeviceSet::FindMatchingDevices(const DeviceNameUtils::ParsedName& spec,
45                                     std::vector<Device*>* devices) const {
46   // TODO(jeff): If we are going to repeatedly lookup the set of devices
47   // for the same spec, maybe we should have a cache of some sort
48   devices->clear();
49   for (Device* d : devices_) {
50     if (DeviceNameUtils::IsCompleteSpecification(spec, d->parsed_name())) {
51       devices->push_back(d);
52     }
53   }
54 }
55 
FindDeviceByName(const string & name) const56 Device* DeviceSet::FindDeviceByName(const string& name) const {
57   return gtl::FindPtrOrNull(device_by_name_, name);
58 }
59 
60 // static
DeviceTypeOrder(const DeviceType & d)61 int DeviceSet::DeviceTypeOrder(const DeviceType& d) {
62   return DeviceFactory::DevicePriority(d.type_string());
63 }
64 
DeviceTypeComparator(const DeviceType & a,const DeviceType & b)65 static bool DeviceTypeComparator(const DeviceType& a, const DeviceType& b) {
66   // First sort by prioritized device type (higher is preferred) and
67   // then by device name (lexicographically).
68   auto a_priority = DeviceSet::DeviceTypeOrder(a);
69   auto b_priority = DeviceSet::DeviceTypeOrder(b);
70   if (a_priority != b_priority) {
71     return a_priority > b_priority;
72   }
73 
74   return StringPiece(a.type()) < StringPiece(b.type());
75 }
76 
PrioritizedDeviceTypeList() const77 std::vector<DeviceType> DeviceSet::PrioritizedDeviceTypeList() const {
78   std::vector<DeviceType> result;
79   std::set<string> seen;
80   for (Device* d : devices_) {
81     const auto& t = d->device_type();
82     if (seen.insert(t).second) {
83       result.emplace_back(t);
84     }
85   }
86   std::sort(result.begin(), result.end(), DeviceTypeComparator);
87   return result;
88 }
89 
SortPrioritizedDeviceTypeVector(PrioritizedDeviceTypeVector * vector)90 void DeviceSet::SortPrioritizedDeviceTypeVector(
91     PrioritizedDeviceTypeVector* vector) {
92   if (vector == nullptr) return;
93 
94   auto device_sort = [](const PrioritizedDeviceTypeVector::value_type& a,
95                         const PrioritizedDeviceTypeVector::value_type& b) {
96     // First look at set priorities.
97     if (a.second != b.second) {
98       return a.second > b.second;
99     }
100     // Then fallback to default priorities.
101     return DeviceTypeComparator(a.first, b.first);
102   };
103 
104   std::sort(vector->begin(), vector->end(), device_sort);
105 }
106 
SortPrioritizedDeviceVector(PrioritizedDeviceVector * vector)107 void DeviceSet::SortPrioritizedDeviceVector(PrioritizedDeviceVector* vector) {
108   auto device_sort = [](const std::pair<Device*, int32>& a,
109                         const std::pair<Device*, int32>& b) {
110     if (a.second != b.second) {
111       return a.second > b.second;
112     }
113 
114     const string& a_type_name = a.first->device_type();
115     const string& b_type_name = b.first->device_type();
116     if (a_type_name != b_type_name) {
117       auto a_priority = DeviceFactory::DevicePriority(a_type_name);
118       auto b_priority = DeviceFactory::DevicePriority(b_type_name);
119       if (a_priority != b_priority) {
120         return a_priority > b_priority;
121       }
122     }
123 
124     if (a.first->IsLocal() != b.first->IsLocal()) {
125       return a.first->IsLocal();
126     }
127 
128     return StringPiece(a.first->name()) < StringPiece(b.first->name());
129   };
130   std::sort(vector->begin(), vector->end(), device_sort);
131 }
132 
133 namespace {
134 
UpdatePrioritizedVectors(const std::vector<Device * > & devices,PrioritizedDeviceVector * prioritized_devices,PrioritizedDeviceTypeVector * prioritized_device_types)135 void UpdatePrioritizedVectors(
136     const std::vector<Device*>& devices,
137     PrioritizedDeviceVector* prioritized_devices,
138     PrioritizedDeviceTypeVector* prioritized_device_types) {
139   if (prioritized_devices->size() != devices.size()) {
140     for (Device* d : devices) {
141       prioritized_devices->emplace_back(
142           d, DeviceSet::DeviceTypeOrder(DeviceType(d->device_type())));
143     }
144     DeviceSet::SortPrioritizedDeviceVector(prioritized_devices);
145   }
146 
147   if (prioritized_device_types != nullptr &&
148       prioritized_device_types->size() != devices.size()) {
149     std::set<DeviceType> seen;
150     for (const std::pair<Device*, int32>& p : *prioritized_devices) {
151       DeviceType t(p.first->device_type());
152       if (seen.insert(t).second) {
153         prioritized_device_types->emplace_back(t, p.second);
154       }
155     }
156   }
157 }
158 
159 }  // namespace
160 
prioritized_devices() const161 const PrioritizedDeviceVector& DeviceSet::prioritized_devices() const {
162   mutex_lock l(devices_mu_);
163   UpdatePrioritizedVectors(devices_, &prioritized_devices_,
164                            /* prioritized_device_types */ nullptr);
165   return prioritized_devices_;
166 }
167 
prioritized_device_types() const168 const PrioritizedDeviceTypeVector& DeviceSet::prioritized_device_types() const {
169   mutex_lock l(devices_mu_);
170   UpdatePrioritizedVectors(devices_, &prioritized_devices_,
171                            &prioritized_device_types_);
172   return prioritized_device_types_;
173 }
174 
175 }  // namespace tensorflow
176