1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
17 
18 #include <string>
19 
20 #include "absl/strings/str_cat.h"
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
28 #include "tensorflow/stream_executor/blas.h"
29 #include "tensorflow/stream_executor/device_memory.h"
30 
31 namespace xla {
32 namespace gpu {
33 
TriangularSolveThunk(const TriangularSolveOptions & options,const BufferAllocation::Slice & a_buffer,const BufferAllocation::Slice & b_buffer,PrimitiveType type,int64 batch_size,int64 m,int64 n,int64 a_batch_stride,int64 b_batch_stride,const HloInstruction * hlo)34 TriangularSolveThunk::TriangularSolveThunk(
35     const TriangularSolveOptions& options,
36     const BufferAllocation::Slice& a_buffer,
37     const BufferAllocation::Slice& b_buffer, PrimitiveType type,
38     int64 batch_size, int64 m, int64 n, int64 a_batch_stride,
39     int64 b_batch_stride, const HloInstruction* hlo)
40     : Thunk(Kind::kTriangularSolve, hlo),
41       uplo_(options.lower() ? se::blas::UpperLower::kLower
42                             : se::blas::UpperLower::kUpper),
43       side_(options.left_side() ? se::blas::Side::kLeft
44                                 : se::blas::Side::kRight),
45       unit_diagonal_(options.unit_diagonal() ? se::blas::Diagonal::kUnit
46                                              : se::blas::Diagonal::kNonUnit),
47       a_buffer_(a_buffer),
48       b_buffer_(b_buffer),
49       type_(type),
50       batch_size_(batch_size),
51       m_(m),
52       n_(n),
53       a_batch_stride_(a_batch_stride),
54       b_batch_stride_(b_batch_stride) {
55   transpose_a_ = [&] {
56     switch (options.transpose_a()) {
57       case TriangularSolveOptions::NO_TRANSPOSE:
58         return se::blas::Transpose::kNoTranspose;
59       case TriangularSolveOptions::TRANSPOSE:
60         return se::blas::Transpose::kTranspose;
61       case TriangularSolveOptions::ADJOINT:
62         return se::blas::Transpose::kConjugateTranspose;
63       default:
64         LOG(ERROR) << "Invalid triangular solve transpose value "
65                    << options.transpose_a();
66         return se::blas::Transpose::kNoTranspose;
67     }
68   }();
69 }
70 
ExecuteOnStream(const BufferAllocations & buffer_allocations,se::Stream * stream,HloExecutionProfiler * profiler)71 Status TriangularSolveThunk::ExecuteOnStream(
72     const BufferAllocations& buffer_allocations, se::Stream* stream,
73     HloExecutionProfiler* profiler) {
74   VLOG(3) << "uplo=" << se::blas::UpperLowerString(uplo_)
75           << " side=" << se::blas::SideString(side_)
76           << " diagonal=" << se::blas::DiagonalString(unit_diagonal_)
77           << " batch_size=" << batch_size_ << " m=" << m_ << " n=" << n_
78           << " a_batch_stride=" << a_batch_stride_
79           << " b_batch_stride=" << b_batch_stride_;
80 
81   const int lda = side_ == se::blas::Side::kLeft ? m_ : n_;
82   const int ldb = m_;
83 
84   char* a_base = static_cast<char*>(
85       buffer_allocations.GetDeviceAddress(a_buffer_).opaque());
86   char* b_base = static_cast<char*>(
87       buffer_allocations.GetDeviceAddress(b_buffer_).opaque());
88   for (int64 i = 0; i < batch_size_; ++i) {
89     bool launch_ok;
90     se::DeviceMemoryBase a_data =
91         se::DeviceMemoryBase(a_base + i * a_batch_stride_, a_batch_stride_);
92     se::DeviceMemoryBase b_data =
93         se::DeviceMemoryBase(b_base + i * b_batch_stride_, b_batch_stride_);
94     switch (type_) {
95       case F32: {
96         se::DeviceMemory<float> b_data_typed(b_data);
97         launch_ok = stream
98                         ->ThenBlasTrsm(side_, uplo_, transpose_a_,
99                                        unit_diagonal_, m_, n_, /*alpha=*/1.0f,
100                                        se::DeviceMemory<float>(a_data), lda,
101                                        &b_data_typed, ldb)
102                         .ok();
103         break;
104       }
105       case F64: {
106         se::DeviceMemory<double> b_data_typed(b_data);
107         launch_ok = stream
108                         ->ThenBlasTrsm(side_, uplo_, transpose_a_,
109                                        unit_diagonal_, m_, n_, /*alpha=*/1.0,
110                                        se::DeviceMemory<double>(a_data), lda,
111                                        &b_data_typed, ldb)
112                         .ok();
113         break;
114       }
115       case C64: {
116         se::DeviceMemory<std::complex<float>> b_data_typed(b_data);
117         launch_ok =
118             stream
119                 ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
120                                n_, /*alpha=*/1.0f,
121                                se::DeviceMemory<std::complex<float>>(a_data),
122                                lda, &b_data_typed, ldb)
123                 .ok();
124         break;
125       }
126       case C128: {
127         se::DeviceMemory<std::complex<double>> b_data_typed(b_data);
128         launch_ok =
129             stream
130                 ->ThenBlasTrsm(side_, uplo_, transpose_a_, unit_diagonal_, m_,
131                                n_, /*alpha=*/1.0,
132                                se::DeviceMemory<std::complex<double>>(a_data),
133                                lda, &b_data_typed, ldb)
134                 .ok();
135         break;
136       }
137       default:
138         return InvalidArgument("Invalid type for triangular solve %d", type_);
139     }
140     if (!launch_ok) {
141       return InternalError("Unable to launch triangular solve for thunk %p",
142                            this);
143     }
144   }
145   return Status::OK();
146 }
147 
148 }  // namespace gpu
149 }  // namespace xla
150