1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
17
18 #include <chrono> // NOLINT (required by TF interfaces)
19 #include <cstdlib>
20 #include <memory>
21 #include <string>
22 #include <utility>
23 #include <vector>
24
25 #include "absl/strings/str_format.h"
26 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
28 #include "tensorflow/compiler/xla/util.h"
29
30 namespace xla {
31 namespace gpu {
32
GetNcclAllGatherConfig(mlir::lmhlo::AllGatherOp op,int64 replica_count)33 static NcclAllGatherConfig GetNcclAllGatherConfig(mlir::lmhlo::AllGatherOp op,
34 int64 replica_count) {
35 NcclAllGatherConfig config;
36 config.config = GetNcclCollectiveConfigForMlir(op, replica_count);
37 return config;
38 }
39
CanImplement(const HloInstruction * hlo)40 /*static*/ bool NcclAllGatherThunk::CanImplement(const HloInstruction* hlo) {
41 auto operands_are_supported = [hlo]() {
42 return absl::c_all_of(hlo->operands(), [](HloInstruction* operand) {
43 return LayoutUtil::IsDenseArray(operand->shape()) &&
44 IsTypeSupportedByNccl(operand->shape().element_type());
45 });
46 };
47 return (Cast<HloAllGatherInstruction>(hlo)->all_gather_dimension() == 0) &&
48 operands_are_supported();
49 }
50
CanImplement(mlir::lmhlo::AllGatherOp op)51 /*static*/ bool NcclAllGatherThunk::CanImplement(mlir::lmhlo::AllGatherOp op) {
52 bool operands_are_supported =
53 absl::c_all_of(op.operands(), [](mlir::Value operand) {
54 Shape shape = TypeToShape(operand.getType());
55 return LayoutUtil::IsDenseArray(shape) &&
56 IsTypeSupportedByNccl(shape.element_type());
57 });
58 return op.all_gather_dimension() == 0 && operands_are_supported;
59 }
60
NcclAllGatherThunk(ThunkInfo thunk_info,mlir::lmhlo::AllGatherOp op,int64 replica_count,std::vector<NcclAllGatherThunk::Buffer> buffers)61 NcclAllGatherThunk::NcclAllGatherThunk(
62 ThunkInfo thunk_info, mlir::lmhlo::AllGatherOp op, int64 replica_count,
63 std::vector<NcclAllGatherThunk::Buffer> buffers)
64 : NcclCollectiveThunk(Thunk::kNcclAllGather, thunk_info),
65 config_(GetNcclAllGatherConfig(op, replica_count)),
66 buffers_(std::move(buffers)) {
67 CHECK_EQ(config_.config.operand_count, buffers_.size());
68 }
69
RunNcclCollective(const ExecuteParams & params,ncclComm_t comm)70 Status NcclAllGatherThunk::RunNcclCollective(const ExecuteParams& params,
71 ncclComm_t comm) {
72 #if XLA_ENABLE_XCCL
73 int device_ordinal = params.stream->parent()->device_ordinal();
74 VLOG(3) << "Performing all-gather from device ordinal: " << device_ordinal;
75
76 cudaStream_t* cu_stream = reinterpret_cast<cudaStream_t*>(
77 params.stream->implementation()->GpuStreamMemberHack());
78
79 XLA_CUDA_RETURN_IF_ERROR(ncclGroupStart());
80 for (size_t i = 0; i < buffers_.size(); ++i) {
81 const Buffer& buffer = buffers_[i];
82 const void* send_buffer =
83 params.buffer_allocations->GetDeviceAddress(buffer.source_buffer)
84 .opaque();
85 void* recv_buffer =
86 params.buffer_allocations->GetDeviceAddress(buffer.destination_buffer)
87 .opaque();
88
89 TF_ASSIGN_OR_RETURN(ncclDataType_t datatype,
90 ToNcclDataType(config_.config.operand_element_type[i]));
91
92 VLOG(3) << absl::StreamFormat(
93 "Calling ncclAllGather(send_buffer=%p, recv_buffer=%p, count=%d, "
94 "comm=%p, stream=%p)",
95 send_buffer, recv_buffer, buffer.element_count,
96 static_cast<const void*>(comm), cu_stream);
97
98 XLA_CUDA_RETURN_IF_ERROR(ncclAllGather(send_buffer, recv_buffer,
99 buffer.element_count, datatype, comm,
100 *cu_stream));
101 }
102 XLA_CUDA_RETURN_IF_ERROR(ncclGroupEnd());
103
104 VLOG(3) << "Done performing all-gather for ordinal: " << device_ordinal;
105 return Status::OK();
106 #else // XLA_ENABLE_XCCL
107 return Unimplemented(
108 "NCCL support is not available: this binary was not built with a CUDA "
109 "compiler, which is necessary to build the NCCL source library.");
110 #endif // XLA_ENABLE_XCCL
111 }
112
113 } // namespace gpu
114 } // namespace xla
115