1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ 18 19 #include "tensorflow/compiler/xla/service/buffer_assignment.h" 20 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 21 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" 22 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 23 #include "tensorflow/compiler/xla/service/hlo_instruction.h" 24 #include "tensorflow/core/platform/stream_executor_no_cuda.h" 25 #include "tensorflow/core/platform/types.h" 26 27 namespace xla { 28 namespace gpu { 29 30 // Thunk that performs a NCCL-based All-Reduce among CUDA GPU-based replicas. 31 class NcclAllReduceThunk : public Thunk { 32 public: 33 // Returns whether NCCL operations appear possible to perform; e.g. if we 34 // haven't done a build with the CUDA compiler enabled, we can't compile the 35 // NCCL header, and thus this will be false. 36 // 37 // When this is false, the ExecuteOnStream() call will simply return a status 38 // error. 39 static bool NcclIsEnabled(); 40 41 // TODO(b/125951860): Plumb more datatypes / reduction operators. Initial 42 // implementation is simply F32 summation. 43 NcclAllReduceThunk(int64 replica_count, int64 element_count, 44 const BufferAllocation::Slice& source_buffer, 45 const BufferAllocation::Slice& destination_buffer, 46 const HloInstruction* all_reduce); 47 48 Status ExecuteOnStream(const BufferAllocations& buffer_allocations, 49 se::Stream* stream, 50 HloExecutionProfiler* profiler) override; 51 52 private: 53 const int64 replica_count_; 54 const int64 element_count_; 55 const BufferAllocation::Slice source_buffer_; 56 const BufferAllocation::Slice destination_buffer_; 57 }; 58 59 } // namespace gpu 60 } // namespace xla 61 62 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_NCCL_ALL_REDUCE_THUNK_H_ 63