heeseon cheon

add code

This diff could not be displayed because it is too large.
1 +import torch
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +from torch.nn.parameter import Parameter
5 +
6 +
7 +class Masker(torch.autograd.Function):
8 + @staticmethod
9 + def forward(ctx, x, mask):
10 + return x * mask
11 +
12 + @staticmethod
13 + def backward(ctx, grad):
14 + return grad, None
15 +
16 +
17 +class MaskConv2d(nn.Conv2d):
18 + def __init__(self, in_channels, out_channels, kernel_size, stride=1,
19 + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
20 + super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
21 + padding, dilation, groups, bias, padding_mode)
22 + self.mask = Parameter(torch.ones(self.weight.size()), requires_grad=False)
23 +
24 + def forward(self, inputs):
25 + masked_weight = Masker.apply(self.weight, self.mask)
26 + return super(MaskConv2d, self)._conv_forward(inputs, masked_weight)
1 +import time
2 +import random
3 +import pathlib
4 +from os.path import isfile
5 +import copy
6 +import sys
7 +
8 +import numpy as np
9 +import cv2
10 +
11 +import torch
12 +import torch.nn as nn
13 +import torch.nn.functional as F
14 +import torch.optim as optim
15 +import torch.backends.cudnn as cudnn
16 +
17 +from torch.autograd import Variable
18 +import torchvision
19 +import torchvision.transforms as transforms
20 +
21 +from resnet_mask import *
22 +from utils import *
23 +
24 +
25 +def main(args):
26 + device = 'cuda' if torch.cuda.is_available() else 'cpu'
27 + torch.manual_seed(777)
28 + if device =='cuda':
29 + torch.cuda.manual_seed_all(777)
30 +
31 + ## args
32 + layers = int(args.layers)
33 + prune_type = args.prune_type
34 + prune_rate = float(args.prune_rate)
35 + prune_imp = args.prune_imp
36 + reg = args.reg
37 + epochs = int(args.epochs)
38 + batch_size = int(args.batch_size)
39 + lr = float(args.lr)
40 + momentum = float(args.momentum)
41 + wd = float(args.wd)
42 + odecay = float(args.odecay)
43 +
44 + if prune_type:
45 + prune = {'type':prune_type, 'rate':prune_rate}
46 + else:
47 + prune = None
48 +
49 + if reg == 'reg_cov':
50 + reg = reg_cov
51 +
52 + cfgs = {
53 + '18': (BasicBlock, [2, 2, 2, 2]),
54 + '34': (BasicBlock, [3, 4, 6, 3]),
55 + '50': (Bottleneck, [3, 4, 6, 3]),
56 + '101': (Bottleneck, [3, 4, 23, 3]),
57 + '152': (Bottleneck, [3, 8, 36, 3]),
58 + }
59 + cfgs_cifar = {
60 + '20': [3, 3, 3],
61 + '32': [5, 5, 5],
62 + '44': [7, 7, 7],
63 + '56': [9, 9, 9],
64 + '110': [18, 18, 18],
65 + }
66 +
67 + train_data_mean = (0.5, 0.5, 0.5)
68 + train_data_std = (0.5, 0.5, 0.5)
69 +
70 + transform_train = transforms.Compose([
71 + transforms.RandomCrop(32, padding=4),
72 + transforms.ToTensor(),
73 + transforms.Normalize(train_data_mean, train_data_std)
74 + ])
75 +
76 + transform_test = transforms.Compose([
77 + transforms.ToTensor(),
78 + transforms.Normalize(train_data_mean, train_data_std)
79 + ])
80 +
81 + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
82 + trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
83 + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
84 + testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)
85 +
86 + classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
87 +
88 + model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
89 + image_size = 32
90 +
91 + criterion = nn.CrossEntropyLoss().to(device)
92 + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
93 + lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
94 +
95 + ##### main 함수 보고 train 짜기
96 + best_acc1 = 0.0
97 +
98 + print('prune rate', prune_rate, 'regularization odecay', odecay)
99 +
100 + for epoch in range(epochs):
101 +
102 + acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model,
103 + criterion=criterion, optimizer=optimizer,
104 + prune=prune, reg=reg, odecay=odecay)
105 + acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion)
106 +
107 + acc1_train = round(acc1_train_cor.item(), 4)
108 + acc5_train = round(acc5_train_cor.item(), 4)
109 + acc1_valid = round(acc1_valid_cor.item(), 4)
110 + acc5_valid = round(acc5_valid_cor.item(), 4)
111 +
112 + # remember best Acc@1 and save checkpoint and summary csv file
113 + # summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
114 +
115 + is_best = acc1_valid > best_acc1
116 + best_acc1 = max(acc1_valid, best_acc1)
117 + if is_best:
118 + summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
119 + print(summary)
120 + # save_model(arch_name, args.dataset, state, args.save)
121 + # save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)
122 +
123 +if __name__ == '__main__':
124 + import argparse
125 + parser = argparse.ArgumentParser(description="")
126 + parser.add_argument('--layers', default=56)
127 + parser.add_argument('--prune_type', default=None, help='None / structured / unstructured')
128 + parser.add_argument('--prune_rate', default=0.9)
129 + parser.add_argument('--prune_imp', default='L2')
130 + parser.add_argument('--reg', default=None, help='None / reg_cov')
131 + parser.add_argument('--epochs', default=300)
132 + parser.add_argument('--batch_size', default=128)
133 + parser.add_argument('--lr', default=0.2)
134 + parser.add_argument('--momentum', default=0.9)
135 + parser.add_argument('--wd', default=1e-4)
136 + parser.add_argument('--odecay', default=1)
137 + args = parser.parse_args()
138 +
139 + main(args)
This diff is collapsed. Click to expand it.
1 +#!/bin/bash
2 +RESULT_DIR=result_201203
3 +
4 +if [ ! -d $RESULT_DIR ]; then
5 + mkdir $RESULT_DIR
6 +fi
7 +
8 +#python modeling_default.py > $RESULT_DIR/default.txt #&
9 +#python modeling_pruning.py > $RESULT_DIR/pruning_prune90.txt &
10 +#python modeling_decorrelation.py > $RESULT_DIR/decorrelation_lambda1.txt &
11 +#python modeling_pruning+decorrelation.py > $RESULT_DIR/pruning+decorrelation_lambda1+prune90.txt
12 +
13 +#python modeling.py --prune_type structured --prune_rate 0.5 > $RESULT_DIR/prune_05.txt
14 +#python modeling.py --prune_type structured --prune_rate 0.6 > $RESULT_DIR/prune_06.txt
15 +#python modeling.py --prune_type structured --prune_rate 0.8 > $RESULT_DIR/prune_08.txt &
16 +#python modeling.py --prune_type structured --prune_rate 0.7 > $RESULT_DIR/prune_07.txt
17 +
18 +#python modeling.py --reg reg_cov --odecay 0.9 > $RESULT_DIR/reg_9.txt
19 +#python modeling.py --reg reg_cov --odecay 0.8 > $RESULT_DIR/reg_8.txt
20 +#python modeling.py --reg reg_cov --odecay 0.7 > $RESULT_DIR/reg_7.txt
21 +#python modeling.py --reg reg_cov --odecay 0.6 > $RESULT_DIR/reg_6.txt
22 +#python modeling.py --reg reg_cov --odecay 0.5 > $RESULT_DIR/reg_5.txt
23 +
24 +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_05_reg_07.txt &
25 +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_05_reg_08.txt &
26 +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_05_reg_09.txt &
27 +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_06_reg_07.txt &
28 +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_06_reg_08.txt &
29 +python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_06_reg_09.txt
30 +
1 +import numpy as np # linear algebra
2 +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
3 +import masknn
4 +import resnet_mask
5 +
6 +import torch
7 +import torch.nn as nn
8 +import torchvision
9 +import torchvision.transforms as transforms
10 +import numpy as np
11 +
12 +import sys
13 +
14 +
15 +def get_weight_threshold(model, rate, prune_imp='L1'):
16 + importance_all = None
17 + for name, item in model.named_parameters():
18 + #module.named_parameters():
19 + if len(item.size())==4 and 'mask' not in name:
20 + weights = item.data.view(-1).cpu()
21 + grads = item.grad.data.view(-1).cpu()
22 +
23 + if prune_imp == 'L1':
24 + importance = weights.abs().numpy()
25 + elif prune_imp == 'L2':
26 + importance = weights.pow(2).numpy()
27 + elif prune_imp == 'grad':
28 + importance = grads.abs().numpy()
29 + elif prune_imp == 'syn':
30 + importance = (weights * grads).abs().numpy()
31 +
32 +
33 + if importance_all is None:
34 + importance_all = importance
35 + else:
36 + importance_all = np.append(importance_all, importance)
37 +
38 + threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
39 + return threshold
40 +
41 +
42 +def weight_prune(model, threshold, prune_imp='L1'):
43 + state = model.state_dict()
44 + for name, item in model.named_parameters():
45 + if 'weight' in name:
46 + key = name.replace('weight', 'mask')
47 + if key in state.keys():
48 + if prune_imp == 'L1':
49 + mat = item.data.abs()
50 + elif prune_imp == 'L2':
51 + mat = item.data.pow(2)
52 + elif prune_imp == 'grad':
53 + mat = item.grad.data.abs()
54 + elif prune_imp == 'syn':
55 + mat = (item.data * item.grad.data).abs()
56 + state[key].data.copy_(torch.gt(mat, threshold).float())
57 +
58 +
59 +def get_filter_mask(model, rate, prune_imp='L1'):
60 + importance_all = None
61 + for name, item in model.named_parameters():
62 + #.module.named_parameters():
63 + if len(item.size())==4 and 'weight' in name:
64 + filters = item.data.view(item.size(0), -1).cpu()
65 + weight_len = filters.size(1)
66 + if prune_imp =='L1':
67 + importance = filters.abs().sum(dim=1).numpy() / weight_len
68 + elif prune_imp == 'L2':
69 + importance = filters.pow(2).sum(dim=1).numpy() / weight_len
70 +
71 + if importance_all is None:
72 + importance_all = importance
73 + else:
74 + importance_all = np.append(importance_all, importance)
75 +
76 +
77 + threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
78 + #threshold = np.percentile(importance_all, rate)
79 + filter_mask = np.greater(importance_all, threshold)
80 + return filter_mask
81 +
82 +
83 +def filter_prune(model, filter_mask):
84 + idx = 0
85 + for name, item in model.named_parameters():
86 + #.module.named_parameters():
87 + if len(item.size())==4 and 'mask' in name:
88 + for i in range(item.size(0)):
89 + item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
90 + idx += 1
91 +
92 +
93 +def reg_ortho(mdl):
94 + l2_reg = None
95 + for W in mdl.parameters():
96 + if W.ndimension() < 2:
97 + continue
98 + else:
99 + cols = W[0].numel()
100 + rows = W.shape[0]
101 + w1 = W.view(-1,cols)
102 + wt = torch.transpose(w1,0,1)
103 + m = torch.matmul(wt,w1)
104 + ident = Variable(torch.eye(cols,cols))
105 + ident = ident.cuda()
106 +
107 + w_tmp = (m - ident)
108 + height = w_tmp.size(0)
109 + u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12)
110 + v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12)
111 + u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12)
112 + sigma = torch.dot(u, torch.matmul(w_tmp, v))
113 +
114 + if l2_reg is None:
115 + l2_reg = (sigma)**2
116 + else:
117 + l2_reg = l2_reg + (sigma)**2
118 + return l2_reg
119 +
120 +
121 +def reg_cov(mdl):
122 + cov_reg = 0
123 + for W in mdl.parameters():
124 + if W.ndimension() < 2:
125 + continue
126 + else:
127 + for w in W:
128 + for w_ in w:
129 + if w_.dim() > 0 and len(w_) == 2:
130 + cov_ = np.cov(w_.detach().numpy())
131 + cov_upper = np.triu(cov_)
132 + cov_upper_abs = np.absolute(cov_upper)
133 + cov_upper_abs_sum = np.sum(cov_upper_abs)
134 + cov_reg += cov_upper_abs_sum
135 +
136 + return cov_reg
137 +
138 +
139 +class AverageMeter(object):
140 + r"""Computes and stores the average and current value
141 + """
142 + def __init__(self, name, fmt=':f'):
143 + self.name = name
144 + self.fmt = fmt
145 + self.reset()
146 +
147 + def reset(self):
148 + self.val = 0
149 + self.avg = 0
150 + self.sum = 0
151 + self.count = 0
152 +
153 + def update(self, val, n=1):
154 + self.val = val
155 + self.sum += val * n
156 + self.count += n
157 + self.avg = self.sum / self.count
158 +
159 + def __str__(self):
160 + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
161 + return fmtstr.format(**self.__dict__)
162 +
163 +
164 +def accuracy(output, target, topk=(1,)):
165 + r"""Computes the accuracy over the $k$ top predictions for the specified values of k
166 + """
167 + with torch.no_grad():
168 + maxk = max(topk)
169 + batch_size = target.size(0)
170 +
171 + _, pred = output.topk(maxk, 1, True, True)
172 + pred = pred.t()
173 + correct = pred.eq(target.view(1, -1).expand_as(pred))
174 +
175 + res = []
176 + for k in topk:
177 + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
178 + res.append(correct_k.mul_(100.0 / batch_size))
179 + return res
180 +
181 +
182 +def cal_sparsity(model):
183 + mask_nonzeros = 0
184 + mask_length = 0
185 + total_weights = 0
186 +
187 + for name, item in model.named_parameters():
188 + #.module.named_parameters():
189 + if 'mask' in name:
190 + flatten = item.data.view(-1)
191 + np_flatten = flatten.cpu().numpy()
192 +
193 + mask_nonzeros += np.count_nonzero(np_flatten)
194 + mask_length += item.numel()
195 +
196 + if 'weight' in name or 'bias' in name:
197 + total_weights += item.numel()
198 +
199 + num_zero = mask_length - mask_nonzeros
200 + sparsity = (num_zero / total_weights) * 100
201 + return total_weights, num_zero, sparsity
202 +
203 +
204 +def train(train_loader, epoch, model, criterion, optimizer, reg=None, prune=None, prune_freq=4, odecay=0, device='cuda'):
205 + losses = AverageMeter('Loss', ':.4e')
206 + top1 = AverageMeter('Acc@1', ':6.2f')
207 + top5 = AverageMeter('Acc@5', ':6.2f')
208 +
209 + model.train()
210 +
211 + for i, (inputs, targets) in enumerate(train_loader):
212 + inputs = inputs.to(device)
213 + targets = targets.to(device)
214 +
215 + if prune:
216 + if (i+1) % prune_freq == 0 and epoch <= 225:
217 + if prune['type'] == 'structured':
218 + filter_mask = get_filter_mask(model, prune['rate'])
219 + filter_prune(model, filter_mask)
220 + elif prune['type'] == 'unstructured':
221 + thres = get_weight_threshold(model, prune['target_sparsity'])
222 + weight_prune(model, thres)
223 +
224 + outputs = model(inputs)
225 +
226 + if reg:
227 + oloss = reg(model)
228 + oloss = odecay * oloss
229 + loss = criterion(outputs, targets) + oloss
230 + else:
231 + loss = criterion(outputs, targets)
232 +
233 + acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
234 + losses.update(loss.item(), inputs.size(0))
235 + top1.update(acc1[0], inputs.size(0))
236 + top5.update(acc5[0], inputs.size(0))
237 +
238 + optimizer.zero_grad()
239 + loss.backward()
240 + optimizer.step()
241 +
242 + print('train {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
243 + if prune:
244 + num_total, num_zero, sparsity = cal_sparsity(model)
245 + print('sparsity {} ====> {:.2f}% || num_zero/num_total: {}/{}'.format(epoch, sparsity, num_zero, num_total))
246 + return top1.avg, top5.avg
247 +
248 +
249 +def validate(val_loader, epoch, model, criterion, device='cuda'):
250 + losses = AverageMeter('Loss', ':.4e')
251 + top1 = AverageMeter('Acc@1', ':6.2f')
252 + top5 = AverageMeter('Acc@5', ':6.2f')
253 +
254 + model.eval()
255 +
256 + with torch.no_grad():
257 + for i, (inputs, targets) in enumerate(val_loader):
258 + inputs = inputs.to(device)
259 + targets = targets.to(device)
260 +
261 + outputs = model(inputs)
262 + loss = criterion(outputs, targets)
263 +
264 + acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
265 + losses.update(loss.item(), inputs.size(0))
266 + top1.update(acc1[0], inputs.size(0))
267 + top5.update(acc5[0], inputs.size(0))
268 +
269 + print('valid {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
270 + return top1.avg, top5.avg