lsq_int.py 11.9 KB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from decimal import Decimal
import numpy as np

# Parent Class for Quantization Module
class LSQModule:
    def __init__(self, abit=None, wbit=None, ibit=None, dequantize=True, scale=None):
        self.abit = abit
        self.wbit = wbit
        self.ibit = ibit
        self.dequantize = dequantize
        self.register_buffer('init_state', torch.zeros(1))
        self.scale = scale

    # member variable setter
    def set_abit(self, v):
        self.abit = v
    def set_wbit(self, v):
        self.wbit = v
    def set_ibit(self, v):
        self.ibit = v
    def set_dequantize(self, v):
        self.dequantize = v

class QAvgPool2d(nn.AdaptiveAvgPool2d, LSQModule):
    def __init__(self, abit, dequantize=True, output_size=(1,1)):
        super(QAvgPool2d, self).__init__(output_size)
        LSQModule.__init__(self, abit=abit, dequantize=dequantize,
                                    scale=nn.Parameter(torch.Tensor(1)))
    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'output_size=' + str(self.output_size) \
            + ', abit=' + str(self.abit) \
            + ')'
    def forward(self, x):
        former = x[1]
        x = x[0]
        x = super().forward(x)
        Qn = - (2 ** (self.abit - 1))
        Qp =  2 ** (self.abit - 1) - 1
        # Qn = 0.
        # Qp = (2 ** self.abit) - 1

        act_scale = self.scale
        down_scale = act_scale / former
        # down_scale = down_scale.numpy().astype()
        # x = x.cpu().numpy().astype(Decimal)

        x = x.cpu().detach().numpy().astype(Decimal)
        down_scale = down_scale.cpu().detach().numpy().astype(Decimal)
        output = x / down_scale
        output = torch.from_numpy(output.astype(np.float32)).cuda()
        x = torch.round(output).clamp(Qn, Qp)

        return x, act_scale

class QMaxPool2d(nn.MaxPool2d, LSQModule):
    def __init__(self, kernel_size=3, stride=2, padding=1):
        super(QMaxPool2d, self).__init__(kernel_size=kernel_size, stride=stride, padding=padding)
        LSQModule.__init__(self)

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'kernel_size=' + str(self.kernel_size) \
            + ', stride=' + str(self.stride) \
            + ', padding=' + str(self.padding) \
            + ')'

    def forward(self, x, act_scale=None):
        result = super().forward(x)
        return result

class QReLU(nn.Module, LSQModule):
    def __init__(self, abit, dequantize=True, inplace=False):
        super(QReLU, self).__init__()
        LSQModule.__init__(self, abit=abit, dequantize=dequantize,
                                scale=nn.Parameter(torch.Tensor(1)))
        self.inplace = inplace

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'abit=' + str(self.abit) \
            + ', dequantize=' + str(self.dequantize) \
            + ', inplace=' + str(self.inplace) \
            + ', init_state=' + str(self.init_state) \
            + ')'

    def forward(self, x):
        x = F.relu(x)
        Qn = 0.
        Qp = (2 ** self.abit) - 1
        if self.training and self.init_state == 0:
            self.scale.data.copy_(2 * x.abs().mean() / math.sqrt(Qp))
            self.init_state.fill_(1)

        g = 1.0 / math.sqrt(x.numel() * Qp)
        act_scale = grad_scale(self.scale, g)
        x = round_pass((x / act_scale).clamp(Qn, Qp))
        if self.dequantize:
            x = x * act_scale
        return x, act_scale

