1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
17 #define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
18 
19 #include "tensorflow/stream_executor/data_type.h"
20 #include "tensorflow/stream_executor/device_memory.h"
21 #include "tensorflow/stream_executor/platform/logging.h"
22 
23 namespace stream_executor {
24 
25 // Allows to represent a value that is either a host scalar or a scalar stored
26 // on the GPU device.
27 // See also the specialization for ElemT=void below.
28 template <typename ElemT>
29 class HostOrDeviceScalar {
30  public:
31   // Not marked as explicit because when using this constructor, we usually want
32   // to set this to a compile-time constant.
HostOrDeviceScalar(ElemT value)33   HostOrDeviceScalar(ElemT value) : value_(value), is_pointer_(false) {}
HostOrDeviceScalar(const DeviceMemory<ElemT> & pointer)34   explicit HostOrDeviceScalar(const DeviceMemory<ElemT>& pointer)
35       : pointer_(pointer), is_pointer_(true) {
36     CHECK_EQ(1, pointer.ElementCount());
37   }
38 
is_pointer()39   bool is_pointer() const { return is_pointer_; }
pointer()40   const DeviceMemory<ElemT>& pointer() const {
41     CHECK(is_pointer());
42     return pointer_;
43   }
value()44   const ElemT& value() const {
45     CHECK(!is_pointer());
46     return value_;
47   }
48 
49  private:
50   union {
51     ElemT value_;
52     DeviceMemory<ElemT> pointer_;
53   };
54   bool is_pointer_;
55 };
56 
57 // Specialization for wrapping a dynamically-typed value (via type erasure).
58 template <>
59 class HostOrDeviceScalar<void> {
60  public:
61   using DataType = dnn::DataType;
62 
63   // Constructors not marked as explicit because when using this constructor, we
64   // usually want to set this to a compile-time constant.
65 
66   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(float value)67   HostOrDeviceScalar(float value)
68       : float_(value), is_pointer_(false), dtype_(DataType::kFloat) {}
69   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(double value)70   HostOrDeviceScalar(double value)
71       : double_(value), is_pointer_(false), dtype_(DataType::kDouble) {}
72   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(Eigen::half value)73   HostOrDeviceScalar(Eigen::half value)
74       : half_(value), is_pointer_(false), dtype_(DataType::kHalf) {}
75   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(int8 value)76   HostOrDeviceScalar(int8 value)
77       : int8_(value), is_pointer_(false), dtype_(DataType::kInt8) {}
78   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(int32 value)79   HostOrDeviceScalar(int32 value)
80       : int32_(value), is_pointer_(false), dtype_(DataType::kInt32) {}
81   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(std::complex<float> value)82   HostOrDeviceScalar(std::complex<float> value)
83       : complex_float_(value),
84         is_pointer_(false),
85         dtype_(DataType::kComplexFloat) {}
86   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(std::complex<double> value)87   HostOrDeviceScalar(std::complex<double> value)
88       : complex_double_(value),
89         is_pointer_(false),
90         dtype_(DataType::kComplexDouble) {}
91   template <typename T>
HostOrDeviceScalar(const DeviceMemory<T> & pointer)92   explicit HostOrDeviceScalar(const DeviceMemory<T>& pointer)
93       : pointer_(pointer),
94         is_pointer_(true),
95         dtype_(dnn::ToDataType<T>::value) {
96     CHECK_EQ(1, pointer.ElementCount());
97   }
98   // Construct from statically-typed version.
99   template <typename T, typename std::enable_if<!std::is_same<T, void>::value,
100                                                 int>::type = 0>
101   // NOLINTNEXTLINE google-explicit-constructor
HostOrDeviceScalar(const HostOrDeviceScalar<T> & other)102   HostOrDeviceScalar(const HostOrDeviceScalar<T>& other) {
103     if (other.is_pointer()) {
104       *this = HostOrDeviceScalar(other.pointer());
105     } else {
106       *this = HostOrDeviceScalar(other.value());
107     }
108   }
109 
is_pointer()110   bool is_pointer() const { return is_pointer_; }
111   template <typename T>
pointer()112   const DeviceMemory<T>& pointer() const {
113     CHECK(is_pointer());
114     CHECK(dtype_ == dnn::ToDataType<T>::value);
115     return pointer_;
116   }
117   template <typename T>
value()118   const T& value() const {
119     CHECK(!is_pointer());
120     CHECK(dtype_ == dnn::ToDataType<T>::value);
121     return value_impl<T>();
122   }
opaque_pointer()123   const DeviceMemoryBase& opaque_pointer() const {
124     CHECK(is_pointer());
125     return pointer_;
126   }
opaque_value()127   const void* opaque_value() const {
128     CHECK(!is_pointer());
129     switch (dtype_) {
130       case DataType::kFloat:
131         return &float_;
132       case DataType::kDouble:
133         return &double_;
134       case DataType::kHalf:
135         return &half_;
136       case DataType::kInt8:
137         return &int8_;
138       case DataType::kInt32:
139         return &int32_;
140       case DataType::kComplexFloat:
141         return &complex_float_;
142       case DataType::kComplexDouble:
143         return &complex_double_;
144       default:
145         return nullptr;
146     }
147   }
data_type()148   DataType data_type() const { return dtype_; }
149 
150  private:
151   template <typename T>
152   const T& value_impl() const;
153 
154   union {
155     float float_;
156     double double_;
157     Eigen::half half_;
158     int8 int8_;
159     int32 int32_;
160     std::complex<float> complex_float_;
161     std::complex<double> complex_double_;
162     DeviceMemoryBase pointer_;
163   };
164   bool is_pointer_;
165   DataType dtype_;
166 };
167 
168 template <>
169 inline const float& HostOrDeviceScalar<void>::value_impl<float>() const {
170   return float_;
171 }
172 
173 template <>
174 inline const double& HostOrDeviceScalar<void>::value_impl<double>() const {
175   return double_;
176 }
177 
178 template <>
179 inline const Eigen::half& HostOrDeviceScalar<void>::value_impl<Eigen::half>()
180     const {
181   return half_;
182 }
183 
184 template <>
185 inline const int8& HostOrDeviceScalar<void>::value_impl<int8>() const {
186   return int8_;
187 }
188 
189 template <>
190 inline const int32& HostOrDeviceScalar<void>::value_impl<int32>() const {
191   return int32_;
192 }
193 
194 template <>
195 inline const std::complex<float>&
196 HostOrDeviceScalar<void>::value_impl<std::complex<float>>() const {
197   return complex_float_;
198 }
199 
200 template <>
201 inline const std::complex<double>&
202 HostOrDeviceScalar<void>::value_impl<std::complex<double>>() const {
203   return complex_double_;
204 }
205 
206 }  // namespace stream_executor
207 #endif  // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
208