1 //===-- X86InstrFMA3Info.cpp - X86 FMA3 Instruction Information -----------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file contains the implementation of the classes providing information
11 // about existing X86 FMA3 opcodes, classifying and grouping them.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "X86InstrFMA3Info.h"
16 #include "X86InstrInfo.h"
17 #include "llvm/Support/ManagedStatic.h"
18 #include "llvm/Support/Threading.h"
19 #include <cassert>
20 #include <cstdint>
21 
22 using namespace llvm;
23 
24 #define FMA3GROUP(Name, Suf, Attrs) \
25   { { X86::Name##132##Suf, X86::Name##213##Suf, X86::Name##231##Suf }, Attrs },
26 
27 #define FMA3GROUP_MASKED(Name, Suf, Attrs) \
28   FMA3GROUP(Name, Suf, Attrs) \
29   FMA3GROUP(Name, Suf##k, Attrs | X86InstrFMA3Group::KMergeMasked) \
30   FMA3GROUP(Name, Suf##kz, Attrs | X86InstrFMA3Group::KZeroMasked)
31 
32 #define FMA3GROUP_PACKED_WIDTHS(Name, Suf, Attrs) \
33   FMA3GROUP(Name, Suf##Ym, Attrs) \
34   FMA3GROUP(Name, Suf##Yr, Attrs) \
35   FMA3GROUP_MASKED(Name, Suf##Z128m, Attrs) \
36   FMA3GROUP_MASKED(Name, Suf##Z128r, Attrs) \
37   FMA3GROUP_MASKED(Name, Suf##Z256m, Attrs) \
38   FMA3GROUP_MASKED(Name, Suf##Z256r, Attrs) \
39   FMA3GROUP_MASKED(Name, Suf##Zm, Attrs) \
40   FMA3GROUP_MASKED(Name, Suf##Zr, Attrs) \
41   FMA3GROUP(Name, Suf##m, Attrs) \
42   FMA3GROUP(Name, Suf##r, Attrs)
43 
44 #define FMA3GROUP_PACKED(Name, Attrs) \
45   FMA3GROUP_PACKED_WIDTHS(Name, PD, Attrs) \
46   FMA3GROUP_PACKED_WIDTHS(Name, PS, Attrs)
47 
48 #define FMA3GROUP_SCALAR_WIDTHS(Name, Suf, Attrs) \
49   FMA3GROUP(Name, Suf##Zm, Attrs) \
50   FMA3GROUP_MASKED(Name, Suf##Zm_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
51   FMA3GROUP(Name, Suf##Zr, Attrs) \
52   FMA3GROUP_MASKED(Name, Suf##Zr_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
53   FMA3GROUP(Name, Suf##m, Attrs) \
54   FMA3GROUP(Name, Suf##m_Int, Attrs | X86InstrFMA3Group::Intrinsic) \
55   FMA3GROUP(Name, Suf##r, Attrs) \
56   FMA3GROUP(Name, Suf##r_Int, Attrs | X86InstrFMA3Group::Intrinsic)
57 
58 #define FMA3GROUP_SCALAR(Name, Attrs) \
59   FMA3GROUP_SCALAR_WIDTHS(Name, SD, Attrs) \
60   FMA3GROUP_SCALAR_WIDTHS(Name, SS, Attrs) \
61 
62 #define FMA3GROUP_FULL(Name, Attrs) \
63   FMA3GROUP_PACKED(Name, Attrs) \
64   FMA3GROUP_SCALAR(Name, Attrs)
65 
66 static const X86InstrFMA3Group Groups[] = {
67   FMA3GROUP_FULL(VFMADD, 0)
68   FMA3GROUP_PACKED(VFMADDSUB, 0)
69   FMA3GROUP_FULL(VFMSUB, 0)
70   FMA3GROUP_PACKED(VFMSUBADD, 0)
71   FMA3GROUP_FULL(VFNMADD, 0)
72   FMA3GROUP_FULL(VFNMSUB, 0)
73 };
74 
75 #define FMA3GROUP_PACKED_AVX512_WIDTHS(Name, Type, Suf, Attrs) \
76   FMA3GROUP_MASKED(Name, Type##Z128##Suf, Attrs) \
77   FMA3GROUP_MASKED(Name, Type##Z256##Suf, Attrs) \
78   FMA3GROUP_MASKED(Name, Type##Z##Suf, Attrs)
79 
80 #define FMA3GROUP_PACKED_AVX512(Name, Suf, Attrs) \
81   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PD, Suf, Attrs) \
82   FMA3GROUP_PACKED_AVX512_WIDTHS(Name, PS, Suf, Attrs)
83 
84 #define FMA3GROUP_PACKED_AVX512_ROUND(Name, Suf, Attrs) \
85   FMA3GROUP_MASKED(Name, PDZ##Suf, Attrs) \
86   FMA3GROUP_MASKED(Name, PSZ##Suf, Attrs)
87 
88 #define FMA3GROUP_SCALAR_AVX512_ROUND(Name, Suf, Attrs) \
89   FMA3GROUP(Name, SDZ##Suf, Attrs) \
90   FMA3GROUP_MASKED(Name, SDZ##Suf##_Int, Attrs) \
91   FMA3GROUP(Name, SSZ##Suf, Attrs) \
92   FMA3GROUP_MASKED(Name, SSZ##Suf##_Int, Attrs)
93 
94 static const X86InstrFMA3Group BroadcastGroups[] = {
95   FMA3GROUP_PACKED_AVX512(VFMADD, mb, 0)
96   FMA3GROUP_PACKED_AVX512(VFMADDSUB, mb, 0)
97   FMA3GROUP_PACKED_AVX512(VFMSUB, mb, 0)
98   FMA3GROUP_PACKED_AVX512(VFMSUBADD, mb, 0)
99   FMA3GROUP_PACKED_AVX512(VFNMADD, mb, 0)
100   FMA3GROUP_PACKED_AVX512(VFNMSUB, mb, 0)
101 };
102 
103 static const X86InstrFMA3Group RoundGroups[] = {
104   FMA3GROUP_PACKED_AVX512_ROUND(VFMADD, rb, 0)
105   FMA3GROUP_SCALAR_AVX512_ROUND(VFMADD, rb, X86InstrFMA3Group::Intrinsic)
106   FMA3GROUP_PACKED_AVX512_ROUND(VFMADDSUB, rb, 0)
107   FMA3GROUP_PACKED_AVX512_ROUND(VFMSUB, rb, 0)
108   FMA3GROUP_SCALAR_AVX512_ROUND(VFMSUB, rb, X86InstrFMA3Group::Intrinsic)
109   FMA3GROUP_PACKED_AVX512_ROUND(VFMSUBADD, rb, 0)
110   FMA3GROUP_PACKED_AVX512_ROUND(VFNMADD, rb, 0)
111   FMA3GROUP_SCALAR_AVX512_ROUND(VFNMADD, rb, X86InstrFMA3Group::Intrinsic)
112   FMA3GROUP_PACKED_AVX512_ROUND(VFNMSUB, rb, 0)
113   FMA3GROUP_SCALAR_AVX512_ROUND(VFNMSUB, rb, X86InstrFMA3Group::Intrinsic)
114 };
115 
verifyTables()116 static void verifyTables() {
117 #ifndef NDEBUG
118   static std::atomic<bool> TableChecked(false);
119   if (!TableChecked.load(std::memory_order_relaxed)) {
120     assert(std::is_sorted(std::begin(Groups), std::end(Groups)) &&
121            std::is_sorted(std::begin(RoundGroups), std::end(RoundGroups)) &&
122            std::is_sorted(std::begin(BroadcastGroups),
123                           std::end(BroadcastGroups)) &&
124            "FMA3 tables not sorted!");
125     TableChecked.store(true, std::memory_order_relaxed);
126   }
127 #endif
128 }
129 
130 /// Returns a reference to a group of FMA3 opcodes to where the given
131 /// \p Opcode is included. If the given \p Opcode is not recognized as FMA3
132 /// and not included into any FMA3 group, then nullptr is returned.
getFMA3Group(unsigned Opcode,uint64_t TSFlags)133 const X86InstrFMA3Group *llvm::getFMA3Group(unsigned Opcode, uint64_t TSFlags) {
134 
135   // FMA3 instructions have a well defined encoding pattern we can exploit.
136   uint8_t BaseOpcode = X86II::getBaseOpcodeFor(TSFlags);
137   bool IsFMA3 = ((TSFlags & X86II::EncodingMask) == X86II::VEX ||
138                  (TSFlags & X86II::EncodingMask) == X86II::EVEX) &&
139                 (TSFlags & X86II::OpMapMask) == X86II::T8 &&
140                 (TSFlags & X86II::OpPrefixMask) == X86II::PD &&
141                 ((BaseOpcode >= 0x96 && BaseOpcode <= 0x9F) ||
142                  (BaseOpcode >= 0xA6 && BaseOpcode <= 0xAF) ||
143                  (BaseOpcode >= 0xB6 && BaseOpcode <= 0xBF));
144   if (!IsFMA3)
145     return nullptr;
146 
147   verifyTables();
148 
149   ArrayRef<X86InstrFMA3Group> Table;
150   if (TSFlags & X86II::EVEX_RC)
151     Table = makeArrayRef(RoundGroups);
152   else if (TSFlags & X86II::EVEX_B)
153     Table = makeArrayRef(BroadcastGroups);
154   else
155     Table = makeArrayRef(Groups);
156 
157   // FMA 132 instructions have an opcode of 0x96-0x9F
158   // FMA 213 instructions have an opcode of 0xA6-0xAF
159   // FMA 231 instructions have an opcode of 0xB6-0xBF
160   unsigned FormIndex = ((BaseOpcode - 0x90) >> 4) & 0x3;
161 
162   auto I = std::lower_bound(Table.begin(), Table.end(), Opcode,
163                             [FormIndex](const X86InstrFMA3Group &Group,
164                                         unsigned Opcode) {
165                               return Group.Opcodes[FormIndex] < Opcode;
166                             });
167   assert(I != Table.end() && I->Opcodes[FormIndex] == Opcode &&
168          "Couldn't find FMA3 opcode!");
169   return I;
170 }
171