QuantizeUtils.cpp
6.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
//===- QuantizeUtils.cpp - Support utilities for quantization -------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Quant/QuantizeUtils.h"
#include "mlir/Dialect/Quant/UniformSupport.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StandardTypes.h"
using namespace mlir;
using namespace mlir::quant;
/// Converts a possible primitive, real expressed value attribute to a
/// corresponding storage attribute (typically FloatAttr -> IntegerAttr).
/// quantizedElementType is the QuantizedType that describes the expressed
/// origValue.
/// Returns a converter Attribute or nullptr if conversion is not possible.
static Attribute convertPrimitiveValueAttr(
Attribute origRealValue, QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
if (origRealValue.isa<FloatAttr>()) {
FloatAttr floatAttr = origRealValue.cast<FloatAttr>();
outConvertedType = quantizedElementType.getStorageType();
return IntegerAttr::get(quantizedElementType.getStorageType(),
converter.quantizeFloatToInt(floatAttr.getValue()));
}
return nullptr;
}
/// Converts a real expressed DenseFPElementsAttr to a corresponding
/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized
/// storage values assuming the given quantizedElementType and converter.
static DenseElementsAttr
convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
// Convert to corresponding quantized value attributes.
SmallVector<APInt, 8> quantValues;
if (realFPElementsAttr.isSplat()) {
quantValues.push_back(
converter.quantizeFloatToInt(*realFPElementsAttr.begin()));
} else {
quantValues.reserve(realFPElementsAttr.getNumElements());
for (APFloat realVal : realFPElementsAttr) {
quantValues.push_back(converter.quantizeFloatToInt(realVal));
}
}
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the dense shape (i.e. tensor<4xf32> -> tensor<4xi8>).
ShapedType newDenseType =
quantizedElementType
.castExpressedToStorageType(realFPElementsAttr.getType())
.dyn_cast_or_null<ShapedType>();
if (!newDenseType) {
return nullptr;
}
return DenseIntElementsAttr::get(newDenseType, quantValues);
}
/// Converts a real expressed SplatElementsAttr to a corresponding
/// SplatElementsAttr containing quantized storage values assuming the given
/// quantizedElementType and converter.
static SparseElementsAttr
convertSparseElementsAttr(SparseElementsAttr realSparseAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
DenseElementsAttr realDenseAttr = realSparseAttr.getValues();
if (!realDenseAttr.isa<DenseFPElementsAttr>()) {
return nullptr;
}
DenseElementsAttr quantDenseAttr =
convertDenseFPElementsAttr(realDenseAttr.cast<DenseFPElementsAttr>(),
quantizedElementType, converter);
if (!quantDenseAttr) {
return nullptr;
}
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>).
ShapedType newSparseType =
quantizedElementType.castExpressedToStorageType(realSparseAttr.getType())
.dyn_cast_or_null<ShapedType>();
if (!newSparseType) {
return nullptr;
}
return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(),
quantDenseAttr);
}
/// Converts a real expressed Attribute to a corresponding Attribute containing
/// quantized storage values assuming the given uniform quantizedElementType and
/// converter.
Attribute mlir::quant::quantizeAttrUniform(
Attribute realValue, UniformQuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter, Type &outConvertedType) {
// Fork to handle different variants of constants supported.
if (realValue.isa<DenseFPElementsAttr>()) {
// Dense tensor or vector constant.
auto converted = convertDenseFPElementsAttr(
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType();
return converted;
} else if (realValue.isa<SparseElementsAttr>()) {
// Sparse tensor or vector constant.
auto converted = convertSparseElementsAttr(
realValue.cast<SparseElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType();
return converted;
} else {
// Nothing else matched: try to convert a primitive.
return convertPrimitiveValueAttr(realValue, quantizedElementType, converter,
outConvertedType);
}
}
/// Convert an attribute from a type based on
/// quantizedElementType.getExpressedType() to one based on
/// quantizedElementType.getStorageType().
/// Returns nullptr if the conversion is not supported.
/// On success, stores the converted type in outConvertedType.
Attribute mlir::quant::quantizeAttr(Attribute realValue,
QuantizedType quantizedElementType,
Type &outConvertedType) {
if (auto uniformQuantized =
quantizedElementType.dyn_cast<UniformQuantizedType>()) {
UniformQuantizedValueConverter converter(uniformQuantized);
return quantizeAttrUniform(realValue, uniformQuantized, converter,
outConvertedType);
} else if (auto uniformQuantizedPerAxis =
quantizedElementType.dyn_cast<UniformQuantizedPerAxisType>()) {
UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis);
auto converted = converter.convert(realValue);
// TODO: why we need this outConvertedType? remove it?
if (converted) {
outConvertedType = converted.getType();
}
return converted;
} else {
return nullptr;
}
}