1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Casting utilitiy functions for HLO instructions.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
19 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
20 
21 #include <type_traits>
22 #include "tensorflow/core/platform/logging.h"
23 
24 namespace xla {
25 
26 class HloInstruction;
27 
28 template <class T>
29 using EnableIfDerivedFromHlo =
30     typename std::enable_if<std::is_base_of<HloInstruction, T>::value>::type;
31 
32 // TODO(b/93238915): Switch implementation from C++'s dynamic_cast to LLVM-like
33 // RTTI if it turns out to be a performance issue.
34 // Casts an HloInstruction pointer to one of its subclasses, dies if argument is
35 // nullptr or runtime information does not match.
36 //
37 // Similar to LLVM's cast.
38 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
Cast(const HloInstruction * instruction)39 const T* Cast(const HloInstruction* instruction) {
40   CHECK(instruction != nullptr);
41   const T* casted = dynamic_cast<const T*>(instruction);
42   CHECK(casted != nullptr);
43   return casted;
44 }
45 
46 // Non-const overload of Cast.
47 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
Cast(HloInstruction * instruction)48 T* Cast(HloInstruction* instruction) {
49   return const_cast<T*>(
50       Cast<T>(const_cast<const HloInstruction*>(instruction)));
51 }
52 
53 // Works just like the Cast, except that it allows for a null pointer as an
54 // argument which it then propagates.
55 //
56 // Similar to LLVM's cast_or_null.
57 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
CastOrNull(const HloInstruction * instruction)58 const T* CastOrNull(const HloInstruction* instruction) {
59   return instruction != nullptr ? Cast<T>(instruction) : nullptr;
60 }
61 
62 // Non-const overload of CastOrNull.
63 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
CastOrNull(HloInstruction * instruction)64 T* CastOrNull(HloInstruction* instruction) {
65   return const_cast<T*>(
66       CastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
67 }
68 
69 // Casts an HloInstruction pointer to one of its subclasses, dies if argument is
70 // nullptr, returns nullptr if runtime information does not match.
71 //
72 // Similar to LLVM's dyn_cast.
73 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCast(const HloInstruction * instruction)74 const T* DynCast(const HloInstruction* instruction) {
75   CHECK(instruction != nullptr);
76   return dynamic_cast<const T*>(instruction);
77 }
78 
79 // Non-const overload of DynCast.
80 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCast(HloInstruction * instruction)81 T* DynCast(HloInstruction* instruction) {
82   return const_cast<T*>(
83       DynCast<T>(const_cast<const HloInstruction*>(instruction)));
84 }
85 
86 // Works just like the DynCast, except that it allows for a null pointer as an
87 // argument which it then propagates.
88 //
89 // Similar to LLVM's dyn_cast_or_null.
90 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCastOrNull(const HloInstruction * instruction)91 const T* DynCastOrNull(const HloInstruction* instruction) {
92   return instruction != nullptr ? DynCast<T>(instruction) : nullptr;
93 }
94 
95 // Non-const overload of DynCastOrNull.
96 template <class T, EnableIfDerivedFromHlo<T>* = nullptr>
DynCastOrNull(HloInstruction * instruction)97 T* DynCastOrNull(HloInstruction* instruction) {
98   return const_cast<T*>(
99       DynCastOrNull<T>(const_cast<const HloInstruction*>(instruction)));
100 }
101 
102 }  // namespace xla
103 
104 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_CASTING_UTILS_H_
105