Showing
6 changed files
with
803 additions
and
0 deletions
code/Visualization.ipynb
0 → 100644
This diff could not be displayed because it is too large.
code/masknn.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import torch.nn.functional as F | ||
4 | +from torch.nn.parameter import Parameter | ||
5 | + | ||
6 | + | ||
7 | +class Masker(torch.autograd.Function): | ||
8 | + @staticmethod | ||
9 | + def forward(ctx, x, mask): | ||
10 | + return x * mask | ||
11 | + | ||
12 | + @staticmethod | ||
13 | + def backward(ctx, grad): | ||
14 | + return grad, None | ||
15 | + | ||
16 | + | ||
17 | +class MaskConv2d(nn.Conv2d): | ||
18 | + def __init__(self, in_channels, out_channels, kernel_size, stride=1, | ||
19 | + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'): | ||
20 | + super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride, | ||
21 | + padding, dilation, groups, bias, padding_mode) | ||
22 | + self.mask = Parameter(torch.ones(self.weight.size()), requires_grad=False) | ||
23 | + | ||
24 | + def forward(self, inputs): | ||
25 | + masked_weight = Masker.apply(self.weight, self.mask) | ||
26 | + return super(MaskConv2d, self)._conv_forward(inputs, masked_weight) |
code/modeling.py
0 → 100644
1 | +import time | ||
2 | +import random | ||
3 | +import pathlib | ||
4 | +from os.path import isfile | ||
5 | +import copy | ||
6 | +import sys | ||
7 | + | ||
8 | +import numpy as np | ||
9 | +import cv2 | ||
10 | + | ||
11 | +import torch | ||
12 | +import torch.nn as nn | ||
13 | +import torch.nn.functional as F | ||
14 | +import torch.optim as optim | ||
15 | +import torch.backends.cudnn as cudnn | ||
16 | + | ||
17 | +from torch.autograd import Variable | ||
18 | +import torchvision | ||
19 | +import torchvision.transforms as transforms | ||
20 | + | ||
21 | +from resnet_mask import * | ||
22 | +from utils import * | ||
23 | + | ||
24 | + | ||
25 | +def main(args): | ||
26 | + device = 'cuda' if torch.cuda.is_available() else 'cpu' | ||
27 | + torch.manual_seed(777) | ||
28 | + if device =='cuda': | ||
29 | + torch.cuda.manual_seed_all(777) | ||
30 | + | ||
31 | + ## args | ||
32 | + layers = int(args.layers) | ||
33 | + prune_type = args.prune_type | ||
34 | + prune_rate = float(args.prune_rate) | ||
35 | + prune_imp = args.prune_imp | ||
36 | + reg = args.reg | ||
37 | + epochs = int(args.epochs) | ||
38 | + batch_size = int(args.batch_size) | ||
39 | + lr = float(args.lr) | ||
40 | + momentum = float(args.momentum) | ||
41 | + wd = float(args.wd) | ||
42 | + odecay = float(args.odecay) | ||
43 | + | ||
44 | + if prune_type: | ||
45 | + prune = {'type':prune_type, 'rate':prune_rate} | ||
46 | + else: | ||
47 | + prune = None | ||
48 | + | ||
49 | + if reg == 'reg_cov': | ||
50 | + reg = reg_cov | ||
51 | + | ||
52 | + cfgs = { | ||
53 | + '18': (BasicBlock, [2, 2, 2, 2]), | ||
54 | + '34': (BasicBlock, [3, 4, 6, 3]), | ||
55 | + '50': (Bottleneck, [3, 4, 6, 3]), | ||
56 | + '101': (Bottleneck, [3, 4, 23, 3]), | ||
57 | + '152': (Bottleneck, [3, 8, 36, 3]), | ||
58 | + } | ||
59 | + cfgs_cifar = { | ||
60 | + '20': [3, 3, 3], | ||
61 | + '32': [5, 5, 5], | ||
62 | + '44': [7, 7, 7], | ||
63 | + '56': [9, 9, 9], | ||
64 | + '110': [18, 18, 18], | ||
65 | + } | ||
66 | + | ||
67 | + train_data_mean = (0.5, 0.5, 0.5) | ||
68 | + train_data_std = (0.5, 0.5, 0.5) | ||
69 | + | ||
70 | + transform_train = transforms.Compose([ | ||
71 | + transforms.RandomCrop(32, padding=4), | ||
72 | + transforms.ToTensor(), | ||
73 | + transforms.Normalize(train_data_mean, train_data_std) | ||
74 | + ]) | ||
75 | + | ||
76 | + transform_test = transforms.Compose([ | ||
77 | + transforms.ToTensor(), | ||
78 | + transforms.Normalize(train_data_mean, train_data_std) | ||
79 | + ]) | ||
80 | + | ||
81 | + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) | ||
82 | + trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4) | ||
83 | + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) | ||
84 | + testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4) | ||
85 | + | ||
86 | + classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck') | ||
87 | + | ||
88 | + model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device) | ||
89 | + image_size = 32 | ||
90 | + | ||
91 | + criterion = nn.CrossEntropyLoss().to(device) | ||
92 | + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #nesterov=args.nesterov) | ||
93 | + lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5) | ||
94 | + | ||
95 | + ##### main 함수 보고 train 짜기 | ||
96 | + best_acc1 = 0.0 | ||
97 | + | ||
98 | + print('prune rate', prune_rate, 'regularization odecay', odecay) | ||
99 | + | ||
100 | + for epoch in range(epochs): | ||
101 | + | ||
102 | + acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model, | ||
103 | + criterion=criterion, optimizer=optimizer, | ||
104 | + prune=prune, reg=reg, odecay=odecay) | ||
105 | + acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion) | ||
106 | + | ||
107 | + acc1_train = round(acc1_train_cor.item(), 4) | ||
108 | + acc5_train = round(acc5_train_cor.item(), 4) | ||
109 | + acc1_valid = round(acc1_valid_cor.item(), 4) | ||
110 | + acc5_valid = round(acc5_valid_cor.item(), 4) | ||
111 | + | ||
112 | + # remember best Acc@1 and save checkpoint and summary csv file | ||
113 | + # summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid] | ||
114 | + | ||
115 | + is_best = acc1_valid > best_acc1 | ||
116 | + best_acc1 = max(acc1_valid, best_acc1) | ||
117 | + if is_best: | ||
118 | + summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid] | ||
119 | + print(summary) | ||
120 | + # save_model(arch_name, args.dataset, state, args.save) | ||
121 | + # save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary) | ||
122 | + | ||
123 | +if __name__ == '__main__': | ||
124 | + import argparse | ||
125 | + parser = argparse.ArgumentParser(description="") | ||
126 | + parser.add_argument('--layers', default=56) | ||
127 | + parser.add_argument('--prune_type', default=None, help='None / structured / unstructured') | ||
128 | + parser.add_argument('--prune_rate', default=0.9) | ||
129 | + parser.add_argument('--prune_imp', default='L2') | ||
130 | + parser.add_argument('--reg', default=None, help='None / reg_cov') | ||
131 | + parser.add_argument('--epochs', default=300) | ||
132 | + parser.add_argument('--batch_size', default=128) | ||
133 | + parser.add_argument('--lr', default=0.2) | ||
134 | + parser.add_argument('--momentum', default=0.9) | ||
135 | + parser.add_argument('--wd', default=1e-4) | ||
136 | + parser.add_argument('--odecay', default=1) | ||
137 | + args = parser.parse_args() | ||
138 | + | ||
139 | + main(args) |
code/resnet_mask.py
0 → 100644
1 | +import torch | ||
2 | +import torch.nn as nn | ||
3 | +import masknn as mnn | ||
4 | + | ||
5 | + | ||
6 | +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): | ||
7 | + """3x3 convolution with padding""" | ||
8 | + return mnn.MaskConv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
9 | + padding=dilation, groups=groups, bias=False, dilation=dilation) | ||
10 | + | ||
11 | + | ||
12 | +def conv1x1(in_planes, out_planes, stride=1): | ||
13 | + """1x1 convolution""" | ||
14 | + return mnn.MaskConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) | ||
15 | + | ||
16 | + | ||
17 | +class BasicBlock(nn.Module): | ||
18 | + expansion = 1 | ||
19 | + __constants__ = ['downsample'] | ||
20 | + | ||
21 | + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||
22 | + base_width=64, dilation=1, norm_layer=None): | ||
23 | + super(BasicBlock, self).__init__() | ||
24 | + if norm_layer is None: | ||
25 | + norm_layer = nn.BatchNorm2d | ||
26 | + if groups != 1 or base_width != 64: | ||
27 | + raise ValueError('BasicBlock only supports groups=1 and base_width=64') | ||
28 | + if dilation > 1: | ||
29 | + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") | ||
30 | + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 | ||
31 | + self.conv1 = conv3x3(inplanes, planes, stride) | ||
32 | + self.bn1 = norm_layer(planes) | ||
33 | + self.relu = nn.ReLU(inplace=True) | ||
34 | + self.conv2 = conv3x3(planes, planes) | ||
35 | + self.bn2 = norm_layer(planes) | ||
36 | + self.downsample = downsample | ||
37 | + self.stride = stride | ||
38 | + | ||
39 | + def forward(self, x): | ||
40 | + identity = x | ||
41 | + | ||
42 | + out = self.conv1(x) | ||
43 | + out = self.bn1(out) | ||
44 | + out = self.relu(out) | ||
45 | + | ||
46 | + out = self.conv2(out) | ||
47 | + out = self.bn2(out) | ||
48 | + | ||
49 | + if self.downsample is not None: | ||
50 | + identity = self.downsample(x) | ||
51 | + | ||
52 | + out += identity | ||
53 | + out = self.relu(out) | ||
54 | + | ||
55 | + return out | ||
56 | + | ||
57 | + | ||
58 | +class Bottleneck(nn.Module): | ||
59 | + expansion = 4 | ||
60 | + __constants__ = ['downsample'] | ||
61 | + | ||
62 | + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, | ||
63 | + base_width=64, dilation=1, norm_layer=None): | ||
64 | + super(Bottleneck, self).__init__() | ||
65 | + if norm_layer is None: | ||
66 | + norm_layer = nn.BatchNorm2d | ||
67 | + width = int(planes * (base_width / 64.)) * groups | ||
68 | + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 | ||
69 | + self.conv1 = conv1x1(inplanes, width) | ||
70 | + self.bn1 = norm_layer(width) | ||
71 | + self.conv2 = conv3x3(width, width, stride, groups, dilation) | ||
72 | + self.bn2 = norm_layer(width) | ||
73 | + self.conv3 = conv1x1(width, planes * self.expansion) | ||
74 | + self.bn3 = norm_layer(planes * self.expansion) | ||
75 | + self.relu = nn.ReLU(inplace=True) | ||
76 | + self.downsample = downsample | ||
77 | + self.stride = stride | ||
78 | + | ||
79 | + def forward(self, x): | ||
80 | + identity = x | ||
81 | + | ||
82 | + out = self.conv1(x) | ||
83 | + out = self.bn1(out) | ||
84 | + out = self.relu(out) | ||
85 | + | ||
86 | + out = self.conv2(out) | ||
87 | + out = self.bn2(out) | ||
88 | + out = self.relu(out) | ||
89 | + | ||
90 | + out = self.conv3(out) | ||
91 | + out = self.bn3(out) | ||
92 | + | ||
93 | + if self.downsample is not None: | ||
94 | + identity = self.downsample(x) | ||
95 | + | ||
96 | + out += identity | ||
97 | + out = self.relu(out) | ||
98 | + | ||
99 | + return out | ||
100 | + | ||
101 | + | ||
102 | +class ResNet(nn.Module): | ||
103 | + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, | ||
104 | + groups=1, width_per_group=64, replace_stride_with_dilation=None, | ||
105 | + norm_layer=None): | ||
106 | + super(ResNet, self).__init__() | ||
107 | + self.block_name = str(block.__name__) | ||
108 | + if norm_layer is None: | ||
109 | + norm_layer = nn.BatchNorm2d | ||
110 | + self._norm_layer = norm_layer | ||
111 | + | ||
112 | + self.inplanes = 64 | ||
113 | + self.dilation = 1 | ||
114 | + if replace_stride_with_dilation is None: | ||
115 | + # each element in the tuple indicates if we should replace | ||
116 | + # the 2x2 stride with a dilated convolution instead | ||
117 | + replace_stride_with_dilation = [False, False, False] | ||
118 | + if len(replace_stride_with_dilation) != 3: | ||
119 | + raise ValueError("replace_stride_with_dilation should be None " | ||
120 | + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | ||
121 | + self.groups = groups | ||
122 | + self.base_width = width_per_group | ||
123 | + self.conv1 = mnn.MaskConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, | ||
124 | + bias=False) | ||
125 | + self.bn1 = norm_layer(self.inplanes) | ||
126 | + self.relu = nn.ReLU(inplace=True) | ||
127 | + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
128 | + self.layer1 = self._make_layer(block, 64, layers[0]) | ||
129 | + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, | ||
130 | + dilate=replace_stride_with_dilation[0]) | ||
131 | + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, | ||
132 | + dilate=replace_stride_with_dilation[1]) | ||
133 | + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, | ||
134 | + dilate=replace_stride_with_dilation[2]) | ||
135 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
136 | + self.fc = nn.Linear(512 * block.expansion, num_classes) | ||
137 | + | ||
138 | + for m in self.modules(): | ||
139 | + if isinstance(m, mnn.MaskConv2d): | ||
140 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
141 | + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
142 | + nn.init.constant_(m.weight, 1) | ||
143 | + nn.init.constant_(m.bias, 0) | ||
144 | + | ||
145 | + # Zero-initialize the last BN in each residual branch, | ||
146 | + # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||
147 | + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||
148 | + if zero_init_residual: | ||
149 | + for m in self.modules(): | ||
150 | + if isinstance(m, Bottleneck): | ||
151 | + nn.init.constant_(m.bn3.weight, 0) | ||
152 | + elif isinstance(m, BasicBlock): | ||
153 | + nn.init.constant_(m.bn2.weight, 0) | ||
154 | + | ||
155 | + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||
156 | + norm_layer = self._norm_layer | ||
157 | + downsample = None | ||
158 | + previous_dilation = self.dilation | ||
159 | + if dilate: | ||
160 | + self.dilation *= stride | ||
161 | + stride = 1 | ||
162 | + if stride != 1 or self.inplanes != planes * block.expansion: | ||
163 | + downsample = nn.Sequential( | ||
164 | + conv1x1(self.inplanes, planes * block.expansion, stride), | ||
165 | + norm_layer(planes * block.expansion), | ||
166 | + ) | ||
167 | + | ||
168 | + layers = [] | ||
169 | + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | ||
170 | + self.base_width, previous_dilation, norm_layer)) | ||
171 | + self.inplanes = planes * block.expansion | ||
172 | + for _ in range(1, blocks): | ||
173 | + layers.append(block(self.inplanes, planes, groups=self.groups, | ||
174 | + base_width=self.base_width, dilation=self.dilation, | ||
175 | + norm_layer=norm_layer)) | ||
176 | + | ||
177 | + return nn.Sequential(*layers) | ||
178 | + | ||
179 | + def _forward_impl(self, x): | ||
180 | + # See note [TorchScript super()] | ||
181 | + x = self.conv1(x) | ||
182 | + x = self.bn1(x) | ||
183 | + x = self.relu(x) | ||
184 | + x = self.maxpool(x) | ||
185 | + | ||
186 | + x = self.layer1(x) | ||
187 | + x = self.layer2(x) | ||
188 | + x = self.layer3(x) | ||
189 | + x = self.layer4(x) | ||
190 | + | ||
191 | + x = self.avgpool(x) | ||
192 | + x = torch.flatten(x, 1) | ||
193 | + x = self.fc(x) | ||
194 | + return x | ||
195 | + | ||
196 | + def forward(self, x): | ||
197 | + return self._forward_impl(x) | ||
198 | + | ||
199 | +class ResNet_CIFAR(nn.Module): | ||
200 | + def __init__(self, block, layers, num_classes=10, zero_init_residual=False, | ||
201 | + groups=1, width_per_group=64, replace_stride_with_dilation=None, | ||
202 | + norm_layer=None): | ||
203 | + super(ResNet_CIFAR, self).__init__() | ||
204 | + self.block_name = str(block.__name__) | ||
205 | + if norm_layer is None: | ||
206 | + norm_layer = nn.BatchNorm2d | ||
207 | + self._norm_layer = norm_layer | ||
208 | + | ||
209 | + self.inplanes = 16 | ||
210 | + self.dilation = 1 | ||
211 | + if replace_stride_with_dilation is None: | ||
212 | + # each element in the tuple indicates if we should replace | ||
213 | + # the 2x2 stride with a dilated convolution instead | ||
214 | + replace_stride_with_dilation = [False, False, False] | ||
215 | + if len(replace_stride_with_dilation) != 3: | ||
216 | + raise ValueError("replace_stride_with_dilation should be None " | ||
217 | + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) | ||
218 | + self.groups = groups | ||
219 | + self.base_width = width_per_group | ||
220 | + self.conv1 = mnn.MaskConv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, | ||
221 | + bias=False) | ||
222 | + self.bn1 = norm_layer(self.inplanes) | ||
223 | + self.relu = nn.ReLU(inplace=True) | ||
224 | + self.layer1 = self._make_layer(block, 16, layers[0]) | ||
225 | + self.layer2 = self._make_layer(block, 32, layers[1], stride=2, | ||
226 | + dilate=replace_stride_with_dilation[0]) | ||
227 | + self.layer3 = self._make_layer(block, 64, layers[2], stride=2, | ||
228 | + dilate=replace_stride_with_dilation[1]) | ||
229 | + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | ||
230 | + self.fc = nn.Linear(64 * block.expansion, num_classes) | ||
231 | + | ||
232 | + for m in self.modules(): | ||
233 | + if isinstance(m, mnn.MaskConv2d): | ||
234 | + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
235 | + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): | ||
236 | + nn.init.constant_(m.weight, 1) | ||
237 | + nn.init.constant_(m.bias, 0) | ||
238 | + | ||
239 | + # Zero-initialize the last BN in each residual branch, | ||
240 | + # so that the residual branch starts with zeros, and each residual block behaves like an identity. | ||
241 | + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 | ||
242 | + if zero_init_residual: | ||
243 | + for m in self.modules(): | ||
244 | + if isinstance(m, Bottleneck): | ||
245 | + nn.init.constant_(m.bn3.weight, 0) | ||
246 | + elif isinstance(m, BasicBlock): | ||
247 | + nn.init.constant_(m.bn2.weight, 0) | ||
248 | + | ||
249 | + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): | ||
250 | + norm_layer = self._norm_layer | ||
251 | + downsample = None | ||
252 | + previous_dilation = self.dilation | ||
253 | + if dilate: | ||
254 | + self.dilation *= stride | ||
255 | + stride = 1 | ||
256 | + if stride != 1 or self.inplanes != planes * block.expansion: | ||
257 | + downsample = nn.Sequential( | ||
258 | + conv1x1(self.inplanes, planes * block.expansion, stride), | ||
259 | + norm_layer(planes * block.expansion), | ||
260 | + ) | ||
261 | + | ||
262 | + layers = [] | ||
263 | + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, | ||
264 | + self.base_width, previous_dilation, norm_layer)) | ||
265 | + self.inplanes = planes * block.expansion | ||
266 | + for _ in range(1, blocks): | ||
267 | + layers.append(block(self.inplanes, planes, groups=self.groups, | ||
268 | + base_width=self.base_width, dilation=self.dilation, | ||
269 | + norm_layer=norm_layer)) | ||
270 | + | ||
271 | + return nn.Sequential(*layers) | ||
272 | + | ||
273 | + def _forward_impl(self, x): | ||
274 | + # See note [TorchScript super()] | ||
275 | + x = self.conv1(x) | ||
276 | + x = self.bn1(x) | ||
277 | + x = self.relu(x) | ||
278 | + | ||
279 | + x = self.layer1(x) | ||
280 | + x = self.layer2(x) | ||
281 | + x = self.layer3(x) | ||
282 | + | ||
283 | + x = self.avgpool(x) | ||
284 | + x = torch.flatten(x, 1) | ||
285 | + x = self.fc(x) | ||
286 | + return x | ||
287 | + | ||
288 | + def forward(self, x): | ||
289 | + return self._forward_impl(x) | ||
290 | + | ||
291 | + | ||
292 | +# Model configurations | ||
293 | +cfgs = { | ||
294 | + '18': (BasicBlock, [2, 2, 2, 2]), | ||
295 | + '34': (BasicBlock, [3, 4, 6, 3]), | ||
296 | + '50': (Bottleneck, [3, 4, 6, 3]), | ||
297 | + '101': (Bottleneck, [3, 4, 23, 3]), | ||
298 | + '152': (Bottleneck, [3, 8, 36, 3]), | ||
299 | +} | ||
300 | +cfgs_cifar = { | ||
301 | + '20': [3, 3, 3], | ||
302 | + '32': [5, 5, 5], | ||
303 | + '44': [7, 7, 7], | ||
304 | + '56': [9, 9, 9], | ||
305 | + '110': [18, 18, 18], | ||
306 | +} | ||
307 | + | ||
308 | + | ||
309 | +def resnet(data='cifar10', **kwargs): | ||
310 | + r"""ResNet models from "[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)" | ||
311 | + Args: | ||
312 | + data (str): the name of datasets | ||
313 | + """ | ||
314 | + num_layers = str(kwargs.get('num_layers')) | ||
315 | + | ||
316 | + # set pruner | ||
317 | + global mnn | ||
318 | + mnn = kwargs.get('mnn') | ||
319 | + assert mnn is not None, "Please specify proper pruning method" | ||
320 | + | ||
321 | + if data in ['cifar10', 'cifar100']: | ||
322 | + if num_layers in cfgs_cifar.keys(): | ||
323 | + model = ResNet_CIFAR(BasicBlock, cfgs_cifar[num_layers], int(data[5:])) | ||
324 | + else: | ||
325 | + model = None | ||
326 | + image_size = 32 | ||
327 | + elif data == 'imagenet': | ||
328 | + if num_layers in cfgs.keys(): | ||
329 | + block, layers = cfgs[num_layers] | ||
330 | + model = ResNet(block, layers, 1000) | ||
331 | + else: | ||
332 | + model = None | ||
333 | + image_size = 224 | ||
334 | + else: | ||
335 | + model = None | ||
336 | + image_size = None | ||
337 | + | ||
338 | + return model, image_size | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
code/run.sh
0 → 100644
1 | +#!/bin/bash | ||
2 | +RESULT_DIR=result_201203 | ||
3 | + | ||
4 | +if [ ! -d $RESULT_DIR ]; then | ||
5 | + mkdir $RESULT_DIR | ||
6 | +fi | ||
7 | + | ||
8 | +#python modeling_default.py > $RESULT_DIR/default.txt #& | ||
9 | +#python modeling_pruning.py > $RESULT_DIR/pruning_prune90.txt & | ||
10 | +#python modeling_decorrelation.py > $RESULT_DIR/decorrelation_lambda1.txt & | ||
11 | +#python modeling_pruning+decorrelation.py > $RESULT_DIR/pruning+decorrelation_lambda1+prune90.txt | ||
12 | + | ||
13 | +#python modeling.py --prune_type structured --prune_rate 0.5 > $RESULT_DIR/prune_05.txt | ||
14 | +#python modeling.py --prune_type structured --prune_rate 0.6 > $RESULT_DIR/prune_06.txt | ||
15 | +#python modeling.py --prune_type structured --prune_rate 0.8 > $RESULT_DIR/prune_08.txt & | ||
16 | +#python modeling.py --prune_type structured --prune_rate 0.7 > $RESULT_DIR/prune_07.txt | ||
17 | + | ||
18 | +#python modeling.py --reg reg_cov --odecay 0.9 > $RESULT_DIR/reg_9.txt | ||
19 | +#python modeling.py --reg reg_cov --odecay 0.8 > $RESULT_DIR/reg_8.txt | ||
20 | +#python modeling.py --reg reg_cov --odecay 0.7 > $RESULT_DIR/reg_7.txt | ||
21 | +#python modeling.py --reg reg_cov --odecay 0.6 > $RESULT_DIR/reg_6.txt | ||
22 | +#python modeling.py --reg reg_cov --odecay 0.5 > $RESULT_DIR/reg_5.txt | ||
23 | + | ||
24 | +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_05_reg_07.txt & | ||
25 | +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_05_reg_08.txt & | ||
26 | +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_05_reg_09.txt & | ||
27 | +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_06_reg_07.txt & | ||
28 | +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_06_reg_08.txt & | ||
29 | +python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_06_reg_09.txt | ||
30 | + |
code/utils.py
0 → 100644
1 | +import numpy as np # linear algebra | ||
2 | +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv) | ||
3 | +import masknn | ||
4 | +import resnet_mask | ||
5 | + | ||
6 | +import torch | ||
7 | +import torch.nn as nn | ||
8 | +import torchvision | ||
9 | +import torchvision.transforms as transforms | ||
10 | +import numpy as np | ||
11 | + | ||
12 | +import sys | ||
13 | + | ||
14 | + | ||
15 | +def get_weight_threshold(model, rate, prune_imp='L1'): | ||
16 | + importance_all = None | ||
17 | + for name, item in model.named_parameters(): | ||
18 | + #module.named_parameters(): | ||
19 | + if len(item.size())==4 and 'mask' not in name: | ||
20 | + weights = item.data.view(-1).cpu() | ||
21 | + grads = item.grad.data.view(-1).cpu() | ||
22 | + | ||
23 | + if prune_imp == 'L1': | ||
24 | + importance = weights.abs().numpy() | ||
25 | + elif prune_imp == 'L2': | ||
26 | + importance = weights.pow(2).numpy() | ||
27 | + elif prune_imp == 'grad': | ||
28 | + importance = grads.abs().numpy() | ||
29 | + elif prune_imp == 'syn': | ||
30 | + importance = (weights * grads).abs().numpy() | ||
31 | + | ||
32 | + | ||
33 | + if importance_all is None: | ||
34 | + importance_all = importance | ||
35 | + else: | ||
36 | + importance_all = np.append(importance_all, importance) | ||
37 | + | ||
38 | + threshold = np.sort(importance_all)[int(len(importance_all) * rate)] | ||
39 | + return threshold | ||
40 | + | ||
41 | + | ||
42 | +def weight_prune(model, threshold, prune_imp='L1'): | ||
43 | + state = model.state_dict() | ||
44 | + for name, item in model.named_parameters(): | ||
45 | + if 'weight' in name: | ||
46 | + key = name.replace('weight', 'mask') | ||
47 | + if key in state.keys(): | ||
48 | + if prune_imp == 'L1': | ||
49 | + mat = item.data.abs() | ||
50 | + elif prune_imp == 'L2': | ||
51 | + mat = item.data.pow(2) | ||
52 | + elif prune_imp == 'grad': | ||
53 | + mat = item.grad.data.abs() | ||
54 | + elif prune_imp == 'syn': | ||
55 | + mat = (item.data * item.grad.data).abs() | ||
56 | + state[key].data.copy_(torch.gt(mat, threshold).float()) | ||
57 | + | ||
58 | + | ||
59 | +def get_filter_mask(model, rate, prune_imp='L1'): | ||
60 | + importance_all = None | ||
61 | + for name, item in model.named_parameters(): | ||
62 | + #.module.named_parameters(): | ||
63 | + if len(item.size())==4 and 'weight' in name: | ||
64 | + filters = item.data.view(item.size(0), -1).cpu() | ||
65 | + weight_len = filters.size(1) | ||
66 | + if prune_imp =='L1': | ||
67 | + importance = filters.abs().sum(dim=1).numpy() / weight_len | ||
68 | + elif prune_imp == 'L2': | ||
69 | + importance = filters.pow(2).sum(dim=1).numpy() / weight_len | ||
70 | + | ||
71 | + if importance_all is None: | ||
72 | + importance_all = importance | ||
73 | + else: | ||
74 | + importance_all = np.append(importance_all, importance) | ||
75 | + | ||
76 | + | ||
77 | + threshold = np.sort(importance_all)[int(len(importance_all) * rate)] | ||
78 | + #threshold = np.percentile(importance_all, rate) | ||
79 | + filter_mask = np.greater(importance_all, threshold) | ||
80 | + return filter_mask | ||
81 | + | ||
82 | + | ||
83 | +def filter_prune(model, filter_mask): | ||
84 | + idx = 0 | ||
85 | + for name, item in model.named_parameters(): | ||
86 | + #.module.named_parameters(): | ||
87 | + if len(item.size())==4 and 'mask' in name: | ||
88 | + for i in range(item.size(0)): | ||
89 | + item.data[i,:,:,:] = 1 if filter_mask[idx] else 0 | ||
90 | + idx += 1 | ||
91 | + | ||
92 | + | ||
93 | +def reg_ortho(mdl): | ||
94 | + l2_reg = None | ||
95 | + for W in mdl.parameters(): | ||
96 | + if W.ndimension() < 2: | ||
97 | + continue | ||
98 | + else: | ||
99 | + cols = W[0].numel() | ||
100 | + rows = W.shape[0] | ||
101 | + w1 = W.view(-1,cols) | ||
102 | + wt = torch.transpose(w1,0,1) | ||
103 | + m = torch.matmul(wt,w1) | ||
104 | + ident = Variable(torch.eye(cols,cols)) | ||
105 | + ident = ident.cuda() | ||
106 | + | ||
107 | + w_tmp = (m - ident) | ||
108 | + height = w_tmp.size(0) | ||
109 | + u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12) | ||
110 | + v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12) | ||
111 | + u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12) | ||
112 | + sigma = torch.dot(u, torch.matmul(w_tmp, v)) | ||
113 | + | ||
114 | + if l2_reg is None: | ||
115 | + l2_reg = (sigma)**2 | ||
116 | + else: | ||
117 | + l2_reg = l2_reg + (sigma)**2 | ||
118 | + return l2_reg | ||
119 | + | ||
120 | + | ||
121 | +def reg_cov(mdl): | ||
122 | + cov_reg = 0 | ||
123 | + for W in mdl.parameters(): | ||
124 | + if W.ndimension() < 2: | ||
125 | + continue | ||
126 | + else: | ||
127 | + for w in W: | ||
128 | + for w_ in w: | ||
129 | + if w_.dim() > 0 and len(w_) == 2: | ||
130 | + cov_ = np.cov(w_.detach().numpy()) | ||
131 | + cov_upper = np.triu(cov_) | ||
132 | + cov_upper_abs = np.absolute(cov_upper) | ||
133 | + cov_upper_abs_sum = np.sum(cov_upper_abs) | ||
134 | + cov_reg += cov_upper_abs_sum | ||
135 | + | ||
136 | + return cov_reg | ||
137 | + | ||
138 | + | ||
139 | +class AverageMeter(object): | ||
140 | + r"""Computes and stores the average and current value | ||
141 | + """ | ||
142 | + def __init__(self, name, fmt=':f'): | ||
143 | + self.name = name | ||
144 | + self.fmt = fmt | ||
145 | + self.reset() | ||
146 | + | ||
147 | + def reset(self): | ||
148 | + self.val = 0 | ||
149 | + self.avg = 0 | ||
150 | + self.sum = 0 | ||
151 | + self.count = 0 | ||
152 | + | ||
153 | + def update(self, val, n=1): | ||
154 | + self.val = val | ||
155 | + self.sum += val * n | ||
156 | + self.count += n | ||
157 | + self.avg = self.sum / self.count | ||
158 | + | ||
159 | + def __str__(self): | ||
160 | + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | ||
161 | + return fmtstr.format(**self.__dict__) | ||
162 | + | ||
163 | + | ||
164 | +def accuracy(output, target, topk=(1,)): | ||
165 | + r"""Computes the accuracy over the $k$ top predictions for the specified values of k | ||
166 | + """ | ||
167 | + with torch.no_grad(): | ||
168 | + maxk = max(topk) | ||
169 | + batch_size = target.size(0) | ||
170 | + | ||
171 | + _, pred = output.topk(maxk, 1, True, True) | ||
172 | + pred = pred.t() | ||
173 | + correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
174 | + | ||
175 | + res = [] | ||
176 | + for k in topk: | ||
177 | + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
178 | + res.append(correct_k.mul_(100.0 / batch_size)) | ||
179 | + return res | ||
180 | + | ||
181 | + | ||
182 | +def cal_sparsity(model): | ||
183 | + mask_nonzeros = 0 | ||
184 | + mask_length = 0 | ||
185 | + total_weights = 0 | ||
186 | + | ||
187 | + for name, item in model.named_parameters(): | ||
188 | + #.module.named_parameters(): | ||
189 | + if 'mask' in name: | ||
190 | + flatten = item.data.view(-1) | ||
191 | + np_flatten = flatten.cpu().numpy() | ||
192 | + | ||
193 | + mask_nonzeros += np.count_nonzero(np_flatten) | ||
194 | + mask_length += item.numel() | ||
195 | + | ||
196 | + if 'weight' in name or 'bias' in name: | ||
197 | + total_weights += item.numel() | ||
198 | + | ||
199 | + num_zero = mask_length - mask_nonzeros | ||
200 | + sparsity = (num_zero / total_weights) * 100 | ||
201 | + return total_weights, num_zero, sparsity | ||
202 | + | ||
203 | + | ||
204 | +def train(train_loader, epoch, model, criterion, optimizer, reg=None, prune=None, prune_freq=4, odecay=0, device='cuda'): | ||
205 | + losses = AverageMeter('Loss', ':.4e') | ||
206 | + top1 = AverageMeter('Acc@1', ':6.2f') | ||
207 | + top5 = AverageMeter('Acc@5', ':6.2f') | ||
208 | + | ||
209 | + model.train() | ||
210 | + | ||
211 | + for i, (inputs, targets) in enumerate(train_loader): | ||
212 | + inputs = inputs.to(device) | ||
213 | + targets = targets.to(device) | ||
214 | + | ||
215 | + if prune: | ||
216 | + if (i+1) % prune_freq == 0 and epoch <= 225: | ||
217 | + if prune['type'] == 'structured': | ||
218 | + filter_mask = get_filter_mask(model, prune['rate']) | ||
219 | + filter_prune(model, filter_mask) | ||
220 | + elif prune['type'] == 'unstructured': | ||
221 | + thres = get_weight_threshold(model, prune['target_sparsity']) | ||
222 | + weight_prune(model, thres) | ||
223 | + | ||
224 | + outputs = model(inputs) | ||
225 | + | ||
226 | + if reg: | ||
227 | + oloss = reg(model) | ||
228 | + oloss = odecay * oloss | ||
229 | + loss = criterion(outputs, targets) + oloss | ||
230 | + else: | ||
231 | + loss = criterion(outputs, targets) | ||
232 | + | ||
233 | + acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) | ||
234 | + losses.update(loss.item(), inputs.size(0)) | ||
235 | + top1.update(acc1[0], inputs.size(0)) | ||
236 | + top5.update(acc5[0], inputs.size(0)) | ||
237 | + | ||
238 | + optimizer.zero_grad() | ||
239 | + loss.backward() | ||
240 | + optimizer.step() | ||
241 | + | ||
242 | + print('train {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5)) | ||
243 | + if prune: | ||
244 | + num_total, num_zero, sparsity = cal_sparsity(model) | ||
245 | + print('sparsity {} ====> {:.2f}% || num_zero/num_total: {}/{}'.format(epoch, sparsity, num_zero, num_total)) | ||
246 | + return top1.avg, top5.avg | ||
247 | + | ||
248 | + | ||
249 | +def validate(val_loader, epoch, model, criterion, device='cuda'): | ||
250 | + losses = AverageMeter('Loss', ':.4e') | ||
251 | + top1 = AverageMeter('Acc@1', ':6.2f') | ||
252 | + top5 = AverageMeter('Acc@5', ':6.2f') | ||
253 | + | ||
254 | + model.eval() | ||
255 | + | ||
256 | + with torch.no_grad(): | ||
257 | + for i, (inputs, targets) in enumerate(val_loader): | ||
258 | + inputs = inputs.to(device) | ||
259 | + targets = targets.to(device) | ||
260 | + | ||
261 | + outputs = model(inputs) | ||
262 | + loss = criterion(outputs, targets) | ||
263 | + | ||
264 | + acc1, acc5 = accuracy(outputs, targets, topk=(1,5)) | ||
265 | + losses.update(loss.item(), inputs.size(0)) | ||
266 | + top1.update(acc1[0], inputs.size(0)) | ||
267 | + top5.update(acc5[0], inputs.size(0)) | ||
268 | + | ||
269 | + print('valid {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5)) | ||
270 | + return top1.avg, top5.avg |
-
Please register or login to post a comment