heeseon cheon

add code

This diff could not be displayed because it is too large.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class Masker(torch.autograd.Function):
@staticmethod
def forward(ctx, x, mask):
return x * mask
@staticmethod
def backward(ctx, grad):
return grad, None
class MaskConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode)
self.mask = Parameter(torch.ones(self.weight.size()), requires_grad=False)
def forward(self, inputs):
masked_weight = Masker.apply(self.weight, self.mask)
return super(MaskConv2d, self)._conv_forward(inputs, masked_weight)
import time
import random
import pathlib
from os.path import isfile
import copy
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from resnet_mask import *
from utils import *
def main(args):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device =='cuda':
torch.cuda.manual_seed_all(777)
## args
layers = int(args.layers)
prune_type = args.prune_type
prune_rate = float(args.prune_rate)
prune_imp = args.prune_imp
reg = args.reg
epochs = int(args.epochs)
batch_size = int(args.batch_size)
lr = float(args.lr)
momentum = float(args.momentum)
wd = float(args.wd)
odecay = float(args.odecay)
if prune_type:
prune = {'type':prune_type, 'rate':prune_rate}
else:
prune = None
if reg == 'reg_cov':
reg = reg_cov
cfgs = {
'18': (BasicBlock, [2, 2, 2, 2]),
'34': (BasicBlock, [3, 4, 6, 3]),
'50': (Bottleneck, [3, 4, 6, 3]),
'101': (Bottleneck, [3, 4, 23, 3]),
'152': (Bottleneck, [3, 8, 36, 3]),
}
cfgs_cifar = {
'20': [3, 3, 3],
'32': [5, 5, 5],
'44': [7, 7, 7],
'56': [9, 9, 9],
'110': [18, 18, 18],
}
train_data_mean = (0.5, 0.5, 0.5)
train_data_std = (0.5, 0.5, 0.5)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(train_data_mean, train_data_std)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(train_data_mean, train_data_std)
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
image_size = 32
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
##### main 함수 보고 train 짜기
best_acc1 = 0.0
print('prune rate', prune_rate, 'regularization odecay', odecay)
for epoch in range(epochs):
acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model,
criterion=criterion, optimizer=optimizer,
prune=prune, reg=reg, odecay=odecay)
acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion)
acc1_train = round(acc1_train_cor.item(), 4)
acc5_train = round(acc5_train_cor.item(), 4)
acc1_valid = round(acc1_valid_cor.item(), 4)
acc5_valid = round(acc5_valid_cor.item(), 4)
# remember best Acc@1 and save checkpoint and summary csv file
# summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
is_best = acc1_valid > best_acc1
best_acc1 = max(acc1_valid, best_acc1)
if is_best:
summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
print(summary)
# save_model(arch_name, args.dataset, state, args.save)
# save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="")
parser.add_argument('--layers', default=56)
parser.add_argument('--prune_type', default=None, help='None / structured / unstructured')
parser.add_argument('--prune_rate', default=0.9)
parser.add_argument('--prune_imp', default='L2')
parser.add_argument('--reg', default=None, help='None / reg_cov')
parser.add_argument('--epochs', default=300)
parser.add_argument('--batch_size', default=128)
parser.add_argument('--lr', default=0.2)
parser.add_argument('--momentum', default=0.9)
parser.add_argument('--wd', default=1e-4)
parser.add_argument('--odecay', default=1)
args = parser.parse_args()
main(args)
import torch
import torch.nn as nn
import masknn as mnn
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return mnn.MaskConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return mnn.MaskConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
__constants__ = ['downsample']
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
self.block_name = str(block.__name__)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = mnn.MaskConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, mnn.MaskConv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
class ResNet_CIFAR(nn.Module):
def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet_CIFAR, self).__init__()
self.block_name = str(block.__name__)
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 16
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = mnn.MaskConv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(block, 16, layers[0])
self.layer2 = self._make_layer(block, 32, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 64, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(64 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, mnn.MaskConv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def _forward_impl(self, x):
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def forward(self, x):
return self._forward_impl(x)
# Model configurations
cfgs = {
'18': (BasicBlock, [2, 2, 2, 2]),
'34': (BasicBlock, [3, 4, 6, 3]),
'50': (Bottleneck, [3, 4, 6, 3]),
'101': (Bottleneck, [3, 4, 23, 3]),
'152': (Bottleneck, [3, 8, 36, 3]),
}
cfgs_cifar = {
'20': [3, 3, 3],
'32': [5, 5, 5],
'44': [7, 7, 7],
'56': [9, 9, 9],
'110': [18, 18, 18],
}
def resnet(data='cifar10', **kwargs):
r"""ResNet models from "[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)"
Args:
data (str): the name of datasets
"""
num_layers = str(kwargs.get('num_layers'))
# set pruner
global mnn
mnn = kwargs.get('mnn')
assert mnn is not None, "Please specify proper pruning method"
if data in ['cifar10', 'cifar100']:
if num_layers in cfgs_cifar.keys():
model = ResNet_CIFAR(BasicBlock, cfgs_cifar[num_layers], int(data[5:]))
else:
model = None
image_size = 32
elif data == 'imagenet':
if num_layers in cfgs.keys():
block, layers = cfgs[num_layers]
model = ResNet(block, layers, 1000)
else:
model = None
image_size = 224
else:
model = None
image_size = None
return model, image_size
\ No newline at end of file
#!/bin/bash
RESULT_DIR=result_201203
if [ ! -d $RESULT_DIR ]; then
mkdir $RESULT_DIR
fi
#python modeling_default.py > $RESULT_DIR/default.txt #&
#python modeling_pruning.py > $RESULT_DIR/pruning_prune90.txt &
#python modeling_decorrelation.py > $RESULT_DIR/decorrelation_lambda1.txt &
#python modeling_pruning+decorrelation.py > $RESULT_DIR/pruning+decorrelation_lambda1+prune90.txt
#python modeling.py --prune_type structured --prune_rate 0.5 > $RESULT_DIR/prune_05.txt
#python modeling.py --prune_type structured --prune_rate 0.6 > $RESULT_DIR/prune_06.txt
#python modeling.py --prune_type structured --prune_rate 0.8 > $RESULT_DIR/prune_08.txt &
#python modeling.py --prune_type structured --prune_rate 0.7 > $RESULT_DIR/prune_07.txt
#python modeling.py --reg reg_cov --odecay 0.9 > $RESULT_DIR/reg_9.txt
#python modeling.py --reg reg_cov --odecay 0.8 > $RESULT_DIR/reg_8.txt
#python modeling.py --reg reg_cov --odecay 0.7 > $RESULT_DIR/reg_7.txt
#python modeling.py --reg reg_cov --odecay 0.6 > $RESULT_DIR/reg_6.txt
#python modeling.py --reg reg_cov --odecay 0.5 > $RESULT_DIR/reg_5.txt
#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_05_reg_07.txt &
#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_05_reg_08.txt &
#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_05_reg_09.txt &
#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_06_reg_07.txt &
#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_06_reg_08.txt &
python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_06_reg_09.txt
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import masknn
import resnet_mask
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import sys
def get_weight_threshold(model, rate, prune_imp='L1'):
importance_all = None
for name, item in model.named_parameters():
#module.named_parameters():
if len(item.size())==4 and 'mask' not in name:
weights = item.data.view(-1).cpu()
grads = item.grad.data.view(-1).cpu()
if prune_imp == 'L1':
importance = weights.abs().numpy()
elif prune_imp == 'L2':
importance = weights.pow(2).numpy()
elif prune_imp == 'grad':
importance = grads.abs().numpy()
elif prune_imp == 'syn':
importance = (weights * grads).abs().numpy()
if importance_all is None:
importance_all = importance
else:
importance_all = np.append(importance_all, importance)
threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
return threshold
def weight_prune(model, threshold, prune_imp='L1'):
state = model.state_dict()
for name, item in model.named_parameters():
if 'weight' in name:
key = name.replace('weight', 'mask')
if key in state.keys():
if prune_imp == 'L1':
mat = item.data.abs()
elif prune_imp == 'L2':
mat = item.data.pow(2)
elif prune_imp == 'grad':
mat = item.grad.data.abs()
elif prune_imp == 'syn':
mat = (item.data * item.grad.data).abs()
state[key].data.copy_(torch.gt(mat, threshold).float())
def get_filter_mask(model, rate, prune_imp='L1'):
importance_all = None
for name, item in model.named_parameters():
#.module.named_parameters():
if len(item.size())==4 and 'weight' in name:
filters = item.data.view(item.size(0), -1).cpu()
weight_len = filters.size(1)
if prune_imp =='L1':
importance = filters.abs().sum(dim=1).numpy() / weight_len
elif prune_imp == 'L2':
importance = filters.pow(2).sum(dim=1).numpy() / weight_len
if importance_all is None:
importance_all = importance
else:
importance_all = np.append(importance_all, importance)
threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
#threshold = np.percentile(importance_all, rate)
filter_mask = np.greater(importance_all, threshold)
return filter_mask
def filter_prune(model, filter_mask):
idx = 0
for name, item in model.named_parameters():
#.module.named_parameters():
if len(item.size())==4 and 'mask' in name:
for i in range(item.size(0)):
item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
idx += 1
def reg_ortho(mdl):
l2_reg = None
for W in mdl.parameters():
if W.ndimension() < 2:
continue
else:
cols = W[0].numel()
rows = W.shape[0]
w1 = W.view(-1,cols)
wt = torch.transpose(w1,0,1)
m = torch.matmul(wt,w1)
ident = Variable(torch.eye(cols,cols))
ident = ident.cuda()
w_tmp = (m - ident)
height = w_tmp.size(0)
u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12)
v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12)
u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12)
sigma = torch.dot(u, torch.matmul(w_tmp, v))
if l2_reg is None:
l2_reg = (sigma)**2
else:
l2_reg = l2_reg + (sigma)**2
return l2_reg
def reg_cov(mdl):
cov_reg = 0
for W in mdl.parameters():
if W.ndimension() < 2:
continue
else:
for w in W:
for w_ in w:
if w_.dim() > 0 and len(w_) == 2:
cov_ = np.cov(w_.detach().numpy())
cov_upper = np.triu(cov_)
cov_upper_abs = np.absolute(cov_upper)
cov_upper_abs_sum = np.sum(cov_upper_abs)
cov_reg += cov_upper_abs_sum
return cov_reg
class AverageMeter(object):
r"""Computes and stores the average and current value
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
r"""Computes the accuracy over the $k$ top predictions for the specified values of k
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def cal_sparsity(model):
mask_nonzeros = 0
mask_length = 0
total_weights = 0
for name, item in model.named_parameters():
#.module.named_parameters():
if 'mask' in name:
flatten = item.data.view(-1)
np_flatten = flatten.cpu().numpy()
mask_nonzeros += np.count_nonzero(np_flatten)
mask_length += item.numel()
if 'weight' in name or 'bias' in name:
total_weights += item.numel()
num_zero = mask_length - mask_nonzeros
sparsity = (num_zero / total_weights) * 100
return total_weights, num_zero, sparsity
def train(train_loader, epoch, model, criterion, optimizer, reg=None, prune=None, prune_freq=4, odecay=0, device='cuda'):
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
model.train()
for i, (inputs, targets) in enumerate(train_loader):
inputs = inputs.to(device)
targets = targets.to(device)
if prune:
if (i+1) % prune_freq == 0 and epoch <= 225:
if prune['type'] == 'structured':
filter_mask = get_filter_mask(model, prune['rate'])
filter_prune(model, filter_mask)
elif prune['type'] == 'unstructured':
thres = get_weight_threshold(model, prune['target_sparsity'])
weight_prune(model, thres)
outputs = model(inputs)
if reg:
oloss = reg(model)
oloss = odecay * oloss
loss = criterion(outputs, targets) + oloss
else:
loss = criterion(outputs, targets)
acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('train {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
if prune:
num_total, num_zero, sparsity = cal_sparsity(model)
print('sparsity {} ====> {:.2f}% || num_zero/num_total: {}/{}'.format(epoch, sparsity, num_zero, num_total))
return top1.avg, top5.avg
def validate(val_loader, epoch, model, criterion, device='cuda'):
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
model.eval()
with torch.no_grad():
for i, (inputs, targets) in enumerate(val_loader):
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
print('valid {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
return top1.avg, top5.avg