1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Suite of datatypes to represent data-parallel kernel objects (code entities).
17 // Kernel is the untyped variant, whereas TypedKernel takes a type signature
18 // to do some template-based helper generation and give compile-time type
19 // checking for kernel launch parameters.
20 //
21 // Users typically don't see KernelBase, they see typed kernels, analogous to a
22 // typed function pointer. TypedKernels express their argument types via
23 // template parameters like so:
24 //
25 //  TypedKernel<DeviceMemory<int>*, int>
26 //
27 // Which expresses a data parallel kernel signature for:
28 //
29 //  void(int*, int);
30 //
31 // And for a const memory region:
32 //
33 //  TypedKernel<const DeviceMemory<int>&, int>
34 //
35 // Corresponds to a data parallel kernel signature for:
36 //
37 //  void(const int*, int)
38 //
39 // Note that kernels always have a void return type, so results typically must
40 // be memcpy'ied from device memory to the host.
41 //
42 // Also note that a scalar integer residing in device memory and an array of
43 // integers residing in device memory have the same signature: DeviceMemory<T>.
44 // However, in the future, checks may be added for additional safety that arrays
45 // of minimum sizes are passed when those minimum sizes are contractually
46 // expected by the kernel.
47 //
48 // For user-defined types whose definitions are appropriately shared between the
49 // host code doing the launching and the kernel code being launched, the user
50 // defined types are similarly permitted to be expressed as residing in device
51 // memory:
52 //
53 //  TypedKernel<DeviceMemory<MyUserDefinedStructure>>
54 //
55 // And, when the alignment and padding are agreed upon, POD types will also be
56 // able to be passed by value; for example, it is a common idiom to specify a
57 // bunch of options simultaneously with a structure:
58 //
59 //  TypedKernel<MyOptionsStructurePassedByValue, DeviceMemory<float>>
60 //
61 // Which corresponds to a data parallel kernel signature like:
62 //
63 //  void(MyOptionsStructurePassedByValue value, float *result);
64 //
65 // Users typically won't need to type out the TypedKernel signature in full, it
66 // will be typedef'd by automatically generated code; for example, see
67 // stream_executor::executor_sample::VecReduceAddKernel.
68 
69 #ifndef TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
70 #define TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
71 
72 #include <array>
73 #include <memory>
74 #include <tuple>
75 #include <type_traits>
76 #include <vector>
77 
78 #include "absl/strings/string_view.h"
79 #include "tensorflow/stream_executor/device_memory.h"
80 #include "tensorflow/stream_executor/kernel_cache_config.h"
81 #include "tensorflow/stream_executor/lib/array_slice.h"
82 #include "tensorflow/stream_executor/platform/port.h"
83 
84 namespace stream_executor {
85 
86 class DeviceMemoryBase;
87 template <typename ElemT>
88 class DeviceMemory;
89 class StreamExecutor;
90 
91 namespace internal {
92 class KernelInterface;
93 }  // namespace internal
94 
95 // KernelMetadata holds runtime-queryable attributes of a loaded kernel, such as
96 // registers allocated, shared memory used, etc.
97 // Not all platforms support reporting of all information, so each accessor
98 // returns false if the associated field is not populated in the underlying
99 // platform.
100 class KernelMetadata {
101  public:
KernelMetadata()102   KernelMetadata()
103       : has_registers_per_thread_(false), has_shared_memory_bytes_(false) {}
104 
105   // Returns the number of registers used per thread executing this kernel.
106   bool registers_per_thread(int *registers_per_thread) const;
107 
108   // Sets the number of registers used per thread executing this kernel.
109   void set_registers_per_thread(int registers_per_thread);
110 
111   // Returns the amount of [static] shared memory used per block executing this
112   // kernel. Note that dynamic shared memory allocations are not (and can not)
113   // be reported here (since they're not specified until kernel launch time).
114   bool shared_memory_bytes(int *shared_memory_bytes) const;
115 
116   // Sets the amount of [static] shared memory used per block executing this
117   // kernel.
118   void set_shared_memory_bytes(int shared_memory_bytes);
119 
120  private:
121   // Holds the value returned by registers_per_thread above.
122   bool has_registers_per_thread_;
123   int registers_per_thread_;
124 
125   // Holds the value returned by shared_memory_bytes above.
126   bool has_shared_memory_bytes_;
127   int64 shared_memory_bytes_;
128 };
129 
130 // A data-parallel kernel (code entity) for launching via the StreamExecutor,
131 // analogous to a void* device function pointer. See TypedKernel for the typed
132 // variant.
133 //
134 // Thread-compatible.
135 class KernelBase {
136  public:
137   KernelBase(KernelBase &&from);
138 
139   // Constructs an "empty" (not-yet-loaded) kernel instance.
140   //
141   // parent is the StreamExecutor that will be responsible for loading the
142   // implementation of this kernel. It must not be null.
143   explicit KernelBase(StreamExecutor *parent);
144 
145   // Test-only constructor that can take a mock KernelInterface implementation.
146   KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
147 
148   // Releases resources associated with the kernel instance (i.e.
149   // platform-specific implementation).
150   ~KernelBase();
151 
152   // Returns the number of parameters that this kernel accepts. (Arity refers to
153   // nullary, unary, ...).
154   unsigned Arity() const;
155 
156   // Returns the StreamExecutor that represents the platform this kernel
157   // executes upon.
parent()158   StreamExecutor *parent() const { return parent_; }
159 
160   // Returns a const pointer to the (opaque) platform-dependent implementation.
implementation()161   const internal::KernelInterface *implementation() const {
162     return implementation_.get();
163   }
164 
165   // Returns a non-const pointer to the (opaque) platform-dependent
166   // implementation.
implementation()167   internal::KernelInterface *implementation() { return implementation_.get(); }
168 
set_metadata(const KernelMetadata & metadata)169   void set_metadata(const KernelMetadata &metadata) { metadata_ = metadata; }
170 
metadata()171   const KernelMetadata &metadata() const { return metadata_; }
172 
173   // Sets the preferred cache configuration for a kernel. This is just a
174   // suggestion to the runtime, and may not be honored during execution.
175   void SetPreferredCacheConfig(KernelCacheConfig config);
176 
177   // Gets the preferred cache configuration for a kernel.
178   KernelCacheConfig GetPreferredCacheConfig() const;
179 
180   void set_name(absl::string_view name);
name()181   const string &name() const { return name_; }
demangled_name()182   const string &demangled_name() const { return demangled_name_; }
183 
184  private:
185   // The StreamExecutor that loads this kernel object.
186   StreamExecutor *parent_;
187 
188   // Implementation delegated to for platform-specific functionality.
189   std::unique_ptr<internal::KernelInterface> implementation_;
190 
191   string name_;
192   string demangled_name_;
193 
194   KernelMetadata metadata_;
195 
196   SE_DISALLOW_COPY_AND_ASSIGN(KernelBase);
197 };
198 
199 // Whether T is a DeviceMemory-family pointer.
200 template <typename T>
201 struct IsDeviceMemoryPointer {
202   static constexpr bool value = false;
203 };
204 
205 template <typename U>
206 struct IsDeviceMemoryPointer<DeviceMemory<U> *> {
207   static constexpr bool value = true;
208 };
209 
210 template <>
211 struct IsDeviceMemoryPointer<DeviceMemoryBase *> {
212   static constexpr bool value = true;
213 };
214 
215 // Whether T is a DeviceMemory-family value-like thing (which includes a
216 // reference). This trait is useful because we pack values in the same manner as
217 // references.
218 template <typename T>
219 struct IsDeviceMemoryValueLike {
220   static constexpr bool value = false;
221 };
222 
223 template <typename U>
224 struct IsDeviceMemoryValueLike<DeviceMemory<U> &> {
225   static constexpr bool value = true;
226 };
227 
228 // We need to treat SharedDeviceMemory types differently than other DeviceMemory
229 // types (since they maintain no allocations), hence these specializations.
230 template <typename U>
231 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U> &> {
232   static constexpr bool value = false;
233 };
234 
235 template <>
236 struct IsDeviceMemoryValueLike<DeviceMemoryBase &> {
237   static constexpr bool value = true;
238 };
239 
240 template <typename U>
241 struct IsDeviceMemoryValueLike<DeviceMemory<U>> {
242   static constexpr bool value = true;
243 };
244 
245 template <typename U>
246 struct IsDeviceMemoryValueLike<SharedDeviceMemory<U>> {
247   static constexpr bool value = false;
248 };
249 
250 template <>
251 struct IsDeviceMemoryValueLike<DeviceMemoryBase> {
252   static constexpr bool value = true;
253 };
254 
255 template <typename U>
256 struct IsSharedDeviceMemory {
257   static constexpr bool value = false;
258 };
259 
260 template <typename U>
261 struct IsSharedDeviceMemory<SharedDeviceMemory<U> &> {
262   static constexpr bool value = true;
263 };
264 
265 template <typename U>
266 struct IsSharedDeviceMemory<SharedDeviceMemory<U>> {
267   static constexpr bool value = true;
268 };
269 
270 // Basic data about a kernel argument.
271 struct KernelArg {
272   bool is_shared;
273   const void *address;
274   size_t size;
275 };
276 
277 // An iterator for traversing all the arguments of a KernelArgsArray.
278 class KernelArgIterator {
279  public:
280   KernelArgIterator(int number_of_argument_addresses,
281                     int number_of_shared_memory_arguments,
282                     const void *const *arg_addresses_data,
283                     const size_t *arg_sizes_data,
284                     const size_t *shmem_bytes_data,
285                     const size_t *shmem_indices_data)
286       : arg_index_(0),
287         number_of_arguments_(number_of_argument_addresses +
288                              number_of_shared_memory_arguments),
289         arg_address_iter_(arg_addresses_data),
290         arg_size_iter_(arg_sizes_data),
291         shmem_bytes_iter_(shmem_bytes_data),
292         shmem_indices_iter_(shmem_indices_data),
293         shmem_indices_end_(shmem_indices_data +
294                            number_of_shared_memory_arguments) {}
295 
296   // Returns true if another argument is present in the iterator.
297   bool has_next() { return arg_index_ < number_of_arguments_; }
298 
299   // Returns the next argument in the iterator.
300   //
301   // Returns a default-constructed KernelArg if there is no next argument.
302   KernelArg next() {
303     KernelArg result = {};
304     if (!has_next()) {
305       return result;
306     } else if ((shmem_indices_iter_ != shmem_indices_end_) &&
307                (arg_index_ == *shmem_indices_iter_)) {
308       result.is_shared = true;
309       result.address = nullptr;
310       result.size = *shmem_bytes_iter_;
311       ++shmem_indices_iter_;
312       ++shmem_bytes_iter_;
313     } else {
314       result.is_shared = false;
315       result.address = *arg_address_iter_;
316       result.size = *arg_size_iter_;
317       ++arg_address_iter_;
318       ++arg_size_iter_;
319     }
320     ++arg_index_;
321     return result;
322   }
323 
324  private:
325   size_t arg_index_;
326   size_t number_of_arguments_;
327   const void *const *arg_address_iter_;
328   const size_t *arg_size_iter_;
329   const size_t *shmem_bytes_iter_;
330   const size_t *shmem_indices_iter_;
331   const size_t *const shmem_indices_end_;
332 };
333 
334 // Base class for KernelArgsArray.
335 //
336 // Supports all the getter methods that do not depend on the compile-time number
337 // of arguments template parameter.
338 //
339 // This class exists as a way to pass kernel arguments to
340 // StreamExecutorInterface::Launch. That Launch method is virtual, so it can't
341 // be templated to accept any KernelArgsArray type, therefore a reference to
342 // this base type is passed instead.
343 //
344 // Performance is not a concern here because each of these methods will be
345 // called at most once per kernel launch. Past performance concerns with
346 // KernelArgsArray have been in reference to the argument packing routines which
347 // are called once per kernel argument. Those packing routines are now handled
348 // by the templated KernelArgsArray subclass of this class where they can take
349 // advantage of compile-time knowledge of the number of arguments in order to be
350 // very efficient.
351 class KernelArgsArrayBase {
352  public:
353   virtual ~KernelArgsArrayBase() = default;
354 
355   // Gets the number of arguments added so far, including shared memory
356   // arguments.
357   virtual size_t number_of_arguments() const = 0;
358 
359   // Gets the total number of shared memory bytes added so far.
360   virtual uint64 number_of_shared_bytes() const = 0;
361 
362   // Gets the list of argument addresses.
363   virtual port::ArraySlice<const void *> argument_addresses() const = 0;
364 
365   // Gets an iterator to the arguments in the array.
366   virtual KernelArgIterator arg_iterator() const = 0;
367 };
368 
369 // A list of arguments for a kernel call.
370 //
371 // The template parameter kNumArgs is the maximum number of arguments which can
372 // be stored in the list.
373 //
374 // Contains a list of addresses for non-shared-memory arguments and a list of
375 // sizes for shared-memory arguments. Since the shared-memory arguments may be
376 // interspersed with the non-shared-memory arguments, it also stores a list of
377 // the indices at which the shared-memory arguments appeared.
378 //
379 // For example, if the argument address list contains {a, b, c, d, e}, the
380 // shared-memory arguments list contains the sizes of {A, B, C}, and the
381 // shared-memory indices list contains {0, 3, 5}, then the original list of
382 // arguments was {A, a, b, B, c, C, d, e}.
383 //
384 // This way of storing the arguments makes CUDA kernel calls efficient because
385 // they only require the argument address list and the total number of shared
386 // bytes, but it also makes it possible for OpenCL kernel calls because they
387 // depend on the location of each shared-memory argument and its size.
388 //
389 // Note that the code for adding arguments has been identified as a performance
390 // hotspot in some real-world applications so this structure has been optimized
391 // for the performance of argument adding.
392 template <size_t kNumArgs>
393 class KernelArgsArray : public KernelArgsArrayBase {
394  public:
395   explicit KernelArgsArray()
396       : total_shared_memory_bytes_(0),
397         number_of_argument_addresses_(0),
398         number_of_shared_memory_arguments_(0) {}
399 
400   // Adds an argument to the list.
401   //
402   // Note that the address of the argument is stored, so the input must not go
403   // out of scope before the instance of this class that calls this method does.
404   template <typename T>
405   void add_argument(const T &arg) {
406     argument_addresses_[number_of_argument_addresses_] =
407         static_cast<const void *>(&arg);
408     argument_sizes_[number_of_argument_addresses_] = sizeof(arg);
409     ++number_of_argument_addresses_;
410   }
411 
412   // Adds a device memory argument to the list.
413   void add_device_memory_argument(const DeviceMemoryBase &arg) {
414     const void **copy_ptr =
415         &device_memory_opaque_pointers_[number_of_argument_addresses_];
416     *copy_ptr = arg.opaque();
417     argument_addresses_[number_of_argument_addresses_] = copy_ptr;
418     argument_sizes_[number_of_argument_addresses_] = sizeof(void *);
419     ++number_of_argument_addresses_;
420   }
421 
422   // Adds a shared memory argument to the list.
423   //
424   // The only significant information about a shared argument is its size, so
425   // that is the only parameter in this function.
426   void add_shared_bytes(size_t number_of_bytes) {
427     shared_memory_indices_[number_of_shared_memory_arguments_] =
428         number_of_argument_addresses_ + number_of_shared_memory_arguments_;
429     shared_memory_bytes_[number_of_shared_memory_arguments_] = number_of_bytes;
430     ++number_of_shared_memory_arguments_;
431     total_shared_memory_bytes_ += number_of_bytes;
432   }
433 
434   // Gets the number of arguments added so far, including shared memory
435   // arguments.
436   size_t number_of_arguments() const override {
437     return number_of_argument_addresses_ + number_of_shared_memory_arguments_;
438   }
439 
440   // Gets the total number of shared memory bytes added so far.
441   uint64 number_of_shared_bytes() const override {
442     return total_shared_memory_bytes_;
443   }
444 
445   // Gets the list of argument addresses.
446   port::ArraySlice<const void *> argument_addresses() const override {
447     return port::ArraySlice<const void *>(argument_addresses_.data(),
448                                           number_of_argument_addresses_);
449   }
450 
451   // Gets an iterator to the arguments in the array.
452   KernelArgIterator arg_iterator() const override {
453     return KernelArgIterator(
454         number_of_argument_addresses_, number_of_shared_memory_arguments_,
455         argument_addresses_.data(), argument_sizes_.data(),
456         shared_memory_bytes_.data(), shared_memory_indices_.data());
457   }
458 
459  private:
460   // A place to store copies of opaque pointers from device memory arguments.
461   std::array<const void *, kNumArgs> device_memory_opaque_pointers_;
462 
463   // Addresses for non-shared-memory arguments.
464   std::array<const void *, kNumArgs> argument_addresses_;
465 
466   // Sizes for non-shared-memory arguments.
467   std::array<size_t, kNumArgs> argument_sizes_;
468 
469   // Size in bytes for each shared memory argument.
470   std::array<size_t, kNumArgs> shared_memory_bytes_;
471 
472   // Indices in the arguments array for shared memory arguments.
473   std::array<size_t, kNumArgs> shared_memory_indices_;
474 
475   // Total of all shared memory sizes.
476   size_t total_shared_memory_bytes_;
477 
478   // Number of significant entries in argument_addresses_ and argument_sizes_.
479   size_t number_of_argument_addresses_;
480 
481   // Number of significant entries in shared_memory_bytes_ and
482   // shared_memory_indices_.
483   size_t number_of_shared_memory_arguments_;
484 };
485 
486 // Typed variant of KernelBase, like a typed device function pointer. See the
487 // file comment for details and example usage.
488 //
489 // This class contains template metaprogramming magic to type check the
490 // parameters passed to a kernel launch are acceptable, and subsequently pack
491 // them into a form which can be used by the StreamExecutorInterface
492 // implementation. (i.e.  CUDA and OpenCL both bind void*s with associated
493 // sizes as kernel arguments.)
494 //
495 // Thread-compatible.
496 template <typename... Params>
497 class TypedKernel : public KernelBase {
498  public:
499   static constexpr size_t kNumberOfParameters = sizeof...(Params);
500 
501   // Delegates to KernelBase::KernelBase(), see that constructor.
502   explicit TypedKernel(StreamExecutor *parent) : KernelBase(parent) {}
503 
504   // Test-only constructor that can take a mock KernelInterface implementation.
505   // Takes ownership of implementation, it should not be null.
506   TypedKernel(StreamExecutor *parent, internal::KernelInterface *implementation)
507       : KernelBase(parent, implementation) {}
508 
509  private:
510   // Stream needs access to the specific parameter-packing functionality that
511   // the TypedKernel provides for its corresponding type signature (and no other
512   // type signatures).
513   friend class Stream;
514 
515   // This is the main entry point into the magic. Packs the parameters (which
516   // must type check against the class template) into the args and sizes
517   // arrays.
518   //
519   // Const refs are taken as parameters on all of the handlers to avoid
520   // implicit type promotion of integers.
521   //
522   // WARNING: as a performance optimization this method may store pointers to
523   // some of the input parameters in the kernel args structure, so any params
524   // passed into this method must live at least as long as the kernel args
525   // structure.
526   void PackParams(KernelArgsArray<kNumberOfParameters> *args,
527                   Params &... params) const {
528     PackOneParam(args, params...);
529   }
530 
531   template <typename T, typename... RestOfParams>
532   void PackOneParam(KernelArgsArray<kNumberOfParameters> *args, const T &arg,
533                     const RestOfParams &... rest) const {
534     PackOneParam(args, arg);
535     PackOneParam(args, rest...);
536   }
537 
538   // Packs one (non-DeviceMemoryBase) parameter into the arg and sizes array.
539   // The enable_if<> is for excluding DeviceMemoryBase args, which have a
540   // separate implementation below.
541   template <typename T>
542   void PackOneParam(
543       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
544       typename std::enable_if<!IsDeviceMemoryValueLike<T>::value &&
545                               !IsDeviceMemoryPointer<T>::value &&
546                               !IsSharedDeviceMemory<T>::value>::type * =
547           nullptr) const {
548     static_assert(!std::is_pointer<T>::value,
549                   "cannot pass raw pointer to the device");
550     static_assert(!std::is_convertible<T, DeviceMemoryBase>::value,
551                   "cannot pass device memory as a normal value");
552     args->add_argument(arg);
553   }
554 
555   // DeviceMemoryBase family reference override.
556   template <typename T>
557   void PackOneParam(
558       KernelArgsArray<kNumberOfParameters> *args, const T &arg,
559       typename std::enable_if<IsDeviceMemoryValueLike<T>::value>::type * =
560           nullptr) const {
561     args->add_device_memory_argument(arg);
562   }
563 
564   // DeviceMemoryBase family pointer override.
565   template <typename T>
566   void PackOneParam(
567       KernelArgsArray<kNumberOfParameters> *args, T arg,
568       typename std::enable_if<IsDeviceMemoryPointer<T>::value>::type * =
569           nullptr) const {
570     DeviceMemoryBase *ptr = static_cast<DeviceMemoryBase *>(arg);
571     args->add_device_memory_argument(*ptr);
572   }
573 
574   // Dynamic shared device memory has a size, but no associated allocation on
575   // the host; internally, the device will allocate storage.
576   template <typename T>
577   void PackOneParam(
578       KernelArgsArray<kNumberOfParameters> *args, T arg,
579       typename std::enable_if<IsSharedDeviceMemory<T>::value>::type * =
580           nullptr) const {
581     args->add_shared_bytes(arg.size());
582   }
583 
584   // Base case for variadic template expansion - nothing to do!
585   void PackOneParam(KernelArgsArray<kNumberOfParameters> *args) const {}
586 
587   SE_DISALLOW_COPY_AND_ASSIGN(TypedKernel);
588 };
589 
590 // Template metaprogramming helper type that helps us produce better error
591 // messages at compile time when the are mismatches between the parameter
592 // type list and the argument type list.
593 template <typename ParamTuple, typename ArgTuple>
594 struct KernelInvocationChecker {
595   // Whether the parameter tuple and argument tuple match in length.
596   static constexpr bool kLengthMatches =
597       std::tuple_size<ParamTuple>::value == std::tuple_size<ArgTuple>::value;
598 
599   // The (matching) length of the parameters and arguments type lists.
600   static constexpr int kTupleLength =
601       static_cast<int>(std::tuple_size<ArgTuple>::value);
602 
603   // Helper trait to say whether the parameter wants a DeviceMemory-reference
604   // compatible type. This is for inexact type matches, so that it doesn't have
605   // to be precisely a const DeviceMemory<T>&, but can also be a value that
606   // represents the same.
607   template <typename ParamType, typename ArgType>
608   struct IsCompatibleDeviceMemoryRef {
609     static constexpr bool value = false;
610   };
611 
612   // See type trait definition above.
613   template <typename U>
614   struct IsCompatibleDeviceMemoryRef<const DeviceMemory<U> &, DeviceMemory<U>> {
615     static constexpr bool value = true;
616   };
617 
618   // See type trait definition above.
619   template <typename U>
620   struct IsCompatibleDeviceMemoryRef<const SharedDeviceMemory<U> &,
621                                      SharedDeviceMemory<U>> {
622     static constexpr bool value = true;
623   };
624 
625   // Returns whether ParamT and ArgT are compatible for data parallel kernel
626   // parameter packing without any assert functionality.
627   template <typename ParamT, typename ArgT>
628   static constexpr bool CompatibleNoAssert() {
629     return std::is_same<typename std::remove_const<ParamT>::type,
630                         ArgT>::value ||
631            IsCompatibleDeviceMemoryRef<ParamT, ArgT>::value;
632   }
633 
634   // Checks whether ParamT and ArgT are compatible for data parallel kernel
635   // parameter packing. kArgumentNumber is unused, it just for error display.
636   //
637   // NOTE: if you encounter an error here, you can see the mismatch by looking
638   // at the end of the last error message, which will be of the form:
639   //
640   //    ...::Compatible<const stream_executor::DeviceMemory<OneThing> &,
641   //                    stream_executor::DeviceMemory<AnotherThing>, true,
642   //                    0>'
643   //    requested here
644   //
645   // This means that the 0th argument you passed to the kernel invocation should
646   // have been DeviceMemory<OneThing> but was observed to be
647   // DeviceMemory<AnotherThing>.
648   template <typename ParamT, typename ArgT, bool kShouldStaticAssert,
649             int kArgumentNumber>
650   static constexpr bool Compatible() {
651     static_assert(
652         kShouldStaticAssert ? CompatibleNoAssert<ParamT, ArgT>() : true,
653         "parameter type (LHS) is not compatible with argument type (RHS)");
654     return CompatibleNoAssert<ParamT, ArgT>();
655   }
656 
657   // Checks the parameter/argument match at kArgumentNumber for an out of bounds
658   // argument number.
659   //
660   // This is the base case: we've run out of argument to check, so we're all
661   // good.
662   template <int kArgumentNumber, bool kShouldStaticAssert>
663   static constexpr bool CheckParam(
664       typename std::enable_if<(kArgumentNumber < 0)>::type *dummy = nullptr) {
665     return true;
666   }
667 
668   // Checks the parameter/argument match at kArgumentNumber.
669   // kShouldStaticAssert determines whether to assert out on a mismatch, or just
670   // yield the constexpr boolean value.
671   template <int kArgumentNumber, bool kShouldStaticAssert>
672   static constexpr bool CheckParam(
673       typename std::enable_if<kArgumentNumber >= 0>::type *dummy = nullptr) {
674     typedef typename std::tuple_element<kArgumentNumber, ParamTuple>::type
675         ParamT;
676     typedef typename std::tuple_element<kArgumentNumber, ArgTuple>::type ArgT;
677     return Compatible<ParamT, ArgT, kShouldStaticAssert, kArgumentNumber>() &&
678            CheckParam<kArgumentNumber - 1, kShouldStaticAssert>();
679   }
680 
681   // Checks the parameters/arguments for match, but doesn't static assert out.
682   // This is useful for testing/inspecting whether a set of parameters match in
683   // things like tests.
684   static constexpr bool CheckAllNoStaticAssert() {
685     return kLengthMatches && CheckParam<kTupleLength - 1, false>();
686   }
687 
688   // Checks the parameters and static asserts out with a helpful error message
689   // (and useful template parameters in the instantiation stack) if there is an
690   // error.
691   static constexpr bool CheckAllStaticAssert() {
692     static_assert(kLengthMatches,
693                   "argument length mismatched against typed kernel parameters");
694     return kLengthMatches && CheckParam<kTupleLength - 1, true>();
695   }
696 };
697 
698 // This is a convenience type for checking whether a typed kernel matches
699 // against a type list.
700 template <typename KernelT, typename... Params>
701 struct KernelParamsOk {
702   static constexpr bool kResult = false;
703 };
704 
705 // See above.
706 template <typename... Params, typename... Args>
707 struct KernelParamsOk<TypedKernel<Params...>, Args...> {
708   static constexpr bool kResult = KernelInvocationChecker<
709       std::tuple<Params...>, std::tuple<Args...>>::CheckAllNoStaticAssert();
710 };
711 
712 }  // namespace stream_executor
713 
714 #endif  // TENSORFLOW_STREAM_EXECUTOR_KERNEL_H_
715