class QLeakyReLU(nn.Module, LSQModule):
    def __init__(self, abit, negative_slope=0.1, dequantize=True, inplace=False):
        super(QLeakyReLU, self).__init__()
        LSQModule.__init__(self, abit=abit, dequantize=dequantize,
                                scale=nn.Parameter(torch.Tensor(1)))
        self.inplace = inplace
        self.negative_slope=negative_slope

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'abit=' + str(self.abit) \
               + ', negative_slope=' + str(self.negative_slope) \
               + ', inplace=' + str(self.inplace) \
               + ')'

    def forward(self, input):
        deq_scale = input[1]
        input = input[0]

        Qn = - (2 ** (self.abit - 1))
        Qp =  2 ** (self.abit - 1) - 1


        input = input.cpu().detach().numpy().astype(Decimal)
        # input = torch.from_numpy(input)
        down_scale = deq_scale / self.scale
        slope_scale = self.negative_slope * down_scale
        down_scale = down_scale.cpu().detach().numpy().astype(Decimal)
        slope_scale = slope_scale.cpu().detach().numpy().astype(Decimal)

        output = np.where(input<0, input*slope_scale, input*down_scale).astype(np.float32)
        output = torch.from_numpy(output).cuda()

        x = torch.round(output).clamp(Qn, Qp)
        return x, self.scale

class QHswish(nn.Hardswish, LSQModule):
    def __init__(self, abit, dequantize=True, inplace=False):
        super(QHswish, self).__init__(inplace=inplace)
        LSQModule.__init__(self, abit=abit, dequantize=dequantize,
                                scale=nn.Parameter(torch.Tensor(1)))
        self.inplace = inplace

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'abit=' + str(self.abit) \
               + ', inplace=' + str(self.inplace) \
               + ')'

    def forward(self, input):
        deq_scale = input[1]
        x = input[0]
        # input = input * deq_scale

        # x = super().forward(input)    

        Qn = - (2 ** (self.abit - 1))
        Qp =  2 ** (self.abit - 1) - 1
        
        q_scale = self.scale
        down_scale = deq_scale / q_scale

        flag = int(torch.round(3/deq_scale))
        c1 = (down_scale * deq_scale / 6).cpu().detach().numpy().astype(Decimal)
        c2 = (down_scale / 2).cpu().detach().numpy().astype(Decimal)
        down_scale = down_scale.cpu().detach().numpy().astype(Decimal)

        x = x.cpu().detach().numpy().astype(Decimal)

        x = np.where(x<=-flag, x*0, x)
        x = np.where(x>=flag, down_scale*x, x*(c1*x+c2)).astype(np.float32)
        x = torch.from_numpy(x).cuda()
        # x = torch.where(x <= -flag, x*0, x)
        # x = torch.where(x >= flag, 
        #         down_scale*x, x*x*c1+x*c2)

        # act_scale = self.scale
        # down_scale = former_scale / self.scale
        # x = x * former_scal
        x = torch.round(x).clamp(Qn, Qp)

        return x, self.scale

