replace_int.py
2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from lsq_int import *
from models.mobilenet import *
conv_idx = -1
act_idx = -1
former_conv = None
def replace_int(model, bit_width=8):
global conv_idx, act_idx
for name, module in model.named_children():
if isinstance(module, (nn.Sequential)): #- conventional
replace_int(model.__dict__['_modules'][name], bit_width)
elif isinstance(module, nn.Conv2d):
former_conv = name
conv_idx += 1
bias = False if module.bias is None else True
model.__dict__['_modules'][name] = FuseConv2dQ(module.in_channels, module.out_channels,
module.kernel_size, stride=module.stride,
padding=module.padding, dilation=module.dilation,
groups=module.groups, bias=bias, wbit=bit_width)
model.__dict__['_modules'][name].weight = module.weight
if bias:
model.__dict__['_modules'][name].bias = module.bias
elif isinstance(module, nn.BatchNorm2d):
model.__dict__['_modules'][former_conv].replace_bn(module)
model.__dict__['_modules'][name] = nn.Identity()
elif isinstance(module, nn.ReLU):
act_idx += 1
model.__dict__['_modules'][name] = QReLU(abit=bit_width, inplace=False, dequantize=True)
elif isinstance(module, nn.Hardswish):
act_idx += 1
model.__dict__['_modules'][name] = QHswish(abit=bit_width, inplace=False, dequantize=True)
elif isinstance(module, nn.LeakyReLU):
act_idx += 1
model.__dict__['_modules'][name] = QLeakyReLU(abit=bit_width, inplace=False, dequantize=True)
elif isinstance(module, nn.Linear):
bias = False if module.bias is None else True
model.__dict__['_modules'][name] = QLinear(module.in_features, module.out_features, bias, wbit=bit_width)
model.__dict__['_modules'][name].weight = module.weight
if bias:
model.__dict__['_modules'][name].bias = module.bias
elif isinstance(module, nn.AdaptiveAvgPool2d):
model.__dict__['_modules'][name] = QAvgPool2d(abit=bit_width, dequantize=True, output_size=module.output_size)
# elif isinstance(module, BasicBlock) or isinstance(module, Bottleneck): #- ResNet support
# replace_sq(model.__dict__['_modules'][name], bit_width)
# elif isinstance(module, InvertedResidual): #mv2
# replace_sq(model.__dict__['_modules'][name], bit_width)
else:
model.__dict__['_modules'][name] = module
return model