1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
23 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
24 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
27 
28 namespace xla {
29 namespace gpu {
30 
31 class GpuExecutable;
32 
33 // Thunk acts as the bridge between IrEmitter and GpuExecutable. It stores the
34 // metadata IrEmitter generates for GpuExecutable to invoke an HloInstruction.
35 //
36 // Thunk provides the Initialize and ExecuteOnStream interface for GpuExecutable
37 // to initialize and execute the invocation respectively. Its subclasses are
38 // supposed to override these interfaces to launch a generated kernel or call an
39 // external library function (such as operations in cuBLAS).
40 //
41 // This is thread-compatible.
42 class Thunk {
43  public:
44   enum Kind {
45     kCholesky,
46     kConditional,
47     kConvolution,
48     kCopy,
49     kCudnnBatchNormBackward,
50     kCudnnBatchNormForwardInference,
51     kCudnnBatchNormForwardTraining,
52     kNcclAllReduce,
53     kFft,
54     kGemm,
55     kInfeed,
56     kKernel,
57     kMemset32BitValue,
58     kMemzero,
59     kOutfeed,
60     kSequential,
61     kTriangularSolve,
62     kTuple,
63     kWhile,
64   };
65 
66   // The hlo_instruction argument is meant to be the instruction this thunk was
67   // generated from, but Thunk never uses this argument other than to save it
68   // to Thunk::hlo_instruction, so it can be null.
Thunk(Kind kind,const HloInstruction * hlo_instruction)69   explicit Thunk(Kind kind, const HloInstruction* hlo_instruction)
70       : kind_(kind), hlo_instruction_(hlo_instruction) {}
~Thunk()71   virtual ~Thunk() {}
72   Thunk(const Thunk&) = delete;
73   Thunk& operator=(const Thunk&) = delete;
74 
kind()75   Kind kind() const { return kind_; }
hlo_instruction()76   const HloInstruction* hlo_instruction() const { return hlo_instruction_; }
77 
78   // Prepares the thunk for execution on the given StreamExecutor.
79   //
80   // This may be called multiple times.  Its main purpose is to give us a chance
81   // to do initialization outside of ExecuteOnStream() so that the
82   // time spent initializing doesn't count towards our execution profile.
Initialize(const GpuExecutable &,se::StreamExecutor *)83   virtual Status Initialize(const GpuExecutable& /*executable*/,
84                             se::StreamExecutor* /*executor*/) {
85     return Status::OK();
86   }
87 
88   // Returns true if this kernel will autotune for the stream device the next
89   // time it is run.
WillAutotuneKernel(se::Stream *)90   virtual bool WillAutotuneKernel(se::Stream* /*stream*/) { return false; }
91 
92   // Execute the kernel for the thunk on the given stream. This method must be
93   // called after Initialize and can be called multiple times over Thunk's
94   // lifetime. 'stream' and 'profiler' must be non-null.
95   //
96   // Precondition: Initialize(stream->parent()) has been called.
97   virtual Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
98                                  se::Stream* stream,
99                                  HloExecutionProfiler* profiler) = 0;
100 
101  private:
102   Kind kind_;
103   const HloInstruction* hlo_instruction_;
104 };
105 
106 // A sequence of thunks.
107 using ThunkSequence = std::vector<std::unique_ptr<Thunk>>;
108 
109 absl::string_view ThunkKindToString(Thunk::Kind);
110 std::ostream& operator<<(std::ostream& os, Thunk::Kind kind);
111 
112 }  // namespace gpu
113 }  // namespace xla
114 
115 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_THUNK_H_
116