1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // Registers the XLA_GPU device, which is an XlaDevice instantiation that runs
17 // operators using XLA via the XLA "CUDA" or "ROCM" (GPU) backend.
18
19 #include <set>
20
21 #include "absl/memory/memory.h"
22 #include "absl/strings/numbers.h"
23 #include "absl/strings/str_split.h"
24 #include "tensorflow/compiler/jit/flags.h"
25 #include "tensorflow/compiler/jit/kernels/xla_ops.h"
26 #include "tensorflow/compiler/jit/xla_device.h"
27 #include "tensorflow/compiler/jit/xla_device_ops.h"
28 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
29 #include "tensorflow/core/common_runtime/device_factory.h"
30 #include "tensorflow/core/common_runtime/gpu/gpu_init.h"
31 #include "tensorflow/core/lib/core/status.h"
32
33 namespace tensorflow {
34
35 // Returns a set containing the device ids contained in visible_device_list or
36 // nullopt if it is empty. It returns error in case of malformed configuration
37 // string.
ParseVisibleDeviceList(const string & visible_device_list)38 static xla::StatusOr<absl::optional<std::set<int>>> ParseVisibleDeviceList(
39 const string& visible_device_list) {
40 std::set<int> gpu_ids;
41 if (visible_device_list.empty()) {
42 return {{absl::nullopt}};
43 }
44 const std::vector<string> visible_devices =
45 absl::StrSplit(visible_device_list, ',');
46 for (const string& platform_gpu_id_str : visible_devices) {
47 int32 platform_gpu_id;
48 if (!absl::SimpleAtoi(platform_gpu_id_str, &platform_gpu_id)) {
49 return errors::InvalidArgument(
50 "Could not parse entry in 'visible_device_list': '",
51 platform_gpu_id_str,
52 "'. visible_device_list = ", visible_device_list);
53 }
54 gpu_ids.insert(platform_gpu_id);
55 }
56 return {{gpu_ids}};
57 }
58
59 class XlaGpuDeviceFactory : public DeviceFactory {
60 public:
61 Status ListPhysicalDevices(std::vector<string>* devices) override;
62 Status CreateDevices(const SessionOptions& options, const string& name_prefix,
63 std::vector<std::unique_ptr<Device>>* devices) override;
64 };
65
ListPhysicalDevices(std::vector<string> * devices)66 Status XlaGpuDeviceFactory::ListPhysicalDevices(std::vector<string>* devices) {
67 XlaDeviceFlags* flags = GetXlaDeviceFlags();
68 if (!flags->tf_xla_enable_xla_devices) {
69 VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
70 return Status::OK();
71 }
72
73 auto platform =
74 se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
75 if (!platform.ok()) {
76 // Treat failures as non-fatal; there might not be a GPU in the machine.
77 VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
78 return Status::OK();
79 }
80
81 int device_count = platform.ValueOrDie()->VisibleDeviceCount();
82 if (device_count <= 0) {
83 return Status::OK();
84 }
85
86 for (int i = 0; i < device_count; ++i) {
87 devices->push_back(
88 absl::StrCat("/physical_device:", DEVICE_XLA_GPU, ":", i));
89 }
90
91 return Status::OK();
92 }
93
CreateDevices(const SessionOptions & session_options,const string & name_prefix,std::vector<std::unique_ptr<Device>> * devices)94 Status XlaGpuDeviceFactory::CreateDevices(
95 const SessionOptions& session_options, const string& name_prefix,
96 std::vector<std::unique_ptr<Device>>* devices) {
97 XlaDeviceFlags* flags = GetXlaDeviceFlags();
98 if (!flags->tf_xla_enable_xla_devices) {
99 VLOG(1) << "Not creating XLA devices, tf_xla_enable_xla_devices not set";
100 return Status::OK();
101 }
102
103 XlaOpRegistry::DeviceRegistration registration;
104 registration.compilation_device_name = DEVICE_GPU_XLA_JIT;
105 registration.autoclustering_policy =
106 XlaOpRegistry::AutoclusteringPolicy::kAlways;
107 registration.cluster_resource_variable_ops_unsafely = true;
108 registration.cluster_stack_ops = false;
109 registration.cluster_tensor_array_ops = true;
110 registration.cluster_stateful_rng_ops = true;
111 registration.cluster_control_trigger = true;
112 registration.elide_assert_and_checknumerics = true;
113 registration.cluster_variant_ops = true;
114 registration.cluster_slow_ops = true;
115 registration.cluster_inaccurate_ops = true;
116 XlaOpRegistry::RegisterCompilationDevice(DEVICE_XLA_GPU, registration);
117
118 static XlaDeviceOpRegistrations* registrations =
119 RegisterXlaDeviceKernels(DEVICE_XLA_GPU, DEVICE_GPU_XLA_JIT);
120 (void)registrations;
121
122 auto platform =
123 se::MultiPlatformManager::PlatformWithName(tensorflow::GpuPlatformName());
124 if (!platform.ok()) {
125 // Treat failures as non-fatal; there might not be a GPU in the machine.
126 VLOG(1) << "Failed to create XLA_GPU device: " << platform.status();
127 return Status::OK();
128 }
129
130 auto iter = session_options.config.device_count().find("GPU");
131 if (iter != session_options.config.device_count().end() &&
132 iter->second == 0) {
133 // Device count for GPU is 0.
134 return Status::OK();
135 }
136
137 string allowed_gpus =
138 session_options.config.gpu_options().visible_device_list();
139 absl::optional<std::set<int>> gpu_ids =
140 ParseVisibleDeviceList(allowed_gpus).ValueOrDie();
141 if (!gpu_ids) {
142 gpu_ids.emplace();
143 // Fill the gpu_ids set with all devices if config string is empty.
144 for (int i = 0; i < platform.ValueOrDie()->VisibleDeviceCount(); ++i) {
145 gpu_ids->insert(i);
146 }
147 }
148 for (int i : *gpu_ids) {
149 XlaDevice::Options options;
150 options.platform = platform.ValueOrDie();
151 options.device_name_prefix = name_prefix;
152 options.device_name = DEVICE_XLA_GPU;
153 options.device_ordinal = i;
154 options.compilation_device_name = DEVICE_GPU_XLA_JIT;
155 options.use_multiple_streams = true;
156 options.allowed_devices = gpu_ids;
157 auto device = absl::make_unique<XlaDevice>(session_options, options);
158
159 Status status = device->UseGpuDeviceInfo();
160 if (!status.ok()) {
161 LOG(INFO) << "Ignoring visible " << DEVICE_GPU_XLA_JIT
162 << " device. Device number is " << i << ", reason: " << status;
163 continue;
164 }
165
166 devices->push_back(std::move(device));
167 }
168 return Status::OK();
169 }
170
171 REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
172
173 // Kernel registrations
174
175 constexpr std::array<DataType, 16> kAllXlaGpuTypes = {
176 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_INT8, DT_QINT8, DT_INT16, DT_INT32,
177 DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
178 DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
179
180 REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
181 REGISTER_XLA_COMPILE_KERNEL(DEVICE_XLA_GPU, XlaCompileOp, kAllXlaGpuTypes);
182 REGISTER_XLA_RUN_KERNEL(DEVICE_XLA_GPU, XlaRunOp, kAllXlaGpuTypes);
183
184 REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);
185
186 } // namespace tensorflow
187