이동주

calculate each data's variance, kl divergence

# run train.py --dataset cifar10 --model resnet18 --data_augmentation --cutout --length 16
# run train.py --dataset cifar100 --model resnet18 --data_augmentation --cutout --length 8
# run train.py --dataset svhn --model wideresnet --learning_rate 0.01 --epochs 160 --cutout --length 20
import pdb
import argparse
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import MultiStepLR
from torchvision.utils import make_grid
from torchvision import datasets, transforms
from torch.utils.data.dataloader import RandomSampler
from util.misc import CSVLogger
from util.cutout import Cutout
from model.resnet import ResNet18
from model.wide_resnet import WideResNet
model_options = ['resnet18', 'wideresnet']
dataset_options = ['cifar10', 'cifar100', 'svhn']
parser = argparse.ArgumentParser(description='CNN')
parser.add_argument('--dataset', '-d', default='cifar10',
choices=dataset_options)
parser.add_argument('--model', '-a', default='resnet18',
choices=model_options)
parser.add_argument('--batch_size', type=int, default=128,
help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=200,
help='number of epochs to train (default: 20)')
parser.add_argument('--learning_rate', type=float, default=0.1,
help='learning rate')
parser.add_argument('--data_augmentation', action='store_true', default=False,
help='augment data by flipping and cropping')
parser.add_argument('--cutout', action='store_true', default=False,
help='apply cutout')
parser.add_argument('--n_holes', type=int, default=1,
help='number of holes to cut out from image')
parser.add_argument('--length', type=int, default=16,
help='length of the holes')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='enables CUDA training')
parser.add_argument('--seed', type=int, default=0,
help='random seed (default: 1)')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
cudnn.benchmark = True # Should make training should go faster for large models
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
test_id = args.dataset + '_' + args.model
print(args)
# Image Preprocessing
if args.dataset == 'svhn':
normalize = transforms.Normalize(mean=[x / 255.0 for x in[109.9, 109.7, 113.8]],
std=[x / 255.0 for x in [50.1, 50.6, 50.8]])
else:
normalize = transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
std=[x / 255.0 for x in [63.0, 62.1, 66.7]])
train_transform = transforms.Compose([])
if args.data_augmentation:
train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
train_transform.transforms.append(transforms.RandomHorizontalFlip())
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(normalize)
if args.cutout:
train_transform.transforms.append(Cutout(n_holes=args.n_holes, length=args.length))
test_transform = transforms.Compose([
transforms.ToTensor(),
normalize])
if args.dataset == 'cifar10':
num_classes = 10
train_dataset = datasets.CIFAR10(root='data/',
train=True,
transform=train_transform,
download=True)
test_dataset = datasets.CIFAR10(root='data/',
train=False,
transform=test_transform,
download=True)
elif args.dataset == 'cifar100':
num_classes = 100
train_dataset = datasets.CIFAR100(root='data/',
train=True,
transform=train_transform,
download=True)
test_dataset = datasets.CIFAR100(root='data/',
train=False,
transform=test_transform,
download=True)
elif args.dataset == 'svhn':
num_classes = 10
train_dataset = datasets.SVHN(root='data/',
split='train',
transform=train_transform,
download=True)
extra_dataset = datasets.SVHN(root='data/',
split='extra',
transform=train_transform,
download=True)
# Combine both training splits (https://arxiv.org/pdf/1605.07146.pdf)
data = np.concatenate([train_dataset.data, extra_dataset.data], axis=0)
labels = np.concatenate([train_dataset.labels, extra_dataset.labels], axis=0)
train_dataset.data = data
train_dataset.labels = labels
test_dataset = datasets.SVHN(root='data/',
split='test',
transform=test_transform,
download=True)
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=False,
# sampler=RandomSampler(train_dataset, True, 40000),
pin_memory=True,
num_workers=0)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=args.batch_size,
shuffle=False,
pin_memory=True,
num_workers=0)
if args.model == 'resnet18':
cnn = ResNet18(num_classes=num_classes)
elif args.model == 'wideresnet':
if args.dataset == 'svhn':
cnn = WideResNet(depth=16, num_classes=num_classes, widen_factor=8,
dropRate=0.4)
else:
cnn = WideResNet(depth=28, num_classes=num_classes, widen_factor=10,
dropRate=0.3)
checkpoint = torch.load('/content/drive/MyDrive/capstone/Cutout/checkpoints/baseline_cifar10_resnet18.pt', map_location = torch.device('cuda:0'))
cnn = cnn.cuda()
cnn.load_state_dict(checkpoint)
criterion = nn.CrossEntropyLoss().cuda()
cnn_optimizer = torch.optim.SGD(cnn.parameters(), lr=args.learning_rate,
momentum=0.9, nesterov=True, weight_decay=5e-4)
if args.dataset == 'svhn':
scheduler = MultiStepLR(cnn_optimizer, milestones=[80, 120], gamma=0.1)
else:
scheduler = MultiStepLR(cnn_optimizer, milestones=[60, 120, 160], gamma=0.2)
filename = 'logs/' + test_id + '.csv'
csv_logger = CSVLogger(args=args, fieldnames=['epoch', 'train_acc', 'test_acc'], filename=filename)
def test(loader):
cnn.eval() # Change model to 'eval' mode (BN uses moving mean/var).
correct = 0.
total = 0.
for images, labels in loader:
images = images.cuda()
labels = labels.cuda()
with torch.no_grad():
pred = cnn(images)
pred = torch.max(pred.data, 1)[1]
total += labels.size(0)
correct += (pred == labels).sum().item()
val_acc = correct / total
cnn.train()
return val_acc
kl_sum = 0
y_bar = torch.Tensor([0] * 10).detach().cuda()
# y_bar 구하는 epoch
for epoch in range(args.epochs):
cnn.eval()
xentropy_loss_avg = 0.
correct = 0.
total = 0.
norm_const = 0
kldiv = 0
# pred_sum = torch.Tensor([0] * 10).detach().cuda()
progress_bar = tqdm(train_loader)
for i, (images, labels) in enumerate(progress_bar):
progress_bar.set_description('Epoch ' + str(epoch))
images = images.cuda()
labels = labels.cuda()
cnn.zero_grad()
pred = cnn(images)
xentropy_loss = criterion(pred, labels)
# xentropy_loss.backward()
# cnn_optimizer.step()
xentropy_loss_avg += xentropy_loss.item()
pred_softmax = nn.functional.softmax(pred).cuda()
# Calculate running average of accuracy
pred = torch.max(pred.data, 1)[1]
total += labels.size(0)
correct += (pred == labels.data).sum().item()
accuracy = correct / total
for a in range(pred_softmax.data.size()[0]):
for b in range(y_bar.size()[0]):
y_bar[b] += torch.log(pred_softmax.data[a][b])
# expectation(log y_hat)
# y_bar = [x / pred.data.size()[0] for x in y_bar]
# print(pred.data.size()[0], y_bar.size()[0]) # 128, 10
# print(pred)
# y_hat : 모델별 예측값 --> pred_softmax
# y_bar : 예측값들 평균값 -- > pred / total : pred_sum
# labes.data : ground_truth
# y_bar = pred_sum / (i+1)
# kl = torch.nn.functional.kl_div(pred, y_bar)
# kl_sum += kl
# for문 추가안하면 epoch별 iter마다 xentropy_loss_avg값의 1/iter이 xentropy값으로 출력
# for문 추가하면 epoch 별 iter 마다 xentropy_loss_avg 값은 동일하나 xentropy값 출력이 x_l_avg 값의 1/10으로 출력
# for문 상관 없이 pred, labels 값은 동일하게 확인됨.
# for a in range(list(pred_sum.size())[0]):
# for b in range(list(pred.size())[0]):
# if pred[b] == a:
# pred_sum[a] += 1
# variance calculate : E[KL_div(y_bar, y_hat)] -> expectation of KLDivLoss(pred_sum, pred)
# 한 epoch마다 계산해서 출력해야 할듯
# nn.functional.kl_div(pred_sum, pred)
# print('\n',i, ' ', xentropy_loss_avg)
progress_bar.set_postfix(
# y_hat = '%.5f' % pred,
# y_bar = '%.5f' % y_bar,
# groun_truth = '%.5f' % labels.data,
# kl = '%.3f' % kl.item(),
# kl_sum = '%.3f' % (kl_sum.item()),
# kl_div = '%.3f' % (kl_sum.item() / (i + 1)), # kl_div 호출
xentropy='%.3f' % (xentropy_loss_avg / (i + 1)),
acc='%.3f' % accuracy)
# pred_sum = [x / 40000 for x in pred_sum]
y_bar = torch.Tensor([x / 50000 for x in y_bar]).cuda()
y_bar = torch.exp(y_bar)
# print(y_bar)
for index in range(y_bar.size()[0]):
norm_const += y_bar[index]
print(y_bar)
print(norm_const)
# print(norm_const)
for index in range(y_bar.size()[0]):
y_bar[index] = y_bar[index] / norm_const
print(y_bar)
# print(y_bar)
# print(pred_softmax)
# print(y_bar)
# kldiv = torch.nn.functional.kl_div(y_bar, pred_softmax, reduction='batchmean')
# kl_sum += kldiv
# print(kldiv, kl_sum)
y_bar_copy = y_bar.clone().detach()
test_acc = test(test_loader)
# print(pred, labels.data)
tqdm.write('test_acc: %.3f' % (test_acc))
scheduler.step(epoch) # Use this line for PyTorch <1.4
# scheduler.step() # Use this line for PyTorch >=1.4
row = {'epoch': str(epoch), 'train_acc': str(accuracy), 'test_acc': str(test_acc)
}
csv_logger.writerow(row)
del pred
torch.cuda.empty_cache()
# kl_div 구하는 epoch
for epoch in range(args.epochs):
cnn.eval()
kldiv = 0
for i, (images, labels) in enumerate(progress_bar):
progress_bar.set_description('Epoch ' + str(epoch) + ': Calculate kl_div')
images = images.cuda()
labels = labels.cuda()
cnn.zero_grad()
pred = cnn(images)
pred_softmax = nn.functional.softmax(pred).cuda()
# 입력 두 개의 shape이 다르면 batchsize로 평균을 내서 반환.
kldiv = torch.nn.functional.kl_div(y_bar_copy, pred_softmax, reduction='sum')
kl_sum += kldiv.detach()
# print(y_bar_copy.size(), pred_softmax.size())
# print(kl_sum)
print("Average KL_div : ", abs(kl_sum / 50000))
# y_bar = torch.Tensor([x / 40000 for x in y_bar]).cuda()
# y_bar = torch.exp(y_bar)
# # print(y_bar)
# for index in range(y_bar.size()[0]):
# norm_const += y_bar[index]
# # print(norm_const)
# for index in range(y_bar.size()[0]):
# y_bar[index] = y_bar[index] / norm_const
# # print(y_bar)
# # print(pred_softmax)
# # print(y_bar)
# kldiv = torch.nn.functional.kl_div(y_bar, pred_softmax, reduction='batchmean')
# kl_sum += kldiv
# print(kldiv, kl_sum)
torch.save(cnn.state_dict(), 'checkpoints/' + test_id + '.pt')
csv_logger.close()