QuantOps.cpp
3.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
//===- QuantOps.cpp - Quantization Type and Ops Implementation --*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantOps.h"
#include "TypeDetail.h"
#include "mlir/Dialect/Quant/QuantTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/MathExtras.h"
#include <numeric>
using namespace mlir;
using namespace mlir::quant;
using namespace mlir::quant::detail;
void QuantizationDialect::initialize() {
addTypes<AnyQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"
>();
}
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
// value of x if the casts invert each other.
auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
if (!srcScastOp || srcScastOp.arg().getType() != getType())
return OpFoldResult();
return srcScastOp.arg();
}
/// The quantization specification should match the expressed type.
static bool isValidQuantizationSpec(Attribute quantSpec, Type expressed) {
if (auto typeAttr = quantSpec.dyn_cast<TypeAttr>()) {
Type spec = typeAttr.getValue();
if (spec.isa<TensorType, VectorType>())
return false;
// The spec should be either a quantized type which is compatible to the
// expressed type, or a primitive type which is as same as the
// (element type of) the expressed type.
if (auto quantizedType = spec.dyn_cast<QuantizedType>())
return quantizedType.isCompatibleExpressedType(expressed);
if (auto tensorType = expressed.dyn_cast<TensorType>())
return spec == tensorType.getElementType();
if (auto vectorType = expressed.dyn_cast<VectorType>())
return spec == vectorType.getElementType();
}
return false;
}
static LogicalResult verifyRegionOp(QuantizeRegionOp op) {
// There are specifications for both inputs and outputs.
if (op.getNumOperands() != op.input_specs().size() ||
op.getNumResults() != op.output_specs().size())
return op.emitOpError(
"has unmatched operands/results number and spec attributes number");
// Verify that quantization specifications are valid.
for (auto input : llvm::zip(op.getOperandTypes(), op.input_specs())) {
Type inputType = std::get<0>(input);
Attribute inputSpec = std::get<1>(input);
if (!isValidQuantizationSpec(inputSpec, inputType)) {
return op.emitOpError() << "has incompatible specification " << inputSpec
<< " and input type " << inputType;
}
}
for (auto result : llvm::zip(op.getResultTypes(), op.output_specs())) {
Type outputType = std::get<0>(result);
Attribute outputSpec = std::get<1>(result);
if (!isValidQuantizationSpec(outputSpec, outputType)) {
return op.emitOpError() << "has incompatible specification " << outputSpec
<< " and output type " << outputType;
}
}
return success();
}
#define GET_OP_CLASSES
#include "mlir/Dialect/Quant/QuantOps.cpp.inc"