LLVMIRConversionGen.cpp
11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
//===- LLVMIRConversionGen.cpp - MLIR LLVM IR builder generator -----------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file uses tablegen definitions of the LLVM IR Dialect operations to
// generate the code building the LLVM IR from it.
//
//===----------------------------------------------------------------------===//
#include "mlir/Support/LogicalResult.h"
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/TableGen/Record.h"
#include "llvm/TableGen/TableGenBackend.h"
using namespace llvm;
using namespace mlir;
static bool emitError(const Twine &message) {
llvm::errs() << message << "\n";
return false;
}
namespace {
// Helper structure to return a position of the substring in a string.
struct StringLoc {
size_t pos;
size_t length;
// Take a substring identified by this location in the given string.
StringRef in(StringRef str) const { return str.substr(pos, length); }
// A location is invalid if its position is outside the string.
explicit operator bool() { return pos != std::string::npos; }
};
} // namespace
// Find the next TableGen variable in the given pattern. These variables start
// with a `$` character and can contain alphanumeric characters or underscores.
// Return the position of the variable in the pattern and its length, including
// the `$` character. The escape syntax `$$` is also detected and returned.
static StringLoc findNextVariable(StringRef str) {
size_t startPos = str.find('$');
if (startPos == std::string::npos)
return {startPos, 0};
// If we see "$$", return immediately.
if (startPos != str.size() - 1 && str[startPos + 1] == '$')
return {startPos, 2};
// Otherwise, the symbol spans until the first character that is not
// alphanumeric or '_'.
size_t endPos = str.find_if_not([](char c) { return isAlnum(c) || c == '_'; },
startPos + 1);
if (endPos == std::string::npos)
endPos = str.size();
return {startPos, endPos - startPos};
}
// Check if `name` is the name of the variadic operand of `op`. The variadic
// operand can only appear at the last position in the list of operands.
static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
unsigned numOperands = op.getNumOperands();
if (numOperands == 0)
return false;
const auto &operand = op.getOperand(numOperands - 1);
return operand.isVariableLength() && operand.name == name;
}
// Check if `result` is a known name of a result of `op`.
static bool isResultName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumResults(); i < e; ++i)
if (op.getResultName(i) == name)
return true;
return false;
}
// Check if `name` is a known name of an attribute of `op`.
static bool isAttributeName(const tblgen::Operator &op, StringRef name) {
return llvm::any_of(
op.getAttributes(),
[name](const tblgen::NamedAttribute &attr) { return attr.name == name; });
}
// Check if `name` is a known name of an operand of `op`.
static bool isOperandName(const tblgen::Operator &op, StringRef name) {
for (int i = 0, e = op.getNumOperands(); i < e; ++i)
if (op.getOperand(i).name == name)
return true;
return false;
}
// Emit to `os` the operator-name driven check and the call to LLVM IRBuilder
// for one definition of a LLVM IR Dialect operation. Return true on success.
static bool emitOneBuilder(const Record &record, raw_ostream &os) {
auto op = tblgen::Operator(record);
if (!record.getValue("llvmBuilder"))
return emitError("no 'llvmBuilder' field for op " + op.getOperationName());
// Return early if there is no builder specified.
auto builderStrRef = record.getValueAsString("llvmBuilder");
if (builderStrRef.empty())
return true;
// Progressively create the builder string by replacing $-variables with
// value lookups. Keep only the not-yet-traversed part of the builder pattern
// to avoid re-traversing the string multiple times.
std::string builder;
llvm::raw_string_ostream bs(builder);
while (auto loc = findNextVariable(builderStrRef)) {
auto name = loc.in(builderStrRef).drop_front();
// First, insert the non-matched part as is.
bs << builderStrRef.substr(0, loc.pos);
// Then, rewrite the name based on its kind.
bool isVariadicOperand = isVariadicOperandName(op, name);
if (isOperandName(op, name)) {
auto result = isVariadicOperand
? formatv("lookupValues(op.{0}())", name)
: formatv("valueMapping.lookup(op.{0}())", name);
bs << result;
} else if (isAttributeName(op, name)) {
bs << formatv("op.{0}()", name);
} else if (isResultName(op, name)) {
bs << formatv("valueMapping[op.{0}()]", name);
} else if (name == "_resultType") {
bs << "op.getResult().getType().cast<LLVM::LLVMType>()."
"getUnderlyingType()";
} else if (name == "_hasResult") {
bs << "opInst.getNumResults() == 1";
} else if (name == "_location") {
bs << "opInst.getLoc()";
} else if (name == "_numOperands") {
bs << "opInst.getNumOperands()";
} else if (name == "$") {
bs << '$';
} else {
return emitError(name + " is neither an argument nor a result of " +
op.getOperationName());
}
// Finally, only keep the untraversed part of the string.
builderStrRef = builderStrRef.substr(loc.pos + loc.length);
}
// Output the check and the rewritten builder string.
os << "if (auto op = dyn_cast<" << op.getQualCppClassName()
<< ">(opInst)) {\n";
os << bs.str() << builderStrRef << "\n";
os << " return success();\n";
os << "}\n";
return true;
}
// Emit all builders. Returns false on success because of the generator
// registration requirements.
static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
if (!emitOneBuilder(*def, os))
return true;
}
return false;
}
namespace {
// Wrapper class around a Tablegen definition of an LLVM enum attribute case.
class LLVMEnumAttrCase : public tblgen::EnumAttrCase {
public:
using tblgen::EnumAttrCase::EnumAttrCase;
// Constructs a case from a non LLVM-specific enum attribute case.
explicit LLVMEnumAttrCase(const tblgen::EnumAttrCase &other)
: tblgen::EnumAttrCase(&other.getDef()) {}
// Returns the C++ enumerant for the LLVM API.
StringRef getLLVMEnumerant() const {
return def->getValueAsString("llvmEnumerant");
}
};
// Wraper class around a Tablegen definition of an LLVM enum attribute.
class LLVMEnumAttr : public tblgen::EnumAttr {
public:
using tblgen::EnumAttr::EnumAttr;
// Returns the C++ enum name for the LLVM API.
StringRef getLLVMClassName() const {
return def->getValueAsString("llvmClassName");
}
// Returns all associated cases viewed as LLVM-specific enum cases.
std::vector<LLVMEnumAttrCase> getAllCases() const {
std::vector<LLVMEnumAttrCase> cases;
for (auto &c : tblgen::EnumAttr::getAllCases())
cases.push_back(LLVMEnumAttrCase(c));
return cases;
}
};
} // namespace
// Emits conversion function "LLVMClass convertEnumToLLVM(Enum)" and containing
// switch-based logic to convert from the MLIR LLVM dialect enum attribute case
// (Enum) to the corresponding LLVM API enumerant
static void emitOneEnumToConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute to its LLVM counterpart.
os << formatv("static {0} convert{1}ToLLVM({2}::{1} value) {{\n", llvmClass,
cppClassName, cppNamespace);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}::{2}:\n", cppNamespace, cppClassName,
cppEnumerant);
os << formatv(" return {0}::{1};\n", llvmClass, llvmEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");\n",
enumAttr.getEnumClassName());
os << "}\n\n";
}
// Emits conversion function "Enum convertEnumFromLLVM(LLVMClass)" and
// containing switch-based logic to convert from the LLVM API enumerant to MLIR
// LLVM dialect enum attribute (Enum).
static void emitOneEnumFromConversion(const llvm::Record *record,
raw_ostream &os) {
LLVMEnumAttr enumAttr(record);
StringRef llvmClass = enumAttr.getLLVMClassName();
StringRef cppClassName = enumAttr.getEnumClassName();
StringRef cppNamespace = enumAttr.getCppNamespace();
// Emit the function converting the enum attribute from its LLVM counterpart.
os << formatv("static {0}::{1} convert{1}FromLLVM({2} value) {{\n",
cppNamespace, cppClassName, llvmClass);
os << " switch (value) {\n";
for (const auto &enumerant : enumAttr.getAllCases()) {
StringRef llvmEnumerant = enumerant.getLLVMEnumerant();
StringRef cppEnumerant = enumerant.getSymbol();
os << formatv(" case {0}::{1}:\n", llvmClass, llvmEnumerant);
os << formatv(" return {0}::{1}::{2};\n", cppNamespace, cppClassName,
cppEnumerant);
}
os << " }\n";
os << formatv(" llvm_unreachable(\"unknown {0} type\");",
enumAttr.getLLVMClassName());
os << "}\n\n";
}
// Emits conversion functions between MLIR enum attribute case and corresponding
// LLVM API enumerants for all registered LLVM dialect enum attributes.
template <bool ConvertTo>
static bool emitEnumConversionDefs(const RecordKeeper &recordKeeper,
raw_ostream &os) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_EnumAttr"))
if (ConvertTo)
emitOneEnumToConversion(def, os);
else
emitOneEnumFromConversion(def, os);
return false;
}
static mlir::GenRegistration
genLLVMIRConversions("gen-llvmir-conversions",
"Generate LLVM IR conversions", emitBuilders);
static mlir::GenRegistration
genEnumToLLVMConversion("gen-enum-to-llvmir-conversions",
"Generate conversions of EnumAttrs to LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/true>);
static mlir::GenRegistration
genEnumFromLLVMConversion("gen-enum-from-llvmir-conversions",
"Generate conversions of EnumAttrs from LLVM IR",
emitEnumConversionDefs</*ConvertTo=*/false>);