TestDialect.cpp
11.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
//===- TestDialect.cpp - MLIR Dialect for Testing -------------------------===//
//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "TestDialect.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/InliningUtils.h"
#include "llvm/ADT/StringSwitch.h"
using namespace mlir;
//===----------------------------------------------------------------------===//
// TestDialect Interfaces
//===----------------------------------------------------------------------===//
namespace {
// Test support for interacting with the AsmPrinter.
struct TestOpAsmInterface : public OpAsmDialectInterface {
using OpAsmDialectInterface::OpAsmDialectInterface;
void getAsmResultNames(Operation *op,
OpAsmSetValueNameFn setNameFn) const final {
if (auto asmOp = dyn_cast<AsmDialectInterfaceOp>(op))
setNameFn(asmOp, "result");
}
void getAsmBlockArgumentNames(Block *block,
OpAsmSetValueNameFn setNameFn) const final {
auto op = block->getParentOp();
auto arrayAttr = op->getAttrOfType<ArrayAttr>("arg_names");
if (!arrayAttr)
return;
auto args = block->getArguments();
auto e = std::min(arrayAttr.size(), args.size());
for (unsigned i = 0; i < e; ++i) {
if (auto strAttr = arrayAttr.getValue()[i].dyn_cast<StringAttr>())
setNameFn(args[i], strAttr.getValue());
}
}
};
struct TestOpFolderDialectInterface : public OpFolderDialectInterface {
using OpFolderDialectInterface::OpFolderDialectInterface;
/// Registered hook to check if the given region, which is attached to an
/// operation that is *not* isolated from above, should be used when
/// materializing constants.
bool shouldMaterializeInto(Region *region) const final {
// If this is a one region operation, then insert into it.
return isa<OneRegionOp>(region->getParentOp());
}
};
/// This class defines the interface for handling inlining with standard
/// operations.
struct TestInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
//===--------------------------------------------------------------------===//
// Analysis Hooks
//===--------------------------------------------------------------------===//
bool isLegalToInline(Region *, Region *, BlockAndValueMapping &) const final {
// Inlining into test dialect regions is legal.
return true;
}
bool isLegalToInline(Operation *, Region *,
BlockAndValueMapping &) const final {
return true;
}
bool shouldAnalyzeRecursively(Operation *op) const final {
// Analyze recursively if this is not a functional region operation, it
// froms a separate functional scope.
return !isa<FunctionalRegionOp>(op);
}
//===--------------------------------------------------------------------===//
// Transformation Hooks
//===--------------------------------------------------------------------===//
/// Handle the given inlined terminator by replacing it with a new operation
/// as necessary.
void handleTerminator(Operation *op,
ArrayRef<Value> valuesToRepl) const final {
// Only handle "test.return" here.
auto returnOp = dyn_cast<TestReturnOp>(op);
if (!returnOp)
return;
// Replace the values directly with the return operands.
assert(returnOp.getNumOperands() == valuesToRepl.size());
for (const auto &it : llvm::enumerate(returnOp.getOperands()))
valuesToRepl[it.index()].replaceAllUsesWith(it.value());
}
/// Attempt to materialize a conversion for a type mismatch between a call
/// from this dialect, and a callable region. This method should generate an
/// operation that takes 'input' as the only operand, and produces a single
/// result of 'resultType'. If a conversion can not be generated, nullptr
/// should be returned.
Operation *materializeCallConversion(OpBuilder &builder, Value input,
Type resultType,
Location conversionLoc) const final {
// Only allow conversion for i16/i32 types.
if (!(resultType.isInteger(16) || resultType.isInteger(32)) ||
!(input.getType().isInteger(16) || input.getType().isInteger(32)))
return nullptr;
return builder.create<TestCastOp>(conversionLoc, resultType, input);
}
};
} // end anonymous namespace
//===----------------------------------------------------------------------===//
// TestDialect
//===----------------------------------------------------------------------===//
TestDialect::TestDialect(MLIRContext *context)
: Dialect(getDialectName(), context) {
addOperations<
#define GET_OP_LIST
#include "TestOps.cpp.inc"
>();
addInterfaces<TestOpAsmInterface, TestOpFolderDialectInterface,
TestInlinerInterface>();
allowUnknownOperations();
}
LogicalResult TestDialect::verifyOperationAttribute(Operation *op,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
LogicalResult TestDialect::verifyRegionArgAttribute(Operation *op,
unsigned regionIndex,
unsigned argIndex,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
LogicalResult
TestDialect::verifyRegionResultAttribute(Operation *op, unsigned regionIndex,
unsigned resultIndex,
NamedAttribute namedAttr) {
if (namedAttr.first == "test.invalid_attr")
return op->emitError() << "invalid to use 'test.invalid_attr'";
return success();
}
//===----------------------------------------------------------------------===//
// Test IsolatedRegionOp - parse passthrough region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parseIsolatedRegionOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType argInfo;
Type argType = parser.getBuilder().getIndexType();
// Parse the input operand.
if (parser.parseOperand(argInfo) ||
parser.resolveOperand(argInfo, argType, result.operands))
return failure();
// Parse the body region, and reuse the operand info as the argument info.
Region *body = result.addRegion();
return parser.parseRegion(*body, argInfo, argType,
/*enableNameShadowing=*/true);
}
static void print(OpAsmPrinter &p, IsolatedRegionOp op) {
p << "test.isolated_region ";
p.printOperand(op.getOperand());
p.shadowRegionArgs(op.region(), op.getOperand());
p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
}
//===----------------------------------------------------------------------===//
// Test parser.
//===----------------------------------------------------------------------===//
static ParseResult parseWrappedKeywordOp(OpAsmParser &parser,
OperationState &result) {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return failure();
result.addAttribute("keyword", parser.getBuilder().getStringAttr(keyword));
return success();
}
static void print(OpAsmPrinter &p, WrappedKeywordOp op) {
p << WrappedKeywordOp::getOperationName() << " " << op.keyword();
}
//===----------------------------------------------------------------------===//
// Test WrapRegionOp - wrapping op exercising `parseGenericOperation()`.
static ParseResult parseWrappingRegionOp(OpAsmParser &parser,
OperationState &result) {
if (parser.parseKeyword("wraps"))
return failure();
// Parse the wrapped op in a region
Region &body = *result.addRegion();
body.push_back(new Block);
Block &block = body.back();
Operation *wrapped_op = parser.parseGenericOperation(&block, block.begin());
if (!wrapped_op)
return failure();
// Create a return terminator in the inner region, pass as operand to the
// terminator the returned values from the wrapped operation.
SmallVector<Value, 8> return_operands(wrapped_op->getResults());
OpBuilder builder(parser.getBuilder().getContext());
builder.setInsertionPointToEnd(&block);
builder.create<TestReturnOp>(wrapped_op->getLoc(), return_operands);
// Get the results type for the wrapping op from the terminator operands.
Operation &return_op = body.back().back();
result.types.append(return_op.operand_type_begin(),
return_op.operand_type_end());
// Use the location of the wrapped op for the "test.wrapping_region" op.
result.location = wrapped_op->getLoc();
return success();
}
static void print(OpAsmPrinter &p, WrappingRegionOp op) {
p << op.getOperationName() << " wraps ";
p.printGenericOp(&op.region().front().front());
}
//===----------------------------------------------------------------------===//
// Test PolyForOp - parse list of region arguments.
//===----------------------------------------------------------------------===//
static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) {
SmallVector<OpAsmParser::OperandType, 4> ivsInfo;
// Parse list of region arguments without a delimiter.
if (parser.parseRegionArgumentList(ivsInfo))
return failure();
// Parse the body region.
Region *body = result.addRegion();
auto &builder = parser.getBuilder();
SmallVector<Type, 4> argTypes(ivsInfo.size(), builder.getIndexType());
return parser.parseRegion(*body, ivsInfo, argTypes);
}
//===----------------------------------------------------------------------===//
// Test removing op with inner ops.
//===----------------------------------------------------------------------===//
namespace {
struct TestRemoveOpWithInnerOps
: public OpRewritePattern<TestOpWithRegionPattern> {
using OpRewritePattern<TestOpWithRegionPattern>::OpRewritePattern;
PatternMatchResult matchAndRewrite(TestOpWithRegionPattern op,
PatternRewriter &rewriter) const override {
rewriter.eraseOp(op);
return matchSuccess();
}
};
} // end anonymous namespace
void TestOpWithRegionPattern::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TestRemoveOpWithInnerOps>(context);
}
OpFoldResult TestOpWithRegionFold::fold(ArrayRef<Attribute> operands) {
return operand();
}
LogicalResult TestOpWithVariadicResultsAndFolder::fold(
ArrayRef<Attribute> operands, SmallVectorImpl<OpFoldResult> &results) {
for (Value input : this->operands()) {
results.push_back(input);
}
return success();
}
LogicalResult mlir::OpWithInferTypeInterfaceOp::inferReturnTypes(
llvm::Optional<Location> location, ValueRange operands,
ArrayRef<NamedAttribute> attributes, RegionRange regions,
SmallVectorImpl<Type> &inferedReturnTypes) {
if (operands[0].getType() != operands[1].getType()) {
return emitOptionalError(location, "operand type mismatch ",
operands[0].getType(), " vs ",
operands[1].getType());
}
inferedReturnTypes.assign({operands[0].getType()});
return success();
}
// Static initialization for Test dialect registration.
static mlir::DialectRegistration<mlir::TestDialect> testDialect;
#include "TestOpEnums.cpp.inc"
#define GET_OP_CLASSES
#include "TestOps.cpp.inc"