LowerToLLVM.cpp
9.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
//====- LowerToLLVM.cpp - Lowering from Toy+Affine+Std to LLVM ------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a partial lowering of Toy operations to a combination of
// affine loops and standard operations. This lowering expects that all calls
// have been inlined, and all shapes have been resolved.
//
//===----------------------------------------------------------------------===//
#include "toy/Dialect.h"
#include "toy/Passes.h"
#include "mlir/Conversion/AffineToStandard/AffineToStandard.h"
#include "mlir/Conversion/LoopToStandard/ConvertLoopToStandard.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/AffineOps/AffineOps.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LoopOps/LoopOps.h"
#include "mlir/Dialect/StandardOps/Ops.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/ADT/Sequence.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// ToyToLLVM RewritePatterns
//===----------------------------------------------------------------------===//
namespace {
/// Lowers `toy.print` to a loop nest calling `printf` on each of the individual
/// elements of the array.
class PrintOpLowering : public ConversionPattern {
public:
explicit PrintOpLowering(MLIRContext *context)
: ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto memRefType = (*op->operand_type_begin()).cast<MemRefType>();
auto memRefShape = memRefType.getShape();
auto loc = op->getLoc();
auto *llvmDialect =
op->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(llvmDialect && "expected llvm dialect to be registered");
ModuleOp parentModule = op->getParentOfType<ModuleOp>();
// Get a symbol reference to the printf function, inserting it if necessary.
auto printfRef = getOrInsertPrintf(rewriter, parentModule, llvmDialect);
Value formatSpecifierCst = getOrCreateGlobalString(
loc, rewriter, "frmt_spec", StringRef("%f \0", 4), parentModule,
llvmDialect);
Value newLineCst = getOrCreateGlobalString(
loc, rewriter, "nl", StringRef("\n\0", 2), parentModule, llvmDialect);
// Create a loop for each of the dimensions within the shape.
SmallVector<Value, 4> loopIvs;
for (unsigned i = 0, e = memRefShape.size(); i != e; ++i) {
auto lowerBound = rewriter.create<ConstantIndexOp>(loc, 0);
auto upperBound = rewriter.create<ConstantIndexOp>(loc, memRefShape[i]);
auto step = rewriter.create<ConstantIndexOp>(loc, 1);
auto loop =
rewriter.create<loop::ForOp>(loc, lowerBound, upperBound, step);
loop.getBody()->clear();
loopIvs.push_back(loop.getInductionVar());
// Terminate the loop body.
rewriter.setInsertionPointToStart(loop.getBody());
// Insert a newline after each of the inner dimensions of the shape.
if (i != e - 1)
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
newLineCst);
rewriter.create<loop::TerminatorOp>(loc);
rewriter.setInsertionPointToStart(loop.getBody());
}
// Generate a call to printf for the current element of the loop.
auto printOp = cast<toy::PrintOp>(op);
auto elementLoad = rewriter.create<LoadOp>(loc, printOp.input(), loopIvs);
rewriter.create<CallOp>(loc, printfRef, rewriter.getIntegerType(32),
ArrayRef<Value>({formatSpecifierCst, elementLoad}));
// Notify the rewriter that this operation has been removed.
rewriter.eraseOp(op);
return matchSuccess();
}
private:
/// Return a symbol reference to the printf function, inserting it into the
/// module if necessary.
static FlatSymbolRefAttr getOrInsertPrintf(PatternRewriter &rewriter,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
auto *context = module.getContext();
if (module.lookupSymbol<LLVM::LLVMFuncOp>("printf"))
return SymbolRefAttr::get("printf", context);
// Create a function declaration for printf, the signature is:
// * `i32 (i8*, ...)`
auto llvmI32Ty = LLVM::LLVMType::getInt32Ty(llvmDialect);
auto llvmI8PtrTy = LLVM::LLVMType::getInt8PtrTy(llvmDialect);
auto llvmFnType = LLVM::LLVMType::getFunctionTy(llvmI32Ty, llvmI8PtrTy,
/*isVarArg=*/true);
// Insert the printf function into the body of the parent module.
PatternRewriter::InsertionGuard insertGuard(rewriter);
rewriter.setInsertionPointToStart(module.getBody());
rewriter.create<LLVM::LLVMFuncOp>(module.getLoc(), "printf", llvmFnType);
return SymbolRefAttr::get("printf", context);
}
/// Return a value representing an access into a global string with the given
/// name, creating the string if necessary.
static Value getOrCreateGlobalString(Location loc, OpBuilder &builder,
StringRef name, StringRef value,
ModuleOp module,
LLVM::LLVMDialect *llvmDialect) {
// Create the global at the entry of the module.
LLVM::GlobalOp global;
if (!(global = module.lookupSymbol<LLVM::GlobalOp>(name))) {
OpBuilder::InsertionGuard insertGuard(builder);
builder.setInsertionPointToStart(module.getBody());
auto type = LLVM::LLVMType::getArrayTy(
LLVM::LLVMType::getInt8Ty(llvmDialect), value.size());
global = builder.create<LLVM::GlobalOp>(loc, type, /*isConstant=*/true,
LLVM::Linkage::Internal, name,
builder.getStringAttr(value));
}
// Get the pointer to the first character in the global string.
Value globalPtr = builder.create<LLVM::AddressOfOp>(loc, global);
Value cst0 = builder.create<LLVM::ConstantOp>(
loc, LLVM::LLVMType::getInt64Ty(llvmDialect),
builder.getIntegerAttr(builder.getIndexType(), 0));
return builder.create<LLVM::GEPOp>(
loc, LLVM::LLVMType::getInt8PtrTy(llvmDialect), globalPtr,
ArrayRef<Value>({cst0, cst0}));
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// ToyToLLVMLoweringPass
//===----------------------------------------------------------------------===//
namespace {
struct ToyToLLVMLoweringPass : public ModulePass<ToyToLLVMLoweringPass> {
void runOnModule() final;
};
} // end anonymous namespace
void ToyToLLVMLoweringPass::runOnModule() {
// The first thing to define is the conversion target. This will define the
// final target for this lowering. For this lowering, we are only targeting
// the LLVM dialect.
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
// During this lowering, we will also be lowering the MemRef types, that are
// currently being operated on, to a representation in LLVM. Do perform this
// conversion we use a TypeConverter as part of the lowering. This converter
// details how one type maps to another. This is necessary now that we will be
// doing more complicated lowerings, involving loop region arguments.
LLVMTypeConverter typeConverter(&getContext());
// Now that the conversion target has been defined, we need to provide the
// patterns used for lowering. At this point of the compilation process, we
// have a combination of `toy`, `affine`, and `std` operations. Luckily, there
// are already exists a set of patterns to transform `affine` and `std`
// dialects. These patterns lowering in multiple stages, relying on transitive
// lowerings. Transitive lowering, or A->B->C lowering, is when multiple
// patterns must be applied to fully transform an illegal operation into a
// set of legal ones.
OwningRewritePatternList patterns;
populateAffineToStdConversionPatterns(patterns, &getContext());
populateLoopToStdConversionPatterns(patterns, &getContext());
populateStdToLLVMConversionPatterns(typeConverter, patterns);
// The only remaining operation to lower from the `toy` dialect, is the
// PrintOp.
patterns.insert<PrintOpLowering>(&getContext());
// We want to completely lower to LLVM, so we use a `FullConversion`. This
// ensures that only legal operations will remain after the conversion.
auto module = getModule();
if (failed(applyFullConversion(module, target, patterns, &typeConverter)))
signalPassFailure();
}
/// Create a pass for lowering operations the remaining `Toy` operations, as
/// well as `Affine` and `Std`, to the LLVM dialect for codegen.
std::unique_ptr<mlir::Pass> mlir::toy::createLowerToLLVMPass() {
return std::make_unique<ToyToLLVMLoweringPass>();
}