heeseon cheon

add code

This diff could not be displayed because it is too large.
1 +import torch
2 +import torch.nn as nn
3 +import torch.nn.functional as F
4 +from torch.nn.parameter import Parameter
5 +
6 +
7 +class Masker(torch.autograd.Function):
8 + @staticmethod
9 + def forward(ctx, x, mask):
10 + return x * mask
11 +
12 + @staticmethod
13 + def backward(ctx, grad):
14 + return grad, None
15 +
16 +
17 +class MaskConv2d(nn.Conv2d):
18 + def __init__(self, in_channels, out_channels, kernel_size, stride=1,
19 + padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
20 + super(MaskConv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
21 + padding, dilation, groups, bias, padding_mode)
22 + self.mask = Parameter(torch.ones(self.weight.size()), requires_grad=False)
23 +
24 + def forward(self, inputs):
25 + masked_weight = Masker.apply(self.weight, self.mask)
26 + return super(MaskConv2d, self)._conv_forward(inputs, masked_weight)
1 +import time
2 +import random
3 +import pathlib
4 +from os.path import isfile
5 +import copy
6 +import sys
7 +
8 +import numpy as np
9 +import cv2
10 +
11 +import torch
12 +import torch.nn as nn
13 +import torch.nn.functional as F
14 +import torch.optim as optim
15 +import torch.backends.cudnn as cudnn
16 +
17 +from torch.autograd import Variable
18 +import torchvision
19 +import torchvision.transforms as transforms
20 +
21 +from resnet_mask import *
22 +from utils import *
23 +
24 +
25 +def main(args):
26 + device = 'cuda' if torch.cuda.is_available() else 'cpu'
27 + torch.manual_seed(777)
28 + if device =='cuda':
29 + torch.cuda.manual_seed_all(777)
30 +
31 + ## args
32 + layers = int(args.layers)
33 + prune_type = args.prune_type
34 + prune_rate = float(args.prune_rate)
35 + prune_imp = args.prune_imp
36 + reg = args.reg
37 + epochs = int(args.epochs)
38 + batch_size = int(args.batch_size)
39 + lr = float(args.lr)
40 + momentum = float(args.momentum)
41 + wd = float(args.wd)
42 + odecay = float(args.odecay)
43 +
44 + if prune_type:
45 + prune = {'type':prune_type, 'rate':prune_rate}
46 + else:
47 + prune = None
48 +
49 + if reg == 'reg_cov':
50 + reg = reg_cov
51 +
52 + cfgs = {
53 + '18': (BasicBlock, [2, 2, 2, 2]),
54 + '34': (BasicBlock, [3, 4, 6, 3]),
55 + '50': (Bottleneck, [3, 4, 6, 3]),
56 + '101': (Bottleneck, [3, 4, 23, 3]),
57 + '152': (Bottleneck, [3, 8, 36, 3]),
58 + }
59 + cfgs_cifar = {
60 + '20': [3, 3, 3],
61 + '32': [5, 5, 5],
62 + '44': [7, 7, 7],
63 + '56': [9, 9, 9],
64 + '110': [18, 18, 18],
65 + }
66 +
67 + train_data_mean = (0.5, 0.5, 0.5)
68 + train_data_std = (0.5, 0.5, 0.5)
69 +
70 + transform_train = transforms.Compose([
71 + transforms.RandomCrop(32, padding=4),
72 + transforms.ToTensor(),
73 + transforms.Normalize(train_data_mean, train_data_std)
74 + ])
75 +
76 + transform_test = transforms.Compose([
77 + transforms.ToTensor(),
78 + transforms.Normalize(train_data_mean, train_data_std)
79 + ])
80 +
81 + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
82 + trainloader = torch.utils.data.DataLoader(trainset, batch_size=256, shuffle=True, num_workers=4)
83 + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
84 + testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=False, num_workers=4)
85 +
86 + classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
87 +
88 + model = ResNet_CIFAR(BasicBlock, cfgs_cifar['56'], 10).to(device)
89 + image_size = 32
90 +
91 + criterion = nn.CrossEntropyLoss().to(device)
92 + optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=wd) #nesterov=args.nesterov)
93 + lr_sche = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
94 +
95 + ##### main 함수 보고 train 짜기
96 + best_acc1 = 0.0
97 +
98 + print('prune rate', prune_rate, 'regularization odecay', odecay)
99 +
100 + for epoch in range(epochs):
101 +
102 + acc1_train_cor, acc5_train_cor = train(trainloader, epoch=epoch, model=model,
103 + criterion=criterion, optimizer=optimizer,
104 + prune=prune, reg=reg, odecay=odecay)
105 + acc1_valid_cor, acc5_valid_cor = validate(testloader, epoch=epoch, model=model, criterion=criterion)
106 +
107 + acc1_train = round(acc1_train_cor.item(), 4)
108 + acc5_train = round(acc5_train_cor.item(), 4)
109 + acc1_valid = round(acc1_valid_cor.item(), 4)
110 + acc5_valid = round(acc5_valid_cor.item(), 4)
111 +
112 + # remember best Acc@1 and save checkpoint and summary csv file
113 + # summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
114 +
115 + is_best = acc1_valid > best_acc1
116 + best_acc1 = max(acc1_valid, best_acc1)
117 + if is_best:
118 + summary = [epoch, acc1_train, acc5_train, acc1_valid, acc5_valid]
119 + print(summary)
120 + # save_model(arch_name, args.dataset, state, args.save)
121 + # save_summary(arch_name, args.dataset, args.save.split('.pth')[0], summary)
122 +
123 +if __name__ == '__main__':
124 + import argparse
125 + parser = argparse.ArgumentParser(description="")
126 + parser.add_argument('--layers', default=56)
127 + parser.add_argument('--prune_type', default=None, help='None / structured / unstructured')
128 + parser.add_argument('--prune_rate', default=0.9)
129 + parser.add_argument('--prune_imp', default='L2')
130 + parser.add_argument('--reg', default=None, help='None / reg_cov')
131 + parser.add_argument('--epochs', default=300)
132 + parser.add_argument('--batch_size', default=128)
133 + parser.add_argument('--lr', default=0.2)
134 + parser.add_argument('--momentum', default=0.9)
135 + parser.add_argument('--wd', default=1e-4)
136 + parser.add_argument('--odecay', default=1)
137 + args = parser.parse_args()
138 +
139 + main(args)
1 +import torch
2 +import torch.nn as nn
3 +import masknn as mnn
4 +
5 +
6 +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
7 + """3x3 convolution with padding"""
8 + return mnn.MaskConv2d(in_planes, out_planes, kernel_size=3, stride=stride,
9 + padding=dilation, groups=groups, bias=False, dilation=dilation)
10 +
11 +
12 +def conv1x1(in_planes, out_planes, stride=1):
13 + """1x1 convolution"""
14 + return mnn.MaskConv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
15 +
16 +
17 +class BasicBlock(nn.Module):
18 + expansion = 1
19 + __constants__ = ['downsample']
20 +
21 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
22 + base_width=64, dilation=1, norm_layer=None):
23 + super(BasicBlock, self).__init__()
24 + if norm_layer is None:
25 + norm_layer = nn.BatchNorm2d
26 + if groups != 1 or base_width != 64:
27 + raise ValueError('BasicBlock only supports groups=1 and base_width=64')
28 + if dilation > 1:
29 + raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
30 + # Both self.conv1 and self.downsample layers downsample the input when stride != 1
31 + self.conv1 = conv3x3(inplanes, planes, stride)
32 + self.bn1 = norm_layer(planes)
33 + self.relu = nn.ReLU(inplace=True)
34 + self.conv2 = conv3x3(planes, planes)
35 + self.bn2 = norm_layer(planes)
36 + self.downsample = downsample
37 + self.stride = stride
38 +
39 + def forward(self, x):
40 + identity = x
41 +
42 + out = self.conv1(x)
43 + out = self.bn1(out)
44 + out = self.relu(out)
45 +
46 + out = self.conv2(out)
47 + out = self.bn2(out)
48 +
49 + if self.downsample is not None:
50 + identity = self.downsample(x)
51 +
52 + out += identity
53 + out = self.relu(out)
54 +
55 + return out
56 +
57 +
58 +class Bottleneck(nn.Module):
59 + expansion = 4
60 + __constants__ = ['downsample']
61 +
62 + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
63 + base_width=64, dilation=1, norm_layer=None):
64 + super(Bottleneck, self).__init__()
65 + if norm_layer is None:
66 + norm_layer = nn.BatchNorm2d
67 + width = int(planes * (base_width / 64.)) * groups
68 + # Both self.conv2 and self.downsample layers downsample the input when stride != 1
69 + self.conv1 = conv1x1(inplanes, width)
70 + self.bn1 = norm_layer(width)
71 + self.conv2 = conv3x3(width, width, stride, groups, dilation)
72 + self.bn2 = norm_layer(width)
73 + self.conv3 = conv1x1(width, planes * self.expansion)
74 + self.bn3 = norm_layer(planes * self.expansion)
75 + self.relu = nn.ReLU(inplace=True)
76 + self.downsample = downsample
77 + self.stride = stride
78 +
79 + def forward(self, x):
80 + identity = x
81 +
82 + out = self.conv1(x)
83 + out = self.bn1(out)
84 + out = self.relu(out)
85 +
86 + out = self.conv2(out)
87 + out = self.bn2(out)
88 + out = self.relu(out)
89 +
90 + out = self.conv3(out)
91 + out = self.bn3(out)
92 +
93 + if self.downsample is not None:
94 + identity = self.downsample(x)
95 +
96 + out += identity
97 + out = self.relu(out)
98 +
99 + return out
100 +
101 +
102 +class ResNet(nn.Module):
103 + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
104 + groups=1, width_per_group=64, replace_stride_with_dilation=None,
105 + norm_layer=None):
106 + super(ResNet, self).__init__()
107 + self.block_name = str(block.__name__)
108 + if norm_layer is None:
109 + norm_layer = nn.BatchNorm2d
110 + self._norm_layer = norm_layer
111 +
112 + self.inplanes = 64
113 + self.dilation = 1
114 + if replace_stride_with_dilation is None:
115 + # each element in the tuple indicates if we should replace
116 + # the 2x2 stride with a dilated convolution instead
117 + replace_stride_with_dilation = [False, False, False]
118 + if len(replace_stride_with_dilation) != 3:
119 + raise ValueError("replace_stride_with_dilation should be None "
120 + "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
121 + self.groups = groups
122 + self.base_width = width_per_group
123 + self.conv1 = mnn.MaskConv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
124 + bias=False)
125 + self.bn1 = norm_layer(self.inplanes)
126 + self.relu = nn.ReLU(inplace=True)
127 + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
128 + self.layer1 = self._make_layer(block, 64, layers[0])
129 + self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
130 + dilate=replace_stride_with_dilation[0])
131 + self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
132 + dilate=replace_stride_with_dilation[1])
133 + self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
134 + dilate=replace_stride_with_dilation[2])
135 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
136 + self.fc = nn.Linear(512 * block.expansion, num_classes)
137 +
138 + for m in self.modules():
139 + if isinstance(m, mnn.MaskConv2d):
140 + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
141 + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
142 + nn.init.constant_(m.weight, 1)
143 + nn.init.constant_(m.bias, 0)
144 +
145 + # Zero-initialize the last BN in each residual branch,
146 + # so that the residual branch starts with zeros, and each residual block behaves like an identity.
147 + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
148 + if zero_init_residual:
149 + for m in self.modules():
150 + if isinstance(m, Bottleneck):
151 + nn.init.constant_(m.bn3.weight, 0)
152 + elif isinstance(m, BasicBlock):
153 + nn.init.constant_(m.bn2.weight, 0)
154 +
155 + def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
156 + norm_layer = self._norm_layer
157 + downsample = None
158 + previous_dilation = self.dilation
159 + if dilate:
160 + self.dilation *= stride
161 + stride = 1
162 + if stride != 1 or self.inplanes != planes * block.expansion:
163 + downsample = nn.Sequential(
164 + conv1x1(self.inplanes, planes * block.expansion, stride),
165 + norm_layer(planes * block.expansion),
166 + )
167 +
168 + layers = []
169 + layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
170 + self.base_width, previous_dilation, norm_layer))
171 + self.inplanes = planes * block.expansion
172 + for _ in range(1, blocks):
173 + layers.append(block(self.inplanes, planes, groups=self.groups,
174 + base_width=self.base_width, dilation=self.dilation,
175 + norm_layer=norm_layer))
176 +
177 + return nn.Sequential(*layers)
178 +
179 + def _forward_impl(self, x):
180 + # See note [TorchScript super()]
181 + x = self.conv1(x)
182 + x = self.bn1(x)
183 + x = self.relu(x)
184 + x = self.maxpool(x)
185 +
186 + x = self.layer1(x)
187 + x = self.layer2(x)
188 + x = self.layer3(x)
189 + x = self.layer4(x)
190 +
191 + x = self.avgpool(x)
192 + x = torch.flatten(x, 1)
193 + x = self.fc(x)
194 + return x
195 +
196 + def forward(self, x):
197 + return self._forward_impl(x)
198 +
199 +class ResNet_CIFAR(nn.Module):
200 + def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
201 + groups=1, width_per_group=64, replace_stride_with_dilation=None,
202 + norm_layer=None):
203 + super(ResNet_CIFAR, self).__init__()
204 + self.block_name = str(block.__name__)
205 + if norm_layer is None:
206 + norm_layer = nn.BatchNorm2d
207 + self._norm_layer = norm_layer
208 +
209 + self.inplanes = 16
210 + self.dilation = 1
211 + if replace_stride_with_dilation is None:
212 + # each element in the tuple indicates if we should replace
213 + # the 2x2 stride with a dilated convolution instead
214 + replace_stride_with_dilation = [False, False, False]
215 + if len(replace_stride_with_dilation) != 3:
216 + raise ValueError("replace_stride_with_dilation should be None "
217 + "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
218 + self.groups = groups
219 + self.base_width = width_per_group
220 + self.conv1 = mnn.MaskConv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
221 + bias=False)
222 + self.bn1 = norm_layer(self.inplanes)
223 + self.relu = nn.ReLU(inplace=True)
224 + self.layer1 = self._make_layer(block, 16, layers[0])
225 + self.layer2 = self._make_layer(block, 32, layers[1], stride=2,
226 + dilate=replace_stride_with_dilation[0])
227 + self.layer3 = self._make_layer(block, 64, layers[2], stride=2,
228 + dilate=replace_stride_with_dilation[1])
229 + self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
230 + self.fc = nn.Linear(64 * block.expansion, num_classes)
231 +
232 + for m in self.modules():
233 + if isinstance(m, mnn.MaskConv2d):
234 + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
235 + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
236 + nn.init.constant_(m.weight, 1)
237 + nn.init.constant_(m.bias, 0)
238 +
239 + # Zero-initialize the last BN in each residual branch,
240 + # so that the residual branch starts with zeros, and each residual block behaves like an identity.
241 + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
242 + if zero_init_residual:
243 + for m in self.modules():
244 + if isinstance(m, Bottleneck):
245 + nn.init.constant_(m.bn3.weight, 0)
246 + elif isinstance(m, BasicBlock):
247 + nn.init.constant_(m.bn2.weight, 0)
248 +
249 + def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
250 + norm_layer = self._norm_layer
251 + downsample = None
252 + previous_dilation = self.dilation
253 + if dilate:
254 + self.dilation *= stride
255 + stride = 1
256 + if stride != 1 or self.inplanes != planes * block.expansion:
257 + downsample = nn.Sequential(
258 + conv1x1(self.inplanes, planes * block.expansion, stride),
259 + norm_layer(planes * block.expansion),
260 + )
261 +
262 + layers = []
263 + layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
264 + self.base_width, previous_dilation, norm_layer))
265 + self.inplanes = planes * block.expansion
266 + for _ in range(1, blocks):
267 + layers.append(block(self.inplanes, planes, groups=self.groups,
268 + base_width=self.base_width, dilation=self.dilation,
269 + norm_layer=norm_layer))
270 +
271 + return nn.Sequential(*layers)
272 +
273 + def _forward_impl(self, x):
274 + # See note [TorchScript super()]
275 + x = self.conv1(x)
276 + x = self.bn1(x)
277 + x = self.relu(x)
278 +
279 + x = self.layer1(x)
280 + x = self.layer2(x)
281 + x = self.layer3(x)
282 +
283 + x = self.avgpool(x)
284 + x = torch.flatten(x, 1)
285 + x = self.fc(x)
286 + return x
287 +
288 + def forward(self, x):
289 + return self._forward_impl(x)
290 +
291 +
292 +# Model configurations
293 +cfgs = {
294 + '18': (BasicBlock, [2, 2, 2, 2]),
295 + '34': (BasicBlock, [3, 4, 6, 3]),
296 + '50': (Bottleneck, [3, 4, 6, 3]),
297 + '101': (Bottleneck, [3, 4, 23, 3]),
298 + '152': (Bottleneck, [3, 8, 36, 3]),
299 +}
300 +cfgs_cifar = {
301 + '20': [3, 3, 3],
302 + '32': [5, 5, 5],
303 + '44': [7, 7, 7],
304 + '56': [9, 9, 9],
305 + '110': [18, 18, 18],
306 +}
307 +
308 +
309 +def resnet(data='cifar10', **kwargs):
310 + r"""ResNet models from "[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)"
311 + Args:
312 + data (str): the name of datasets
313 + """
314 + num_layers = str(kwargs.get('num_layers'))
315 +
316 + # set pruner
317 + global mnn
318 + mnn = kwargs.get('mnn')
319 + assert mnn is not None, "Please specify proper pruning method"
320 +
321 + if data in ['cifar10', 'cifar100']:
322 + if num_layers in cfgs_cifar.keys():
323 + model = ResNet_CIFAR(BasicBlock, cfgs_cifar[num_layers], int(data[5:]))
324 + else:
325 + model = None
326 + image_size = 32
327 + elif data == 'imagenet':
328 + if num_layers in cfgs.keys():
329 + block, layers = cfgs[num_layers]
330 + model = ResNet(block, layers, 1000)
331 + else:
332 + model = None
333 + image_size = 224
334 + else:
335 + model = None
336 + image_size = None
337 +
338 + return model, image_size
...\ No newline at end of file ...\ No newline at end of file
1 +#!/bin/bash
2 +RESULT_DIR=result_201203
3 +
4 +if [ ! -d $RESULT_DIR ]; then
5 + mkdir $RESULT_DIR
6 +fi
7 +
8 +#python modeling_default.py > $RESULT_DIR/default.txt #&
9 +#python modeling_pruning.py > $RESULT_DIR/pruning_prune90.txt &
10 +#python modeling_decorrelation.py > $RESULT_DIR/decorrelation_lambda1.txt &
11 +#python modeling_pruning+decorrelation.py > $RESULT_DIR/pruning+decorrelation_lambda1+prune90.txt
12 +
13 +#python modeling.py --prune_type structured --prune_rate 0.5 > $RESULT_DIR/prune_05.txt
14 +#python modeling.py --prune_type structured --prune_rate 0.6 > $RESULT_DIR/prune_06.txt
15 +#python modeling.py --prune_type structured --prune_rate 0.8 > $RESULT_DIR/prune_08.txt &
16 +#python modeling.py --prune_type structured --prune_rate 0.7 > $RESULT_DIR/prune_07.txt
17 +
18 +#python modeling.py --reg reg_cov --odecay 0.9 > $RESULT_DIR/reg_9.txt
19 +#python modeling.py --reg reg_cov --odecay 0.8 > $RESULT_DIR/reg_8.txt
20 +#python modeling.py --reg reg_cov --odecay 0.7 > $RESULT_DIR/reg_7.txt
21 +#python modeling.py --reg reg_cov --odecay 0.6 > $RESULT_DIR/reg_6.txt
22 +#python modeling.py --reg reg_cov --odecay 0.5 > $RESULT_DIR/reg_5.txt
23 +
24 +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_05_reg_07.txt &
25 +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_05_reg_08.txt &
26 +#python modeling.py --prune_type structured --prune_rate 0.5 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_05_reg_09.txt &
27 +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.7 > $RESULT_DIR/prune_06_reg_07.txt &
28 +#python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.8 > $RESULT_DIR/prune_06_reg_08.txt &
29 +python modeling.py --prune_type structured --prune_rate 0.6 --reg reg_cov --odecay 0.9 > $RESULT_DIR/prune_06_reg_09.txt
30 +
1 +import numpy as np # linear algebra
2 +import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
3 +import masknn
4 +import resnet_mask
5 +
6 +import torch
7 +import torch.nn as nn
8 +import torchvision
9 +import torchvision.transforms as transforms
10 +import numpy as np
11 +
12 +import sys
13 +
14 +
15 +def get_weight_threshold(model, rate, prune_imp='L1'):
16 + importance_all = None
17 + for name, item in model.named_parameters():
18 + #module.named_parameters():
19 + if len(item.size())==4 and 'mask' not in name:
20 + weights = item.data.view(-1).cpu()
21 + grads = item.grad.data.view(-1).cpu()
22 +
23 + if prune_imp == 'L1':
24 + importance = weights.abs().numpy()
25 + elif prune_imp == 'L2':
26 + importance = weights.pow(2).numpy()
27 + elif prune_imp == 'grad':
28 + importance = grads.abs().numpy()
29 + elif prune_imp == 'syn':
30 + importance = (weights * grads).abs().numpy()
31 +
32 +
33 + if importance_all is None:
34 + importance_all = importance
35 + else:
36 + importance_all = np.append(importance_all, importance)
37 +
38 + threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
39 + return threshold
40 +
41 +
42 +def weight_prune(model, threshold, prune_imp='L1'):
43 + state = model.state_dict()
44 + for name, item in model.named_parameters():
45 + if 'weight' in name:
46 + key = name.replace('weight', 'mask')
47 + if key in state.keys():
48 + if prune_imp == 'L1':
49 + mat = item.data.abs()
50 + elif prune_imp == 'L2':
51 + mat = item.data.pow(2)
52 + elif prune_imp == 'grad':
53 + mat = item.grad.data.abs()
54 + elif prune_imp == 'syn':
55 + mat = (item.data * item.grad.data).abs()
56 + state[key].data.copy_(torch.gt(mat, threshold).float())
57 +
58 +
59 +def get_filter_mask(model, rate, prune_imp='L1'):
60 + importance_all = None
61 + for name, item in model.named_parameters():
62 + #.module.named_parameters():
63 + if len(item.size())==4 and 'weight' in name:
64 + filters = item.data.view(item.size(0), -1).cpu()
65 + weight_len = filters.size(1)
66 + if prune_imp =='L1':
67 + importance = filters.abs().sum(dim=1).numpy() / weight_len
68 + elif prune_imp == 'L2':
69 + importance = filters.pow(2).sum(dim=1).numpy() / weight_len
70 +
71 + if importance_all is None:
72 + importance_all = importance
73 + else:
74 + importance_all = np.append(importance_all, importance)
75 +
76 +
77 + threshold = np.sort(importance_all)[int(len(importance_all) * rate)]
78 + #threshold = np.percentile(importance_all, rate)
79 + filter_mask = np.greater(importance_all, threshold)
80 + return filter_mask
81 +
82 +
83 +def filter_prune(model, filter_mask):
84 + idx = 0
85 + for name, item in model.named_parameters():
86 + #.module.named_parameters():
87 + if len(item.size())==4 and 'mask' in name:
88 + for i in range(item.size(0)):
89 + item.data[i,:,:,:] = 1 if filter_mask[idx] else 0
90 + idx += 1
91 +
92 +
93 +def reg_ortho(mdl):
94 + l2_reg = None
95 + for W in mdl.parameters():
96 + if W.ndimension() < 2:
97 + continue
98 + else:
99 + cols = W[0].numel()
100 + rows = W.shape[0]
101 + w1 = W.view(-1,cols)
102 + wt = torch.transpose(w1,0,1)
103 + m = torch.matmul(wt,w1)
104 + ident = Variable(torch.eye(cols,cols))
105 + ident = ident.cuda()
106 +
107 + w_tmp = (m - ident)
108 + height = w_tmp.size(0)
109 + u = normalize(w_tmp.new_empty(height).normal_(0,1), dim=0, eps=1e-12)
110 + v = normalize(torch.matmul(w_tmp.t(), u), dim=0, eps=1e-12)
111 + u = normalize(torch.matmul(w_tmp, v), dim=0, eps=1e-12)
112 + sigma = torch.dot(u, torch.matmul(w_tmp, v))
113 +
114 + if l2_reg is None:
115 + l2_reg = (sigma)**2
116 + else:
117 + l2_reg = l2_reg + (sigma)**2
118 + return l2_reg
119 +
120 +
121 +def reg_cov(mdl):
122 + cov_reg = 0
123 + for W in mdl.parameters():
124 + if W.ndimension() < 2:
125 + continue
126 + else:
127 + for w in W:
128 + for w_ in w:
129 + if w_.dim() > 0 and len(w_) == 2:
130 + cov_ = np.cov(w_.detach().numpy())
131 + cov_upper = np.triu(cov_)
132 + cov_upper_abs = np.absolute(cov_upper)
133 + cov_upper_abs_sum = np.sum(cov_upper_abs)
134 + cov_reg += cov_upper_abs_sum
135 +
136 + return cov_reg
137 +
138 +
139 +class AverageMeter(object):
140 + r"""Computes and stores the average and current value
141 + """
142 + def __init__(self, name, fmt=':f'):
143 + self.name = name
144 + self.fmt = fmt
145 + self.reset()
146 +
147 + def reset(self):
148 + self.val = 0
149 + self.avg = 0
150 + self.sum = 0
151 + self.count = 0
152 +
153 + def update(self, val, n=1):
154 + self.val = val
155 + self.sum += val * n
156 + self.count += n
157 + self.avg = self.sum / self.count
158 +
159 + def __str__(self):
160 + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
161 + return fmtstr.format(**self.__dict__)
162 +
163 +
164 +def accuracy(output, target, topk=(1,)):
165 + r"""Computes the accuracy over the $k$ top predictions for the specified values of k
166 + """
167 + with torch.no_grad():
168 + maxk = max(topk)
169 + batch_size = target.size(0)
170 +
171 + _, pred = output.topk(maxk, 1, True, True)
172 + pred = pred.t()
173 + correct = pred.eq(target.view(1, -1).expand_as(pred))
174 +
175 + res = []
176 + for k in topk:
177 + correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
178 + res.append(correct_k.mul_(100.0 / batch_size))
179 + return res
180 +
181 +
182 +def cal_sparsity(model):
183 + mask_nonzeros = 0
184 + mask_length = 0
185 + total_weights = 0
186 +
187 + for name, item in model.named_parameters():
188 + #.module.named_parameters():
189 + if 'mask' in name:
190 + flatten = item.data.view(-1)
191 + np_flatten = flatten.cpu().numpy()
192 +
193 + mask_nonzeros += np.count_nonzero(np_flatten)
194 + mask_length += item.numel()
195 +
196 + if 'weight' in name or 'bias' in name:
197 + total_weights += item.numel()
198 +
199 + num_zero = mask_length - mask_nonzeros
200 + sparsity = (num_zero / total_weights) * 100
201 + return total_weights, num_zero, sparsity
202 +
203 +
204 +def train(train_loader, epoch, model, criterion, optimizer, reg=None, prune=None, prune_freq=4, odecay=0, device='cuda'):
205 + losses = AverageMeter('Loss', ':.4e')
206 + top1 = AverageMeter('Acc@1', ':6.2f')
207 + top5 = AverageMeter('Acc@5', ':6.2f')
208 +
209 + model.train()
210 +
211 + for i, (inputs, targets) in enumerate(train_loader):
212 + inputs = inputs.to(device)
213 + targets = targets.to(device)
214 +
215 + if prune:
216 + if (i+1) % prune_freq == 0 and epoch <= 225:
217 + if prune['type'] == 'structured':
218 + filter_mask = get_filter_mask(model, prune['rate'])
219 + filter_prune(model, filter_mask)
220 + elif prune['type'] == 'unstructured':
221 + thres = get_weight_threshold(model, prune['target_sparsity'])
222 + weight_prune(model, thres)
223 +
224 + outputs = model(inputs)
225 +
226 + if reg:
227 + oloss = reg(model)
228 + oloss = odecay * oloss
229 + loss = criterion(outputs, targets) + oloss
230 + else:
231 + loss = criterion(outputs, targets)
232 +
233 + acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
234 + losses.update(loss.item(), inputs.size(0))
235 + top1.update(acc1[0], inputs.size(0))
236 + top5.update(acc5[0], inputs.size(0))
237 +
238 + optimizer.zero_grad()
239 + loss.backward()
240 + optimizer.step()
241 +
242 + print('train {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
243 + if prune:
244 + num_total, num_zero, sparsity = cal_sparsity(model)
245 + print('sparsity {} ====> {:.2f}% || num_zero/num_total: {}/{}'.format(epoch, sparsity, num_zero, num_total))
246 + return top1.avg, top5.avg
247 +
248 +
249 +def validate(val_loader, epoch, model, criterion, device='cuda'):
250 + losses = AverageMeter('Loss', ':.4e')
251 + top1 = AverageMeter('Acc@1', ':6.2f')
252 + top5 = AverageMeter('Acc@5', ':6.2f')
253 +
254 + model.eval()
255 +
256 + with torch.no_grad():
257 + for i, (inputs, targets) in enumerate(val_loader):
258 + inputs = inputs.to(device)
259 + targets = targets.to(device)
260 +
261 + outputs = model(inputs)
262 + loss = criterion(outputs, targets)
263 +
264 + acc1, acc5 = accuracy(outputs, targets, topk=(1,5))
265 + losses.update(loss.item(), inputs.size(0))
266 + top1.update(acc1[0], inputs.size(0))
267 + top5.update(acc5[0], inputs.size(0))
268 +
269 + print('valid {i} ====> Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(i=epoch, top1=top1, top5=top5))
270 + return top1.avg, top5.avg