SCFToSPIRV.cpp
12.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
//===- SCFToSPIRV.cpp - Convert SCF ops to SPIR-V dialect -----------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements the conversion patterns from SCF ops to SPIR-V dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h"
#include "mlir/Dialect/SCF/SCF.h"
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/Module.h"
using namespace mlir;
namespace mlir {
struct ScfToSPIRVContextImpl {
// Map between the spirv region control flow operation (spv.loop or
// spv.selection) to the VariableOp created to store the region results. The
// order of the VariableOp matches the order of the results.
DenseMap<Operation *, SmallVector<spirv::VariableOp, 8>> outputVars;
};
} // namespace mlir
/// We use ScfToSPIRVContext to store information about the lowering of the scf
/// region that need to be used later on. When we lower scf.for/scf.if we create
/// VariableOp to store the results. We need to keep track of the VariableOp
/// created as we need to insert stores into them when lowering Yield. Those
/// StoreOp cannot be created earlier as they may use a different type than
/// yield operands.
ScfToSPIRVContext::ScfToSPIRVContext() {
impl = std::make_unique<ScfToSPIRVContextImpl>();
}
ScfToSPIRVContext::~ScfToSPIRVContext() = default;
namespace {
/// Common class for all vector to GPU patterns.
template <typename OpTy>
class SCFToSPIRVPattern : public SPIRVOpLowering<OpTy> {
public:
SCFToSPIRVPattern<OpTy>(MLIRContext *context, SPIRVTypeConverter &converter,
ScfToSPIRVContextImpl *scfToSPIRVContext)
: SPIRVOpLowering<OpTy>::SPIRVOpLowering(context, converter),
scfToSPIRVContext(scfToSPIRVContext) {}
protected:
ScfToSPIRVContextImpl *scfToSPIRVContext;
};
/// Pattern to convert a scf::ForOp within kernel functions into spirv::LoopOp.
class ForOpConversion final : public SCFToSPIRVPattern<scf::ForOp> {
public:
using SCFToSPIRVPattern<scf::ForOp>::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to convert a scf::IfOp within kernel functions into
/// spirv::SelectionOp.
class IfOpConversion final : public SCFToSPIRVPattern<scf::IfOp> {
public:
using SCFToSPIRVPattern<scf::IfOp>::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
class TerminatorOpConversion final : public SCFToSPIRVPattern<scf::YieldOp> {
public:
using SCFToSPIRVPattern<scf::YieldOp>::SCFToSPIRVPattern;
LogicalResult
matchAndRewrite(scf::YieldOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
/// Helper function to replaces SCF op outputs with SPIR-V variable loads.
/// We create VariableOp to handle the results value of the control flow region.
/// spv.loop/spv.selection currently don't yield value. Right after the loop
/// we load the value from the allocation and use it as the SCF op result.
template <typename ScfOp, typename OpTy>
static void replaceSCFOutputValue(ScfOp scfOp, OpTy newOp,
SPIRVTypeConverter &typeConverter,
ConversionPatternRewriter &rewriter,
ScfToSPIRVContextImpl *scfToSPIRVContext,
ArrayRef<Type> returnTypes) {
Location loc = scfOp.getLoc();
auto &allocas = scfToSPIRVContext->outputVars[newOp];
SmallVector<Value, 8> resultValue;
for (Type convertedType : returnTypes) {
auto pointerType =
spirv::PointerType::get(convertedType, spirv::StorageClass::Function);
rewriter.setInsertionPoint(newOp);
auto alloc = rewriter.create<spirv::VariableOp>(
loc, pointerType, spirv::StorageClass::Function,
/*initializer=*/nullptr);
allocas.push_back(alloc);
rewriter.setInsertionPointAfter(newOp);
Value loadResult = rewriter.create<spirv::LoadOp>(loc, alloc);
resultValue.push_back(loadResult);
}
rewriter.replaceOp(scfOp, resultValue);
}
//===----------------------------------------------------------------------===//
// scf::ForOp.
//===----------------------------------------------------------------------===//
LogicalResult
ForOpConversion::matchAndRewrite(scf::ForOp forOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// scf::ForOp can be lowered to the structured control flow represented by
// spirv::LoopOp by making the continue block of the spirv::LoopOp the loop
// latch and the merge block the exit block. The resulting spirv::LoopOp has a
// single back edge from the continue to header block, and a single exit from
// header to merge.
scf::ForOpAdaptor forOperands(operands);
auto loc = forOp.getLoc();
auto loopControl = rewriter.getI32IntegerAttr(
static_cast<uint32_t>(spirv::LoopControl::None));
auto loopOp = rewriter.create<spirv::LoopOp>(loc, loopControl);
loopOp.addEntryAndMergeBlock();
OpBuilder::InsertionGuard guard(rewriter);
// Create the block for the header.
auto *header = new Block();
// Insert the header.
loopOp.body().getBlocks().insert(std::next(loopOp.body().begin(), 1), header);
// Create the new induction variable to use.
BlockArgument newIndVar =
header->addArgument(forOperands.lowerBound().getType());
for (Value arg : forOperands.initArgs())
header->addArgument(arg.getType());
Block *body = forOp.getBody();
// Apply signature conversion to the body of the forOp. It has a single block,
// with argument which is the induction variable. That has to be replaced with
// the new induction variable.
TypeConverter::SignatureConversion signatureConverter(
body->getNumArguments());
signatureConverter.remapInput(0, newIndVar);
for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
signatureConverter.remapInput(i, header->getArgument(i));
body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
signatureConverter);
// Move the blocks from the forOp into the loopOp. This is the body of the
// loopOp.
rewriter.inlineRegionBefore(forOp.getOperation()->getRegion(0), loopOp.body(),
std::next(loopOp.body().begin(), 2));
SmallVector<Value, 8> args(1, forOperands.lowerBound());
args.append(forOperands.initArgs().begin(), forOperands.initArgs().end());
// Branch into it from the entry.
rewriter.setInsertionPointToEnd(&(loopOp.body().front()));
rewriter.create<spirv::BranchOp>(loc, header, args);
// Generate the rest of the loop header.
rewriter.setInsertionPointToEnd(header);
auto *mergeBlock = loopOp.getMergeBlock();
auto cmpOp = rewriter.create<spirv::SLessThanOp>(
loc, rewriter.getI1Type(), newIndVar, forOperands.upperBound());
rewriter.create<spirv::BranchConditionalOp>(
loc, cmpOp, body, ArrayRef<Value>(), mergeBlock, ArrayRef<Value>());
// Generate instructions to increment the step of the induction variable and
// branch to the header.
Block *continueBlock = loopOp.getContinueBlock();
rewriter.setInsertionPointToEnd(continueBlock);
// Add the step to the induction variable and branch to the header.
Value updatedIndVar = rewriter.create<spirv::IAddOp>(
loc, newIndVar.getType(), newIndVar, forOperands.step());
rewriter.create<spirv::BranchOp>(loc, header, updatedIndVar);
// Infer the return types from the init operands. Vector type may get
// converted to CooperativeMatrix or to Vector type, to avoid having complex
// extra logic to figure out the right type we just infer it from the Init
// operands.
SmallVector<Type, 8> initTypes;
for (auto arg : forOperands.initArgs())
initTypes.push_back(arg.getType());
replaceSCFOutputValue(forOp, loopOp, typeConverter, rewriter,
scfToSPIRVContext, initTypes);
return success();
}
//===----------------------------------------------------------------------===//
// scf::IfOp.
//===----------------------------------------------------------------------===//
LogicalResult
IfOpConversion::matchAndRewrite(scf::IfOp ifOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// When lowering `scf::IfOp` we explicitly create a selection header block
// before the control flow diverges and a merge block where control flow
// subsequently converges.
scf::IfOpAdaptor ifOperands(operands);
auto loc = ifOp.getLoc();
// Create `spv.selection` operation, selection header block and merge block.
auto selectionControl = rewriter.getI32IntegerAttr(
static_cast<uint32_t>(spirv::SelectionControl::None));
auto selectionOp = rewriter.create<spirv::SelectionOp>(loc, selectionControl);
selectionOp.addMergeBlock();
auto *mergeBlock = selectionOp.getMergeBlock();
OpBuilder::InsertionGuard guard(rewriter);
auto *selectionHeaderBlock = new Block();
selectionOp.body().getBlocks().push_front(selectionHeaderBlock);
// Inline `then` region before the merge block and branch to it.
auto &thenRegion = ifOp.thenRegion();
auto *thenBlock = &thenRegion.front();
rewriter.setInsertionPointToEnd(&thenRegion.back());
rewriter.create<spirv::BranchOp>(loc, mergeBlock);
rewriter.inlineRegionBefore(thenRegion, mergeBlock);
auto *elseBlock = mergeBlock;
// If `else` region is not empty, inline that region before the merge block
// and branch to it.
if (!ifOp.elseRegion().empty()) {
auto &elseRegion = ifOp.elseRegion();
elseBlock = &elseRegion.front();
rewriter.setInsertionPointToEnd(&elseRegion.back());
rewriter.create<spirv::BranchOp>(loc, mergeBlock);
rewriter.inlineRegionBefore(elseRegion, mergeBlock);
}
// Create a `spv.BranchConditional` operation for selection header block.
rewriter.setInsertionPointToEnd(selectionHeaderBlock);
rewriter.create<spirv::BranchConditionalOp>(loc, ifOperands.condition(),
thenBlock, ArrayRef<Value>(),
elseBlock, ArrayRef<Value>());
SmallVector<Type, 8> returnTypes;
for (auto result : ifOp.results()) {
auto convertedType = typeConverter.convertType(result.getType());
returnTypes.push_back(convertedType);
}
replaceSCFOutputValue(ifOp, selectionOp, typeConverter, rewriter,
scfToSPIRVContext, returnTypes);
return success();
}
/// Yield is lowered to stores to the VariableOp created during lowering of the
/// parent region. For loops we also need to update the branch looping back to
/// the header with the loop carried values.
LogicalResult TerminatorOpConversion::matchAndRewrite(
scf::YieldOp terminatorOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// If the region is return values, store each value into the associated
// VariableOp created during lowering of the parent region.
if (!operands.empty()) {
auto loc = terminatorOp.getLoc();
auto &allocas = scfToSPIRVContext->outputVars[terminatorOp.getParentOp()];
assert(allocas.size() == operands.size());
for (unsigned i = 0, e = operands.size(); i < e; i++)
rewriter.create<spirv::StoreOp>(loc, allocas[i], operands[i]);
if (isa<spirv::LoopOp>(terminatorOp.getParentOp())) {
// For loops we also need to update the branch jumping back to the header.
auto br =
cast<spirv::BranchOp>(rewriter.getInsertionBlock()->getTerminator());
SmallVector<Value, 8> args(br.getBlockArguments());
args.append(operands.begin(), operands.end());
rewriter.setInsertionPoint(br);
rewriter.create<spirv::BranchOp>(terminatorOp.getLoc(), br.getTarget(),
args);
rewriter.eraseOp(br);
}
}
rewriter.eraseOp(terminatorOp);
return success();
}
void mlir::populateSCFToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
ScfToSPIRVContext &scfToSPIRVContext,
OwningRewritePatternList &patterns) {
patterns.insert<ForOpConversion, IfOpConversion, TerminatorOpConversion>(
context, typeConverter, scfToSPIRVContext.getImpl());
}