class QConv2d(nn.Conv2d, LSQModule):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False, wbit=32, dequantize=True):
        super(QConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
        LSQModule.__init__(self, wbit=wbit, dequantize=dequantize,
                                scale=nn.Parameter(torch.Tensor(1)))
    
    def __repr__(self): #- for show detail arttribute on print(model)
        return self.__class__.__name__ + '(' \
               + 'in_channels=' + str(self.in_channels) \
               + ', out_channels=' + str(self.out_channels) \
               + ', bias=' + str(self.bias is not None) \
               + ', kernel_size=' + str(self.kernel_size) \
               + ', stride=' + str(self.stride) \
               + ', groups=' + str(self.groups) \
               + ', padding=' + str(self.padding) \
               + ', wbit=' + str(self.wbit) \
               + ')'
    
    def forward(self, x, act_scale=None):
        Qn = - (2 ** (self.wbit - 1))
        Qp =  2 ** (self.wbit - 1) - 1
        if self.training and self.init_state == 0:
            self.scale.data.copy_(2 * self.weight.abs().mean() / math.sqrt(Qp))
            self.init_state.fill_(1)

        g = 1.0 / math.sqrt(x.numel() * Qp)
        scale = grad_scale(self.scale, g)

        self.weight.data = round_pass((self.weight.data / scale).clamp(Qn, Qp)) 

        if self.dequantize:
            self.weight.data = self.weight.data * scale

        if self.bias is not None:
            bias_scale = scale*act_scale
            self.bias.data = round_pass((self.bias.data / bias_scale).clamp(Qn, Qp))
            if self.dequantize:
                self.bias.data = self.bias.data * bias_scale

        output = super().forward(x)
        return output

class QLinear(nn.Linear, LSQModule):
    def __init__(self, in_features, out_features, bias=True, wbit=32, dequantize=True):
        super(QLinear, self).__init__(in_features, out_features, bias)
        LSQModule.__init__(self, wbit=wbit, dequantize=dequantize,
                                scale=nn.Parameter(torch.Tensor(1)))

    def __repr__(self):
        return self.__class__.__name__ + '(' \
               + 'in_features=' + str(self.in_features) \
               + ', out_features=' + str(self.out_features) \
               + ', bias=' + str(self.bias is not None) \
               + ', wbit=' + str(self.wbit) \
               + ')'

    def forward(self, input, act_scale=None):

        if self.wbit < 32:          
            Qn = - (2 ** (self.wbit - 1)) 
            Qp =  2 ** (self.wbit - 1) - 1

            scale = self.scale

            cur_weight = torch.round((self.weight.data / scale).clamp(Qn, Qp))

            # with torch.no_grad():
            if self.bias is not None:
                bias_scale = scale*act_scale
                cur_bias = torch.round((self.bias.data / bias_scale))


        output = F.linear(input, cur_weight, cur_bias)
        return output

class Input_Quantizer(nn.Module, LSQModule):
    def __init__(self, abit=8, dequantize=True):
        super(Input_Quantizer, self).__init__()
        LSQModule.__init__(self, abit=abit, dequantize=dequantize,
                                scale=nn.Parameter(torch.Tensor(1)))

    def __repr__(self):
        return self.__class__.__name__ + '(' \
            + 'abit=' + str(self.abit) \
            + ', dequantize=' + str(self.dequantize) \
            + ', init_state=' + str(self.init_state) \
            + ')'

    def forward(self, x):
        Qn = - (2 ** (self.abit - 1))
        Qp = (2 ** (self.abit - 1)) - 1

        x = torch.round((x / self.scale).clamp(Qn, Qp))

        return x, self.scale

class FuseConv2dQ(QConv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                    padding=0, dilation=1, groups=1, bias=True, wbit=32, dequantize=True):
        super(FuseConv2dQ, self).__init__(
            in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
            stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias,
            wbit=wbit, dequantize=dequantize)
            
        self.bn = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        act_scale = x[1]
        x = x[0]

        # simulate bn folding to Conv
        f_weight, f_bias = self.fusing()
        Qn = - (2 ** (self.wbit - 1))
        Qp =  2 ** (self.wbit - 1) - 1

        scale = self.scale
        q_weight = torch.round((f_weight.data / scale).clamp(Qn, Qp))
        bias_scale = scale*act_scale
        q_bias = torch.round((f_bias / bias_scale))

        output = F.conv2d(x, q_weight, q_bias, self.stride, self.padding, self.dilation, self.groups)
        # output *= bias_scale # dequantize

        return output, bias_scale

    def replace_bn(self, bn_module):
        self.bn = bn_module
        self.bn.track_running_stats = False

    def fusing(self):
        std = torch.sqrt(self.bn.running_var + self.bn.eps)
        f_weight = self.weight * (self.bn.weight / std).reshape([len(self.bn.weight), 1,1,1])
        if self.bias is not None:
            f_bias = self.bn.bias + (self.bias - self.bn.runnning_mean) * (self.bn.weight / std)
        else:
            f_bias = self.bn.bias - self.bn.running_mean * (self.bn.weight / std)
        return f_weight, f_bias

def grad_scale(x, scale):
    y = x
    y_grad = x * scale
    output = (y - y_grad).detach() + y_grad

    return output

def round_pass(x):
    y = torch.round(x)
    y_grad = x
    output = (y - y_grad).detach() + y_grad

    return output