TestSlicing.cpp
2.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
//===- TestSlicing.cpp - Testing slice functionality ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements a simple testing pass for slicing.
//
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/LLVM.h"
using namespace mlir;
/// Create a function with the same signature as the parent function of `op`
/// with name being the function name and a `suffix`.
static LogicalResult createBackwardSliceFunction(Operation *op,
StringRef suffix) {
FuncOp parentFuncOp = op->getParentOfType<FuncOp>();
OpBuilder builder(parentFuncOp);
Location loc = op->getLoc();
std::string clonedFuncOpName = parentFuncOp.getName().str() + suffix.str();
FuncOp clonedFuncOp =
builder.create<FuncOp>(loc, clonedFuncOpName, parentFuncOp.getType());
BlockAndValueMapping mapper;
builder.setInsertionPointToEnd(clonedFuncOp.addEntryBlock());
for (auto arg : enumerate(parentFuncOp.getArguments()))
mapper.map(arg.value(), clonedFuncOp.getArgument(arg.index()));
llvm::SetVector<Operation *> slice;
getBackwardSlice(op, &slice);
for (Operation *slicedOp : slice)
builder.clone(*slicedOp, mapper);
builder.create<ReturnOp>(loc);
return success();
}
namespace {
/// Pass to test slice generated from slice analysis.
struct SliceAnalysisTestPass
: public PassWrapper<SliceAnalysisTestPass, OperationPass<ModuleOp>> {
void runOnOperation() override;
SliceAnalysisTestPass() = default;
SliceAnalysisTestPass(const SliceAnalysisTestPass &) {}
};
} // namespace
void SliceAnalysisTestPass::runOnOperation() {
ModuleOp module = getOperation();
auto funcOps = module.getOps<FuncOp>();
unsigned opNum = 0;
for (auto funcOp : funcOps) {
// TODO: For now this is just looking for Linalg ops. It can be generalized
// to look for other ops using flags.
funcOp.walk([&](Operation *op) {
if (!isa<linalg::LinalgOp>(op))
return WalkResult::advance();
std::string append =
std::string("__backward_slice__") + std::to_string(opNum);
createBackwardSliceFunction(op, append);
opNum++;
return WalkResult::advance();
});
}
}
namespace mlir {
void registerSliceAnalysisTestPass() {
PassRegistration<SliceAnalysisTestPass> pass(
"slice-analysis-test", "Test Slice analysis functionality.");
}
} // namespace mlir