1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17
18 #include <string>
19
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30
31 namespace xla {
32 namespace gpu {
33
TriangularSolveThunk(const TriangularSolveOptions & options,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,PrimitiveType type,int64 batch_size,int64 m,int64 n,int64 a_batch_stride,int64 b_batch_stride,const HloInstruction * hlo)34 TriangularSolveThunk::TriangularSolveThunk(
35 const TriangularSolveOptions& options,
36 const BufferAllocation::Slice& a_buffer,
37 const BufferAllocation::Slice& b_buffer, PrimitiveType type,
38 int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
39 int64 b_batch_stride, const HloInstruction* hlo)
40 : Thunk(Kind::kTriangularSolve, hlo),
41 uplo_(options.lower() ? se::blas::UpperLower::kLower
42 : se::blas::UpperLower::kUpper),
43 side_(options.left_side() ? se::blas::Side::kLeft
44 : se::blas::Side::kRight),
45 unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
46 : se::blas::Diagonal::kNonUnit),
47 a_buffer_(a_buffer),
48 b_buffer_(b_buffer),
49 type_(type),
50 batch_size_(batch_size),
51 m_(m),
52 n_(n),
53 a_batch_stride_(a_batch_stride),
54 b_batch_stride_(b_batch_stride) {
55 transpose_a_ = [&] {
56 switch (options.transpose_a()) {
57 case TriangularSolveOptions::NO_TRANSPOSE:
58 return se::blas::Transpose::kNoTranspose;
59 case TriangularSolveOptions::TRANSPOSE:
60 return se::blas::Transpose::kTranspose;
61 case TriangularSolveOptions::ADJOINT:
62 return se::blas::Transpose::kConjugateTranspose;
63 default:
64 LOG(ERROR) << "Invalid triangular solve transpose value "
65 << options.transpose_a();
66 return se::blas::Transpose::kNoTranspose;
67 }
68 }();
69 }
70
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)71 Status TriangularSolveThunk::ExecuteOnStream(
72 const BufferAllocations& buffer_allocations, se::Stream* stream,
73 HloExecutionProfiler* profiler) {
74 VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
75 << " side=" << se::blas::SideString(side_)
76 << " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
77 << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_
78 << " a_batch_stride=" << a_batch_stride_
79 << " b_batch_stride=" << b_batch_stride_;
80
81 const int lda = side_ == se::blas::Side::kLeft ? m_ : n_;
82 const int ldb = m_;
83
84 char* a_base = static_cast<char*>(
85 buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
86 char* b_base = static_cast<char*>(
87 buffer_allocations.GetDeviceAddress(b_buffer_).opaque());
88 for (int64 i = 0; i < batch_size_; ++i) {
89 bool launch_ok;
90 se::DeviceMemoryBase a_data =
91 se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
92 se::DeviceMemoryBase b_data =
93 se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_);
94 switch (type_) {
95 case F32: {
96 se::DeviceMemory<float> b_data_typed(b_data);
97 launch_ok = stream
98 ->ThenBlasTrsm(side_, uplo_, transpose_a_,
99 unit_diagonal_, m_, n_, /*alpha=*/1.0f,
100 se::DeviceMemory<float>(a_data), lda,
101 &b_data_typed, ldb)
102 .ok();
103 break;
104 }
105 case F64: {
106 se::DeviceMemory<double> b_data_typed(b_data);
107 launch_ok = stream
108 ->ThenBlasTrsm(side_, uplo_, transpose_a_,
109 unit_diagonal_, m_, n_, /*alpha=*/1.0,
110 se::DeviceMemory<double>(a_data), lda,
111 &b_data_typed, ldb)
112 .ok();
113 break;
114 }
115 case C64: {
116 se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
117 launch_ok =
118 stream
119 ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
120 n_, /*alpha=*/1.0f,
121 se::DeviceMemory<std::complex<float>>(a_data),
122 lda, &b_data_typed, ldb)
123 .ok();
124 break;
125 }
126 case C128: {
127 se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
128 launch_ok =
129 stream
130 ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
131 n_, /*alpha=*/1.0,
132 se::DeviceMemory<std::complex<double>>(a_data),
133 lda, &b_data_typed, ldb)
134 .ok();
135 break;
136 }
137 default:
138 return InvalidArgument("Invalid type for triangular solve %d", type_);
139 }
140 if (!launch_ok) {
141 return InternalError("Unable to launch triangular solve for thunk %p",
142 this);
143 }
144 }
145 return Status::OK();
146 }
147
148 } // namespace gpu
149 } // namespace xla
150