1 //===-- lib/Evaluate/complex.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "flang/Evaluate/complex.h"
10 #include "llvm/Support/raw_ostream.h"
11
12 namespace Fortran::evaluate::value {
13
14 template <typename R>
Add(const Complex & that,Rounding rounding) const15 ValueWithRealFlags<Complex<R>> Complex<R>::Add(
16 const Complex &that, Rounding rounding) const {
17 RealFlags flags;
18 Part reSum{re_.Add(that.re_, rounding).AccumulateFlags(flags)};
19 Part imSum{im_.Add(that.im_, rounding).AccumulateFlags(flags)};
20 return {Complex{reSum, imSum}, flags};
21 }
22
23 template <typename R>
Subtract(const Complex & that,Rounding rounding) const24 ValueWithRealFlags<Complex<R>> Complex<R>::Subtract(
25 const Complex &that, Rounding rounding) const {
26 RealFlags flags;
27 Part reDiff{re_.Subtract(that.re_, rounding).AccumulateFlags(flags)};
28 Part imDiff{im_.Subtract(that.im_, rounding).AccumulateFlags(flags)};
29 return {Complex{reDiff, imDiff}, flags};
30 }
31
32 template <typename R>
Multiply(const Complex & that,Rounding rounding) const33 ValueWithRealFlags<Complex<R>> Complex<R>::Multiply(
34 const Complex &that, Rounding rounding) const {
35 // (a + ib)*(c + id) -> ac - bd + i(ad + bc)
36 RealFlags flags;
37 Part ac{re_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
38 Part bd{im_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
39 Part ad{re_.Multiply(that.im_, rounding).AccumulateFlags(flags)};
40 Part bc{im_.Multiply(that.re_, rounding).AccumulateFlags(flags)};
41 Part acbd{ac.Subtract(bd, rounding).AccumulateFlags(flags)};
42 Part adbc{ad.Add(bc, rounding).AccumulateFlags(flags)};
43 return {Complex{acbd, adbc}, flags};
44 }
45
46 template <typename R>
Divide(const Complex & that,Rounding rounding) const47 ValueWithRealFlags<Complex<R>> Complex<R>::Divide(
48 const Complex &that, Rounding rounding) const {
49 // (a + ib)/(c + id) -> [(a+ib)*(c-id)] / [(c+id)*(c-id)]
50 // -> [ac+bd+i(bc-ad)] / (cc+dd)
51 // -> ((ac+bd)/(cc+dd)) + i((bc-ad)/(cc+dd))
52 // but to avoid overflows, scale by d/c if c>=d, else c/d
53 Part scale; // <= 1.0
54 RealFlags flags;
55 bool cGEd{that.re_.ABS().Compare(that.im_.ABS()) != Relation::Less};
56 if (cGEd) {
57 scale = that.im_.Divide(that.re_, rounding).AccumulateFlags(flags);
58 } else {
59 scale = that.re_.Divide(that.im_, rounding).AccumulateFlags(flags);
60 }
61 Part den;
62 if (cGEd) {
63 Part dS{scale.Multiply(that.im_, rounding).AccumulateFlags(flags)};
64 den = dS.Add(that.re_, rounding).AccumulateFlags(flags);
65 } else {
66 Part cS{scale.Multiply(that.re_, rounding).AccumulateFlags(flags)};
67 den = cS.Add(that.im_, rounding).AccumulateFlags(flags);
68 }
69 Part aS{scale.Multiply(re_, rounding).AccumulateFlags(flags)};
70 Part bS{scale.Multiply(im_, rounding).AccumulateFlags(flags)};
71 Part re1, im1;
72 if (cGEd) {
73 re1 = re_.Add(bS, rounding).AccumulateFlags(flags);
74 im1 = im_.Subtract(aS, rounding).AccumulateFlags(flags);
75 } else {
76 re1 = aS.Add(im_, rounding).AccumulateFlags(flags);
77 im1 = bS.Subtract(re_, rounding).AccumulateFlags(flags);
78 }
79 Part re{re1.Divide(den, rounding).AccumulateFlags(flags)};
80 Part im{im1.Divide(den, rounding).AccumulateFlags(flags)};
81 return {Complex{re, im}, flags};
82 }
83
DumpHexadecimal() const84 template <typename R> std::string Complex<R>::DumpHexadecimal() const {
85 std::string result{'('};
86 result += re_.DumpHexadecimal();
87 result += ',';
88 result += im_.DumpHexadecimal();
89 result += ')';
90 return result;
91 }
92
93 template <typename R>
AsFortran(llvm::raw_ostream & o,int kind) const94 llvm::raw_ostream &Complex<R>::AsFortran(llvm::raw_ostream &o, int kind) const {
95 re_.AsFortran(o << '(', kind);
96 im_.AsFortran(o << ',', kind);
97 return o << ')';
98 }
99
100 template class Complex<Real<Integer<16>, 11>>;
101 template class Complex<Real<Integer<16>, 8>>;
102 template class Complex<Real<Integer<32>, 24>>;
103 template class Complex<Real<Integer<64>, 53>>;
104 template class Complex<Real<Integer<80>, 64>>;
105 template class Complex<Real<Integer<128>, 113>>;
106 } // namespace Fortran::evaluate::value
107