PassGen.cpp
6.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
//===- Pass.cpp - MLIR pass registration generator ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// PassGen uses the description of passes to generate base classes for passes
// and command line registration.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/GenInfo.h"
#include "mlir/TableGen/Pass.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// GEN: Pass base class generation
//===----------------------------------------------------------------------===//
/// The code snippet used to generate the start of a pass base class.
///
/// {0}: The def name of the pass record.
/// {1}: The base class for the pass.
/// {2): The command line argument for the pass.
const char *const passDeclBegin = R"(
//===----------------------------------------------------------------------===//
// {0}
//===----------------------------------------------------------------------===//
template <typename DerivedT>
class {0}Base : public {1} {
public:
{0}Base() : {1}(::mlir::TypeID::get<DerivedT>()) {{}
{0}Base(const {0}Base &) : {1}(::mlir::TypeID::get<DerivedT>()) {{}
/// Returns the command-line argument attached to this pass.
::llvm::StringRef getArgument() const override { return "{2}"; }
/// Returns the derived pass name.
::llvm::StringRef getName() const override { return "{0}"; }
/// Support isa/dyn_cast functionality for the derived pass class.
static bool classof(const ::mlir::Pass *pass) {{
return pass->getTypeID() == ::mlir::TypeID::get<DerivedT>();
}
/// A clone method to create a copy of this pass.
std::unique_ptr<::mlir::Pass> clonePass() const override {{
return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
}
protected:
)";
/// Emit the declarations for each of the pass options.
static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
for (const PassOption &opt : pass.getOptions()) {
os.indent(2) << "::mlir::Pass::"
<< (opt.isListOption() ? "ListOption" : "Option");
os << llvm::formatv("<{0}> {1}{{*this, \"{2}\", ::llvm::cl::desc(\"{3}\")",
opt.getType(), opt.getCppVariableName(),
opt.getArgument(), opt.getDescription());
if (Optional<StringRef> defaultVal = opt.getDefaultValue())
os << ", ::llvm::cl::init(" << defaultVal << ")";
if (Optional<StringRef> additionalFlags = opt.getAdditionalFlags())
os << ", " << *additionalFlags;
os << "};\n";
}
}
/// Emit the declarations for each of the pass statistics.
static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
for (const PassStatistic &stat : pass.getStatistics()) {
os << llvm::formatv(
" ::mlir::Pass::Statistic {0}{{this, \"{1}\", \"{2}\"};\n",
stat.getCppVariableName(), stat.getName(), stat.getDescription());
}
}
static void emitPassDecl(const Pass &pass, raw_ostream &os) {
StringRef defName = pass.getDef()->getName();
os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
pass.getArgument());
emitPassOptionDecls(pass, os);
emitPassStatisticDecls(pass, os);
os << "};\n";
}
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitPassDecls(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_CLASSES\n";
for (const Pass &pass : passes)
emitPassDecl(pass, os);
os << "#undef GEN_PASS_CLASSES\n";
os << "#endif // GEN_PASS_CLASSES\n";
}
//===----------------------------------------------------------------------===//
// GEN: Pass registration generation
//===----------------------------------------------------------------------===//
/// Emit the code for registering each of the given passes with the global
/// PassRegistry.
static void emitRegistration(ArrayRef<Pass> passes, raw_ostream &os) {
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv("#define GEN_PASS_REGISTRATION_{0}\n",
pass.getDef()->getName());
}
os << "#endif // GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv("#ifdef GEN_PASS_REGISTRATION_{0}\n",
pass.getDef()->getName());
os << llvm::formatv("::mlir::registerPass(\"{0}\", \"{1}\", []() -> "
"std::unique_ptr<::mlir::Pass> {{ return {2}; });\n",
pass.getArgument(), pass.getSummary(),
pass.getConstructor());
os << llvm::formatv("#endif // GEN_PASS_REGISTRATION_{0}\n",
pass.getDef()->getName());
os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
pass.getDef()->getName());
}
os << "#ifdef GEN_PASS_REGISTRATION\n";
for (const Pass &pass : passes) {
os << llvm::formatv("#undef GEN_PASS_REGISTRATION_{0}\n",
pass.getDef()->getName());
}
os << "#endif // GEN_PASS_REGISTRATION\n";
os << "#undef GEN_PASS_REGISTRATION\n";
}
//===----------------------------------------------------------------------===//
// GEN: Registration hooks
//===----------------------------------------------------------------------===//
static void emitDecls(const llvm::RecordKeeper &recordKeeper, raw_ostream &os) {
os << "/* Autogenerated by mlir-tblgen; don't manually edit */\n";
std::vector<Pass> passes;
for (const auto *def : recordKeeper.getAllDerivedDefinitions("PassBase"))
passes.push_back(Pass(def));
emitPassDecls(passes, os);
emitRegistration(passes, os);
}
static mlir::GenRegistration
genRegister("gen-pass-decls", "Generate operation documentation",
[](const llvm::RecordKeeper &records, raw_ostream &os) {
emitDecls(records, os);
return false;
});