1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
17 #define TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
18 
19 #include <functional>
20 #include <memory>
21 
22 #include "absl/container/flat_hash_map.h"
23 #include "absl/strings/string_view.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
26 #include "tensorflow/compiler/xla/status_macros.h"
27 #include "tensorflow/compiler/xla/statusor.h"
28 #include "tensorflow/core/framework/types.h"
29 
30 namespace tensorflow {
31 namespace jit {
32 class DeviceInfoCache;
33 class DeviceSet;
34 
35 // Instances of DeviceId represent TensorFlow devices as integers.
36 //
37 // This helps avoid having to manipulate device names as strings when
38 // auto-clustering.
39 class DeviceId {
40  public:
41   DeviceId(DeviceId&&) = default;
42   DeviceId(const DeviceId&) = default;
43   DeviceId& operator=(const DeviceId&) = default;
44 
45   bool operator==(const DeviceId& other) const { return id() == other.id(); }
46   bool operator!=(const DeviceId& other) const { return !(*this == other); }
47 
48  private:
49   int id_;
50 
DeviceId(int id)51   explicit DeviceId(int id) : id_(id) {}
52 
id()53   int id() const { return id_; }
54 
55   friend class DeviceInfoCache;
56   friend class DeviceSet;
57 };
58 
59 // A set of DeviceIds, represented as a bitmap.
60 class DeviceSet {
61  public:
62   void Insert(DeviceId device_id);
63   void UnionWith(const DeviceSet& other);
64   bool IsEmpty() const;
65 
66   // Calls `func` on each DeviceId in the set.  Stops iterating early if `func`
67   // return false.
68   //
69   // TODO(sanjoy): Change this to take a typed std::function if that's
70   // performance neutral.
71   template <typename FnTy>
ForEach(FnTy func)72   void ForEach(FnTy func) const {
73     // This is really a poor man's iterator, we should consider writing a proper
74     // iterator if this ends up being used widely.
75     for (int word_index = 0, end = storage_.size(); word_index < end;
76          word_index++) {
77       uint64 word = storage_[word_index];
78       while (word != 0) {
79         uint64 only_lowest_bit_set = word & -word;
80         // The number of trailing zeros in a non-zero word is the index of the
81         // least significant 1.
82         int bit_index = ctz_uint64(word);
83         if (!func(DeviceId(word_index * kWordSize + bit_index))) {
84           return;
85         }
86         word ^= only_lowest_bit_set;
87       }
88     }
89   }
90 
91  private:
ctz_uint64(uint64 x)92   static int ctz_uint64(uint64 x) {
93     DCHECK_NE(x, 0);
94 #ifdef __GNUC__
95     return __builtin_ctzl(x);
96 #else
97     int result = 0u;
98     while ((x & 1u) == 0u) {
99       x >>= 1;
100       ++result;
101     }
102     return result;
103 #endif
104   }
105 
106   absl::InlinedVector<uint64, 1> storage_;
107 
108   const int kWordSize = 64;
109 };
110 
111 // Caches some miscellaneous information about TF devices.  Thread compatible.
112 class DeviceInfoCache {
113  public:
IsGpu(DeviceId device)114   bool IsGpu(DeviceId device) const { return is_gpu_[device.id()]; }
IsCpu(DeviceId device)115   bool IsCpu(DeviceId device) const { return is_cpu_[device.id()]; }
116 
GetNameFor(DeviceId device)117   absl::string_view GetNameFor(DeviceId device) const {
118     return names_[device.id()];
119   }
120 
121   xla::StatusOr<DeviceId> GetIdFor(absl::string_view name);
122 
123   using DeviceRegistration = const XlaOpRegistry::DeviceRegistration;
124 
GetCompilationDevice(DeviceId device)125   DeviceRegistration* GetCompilationDevice(DeviceId device) const {
126     return id_to_compilation_device_[device.id()];
127   }
128 
GetCompilationDevice(absl::string_view name)129   xla::StatusOr<DeviceRegistration*> GetCompilationDevice(
130       absl::string_view name) {
131     TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(name));
132     return GetCompilationDevice(device_id);
133   }
134 
GetDeviceTypeFor(DeviceId device)135   const DeviceType& GetDeviceTypeFor(DeviceId device) const {
136     return *id_to_device_type_[device.id()];
137   }
138 
139   using DeviceTypeConstRef = std::reference_wrapper<const DeviceType>;
140 
GetDeviceTypeFor(absl::string_view device_name)141   xla::StatusOr<DeviceTypeConstRef> GetDeviceTypeFor(
142       absl::string_view device_name) {
143     TF_ASSIGN_OR_RETURN(DeviceId device_id, GetIdFor(device_name));
144     return std::cref(*id_to_device_type_[device_id.id()]);
145   }
146 
147   string DebugString(const DeviceSet& device_set) const;
148 
149  private:
150   absl::flat_hash_map<string, DeviceId> name_to_id_;
151 
152   // These fields are populated for a device in GetIdFor, *before* we give out a
153   // DeviceId.
154   std::vector<const XlaOpRegistry::DeviceRegistration*>
155       id_to_compilation_device_;
156   std::vector<std::unique_ptr<DeviceType>> id_to_device_type_;
157   std::vector<string> names_;
158   std::vector<bool> is_cpu_;
159   std::vector<bool> is_gpu_;
160 };
161 
162 }  // namespace jit
163 
164 // Returns the DeviceType corresponding to 'device'.
165 Status DeviceNameToDeviceType(const string& device, DeviceType* device_type);
166 
167 // Picks the device for which XLA should compile a cluster that contains
168 // operations placed in devices in `devices`.  For instance a cluster that
169 // contains operations solely placed on the CPU will be compiled into a CPU
170 // executable by XLA, whereas a cluster that contains operations placed on the
171 // CPU and also operations placed on the GPU will be compiled into a GPU
172 // executable.
173 //
174 // Returns a non-OK Status if no unambiguous choice of device exists.
175 //
176 // We choose the device using the following rules:
177 //
178 //  - It is an error for `device_names` to contain more than one device of the
179 //    same type.
180 //  - GPU is preferred over CPU.
181 //  - If `allow_mixing_unknown_and_cpu` is true then unknown devices are
182 //    preferred over CPU.
183 //  - XLA devices count as "unrecognized devices".
184 //
185 // This set of rules above implicitly assume that XLA:GPU can compile all
186 // operations in the cluster that XLA:CPU can compile, and if
187 // `allow_mixing_unknown_and_cpu` then the unrecognized device can also compile
188 // all operations in the cluster that XLA:CPU can compile.
189 //
190 // We provide the `allow_mixing_unknown_and_cpu` knob so that we can do both of
191 // the following things:
192 //
193 // - Let MarkForCompilationPass not inject CPU-placed operations into clusters
194 //   that will run on unknown devices (because the unknown XLA backend may not
195 //   support every operation supported by CPU).
196 // - Let BuildXlaOpsPass successfully infer a compilation device for a cluster
197 //   that contains nodes placed on both the CPU and on unknown devices.  In this
198 //   case it is the responsibility of the optimization pass that injected the
199 //   CPU nodes into the cluster to ensure that these nodes can be compiled by
200 //   the unknown XLA backend.
201 xla::StatusOr<jit::DeviceId> PickDeviceForXla(
202     const jit::DeviceInfoCache& device_info_cache,
203     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
204 
205 // This is like `PickDeviceForXla` except that it returns nullopt (instead of a
206 // non-OK Status) if no unambiguous choice of device exists.
207 //
208 // We return a failing Status for errors unrelated to the device choice
209 // algorithm itself.
210 xla::StatusOr<absl::optional<jit::DeviceId>> MaybePickDeviceForXla(
211     const jit::DeviceInfoCache& device_info_cache,
212     const jit::DeviceSet& devices, bool allow_mixing_unknown_and_cpu);
213 }  // namespace tensorflow
214 
215 #endif  // TENSORFLOW_COMPILER_JIT_DEVICE_INFO_CACHE_H_
216