Showing
6 changed files
with
465 additions
and
0 deletions
code/Visualization.ipynb
0 → 100644
This diff could not be displayed because it is too large.
code/masknn.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | +from torch.nn.parameter import Parameter | ||
5 | + | ||
6 | + | ||
7 | +class Masker(torch.autograd.Function): | ||
8 | + @staticmethod | ||
9 | + def forward(ctx, x, mask): | ||
10 | + return x * mask | ||
11 | + | ||
12 | + @staticmethod | ||
13 | + def backward(ctx, grad): | ||
14 | + return grad, None | ||
15 | + | ||
16 | + | ||
17 | +class MaskConv2d(nn.Conv2d): | ||
18 | + def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
19 | + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): | ||
20 | + super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, | ||
21 | + padding, dilation, groups, bias, padding_mode) | ||
22 | + self.mask = Parameter(torch.ones(self.weight.size()), requires_grad=False) | ||
23 | + | ||
24 | + def forward(self, inputs): | ||
25 | + masked_weight = Masker.apply(self.weight, self.mask) | ||
26 | + return super(MaskConv2d, self)._conv_forward(inputs, masked_weight) |
code/modeling.py
0 → 100644
1 | +import time | ||
2 | +import random | ||
3 | +import pathlib | ||
4 | +from os.path import isfile | ||
5 | +import copy | ||
6 | +import sys | ||
7 | + | ||
8 | +import numpy as np | ||
9 | +import cv2 | ||
10 | + | ||
11 | +import torch | ||
12 | +import torch.nn as nn | ||
13 | +import torch.nn.functional as F | ||
14 | +import torch.optim as optim | ||
15 | +import torch.backends.cudnn as cudnn | ||
16 | + | ||
17 | +from torch.autograd import Variable | ||
18 | +import torchvision | ||
19 | +import torchvision.transforms as transforms | ||
20 | + | ||
21 | +from resnet_mask import * | ||
22 | +from utils import * | ||
23 | + | ||
24 | + | ||
25 | +def main(args): | ||
26 | + device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
27 | + torch.manual_seed(777) | ||
28 | + if device =='cuda': | ||
29 | + torch.cuda.manual_seed_all(777) | ||
30 | + | ||
31 | + ## args | ||
32 | + layers = int(args.layers) | ||
33 | + prune_type = args.prune_type | ||
34 | + prune_rate = float(args.prune_rate) | ||
35 | + prune_imp = args.prune_imp | ||
36 | + reg = args.reg | ||
37 | + epochs = int(args.epochs) | ||
38 | + batch_size = int(args.batch_size) | ||
39 | + lr = float(args.lr) | ||
40 | + momentum = float(args.momentum) | ||
41 | + wd = float(args.wd) | ||
42 | + odecay = float(args.odecay) | ||
43 | + | ||
44 | + if prune_type: | ||
45 | + prune = {'type':prune_type, 'rate':prune_rate} | ||
46 | + else: | ||
47 | + prune = None | ||
48 | + | ||
49 | + if reg == 'reg_cov': | ||
50 | + reg = reg_cov | ||
51 | + | ||
52 | + cfgs = { | ||
53 | + '18': (BasicBlock, [2, 2, 2, 2]), | ||
54 | + '34': (BasicBlock, [3, 4, 6, 3]), | ||
55 | + '50': (Bottleneck, [3, 4, 6, 3]), | ||
56 | + '101': (Bottleneck, [3, 4, 23, 3]), | ||
57 | + '152': (Bottleneck, [3, 8, 36, 3]), | ||
58 | + } | ||
59 | + cfgs_cifar = { | ||
60 | + '20': [3, 3, 3], | ||
61 | + '32': [5, 5, 5], | ||
62 | + '44': [7, 7, 7], | ||
63 | + '56': [9, 9, 9], | ||
64 | + '110': [18, 18, 18], | ||
65 | + } | ||
66 | + | ||
67 | + train_data_mean = (0.5, 0.5, 0.5) | ||
68 | + train_data_std = (0.5, 0.5, 0.5) | ||
69 | + | ||
70 | + transform_train = transforms.Compose([ | ||
71 | + transforms.RandomCrop(32, padding=4), | ||
72 | + transforms.ToTensor(), | ||
73 | + transforms.Normalize(train_data_mean, train_data_std) | ||
74 | + ]) | ||
75 | + | ||
76 | + transform_test = transforms.Compose([ | ||
77 | + transforms.ToTensor(), | ||
78 | + transforms.Normalize(train_data_mean, train_data_std) | ||
79 | + ]) | ||
80 | + | ||
81 | + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | ||
82 | + trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4) | ||
83 | + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | ||
84 | + testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4) | ||
85 | + | ||
86 | + classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') | ||
87 | + | ||
88 | + model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device) | ||
89 | + image_size = 32 | ||
90 | + | ||
91 | + criterion = nn.CrossEntropyLoss().to(device) | ||
92 | + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #nesterov=args.nesterov) | ||
93 | + lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) | ||
94 | + | ||
95 | + ##### main 함수 보고 train 짜기 | ||
96 | + best_acc1 = 0.0 | ||
97 | + | ||
98 | + print('prune rate', prune_rate, 'regularization odecay', odecay) | ||
99 | + | ||
100 | + for epoch in range(epochs): | ||
101 | + | ||
102 | + acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model, | ||
103 | + criterion=criterion, optimizer=optimizer, | ||
104 | + prune=prune, reg=reg, odecay=odecay) | ||
105 | + acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion) | ||
106 | + | ||
107 | + acc1_train = round(acc1_train_cor.item(), 4) | ||
108 | + acc5_train = round(acc5_train_cor.item(), 4) | ||
109 | + acc1_valid = round(acc1_valid_cor.item(), 4) | ||
110 | + acc5_valid = round(acc5_valid_cor.item(), 4) | ||
111 | + | ||
112 | + # remember best Acc@1 and save checkpoint and summary csv file | ||
113 | + # summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid] | ||
114 | + | ||
115 | + is_best = acc1_valid > best_acc1 | ||
116 | + best_acc1 = max(acc1_valid, best_acc1) | ||
117 | + if is_best: | ||
118 | + summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid] | ||
119 | + print(summary) | ||
120 | + # save_model(arch_name, args.dataset, state, args.save) | ||
121 | + # save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary) | ||
122 | + | ||
123 | +if __name__ == '__main__': | ||
124 | + import argparse | ||
125 | + parser = argparse.ArgumentParser(description="") | ||
126 | + parser.add_argument('--layers', default=56) | ||
127 | + parser.add_argument('--prune_type', default=None, help='None / structured / unstructured') | ||
128 | + parser.add_argument('--prune_rate', default=0.9) | ||
129 | + parser.add_argument('--prune_imp', default='L2') | ||
130 | + parser.add_argument('--reg', default=None, help='None / reg_cov') | ||
131 | + parser.add_argument('--epochs', default=300) | ||
132 | + parser.add_argument('--batch_size', default=128) | ||
133 | + parser.add_argument('--lr', default=0.2) | ||
134 | + parser.add_argument('--momentum', default=0.9) | ||
135 | + parser.add_argument('--wd', default=1e-4) | ||
136 | + parser.add_argument('--odecay', default=1) | ||
137 | + args = parser.parse_args() | ||
138 | + | ||
139 | + main(args) |
code/resnet_mask.py
0 → 100644
This diff is collapsed. Click to expand it.
code/run.sh
0 → 100644
1 | +#!/bin/bash | ||
2 | +RESULT_DIR=result_201203 | ||
3 | + | ||
4 | +if [ ! -d $RESULT_DIR ]; then | ||
5 | + mkdir $RESULT_DIR | ||
6 | +fi | ||
7 | + | ||
8 | +#python modeling_default.py > $RESULT_DIR/default.txt #& | ||
9 | +#python modeling_pruning.py > $RESULT_DIR/pruning_prune90.txt & | ||
10 | +#python modeling_decorrelation.py > $RESULT_DIR/decorrelation_lambda1.txt & | ||
11 | +#python modeling_pruning+decorrelation.py > $RESULT_DIR/pruning+decorrelation_lambda1+prune90.txt | ||
12 | + | ||
13 | +#python modeling.py --prune_type structured --prune_rate 0.5 > $RESULT_DIR/prune_05.txt | ||
14 | +#python modeling.py --prune_type structured --prune_rate 0.6 > $RESULT_DIR/prune_06.txt | ||
15 | +#python modeling.py --prune_type structured --prune_rate 0.8 > $RESULT_DIR/prune_08.txt & | ||
16 | +#python modeling.py --prune_type structured --prune_rate 0.7 > $RESULT_DIR/prune_07.txt | ||
17 | + | ||
18 | +#python modeling.py --reg reg_cov --odecay 0.9 > $RESULT_DIR/reg_9.txt | ||
19 | +#python modeling.py --reg reg_cov --odecay 0.8 > $RESULT_DIR/reg_8.txt | ||
20 | +#python modeling.py --reg reg_cov --odecay 0.7 > $RESULT_DIR/reg_7.txt | ||
21 | +#python modeling.py --reg reg_cov --odecay 0.6 > $RESULT_DIR/reg_6.txt | ||
22 | +#python modeling.py --reg reg_cov --odecay 0.5 > $RESULT_DIR/reg_5.txt | ||
23 | + | ||
24 | +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_05_reg_07.txt & | ||
25 | +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_05_reg_08.txt & | ||
26 | +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_05_reg_09.txt & | ||
27 | +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_06_reg_07.txt & | ||
28 | +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_06_reg_08.txt & | ||
29 | +python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_06_reg_09.txt | ||
30 | + |
code/utils.py
0 → 100644
1 | +import numpy as np # linear algebra | ||
2 | +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | ||
3 | +import masknn | ||
4 | +import resnet_mask | ||
5 | + | ||
6 | +import torch | ||
7 | +import torch.nn as nn | ||
8 | +import torchvision | ||
9 | +import torchvision.transforms as transforms | ||
10 | +import numpy as np | ||
11 | + | ||
12 | +import sys | ||
13 | + | ||
14 | + | ||
15 | +def get_weight_threshold(model, rate, prune_imp='L1'): | ||
16 | + importance_all = None | ||
17 | + for name, item in model.named_parameters(): | ||
18 | + #module.named_parameters(): | ||
19 | + if len(item.size())==4 and 'mask' not in name: | ||
20 | + weights = item.data.view(-1).cpu() | ||
21 | + grads = item.grad.data.view(-1).cpu() | ||
22 | + | ||
23 | + if prune_imp == 'L1': | ||
24 | + importance = weights.abs().numpy() | ||
25 | + elif prune_imp == 'L2': | ||
26 | + importance = weights.pow(2).numpy() | ||
27 | + elif prune_imp == 'grad': | ||
28 | + importance = grads.abs().numpy() | ||
29 | + elif prune_imp == 'syn': | ||
30 | + importance = (weights * grads).abs().numpy() | ||
31 | + | ||
32 | + | ||
33 | + if importance_all is None: | ||
34 | + importance_all = importance | ||
35 | + else: | ||
36 | + importance_all = np.append(importance_all, importance) | ||
37 | + | ||
38 | + threshold = np.sort(importance_all)[int(len(importance_all) * rate)] | ||
39 | + return threshold | ||
40 | + | ||
41 | + | ||
42 | +def weight_prune(model, threshold, prune_imp='L1'): | ||
43 | + state = model.state_dict() | ||
44 | + for name, item in model.named_parameters(): | ||
45 | + if 'weight' in name: | ||
46 | + key = name.replace('weight', 'mask') | ||
47 | + if key in state.keys(): | ||
48 | + if prune_imp == 'L1': | ||
49 | + mat = item.data.abs() | ||
50 | + elif prune_imp == 'L2': | ||
51 | + mat = item.data.pow(2) | ||
52 | + elif prune_imp == 'grad': | ||
53 | + mat = item.grad.data.abs() | ||
54 | + elif prune_imp == 'syn': | ||
55 | + mat = (item.data * item.grad.data).abs() | ||
56 | + state[key].data.copy_(torch.gt(mat, threshold).float()) | ||
57 | + | ||
58 | + | ||
59 | +def get_filter_mask(model, rate, prune_imp='L1'): | ||
60 | + importance_all = None | ||
61 | + for name, item in model.named_parameters(): | ||
62 | + #.module.named_parameters(): | ||
63 | + if len(item.size())==4 and 'weight' in name: | ||
64 | + filters = item.data.view(item.size(0), -1).cpu() | ||
65 | + weight_len = filters.size(1) | ||
66 | + if prune_imp =='L1': | ||
67 | + importance = filters.abs().sum(dim=1).numpy() / weight_len | ||
68 | + elif prune_imp == 'L2': | ||
69 | + importance = filters.pow(2).sum(dim=1).numpy() / weight_len | ||
70 | + | ||
71 | + if importance_all is None: | ||
72 | + importance_all = importance | ||
73 | + else: | ||
74 | + importance_all = np.append(importance_all, importance) | ||
75 | + | ||
76 | + | ||
77 | + threshold = np.sort(importance_all)[int(len(importance_all) * rate)] | ||
78 | + #threshold = np.percentile(importance_all, rate) | ||
79 | + filter_mask = np.greater(importance_all, threshold) | ||
80 | + return filter_mask | ||
81 | + | ||
82 | + | ||
83 | +def filter_prune(model, filter_mask): | ||
84 | + idx = 0 | ||
85 | + for name, item in model.named_parameters(): | ||
86 | + #.module.named_parameters(): | ||
87 | + if len(item.size())==4 and 'mask' in name: | ||
88 | + for i in range(item.size(0)): | ||
89 | + item.data[i,:,:,:] = 1 if filter_mask[idx] else 0 | ||
90 | + idx += 1 | ||
91 | + | ||
92 | + | ||
93 | +def reg_ortho(mdl): | ||
94 | + l2_reg = None | ||
95 | + for W in mdl.parameters(): | ||
96 | + if W.ndimension() < 2: | ||
97 | + continue | ||
98 | + else: | ||
99 | + cols = W[0].numel() | ||
100 | + rows = W.shape[0] | ||
101 | + w1 = W.view(-1,cols) | ||
102 | + wt = torch.transpose(w1,0,1) | ||
103 | + m = torch.matmul(wt,w1) | ||
104 | + ident = Variable(torch.eye(cols,cols)) | ||
105 | + ident = ident.cuda() | ||
106 | + | ||
107 | + w_tmp = (m - ident) | ||
108 | + height = w_tmp.size(0) | ||
109 | + u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12) | ||
110 | + v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12) | ||
111 | + u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12) | ||
112 | + sigma = torch.dot(u, torch.matmul(w_tmp, v)) | ||
113 | + | ||
114 | + if l2_reg is None: | ||
115 | + l2_reg = (sigma)**2 | ||
116 | + else: | ||
117 | + l2_reg = l2_reg + (sigma)**2 | ||
118 | + return l2_reg | ||
119 | + | ||
120 | + | ||
121 | +def reg_cov(mdl): | ||
122 | + cov_reg = 0 | ||
123 | + for W in mdl.parameters(): | ||
124 | + if W.ndimension() < 2: | ||
125 | + continue | ||
126 | + else: | ||
127 | + for w in W: | ||
128 | + for w_ in w: | ||
129 | + if w_.dim() > 0 and len(w_) == 2: | ||
130 | + cov_ = np.cov(w_.detach().numpy()) | ||
131 | + cov_upper = np.triu(cov_) | ||
132 | + cov_upper_abs = np.absolute(cov_upper) | ||
133 | + cov_upper_abs_sum = np.sum(cov_upper_abs) | ||
134 | + cov_reg += cov_upper_abs_sum | ||
135 | + | ||
136 | + return cov_reg | ||
137 | + | ||
138 | + | ||
139 | +class AverageMeter(object): | ||
140 | + r"""Computes and stores the average and current value | ||
141 | + """ | ||
142 | + def __init__(self, name, fmt=':f'): | ||
143 | + self.name = name | ||
144 | + self.fmt = fmt | ||
145 | + self.reset() | ||
146 | + | ||
147 | + def reset(self): | ||
148 | + self.val = 0 | ||
149 | + self.avg = 0 | ||
150 | + self.sum = 0 | ||
151 | + self.count = 0 | ||
152 | + | ||
153 | + def update(self, val, n=1): | ||
154 | + self.val = val | ||
155 | + self.sum += val * n | ||
156 | + self.count += n | ||
157 | + self.avg = self.sum / self.count | ||
158 | + | ||
159 | + def __str__(self): | ||
160 | + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | ||
161 | + return fmtstr.format(**self.__dict__) | ||
162 | + | ||
163 | + | ||
164 | +def accuracy(output, target, topk=(1,)): | ||
165 | + r"""Computes the accuracy over the $k$ top predictions for the specified values of k | ||
166 | + """ | ||
167 | + with torch.no_grad(): | ||
168 | + maxk = max(topk) | ||
169 | + batch_size = target.size(0) | ||
170 | + | ||
171 | + _, pred = output.topk(maxk, 1, True, True) | ||
172 | + pred = pred.t() | ||
173 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
174 | + | ||
175 | + res = [] | ||
176 | + for k in topk: | ||
177 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
178 | + res.append(correct_k.mul_(100.0 / batch_size)) | ||
179 | + return res | ||
180 | + | ||
181 | + | ||
182 | +def cal_sparsity(model): | ||
183 | + mask_nonzeros = 0 | ||
184 | + mask_length = 0 | ||
185 | + total_weights = 0 | ||
186 | + | ||
187 | + for name, item in model.named_parameters(): | ||
188 | + #.module.named_parameters(): | ||
189 | + if 'mask' in name: | ||
190 | + flatten = item.data.view(-1) | ||
191 | + np_flatten = flatten.cpu().numpy() | ||
192 | + | ||
193 | + mask_nonzeros += np.count_nonzero(np_flatten) | ||
194 | + mask_length += item.numel() | ||
195 | + | ||
196 | + if 'weight' in name or 'bias' in name: | ||
197 | + total_weights += item.numel() | ||
198 | + | ||
199 | + num_zero = mask_length - mask_nonzeros | ||
200 | + sparsity = (num_zero / total_weights) * 100 | ||
201 | + return total_weights, num_zero, sparsity | ||
202 | + | ||
203 | + | ||
204 | +def train(train_loader, epoch, model, criterion, optimizer, reg=None, prune=None, prune_freq=4, odecay=0, device='cuda'): | ||
205 | + losses = AverageMeter('Loss', ':.4e') | ||
206 | + top1 = AverageMeter('Acc@1', ':6.2f') | ||
207 | + top5 = AverageMeter('Acc@5', ':6.2f') | ||
208 | + | ||
209 | + model.train() | ||
210 | + | ||
211 | + for i, (inputs, targets) in enumerate(train_loader): | ||
212 | + inputs = inputs.to(device) | ||
213 | + targets = targets.to(device) | ||
214 | + | ||
215 | + if prune: | ||
216 | + if (i+1) % prune_freq == 0 and epoch <= 225: | ||
217 | + if prune['type'] == 'structured': | ||
218 | + filter_mask = get_filter_mask(model, prune['rate']) | ||
219 | + filter_prune(model, filter_mask) | ||
220 | + elif prune['type'] == 'unstructured': | ||
221 | + thres = get_weight_threshold(model, prune['target_sparsity']) | ||
222 | + weight_prune(model, thres) | ||
223 | + | ||
224 | + outputs = model(inputs) | ||
225 | + | ||
226 | + if reg: | ||
227 | + oloss = reg(model) | ||
228 | + oloss = odecay * oloss | ||
229 | + loss = criterion(outputs, targets) + oloss | ||
230 | + else: | ||
231 | + loss = criterion(outputs, targets) | ||
232 | + | ||
233 | + acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) | ||
234 | + losses.update(loss.item(), inputs.size(0)) | ||
235 | + top1.update(acc1[0], inputs.size(0)) | ||
236 | + top5.update(acc5[0], inputs.size(0)) | ||
237 | + | ||
238 | + optimizer.zero_grad() | ||
239 | + loss.backward() | ||
240 | + optimizer.step() | ||
241 | + | ||
242 | + print('train {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5)) | ||
243 | + if prune: | ||
244 | + num_total, num_zero, sparsity = cal_sparsity(model) | ||
245 | + print('sparsity {} ====> {:.2f}% || num_zero/num_total: {}/{}'.format(epoch, sparsity, num_zero, num_total)) | ||
246 | + return top1.avg, top5.avg | ||
247 | + | ||
248 | + | ||
249 | +def validate(val_loader, epoch, model, criterion, device='cuda'): | ||
250 | + losses = AverageMeter('Loss', ':.4e') | ||
251 | + top1 = AverageMeter('Acc@1', ':6.2f') | ||
252 | + top5 = AverageMeter('Acc@5', ':6.2f') | ||
253 | + | ||
254 | + model.eval() | ||
255 | + | ||
256 | + with torch.no_grad(): | ||
257 | + for i, (inputs, targets) in enumerate(val_loader): | ||
258 | + inputs = inputs.to(device) | ||
259 | + targets = targets.to(device) | ||
260 | + | ||
261 | + outputs = model(inputs) | ||
262 | + loss = criterion(outputs, targets) | ||
263 | + | ||
264 | + acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) | ||
265 | + losses.update(loss.item(), inputs.size(0)) | ||
266 | + top1.update(acc1[0], inputs.size(0)) | ||
267 | + top5.update(acc5[0], inputs.size(0)) | ||
268 | + | ||
269 | + print('valid {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5)) | ||
270 | + return top1.avg, top5.avg |
-
Please register or login to post a comment