1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
18 
19 #include <map>
20 
21 #include "absl/base/macros.h"
22 #include "tensorflow/stream_executor/blas.h"
23 #include "tensorflow/stream_executor/dnn.h"
24 #include "tensorflow/stream_executor/fft.h"
25 #include "tensorflow/stream_executor/lib/status.h"
26 #include "tensorflow/stream_executor/lib/statusor.h"
27 #include "tensorflow/stream_executor/platform.h"
28 #include "tensorflow/stream_executor/platform/mutex.h"
29 #include "tensorflow/stream_executor/plugin.h"
30 #include "tensorflow/stream_executor/rng.h"
31 
32 namespace stream_executor {
33 
34 namespace internal {
35 class StreamExecutorInterface;
36 }
37 
38 // The PluginRegistry is a singleton that maintains the set of registered
39 // "support library" plugins. Currently, there are four kinds of plugins:
40 // BLAS, DNN, FFT, and RNG. Each interface is defined in the corresponding
41 // gpu_{kind}.h header.
42 //
43 // At runtime, a StreamExecutor object will query the singleton registry to
44 // retrieve the plugin kind that StreamExecutor was configured with (refer to
45 // the StreamExecutor and PluginConfig declarations).
46 //
47 // Plugin libraries are best registered using REGISTER_MODULE_INITIALIZER,
48 // but can be registered at any time. When registering a DSO-backed plugin, it
49 // is usually a good idea to load the DSO at registration time, to prevent
50 // late-loading from distorting performance/benchmarks as much as possible.
51 class PluginRegistry {
52  public:
53   typedef blas::BlasSupport* (*BlasFactory)(internal::StreamExecutorInterface*);
54   typedef dnn::DnnSupport* (*DnnFactory)(internal::StreamExecutorInterface*);
55   typedef fft::FftSupport* (*FftFactory)(internal::StreamExecutorInterface*);
56   typedef rng::RngSupport* (*RngFactory)(internal::StreamExecutorInterface*);
57 
58   // Gets (and creates, if necessary) the singleton PluginRegistry instance.
59   static PluginRegistry* Instance();
60 
61   // Registers the specified factory with the specified platform.
62   // Returns a non-successful status if the factory has already been registered
63   // with that platform (but execution should be otherwise unaffected).
64   template <typename FactoryT>
65   port::Status RegisterFactory(Platform::Id platform_id, PluginId plugin_id,
66                                const string& name, FactoryT factory);
67 
68   // Registers the specified factory as usable by _all_ platform types.
69   // Reports errors just as RegisterFactory.
70   template <typename FactoryT>
71   port::Status RegisterFactoryForAllPlatforms(PluginId plugin_id,
72                                               const string& name,
73                                               FactoryT factory);
74 
75   // TODO(b/22689637): Setter for temporary mapping until all users are using
76   // MultiPlatformManager / PlatformId.
77   void MapPlatformKindToId(PlatformKind platform_kind,
78                            Platform::Id platform_id);
79 
80   // Potentially sets the plugin identified by plugin_id to be the default
81   // for the specified platform and plugin kind. If this routine is called
82   // multiple types for the same PluginKind, the PluginId given in the last call
83   // will be used.
84   bool SetDefaultFactory(Platform::Id platform_id, PluginKind plugin_kind,
85                          PluginId plugin_id);
86 
87   // Return true if the factory/id has been registered for the
88   // specified platform and plugin kind and false otherwise.
89   bool HasFactory(Platform::Id platform_id, PluginKind plugin_kind,
90                   PluginId plugin) const;
91 
92   // Retrieves the factory registered for the specified kind,
93   // or a port::Status on error.
94   template <typename FactoryT>
95   port::StatusOr<FactoryT> GetFactory(Platform::Id platform_id,
96                                       PluginId plugin_id);
97 
98   // TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
99   // on MultiPlatformManager / PlatformId.
100   template <typename FactoryT>
101   ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
102   port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
103                                       PluginId plugin_id);
104 
105  private:
106   // Containers for the sets of registered factories, by plugin kind.
107   struct PluginFactories {
108     std::map<PluginId, BlasFactory> blas;
109     std::map<PluginId, DnnFactory> dnn;
110     std::map<PluginId, FftFactory> fft;
111     std::map<PluginId, RngFactory> rng;
112   };
113 
114   // Simple structure to hold the currently configured default plugins (for a
115   // particular Platform).
116   struct DefaultFactories {
117     DefaultFactories();
118     PluginId blas, dnn, fft, rng;
119   };
120 
121   PluginRegistry();
122 
123   // Actually performs the work of registration.
124   template <typename FactoryT>
125   port::Status RegisterFactoryInternal(PluginId plugin_id,
126                                        const string& plugin_name,
127                                        FactoryT factory,
128                                        std::map<PluginId, FactoryT>* factories);
129 
130   // Actually performs the work of factory retrieval.
131   template <typename FactoryT>
132   port::StatusOr<FactoryT> GetFactoryInternal(
133       PluginId plugin_id, const std::map<PluginId, FactoryT>& factories,
134       const std::map<PluginId, FactoryT>& generic_factories) const;
135 
136   // Returns true if the specified plugin has been registered with the specified
137   // platform factories. Unlike the other overload of this method, this does
138   // not implicitly examine the default factory lists.
139   bool HasFactory(const PluginFactories& factories, PluginKind plugin_kind,
140                   PluginId plugin) const;
141 
142   // The singleton itself.
143   static PluginRegistry* instance_;
144 
145   // TODO(b/22689637): Temporary mapping until all users are using
146   // MultiPlatformManager / PlatformId.
147   std::map<PlatformKind, Platform::Id> platform_id_by_kind_;
148 
149   // The set of registered factories, keyed by platform ID.
150   std::map<Platform::Id, PluginFactories> factories_;
151 
152   // Plugins supported for all platform kinds.
153   PluginFactories generic_factories_;
154 
155   // The sets of default factories, keyed by platform ID.
156   std::map<Platform::Id, DefaultFactories> default_factories_;
157 
158   // Lookup table for plugin names.
159   std::map<PluginId, string> plugin_names_;
160 
161   SE_DISALLOW_COPY_AND_ASSIGN(PluginRegistry);
162 };
163 
164 }  // namespace stream_executor
165 
166 #endif  // TENSORFLOW_STREAM_EXECUTOR_PLUGIN_REGISTRY_H_
167