1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
18 
19 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
20 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
21 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
22 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
23 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
24 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas.
31 class NcclAllReduceThunk : public Thunk {
32  public:
33   // Returns whether NCCL operations appear possible to perform; e.g. if we
34   // haven't done a build with the CUDA compiler enabled, we can't compile the
35   // NCCL header, and thus this will be false.
36   //
37   // When this is false, the ExecuteOnStream() call will simply return a status
38   // error.
39   static bool NcclIsEnabled();
40 
41   // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial
42   // implementation is simply F32 summation.
43   NcclAllReduceThunk(int64 replica_count, int64 element_count,
44                      const BufferAllocation::Slice& source_buffer,
45                      const BufferAllocation::Slice& destination_buffer,
46                      const HloInstruction* all_reduce);
47 
48   Status ExecuteOnStream(const BufferAllocations& buffer_allocations,
49                          se::Stream* stream,
50                          HloExecutionProfiler* profiler) override;
51 
52  private:
53   const int64 replica_count_;
54   const int64 element_count_;
55   const BufferAllocation::Slice source_buffer_;
56   const BufferAllocation::Slice destination_buffer_;
57 };
58 
59 }  // namespace gpu
60 }  // namespace xla
61 
62 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_
63