1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
18 
19 #include <map>
20 
21 #include "absl/base/macros.h"
22 #include "tensorflow/stream_executor/blas.h"
23 #include "tensorflow/stream_executor/dnn.h"
24 #include "tensorflow/stream_executor/fft.h"
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/lib/statusor.h"
27 #include "tensorflow/stream_executor/platform.h"
28 #include "tensorflow/stream_executor/plugin.h"
29 #include "tensorflow/stream_executor/rng.h"
30 
31 namespace stream_executor {
32 
33 namespace internal {
34 class StreamExecutorInterface;
35 }
36 
37 // The PluginRegistry is a singleton that maintains the set of registered
38 // "support library" plugins. Currently, there are four kinds of plugins:
39 // BLAS, DNN, FFT, and RNG. Each interface is defined in the corresponding
40 // gpu_{kind}.h header.
41 //
42 // At runtime, a StreamExecutor object will query the singleton registry to
43 // retrieve the plugin kind that StreamExecutor was configured with (refer to
44 // the StreamExecutor and PluginConfig declarations).
45 //
46 // Plugin libraries are best registered using REGISTER_MODULE_INITIALIZER,
47 // but can be registered at any time. When registering a DSO-backed plugin, it
48 // is usually a good idea to load the DSO at registration time, to prevent
49 // late-loading from distorting performance/benchmarks as much as possible.
50 class PluginRegistry {
51  public:
52   typedef blas::BlasSupport* (*BlasFactory)(internal::StreamExecutorInterface*);
53   typedef dnn::DnnSupport* (*DnnFactory)(internal::StreamExecutorInterface*);
54   typedef fft::FftSupport* (*FftFactory)(internal::StreamExecutorInterface*);
55   typedef rng::RngSupport* (*RngFactory)(internal::StreamExecutorInterface*);
56 
57   // Gets (and creates, if necessary) the singleton PluginRegistry instance.
58   static PluginRegistry* Instance();
59 
60   // Registers the specified factory with the specified platform.
61   // Returns a non-successful status if the factory has already been registered
62   // with that platform (but execution should be otherwise unaffected).
63   template <typename FactoryT>
64   port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
65                                const std::string& name, FactoryT factory);
66 
67   // Registers the specified factory as usable by _all_ platform types.
68   // Reports errors just as RegisterFactory.
69   template <typename FactoryT>
70   port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
71                                               const std::string& name,
72                                               FactoryT factory);
73 
74   // TODO(b/22689637): Setter for temporary mapping until all users are using
75   // MultiPlatformManager / PlatformId.
76   void MapPlatformKindToId(PlatformKind platform_kind,
77                            Platform::Id platform_id);
78 
79   // Potentially sets the plugin identified by plugin_id to be the default
80   // for the specified platform and plugin kind. If this routine is called
81   // multiple types for the same PluginKind, the PluginId given in the last call
82   // will be used.
83   bool SetDefaultFactory(Platform::Id platform_id, PluginKind plugin_kind,
84                          PluginId plugin_id);
85 
86   // Return true if the factory/id has been registered for the
87   // specified platform and plugin kind and false otherwise.
88   bool HasFactory(Platform::Id platform_id, PluginKind plugin_kind,
89                   PluginId plugin) const;
90 
91   // Retrieves the factory registered for the specified kind,
92   // or a port::Status on error.
93   template <typename FactoryT>
94   port::StatusOr<FactoryT> GetFactory(Platform::Id platform_id,
95                                       PluginId plugin_id);
96 
97   // TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
98   // on MultiPlatformManager / PlatformId.
99   template <typename FactoryT>
100   ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
101   port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
102                                       PluginId plugin_id);
103 
104  private:
105   // Containers for the sets of registered factories, by plugin kind.
106   struct PluginFactories {
107     std::map<PluginId, BlasFactory> blas;
108     std::map<PluginId, DnnFactory> dnn;
109     std::map<PluginId, FftFactory> fft;
110     std::map<PluginId, RngFactory> rng;
111   };
112 
113   // Simple structure to hold the currently configured default plugins (for a
114   // particular Platform).
115   struct DefaultFactories {
116     DefaultFactories();
117     PluginId blas, dnn, fft, rng;
118   };
119 
120   PluginRegistry();
121 
122   // Actually performs the work of registration.
123   template <typename FactoryT>
124   port::Status RegisterFactoryInternal(PluginId plugin_id,
125                                        const std::string& plugin_name,
126                                        FactoryT factory,
127                                        std::map<PluginId, FactoryT>* factories);
128 
129   // Actually performs the work of factory retrieval.
130   template <typename FactoryT>
131   port::StatusOr<FactoryT> GetFactoryInternal(
132       PluginId plugin_id, const std::map<PluginId, FactoryT>& factories,
133       const std::map<PluginId, FactoryT>& generic_factories) const;
134 
135   // Returns true if the specified plugin has been registered with the specified
136   // platform factories. Unlike the other overload of this method, this does
137   // not implicitly examine the default factory lists.
138   bool HasFactory(const PluginFactories& factories, PluginKind plugin_kind,
139                   PluginId plugin) const;
140 
141   // The singleton itself.
142   static PluginRegistry* instance_;
143 
144   // TODO(b/22689637): Temporary mapping until all users are using
145   // MultiPlatformManager / PlatformId.
146   std::map<PlatformKind, Platform::Id> platform_id_by_kind_;
147 
148   // The set of registered factories, keyed by platform ID.
149   std::map<Platform::Id, PluginFactories> factories_;
150 
151   // Plugins supported for all platform kinds.
152   PluginFactories generic_factories_;
153 
154   // The sets of default factories, keyed by platform ID.
155   std::map<Platform::Id, DefaultFactories> default_factories_;
156 
157   // Lookup table for plugin names.
158   std::map<PluginId, std::string> plugin_names_;
159 
160   SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
161 };
162 
163 // Explicit specializations are defined in plugin_registry.cc.
164 #define DECLARE_PLUGIN_SPECIALIZATIONS(FACTORY_TYPE)                          \
165   template <>                                                                 \
166   port::Status PluginRegistry::RegisterFactory<PluginRegistry::FACTORY_TYPE>( \
167       Platform::Id platform_id, PluginId plugin_id, const std::string& name,  \
168       PluginRegistry::FACTORY_TYPE factory);                                  \
169   template <>                                                                 \
170   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
171       Platform::Id platform_id, PluginId plugin_id);                          \
172   template <>                                                                 \
173   port::StatusOr<PluginRegistry::FACTORY_TYPE> PluginRegistry::GetFactory(    \
174       PlatformKind platform_kind, PluginId plugin_id)
175 
176 DECLARE_PLUGIN_SPECIALIZATIONS(BlasFactory);
177 DECLARE_PLUGIN_SPECIALIZATIONS(DnnFactory);
178 DECLARE_PLUGIN_SPECIALIZATIONS(FftFactory);
179 DECLARE_PLUGIN_SPECIALIZATIONS(RngFactory);
180 #undef DECL_PLUGIN_SPECIALIZATIONS
181 
182 }  // namespace stream_executor
183 
184 #endif  // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
185