replace_int.py 2.75 KB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from lsq_int import *
from models.mobilenet import *

conv_idx = -1
act_idx = -1
former_conv = None

def replace_int(model, bit_width=8):
    global conv_idx, act_idx
    
    for name, module in model.named_children():
        if isinstance(module, (nn.Sequential)): #- conventional
            replace_int(model.__dict__['_modules'][name], bit_width)

        elif isinstance(module, nn.Conv2d):
            former_conv = name
            conv_idx += 1
            bias = False if module.bias is None else True

            model.__dict__['_modules'][name] = FuseConv2dQ(module.in_channels, module.out_channels,
                                                        module.kernel_size, stride=module.stride,
                                                        padding=module.padding, dilation=module.dilation,
                                                        groups=module.groups, bias=bias, wbit=bit_width)
            model.__dict__['_modules'][name].weight = module.weight
            if bias:
                model.__dict__['_modules'][name].bias = module.bias

        elif isinstance(module, nn.BatchNorm2d):
            model.__dict__['_modules'][former_conv].replace_bn(module)
            model.__dict__['_modules'][name] = nn.Identity()

        elif isinstance(module, nn.ReLU):
            act_idx += 1
            model.__dict__['_modules'][name] = QReLU(abit=bit_width, inplace=False, dequantize=True)

        elif isinstance(module, nn.Hardswish): 
            act_idx += 1
            model.__dict__['_modules'][name] = QHswish(abit=bit_width, inplace=False, dequantize=True)

        elif isinstance(module, nn.LeakyReLU): 
            act_idx += 1
            model.__dict__['_modules'][name] = QLeakyReLU(abit=bit_width, inplace=False, dequantize=True)

        elif isinstance(module, nn.Linear):           
            bias = False if module.bias is None else True
            model.__dict__['_modules'][name] = QLinear(module.in_features, module.out_features, bias, wbit=bit_width)
            model.__dict__['_modules'][name].weight = module.weight
            if bias:
                model.__dict__['_modules'][name].bias = module.bias

        elif isinstance(module, nn.AdaptiveAvgPool2d):
            model.__dict__['_modules'][name] = QAvgPool2d(abit=bit_width, dequantize=True, output_size=module.output_size)

        # elif isinstance(module, BasicBlock) or isinstance(module, Bottleneck): #- ResNet support
        #     replace_sq(model.__dict__['_modules'][name], bit_width)

        # elif isinstance(module, InvertedResidual): #mv2
        #     replace_sq(model.__dict__['_modules'][name], bit_width)

        else:
            model.__dict__['_modules'][name] = module

    return model