heeseon cheon

add code

This diff could not be displayed because it is too large.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class Masker(torch.autograd.Function):
@staticmethod
def forward(ctx, x, mask):
return x * mask
@staticmethod
def backward(ctx, grad):
return grad, None
class MaskConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias, padding_mode)
self.mask = Parameter(torch.ones(self.weight.size()), requires_grad=False)
def forward(self, inputs):
masked_weight = Masker.apply(self.weight, self.mask)
return super(MaskConv2d, self)._conv_forward(inputs, masked_weight)
import time
import random
import pathlib
from os.path import isfile
import copy
import sys
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms
from resnet_mask import *
from utils import *
def main(args):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(777)
if device =='cuda':
torch.cuda.manual_seed_all(777)
## args
layers = int(args.layers)
prune_type = args.prune_type
prune_rate = float(args.prune_rate)
prune_imp = args.prune_imp
reg = args.reg
epochs = int(args.epochs)
batch_size = int(args.batch_size)
lr = float(args.lr)
momentum = float(args.momentum)
wd = float(args.wd)
odecay = float(args.odecay)
if prune_type:
prune = {'type':prune_type, 'rate':prune_rate}
else:
prune = None
if reg == 'reg_cov':
reg = reg_cov
cfgs = {
'18': (BasicBlock, [2, 2, 2, 2]),
'34': (BasicBlock, [3, 4, 6, 3]),
'50': (Bottleneck, [3, 4, 6, 3]),
'101': (Bottleneck, [3, 4, 23, 3]),
'152': (Bottleneck, [3, 8, 36, 3]),
}
cfgs_cifar = {
'20': [3, 3, 3],
'32': [5, 5, 5],
'44': [7, 7, 7],
'56': [9, 9, 9],
'110': [18, 18, 18],
}
train_data_mean = (0.5, 0.5, 0.5)
train_data_std = (0.5, 0.5, 0.5)
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(train_data_mean, train_data_std)
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(train_data_mean, train_data_std)
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
image_size = 32
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
##### main 함수 보고 train 짜기
best_acc1 = 0.0
print('prune rate', prune_rate, 'regularization odecay', odecay)
for epoch in range(epochs):
acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model,
criterion=criterion, optimizer=optimizer,
prune=prune, reg=reg, odecay=odecay)
acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion)
acc1_train = round(acc1_train_cor.item(), 4)
acc5_train = round(acc5_train_cor.item(), 4)
acc1_valid = round(acc1_valid_cor.item(), 4)
acc5_valid = round(acc5_valid_cor.item(), 4)
# remember best Acc@1 and save checkpoint and summary csv file
# summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
is_best = acc1_valid > best_acc1
best_acc1 = max(acc1_valid, best_acc1)
if is_best:
summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
print(summary)
# save_model(arch_name, args.dataset, state, args.save)
# save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description="")
parser.add_argument('--layers', default=56)
parser.add_argument('--prune_type', default=None, help='None / structured / unstructured')
parser.add_argument('--prune_rate', default=0.9)
parser.add_argument('--prune_imp', default='L2')
parser.add_argument('--reg', default=None, help='None / reg_cov')
parser.add_argument('--epochs', default=300)
parser.add_argument('--batch_size', default=128)
parser.add_argument('--lr', default=0.2)
parser.add_argument('--momentum', default=0.9)
parser.add_argument('--wd', default=1e-4)
parser.add_argument('--odecay', default=1)
args = parser.parse_args()
main(args)
This diff is collapsed. Click to expand it.
#!/bin/bash
RESULT_DIR=result_201203
if [ ! -d $RESULT_DIR ]; then
mkdir $RESULT_DIR
fi
#python modeling_default.py > $RESULT_DIR/default.txt #&
#python modeling_pruning.py > $RESULT_DIR/pruning_prune90.txt &
#python modeling_decorrelation.py > $RESULT_DIR/decorrelation_lambda1.txt &
#python modeling_pruning+decorrelation.py > $RESULT_DIR/pruning+decorrelation_lambda1+prune90.txt
#python modeling.py --prune_type structured --prune_rate 0.5 > $RESULT_DIR/prune_05.txt
#python modeling.py --prune_type structured --prune_rate 0.6 > $RESULT_DIR/prune_06.txt
#python modeling.py --prune_type structured --prune_rate 0.8 > $RESULT_DIR/prune_08.txt &
#python modeling.py --prune_type structured --prune_rate 0.7 > $RESULT_DIR/prune_07.txt
#python modeling.py --reg reg_cov --odecay 0.9 > $RESULT_DIR/reg_9.txt
#python modeling.py --reg reg_cov --odecay 0.8 > $RESULT_DIR/reg_8.txt
#python modeling.py --reg reg_cov --odecay 0.7 > $RESULT_DIR/reg_7.txt
#python modeling.py --reg reg_cov --odecay 0.6 > $RESULT_DIR/reg_6.txt
#python modeling.py --reg reg_cov --odecay 0.5 > $RESULT_DIR/reg_5.txt
#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_05_reg_07.txt &
#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_05_reg_08.txt &
#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_05_reg_09.txt &
#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_06_reg_07.txt &
#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_06_reg_08.txt &
python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_06_reg_09.txt
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import masknn
import resnet_mask
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import numpy as np
import sys
def get_weight_threshold(model, rate, prune_imp='L1'):
importance_all = None
for name, item in model.named_parameters():
#module.named_parameters():
if len(item.size())==4 and 'mask' not in name:
weights = item.data.view(-1).cpu()
grads = item.grad.data.view(-1).cpu()
if prune_imp == 'L1':
importance = weights.abs().numpy()
elif prune_imp == 'L2':
importance = weights.pow(2).numpy()
elif prune_imp == 'grad':
importance = grads.abs().numpy()
elif prune_imp == 'syn':
importance = (weights * grads).abs().numpy()
if importance_all is None:
importance_all = importance
else:
importance_all = np.append(importance_all, importance)
threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
return threshold
def weight_prune(model, threshold, prune_imp='L1'):
state = model.state_dict()
for name, item in model.named_parameters():
if 'weight' in name:
key = name.replace('weight', 'mask')
if key in state.keys():
if prune_imp == 'L1':
mat = item.data.abs()
elif prune_imp == 'L2':
mat = item.data.pow(2)
elif prune_imp == 'grad':
mat = item.grad.data.abs()
elif prune_imp == 'syn':
mat = (item.data * item.grad.data).abs()
state[key].data.copy_(torch.gt(mat, threshold).float())
def get_filter_mask(model, rate, prune_imp='L1'):
importance_all = None
for name, item in model.named_parameters():
#.module.named_parameters():
if len(item.size())==4 and 'weight' in name:
filters = item.data.view(item.size(0), -1).cpu()
weight_len = filters.size(1)
if prune_imp =='L1':
importance = filters.abs().sum(dim=1).numpy() / weight_len
elif prune_imp == 'L2':
importance = filters.pow(2).sum(dim=1).numpy() / weight_len
if importance_all is None:
importance_all = importance
else:
importance_all = np.append(importance_all, importance)
threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
#threshold = np.percentile(importance_all, rate)
filter_mask = np.greater(importance_all, threshold)
return filter_mask
def filter_prune(model, filter_mask):
idx = 0
for name, item in model.named_parameters():
#.module.named_parameters():
if len(item.size())==4 and 'mask' in name:
for i in range(item.size(0)):
item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
idx += 1
def reg_ortho(mdl):
l2_reg = None
for W in mdl.parameters():
if W.ndimension() < 2:
continue
else:
cols = W[0].numel()
rows = W.shape[0]
w1 = W.view(-1,cols)
wt = torch.transpose(w1,0,1)
m = torch.matmul(wt,w1)
ident = Variable(torch.eye(cols,cols))
ident = ident.cuda()
w_tmp = (m - ident)
height = w_tmp.size(0)
u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12)
v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12)
u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12)
sigma = torch.dot(u, torch.matmul(w_tmp, v))
if l2_reg is None:
l2_reg = (sigma)**2
else:
l2_reg = l2_reg + (sigma)**2
return l2_reg
def reg_cov(mdl):
cov_reg = 0
for W in mdl.parameters():
if W.ndimension() < 2:
continue
else:
for w in W:
for w_ in w:
if w_.dim() > 0 and len(w_) == 2:
cov_ = np.cov(w_.detach().numpy())
cov_upper = np.triu(cov_)
cov_upper_abs = np.absolute(cov_upper)
cov_upper_abs_sum = np.sum(cov_upper_abs)
cov_reg += cov_upper_abs_sum
return cov_reg
class AverageMeter(object):
r"""Computes and stores the average and current value
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def accuracy(output, target, topk=(1,)):
r"""Computes the accuracy over the $k$ top predictions for the specified values of k
"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def cal_sparsity(model):
mask_nonzeros = 0
mask_length = 0
total_weights = 0
for name, item in model.named_parameters():
#.module.named_parameters():
if 'mask' in name:
flatten = item.data.view(-1)
np_flatten = flatten.cpu().numpy()
mask_nonzeros += np.count_nonzero(np_flatten)
mask_length += item.numel()
if 'weight' in name or 'bias' in name:
total_weights += item.numel()
num_zero = mask_length - mask_nonzeros
sparsity = (num_zero / total_weights) * 100
return total_weights, num_zero, sparsity
def train(train_loader, epoch, model, criterion, optimizer, reg=None, prune=None, prune_freq=4, odecay=0, device='cuda'):
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
model.train()
for i, (inputs, targets) in enumerate(train_loader):
inputs = inputs.to(device)
targets = targets.to(device)
if prune:
if (i+1) % prune_freq == 0 and epoch <= 225:
if prune['type'] == 'structured':
filter_mask = get_filter_mask(model, prune['rate'])
filter_prune(model, filter_mask)
elif prune['type'] == 'unstructured':
thres = get_weight_threshold(model, prune['target_sparsity'])
weight_prune(model, thres)
outputs = model(inputs)
if reg:
oloss = reg(model)
oloss = odecay * oloss
loss = criterion(outputs, targets) + oloss
else:
loss = criterion(outputs, targets)
acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('train {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
if prune:
num_total, num_zero, sparsity = cal_sparsity(model)
print('sparsity {} ====> {:.2f}% || num_zero/num_total: {}/{}'.format(epoch, sparsity, num_zero, num_total))
return top1.avg, top5.avg
def validate(val_loader, epoch, model, criterion, device='cuda'):
losses = AverageMeter('Loss', ':.4e')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')
model.eval()
with torch.no_grad():
for i, (inputs, targets) in enumerate(val_loader):
inputs = inputs.to(device)
targets = targets.to(device)
outputs = model(inputs)
loss = criterion(outputs, targets)
acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
losses.update(loss.item(), inputs.size(0))
top1.update(acc1[0], inputs.size(0))
top5.update(acc5[0], inputs.size(0))
print('valid {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
return top1.avg, top5.avg