1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This is a registration-oriented interface for multiple platforms. It will 17 // replace the MachineManager singleton interface, as MachineManager does not 18 // currently support simultaneous use of multiple platforms. 19 // 20 // Usage: 21 // 22 // In your BUILD rule, add a dependency on a platform plugin that you'd like 23 // to use, such as: 24 // 25 // //third_party/tensorflow/stream_executor/cuda:cuda_platform 26 // //third_party/tensorflow/stream_executor/opencl:opencl_platform 27 // 28 // This will register platform plugins that can be discovered via this 29 // interface. Sample API usage: 30 // 31 // port::StatusOr<Platform*> platform_status = 32 // se::MultiPlatformManager::PlatformWithName("OpenCL"); 33 // if (!platform_status.ok()) { ... } 34 // Platform* platform = platform_status.ValueOrDie(); 35 // LOG(INFO) << platform->VisibleDeviceCount() << " devices visible"; 36 // if (platform->VisibleDeviceCount() <= 0) { return; } 37 // 38 // for (int i = 0; i < platform->VisibleDeviceCount(); ++i) { 39 // port::StatusOr<StreamExecutor*> executor_status = 40 // platform->ExecutorForDevice(i); 41 // if (!executor_status.ok()) { 42 // LOG(INFO) << "could not retrieve executor for device ordinal " << i 43 // << ": " << executor_status.status(); 44 // continue; 45 // } 46 // LOG(INFO) << "found usable executor: " << executor_status.ValueOrDie(); 47 // } 48 // 49 // A few things to note: 50 // - There is no standard formatting/practice for identifying the name of a 51 // platform. Ideally, a platform will list its registered name in its header 52 // or in other associated documentation. 53 // - Platform name lookup is case-insensitive. "OpenCL" or "opencl" (or even 54 // ("OpEnCl") would work correctly in the above example. 55 // 56 // And similarly, for standard interfaces (BLAS, RNG, etc.) you can add 57 // dependencies on support libraries, e.g.: 58 // 59 // //third_party/tensorflow/stream_executor/cuda:pluton_blas_plugin 60 // //third_party/tensorflow/stream_executor/cuda:cudnn_plugin 61 // //third_party/tensorflow/stream_executor/cuda:cublas_plugin 62 // //third_party/tensorflow/stream_executor/cuda:curand_plugin 63 64 #ifndef TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ 65 #define TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ 66 67 #include <functional> 68 #include <map> 69 #include <memory> 70 #include <vector> 71 72 #include "absl/strings/string_view.h" 73 #include "tensorflow/stream_executor/lib/initialize.h" 74 #include "tensorflow/stream_executor/lib/status.h" 75 #include "tensorflow/stream_executor/lib/statusor.h" 76 #include "tensorflow/stream_executor/platform.h" 77 #include "tensorflow/stream_executor/platform/port.h" 78 79 namespace stream_executor { 80 81 // Manages multiple platforms that may be present on the current machine. 82 class MultiPlatformManager { 83 public: 84 // Registers a platform object, returns an error status if the platform is 85 // already registered. The associated listener, if not null, will be used to 86 // trace events for ALL executors for that platform. 87 // Takes ownership of platform. 88 static port::Status RegisterPlatform(std::unique_ptr<Platform> platform); 89 90 // Retrieves the platform registered with the given platform name (e.g. 91 // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the 92 // Platform's Id() method). 93 // 94 // If the platform has not already been initialized, it will be initialized 95 // with a default set of parameters. 96 // 97 // If the requested platform is not registered, an error status is returned. 98 // Ownership of the platform is NOT transferred to the caller -- 99 // the MultiPlatformManager owns the platforms in a singleton-like fashion. 100 static port::StatusOr<Platform*> PlatformWithName(absl::string_view target); 101 static port::StatusOr<Platform*> PlatformWithId(const Platform::Id& id); 102 103 // Retrieves the platform registered with the given platform name (e.g. 104 // "CUDA", "OpenCL", ...) or id (an opaque, comparable value provided by the 105 // Platform's Id() method). 106 // 107 // The platform will be initialized with the given options. If the platform 108 // was already initialized, an error will be returned. 109 // 110 // If the requested platform is not registered, an error status is returned. 111 // Ownership of the platform is NOT transferred to the caller -- 112 // the MultiPlatformManager owns the platforms in a singleton-like fashion. 113 static port::StatusOr<Platform*> InitializePlatformWithName( 114 absl::string_view target, const std::map<string, string>& options); 115 116 static port::StatusOr<Platform*> InitializePlatformWithId( 117 const Platform::Id& id, const std::map<string, string>& options); 118 119 static std::vector<Platform*> AllPlatforms(); 120 121 // Although the MultiPlatformManager "owns" its platforms, it holds them as 122 // undecorated pointers to prevent races during program exit (between this 123 // object's data and the underlying platforms (e.g., CUDA, OpenCL). 124 // Because certain platforms have unpredictable deinitialization 125 // times/sequences, it is not possible to strucure a safe deinitialization 126 // sequence. Thus, we intentionally "leak" allocated platforms to defer 127 // cleanup to the OS. This should be acceptable, as these are one-time 128 // allocations per program invocation. 129 // The MultiPlatformManager should be considered the owner 130 // of any platforms registered with it, and leak checking should be disabled 131 // during allocation of such Platforms, to avoid spurious reporting at program 132 // exit. 133 134 // Interface for a listener that gets notfied at certain events. 135 class Listener { 136 public: 137 virtual ~Listener() = default; 138 // Callback that is invoked when a Platform is registered. 139 virtual void PlatformRegistered(Platform* platform) = 0; 140 }; 141 // Registers a listeners to receive notifications about certain events. 142 // Precondition: No Platform has been registered yet. 143 static port::Status RegisterListener(std::unique_ptr<Listener> listener); 144 }; 145 146 } // namespace stream_executor 147 148 // multi_platform_manager.cc will define these instances. 149 // 150 // Registering a platform: 151 // REGISTER_MODULE_INITIALIZER_SEQUENCE(my_platform, multi_platform_manager); 152 // REGISTER_MODULE_INITIALIZER_SEQUENCE(multi_platform_manager_listener, 153 // my_platform); 154 // 155 // Registering a listener: 156 // REGISTER_MODULE_INITIALIZER_SEQUENCE(my_listener, 157 // multi_platform_manager_listener); 158 DECLARE_MODULE_INITIALIZER(multi_platform_manager); 159 DECLARE_MODULE_INITIALIZER(multi_platform_manager_listener); 160 161 #endif // TENSORFLOW_STREAM_EXECUTOR_MULTI_PLATFORM_MANAGER_H_ 162