1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ 17 #define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ 18 19 #include <functional> 20 #include <memory> 21 22 #include "absl/container/flat_hash_map.h" 23 #include "absl/strings/string_view.h" 24 #include "absl/types/span.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/status_macros.h" 27 #include "tensorflow/compiler/xla/statusor.h" 28 #include "tensorflow/core/framework/types.h" 29 30 namespace tensorflow { 31 namespace jit { 32 class DeviceInfoCache; 33 class DeviceSet; 34 35 // Instances of DeviceId represent TensorFlow devices as integers. 36 // 37 // This helps avoid having to manipulate device names as strings when 38 // auto-clustering. 39 class DeviceId { 40 public: 41 DeviceId(DeviceId&&) = default; 42 DeviceId(const DeviceId&) = default; 43 DeviceId& operator=(const DeviceId&) = default; 44 45 bool operator==(const DeviceId& other) const { return id() == other.id(); } 46 bool operator!=(const DeviceId& other) const { return !(*this == other); } 47 48 private: 49 int id_; 50 DeviceId(int id)51 explicit DeviceId(int id) : id_(id) {} 52 id()53 int id() const { return id_; } 54 55 friend class DeviceInfoCache; 56 friend class DeviceSet; 57 }; 58 59 // A set of DeviceIds, represented as a bitmap. 60 class DeviceSet { 61 public: 62 void Insert(DeviceId device_id); 63 void UnionWith(const DeviceSet& other); 64 bool IsEmpty() const; 65 66 // Calls `func` on each DeviceId in the set. Stops iterating early if `func` 67 // return false. 68 // 69 // TODO(sanjoy): Change this to take a typed std::function if that's 70 // performance neutral. 71 template <typename FnTy> ForEach(FnTy func)72 void ForEach(FnTy func) const { 73 // This is really a poor man's iterator, we should consider writing a proper 74 // iterator if this ends up being used widely. 75 for (int word_index = 0, end = storage_.size(); word_index < end; 76 word_index++) { 77 uint64 word = storage_[word_index]; 78 while (word != 0) { 79 uint64 only_lowest_bit_set = word & -word; 80 // The number of trailing zeros in a non-zero word is the index of the 81 // least significant 1. 82 int bit_index = ctz_uint64(word); 83 if (!func(DeviceId(word_index * kWordSize + bit_index))) { 84 return; 85 } 86 word ^= only_lowest_bit_set; 87 } 88 } 89 } 90 91 private: ctz_uint64(uint64 x)92 static int ctz_uint64(uint64 x) { 93 DCHECK_NE(x, 0); 94 #ifdef __GNUC__ 95 return __builtin_ctzl(x); 96 #else 97 int result = 0u; 98 while ((x & 1u) == 0u) { 99 x >>= 1; 100 ++result; 101 } 102 return result; 103 #endif 104 } 105 106 absl::InlinedVector<uint64, 1> storage_; 107 108 const int kWordSize = 64; 109 }; 110 111 // Caches some miscellaneous information about TF devices. Thread compatible. 112 class DeviceInfoCache { 113 public: IsGpu(DeviceId device)114 bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; } IsCpu(DeviceId device)115 bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; } 116 GetNameFor(DeviceId device)117 absl::string_view GetNameFor(DeviceId device) const { 118 return names_[device.id()]; 119 } 120 121 xla::StatusOr<DeviceId> GetIdFor(absl::string_view name); 122 123 using DeviceRegistration = const XlaOpRegistry::DeviceRegistration; 124 GetCompilationDevice(DeviceId device)125 DeviceRegistration* GetCompilationDevice(DeviceId device) const { 126 return id_to_compilation_device_[device.id()]; 127 } 128 GetCompilationDevice(absl::string_view name)129 xla::StatusOr<DeviceRegistration*> GetCompilationDevice( 130 absl::string_view name) { 131 TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name)); 132 return GetCompilationDevice(device_id); 133 } 134 GetDeviceTypeFor(DeviceId device)135 const DeviceType& GetDeviceTypeFor(DeviceId device) const { 136 return *id_to_device_type_[device.id()]; 137 } 138 139 using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>; 140 GetDeviceTypeFor(absl::string_view device_name)141 xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor( 142 absl::string_view device_name) { 143 TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name)); 144 return std::cref(*id_to_device_type_[device_id.id()]); 145 } 146 147 string DebugString(const DeviceSet& device_set) const; 148 149 private: 150 absl::flat_hash_map<string, DeviceId> name_to_id_; 151 152 // These fields are populated for a device in GetIdFor, *before* we give out a 153 // DeviceId. 154 std::vector<const XlaOpRegistry::DeviceRegistration*> 155 id_to_compilation_device_; 156 std::vector<std::unique_ptr<DeviceType>> id_to_device_type_; 157 std::vector<string> names_; 158 std::vector<bool> is_cpu_; 159 std::vector<bool> is_gpu_; 160 }; 161 162 } // namespace jit 163 164 // Returns the DeviceType corresponding to 'device'. 165 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type); 166 167 // Picks the device for which XLA should compile a cluster that contains 168 // operations placed in devices in `devices`. For instance a cluster that 169 // contains operations solely placed on the CPU will be compiled into a CPU 170 // executable by XLA, whereas a cluster that contains operations placed on the 171 // CPU and also operations placed on the GPU will be compiled into a GPU 172 // executable. 173 // 174 // Returns a non-OK Status if no unambiguous choice of device exists. 175 // 176 // We choose the device using the following rules: 177 // 178 // - It is an error for `device_names` to contain more than one device of the 179 // same type. 180 // - GPU is preferred over CPU. 181 // - If `allow_mixing_unknown_and_cpu` is true then unknown devices are 182 // preferred over CPU. 183 // - XLA devices count as "unrecognized devices". 184 // 185 // This set of rules above implicitly assume that XLA:GPU can compile all 186 // operations in the cluster that XLA:CPU can compile, and if 187 // `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile 188 // all operations in the cluster that XLA:CPU can compile. 189 // 190 // We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of 191 // the following things: 192 // 193 // - Let MarkForCompilationPass not inject CPU-placed operations into clusters 194 // that will run on unknown devices (because the unknown XLA backend may not 195 // support every operation supported by CPU). 196 // - Let BuildXlaOpsPass successfully infer a compilation device for a cluster 197 // that contains nodes placed on both the CPU and on unknown devices. In this 198 // case it is the responsibility of the optimization pass that injected the 199 // CPU nodes into the cluster to ensure that these nodes can be compiled by 200 // the unknown XLA backend. 201 xla::StatusOr<jit::DeviceId> PickDeviceForXla( 202 const jit::DeviceInfoCache& device_info_cache, 203 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); 204 205 // This is like `PickDeviceForXla` except that it returns nullopt (instead of a 206 // non-OK Status) if no unambiguous choice of device exists. 207 // 208 // We return a failing Status for errors unrelated to the device choice 209 // algorithm itself. 210 xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla( 211 const jit::DeviceInfoCache& device_info_cache, 212 const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu); 213 } // namespace tensorflow 214 215 #endif // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_ 216