1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17 // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
18 
19 #include <set>
20 
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
26 #include "tensorflow/compiler/jit/xla_device.h"
27 #include "tensorflow/compiler/jit/xla_device_ops.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/core/common_runtime/device_factory.h"
30 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
31 #include "tensorflow/core/lib/core/status.h"
32 
33 namespace tensorflow {
34 
35 // Returns a set containing the device ids contained in visible_device_list or
36 // nullopt if it is empty. It returns error in case of malformed configuration
37 // string.
ParseVisibleDeviceList(const string & visible_device_list)38 static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
39     const string& visible_device_list) {
40   std::set<int> gpu_ids;
41   if (visible_device_list.empty()) {
42     return {{absl::nullopt}};
43   }
44   const std::vector<string> visible_devices =
45       absl::StrSplit(visible_device_list, ',');
46   for (const string& platform_gpu_id_str : visible_devices) {
47     int32 platform_gpu_id;
48     if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) {
49       return errors::InvalidArgument(
50           "Could not parse entry in 'visible_device_list': '",
51           platform_gpu_id_str,
52           "'. visible_device_list = ", visible_device_list);
53     }
54     gpu_ids.insert(platform_gpu_id);
55   }
56   return {{gpu_ids}};
57 }
58 
59 class XlaGpuDeviceFactory : public DeviceFactory {
60  public:
61   Status ListPhysicalDevices(std::vector<string>* devices) override;
62   Status CreateDevices(const SessionOptions& options, const string& name_prefix,
63                        std::vector<std::unique_ptr<Device>>* devices) override;
64 };
65 
ListPhysicalDevices(std::vector<string> * devices)66 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
67   XlaDeviceFlags* flags = GetXlaDeviceFlags();
68   if (!flags->tf_xla_enable_xla_devices) {
69     VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
70     return Status::OK();
71   }
72 
73   auto platform =
74       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
75   if (!platform.ok()) {
76     // Treat failures as non-fatal; there might not be a GPU in the machine.
77     VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
78     return Status::OK();
79   }
80 
81   int device_count = platform.ValueOrDie()->VisibleDeviceCount();
82   if (device_count <= 0) {
83     return Status::OK();
84   }
85 
86   for (int i = 0; i < device_count; ++i) {
87     devices->push_back(
88         absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
89   }
90 
91   return Status::OK();
92 }
93 
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)94 Status XlaGpuDeviceFactory::CreateDevices(
95     const SessionOptions& session_options, const string& name_prefix,
96     std::vector<std::unique_ptr<Device>>* devices) {
97   XlaDeviceFlags* flags = GetXlaDeviceFlags();
98   if (!flags->tf_xla_enable_xla_devices) {
99     VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
100     return Status::OK();
101   }
102 
103   XlaOpRegistry::DeviceRegistration registration;
104   registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
105   registration.autoclustering_policy =
106       XlaOpRegistry::AutoclusteringPolicy::kAlways;
107   registration.cluster_resource_variable_ops_unsafely = true;
108   registration.cluster_stack_ops = false;
109   registration.cluster_tensor_array_ops = true;
110   registration.cluster_stateful_rng_ops = true;
111   registration.cluster_control_trigger = true;
112   registration.elide_assert_and_checknumerics = true;
113   registration.cluster_variant_ops = true;
114   registration.cluster_slow_ops = true;
115   registration.cluster_inaccurate_ops = true;
116   XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
117 
118   static XlaDeviceOpRegistrations* registrations =
119       RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
120   (void)registrations;
121 
122   auto platform =
123       se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
124   if (!platform.ok()) {
125     // Treat failures as non-fatal; there might not be a GPU in the machine.
126     VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
127     return Status::OK();
128   }
129 
130   auto iter = session_options.config.device_count().find("GPU");
131   if (iter != session_options.config.device_count().end() &&
132       iter->second == 0) {
133     // Device count for GPU is 0.
134     return Status::OK();
135   }
136 
137   string allowed_gpus =
138       session_options.config.gpu_options().visible_device_list();
139   absl::optional<std::set<int>> gpu_ids =
140       ParseVisibleDeviceList(allowed_gpus).ValueOrDie();
141   if (!gpu_ids) {
142     gpu_ids.emplace();
143     // Fill the gpu_ids set with all devices if config string is empty.
144     for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
145       gpu_ids->insert(i);
146     }
147   }
148   for (int i : *gpu_ids) {
149     XlaDevice::Options options;
150     options.platform = platform.ValueOrDie();
151     options.device_name_prefix = name_prefix;
152     options.device_name = DEVICE_XLA_GPU;
153     options.device_ordinal = i;
154     options.compilation_device_name = DEVICE_GPU_XLA_JIT;
155     options.use_multiple_streams = true;
156     options.allowed_devices = gpu_ids;
157     auto device = absl::make_unique<XlaDevice>(session_options, options);
158 
159     Status status = device->UseGpuDeviceInfo();
160     if (!status.ok()) {
161       LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT
162                 << " device. Device number is " << i << ", reason: " << status;
163       continue;
164     }
165 
166     devices->push_back(std::move(device));
167   }
168   return Status::OK();
169 }
170 
171 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
172 
173 // Kernel registrations
174 
175 constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
176     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
177      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
178      DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
179 
180 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
181 REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
182 REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
183 
184 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
185 
186 }  // namespace tensorflow
187