get_confidence.py 4.32 KB
import torch.multiprocessing as mp
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import argparse
import numpy as np
from get_mean_std import get_params
from model import mobilenetv3
import parmap


#Image resize 계수
resize_size = 64

#Class 의 개수.
class_num = 7

#사용한 Random Seed
seeds = [39396, 2798, 3843, 62034, 8817, 65014, 45385]

#기기에 있는 GPU 개수.
gpu = 4

#저장된 Checkpoint.
checkpoints = [
    "output/ErrorType/39396_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar",
    "output/ErrorType/2798_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar",
    "output/ErrorType/3843_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar",
    "output/ErrorType/62034_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar",
    "output/ErrorType/8817_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar",
    "output/ErrorType/65014_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar",
    "output/ErrorType/45385_model=MobilenetV3-ep=3000-block=4/checkpoint.pth.tar"
]

class AverageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def get_models():
    models=[]
    for idx, checkpoint in enumerate(checkpoints):
        gpu_idx = idx % gpu

        weights = torch.load(checkpoint)
        model = mobilenetv3(n_class=class_num)

        torch.cuda.set_device(gpu_idx)
        with torch.cuda.device(gpu_idx):
            model = model.cuda()

        model = torch.nn.DataParallel(model, device_ids=[gpu_idx], output_device=gpu_idx)
        model.load_state_dict(weights['state_dict'])

        model.share_memory()
        models.append(model)
    return models

def get_loader(path, resize_size):
    mean, std = get_params(path, resize_size)
    normalize = transforms.Normalize(mean=[mean[0].item()],
                         std=[std[0].item()])

    transform = transforms.Compose([
        transforms.Resize((resize_size, resize_size)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        normalize
    ])
    dataset = datasets.ImageFolder(args.path, transform)
    kwargs = {'num_workers': 4, 'pin_memory': True}

    loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False, **kwargs)
    
    return loader

def get_data(processnum ,model, loader, return_dict):
    with torch.no_grad():
        top1 = AverageMeter()
        model.eval()
        gpu_idx = processnum % gpu
        for i, data in enumerate(loader):
            (input, target) = data

            target = target.cuda(gpu_idx)
            input = input.cuda(gpu_idx)
            
            output = model(input)

            prec1 = accuracy(output, target, topk=(1,))[0]

            top1.update(prec1.item(), input.size(0))

        return_dict[processnum] = top1.avg



if __name__ == '__main__':
    mp.set_start_method('spawn')
    parser = argparse.ArgumentParser()
    parser.add_argument("--path", required=True, help="path")
    args = parser.parse_args()

    manager = mp.Manager()
    return_dict = manager.dict()
    

    # get one loader
    loader = get_loader(args.path, resize_size)

    # multi model with other checkpoint.
    models = get_models()

    #loader is not array so can arise error
    processes = []
    for i, model in enumerate(models):
        p = mp.Process(target=get_data, args=(i, model, loader, return_dict))
        p.start()
        processes.append(p)

    for p in processes: p.join()

    for idx, seed in enumerate(seeds):
        print(f"process {idx}, seed {seed} : {return_dict[idx]}")
    
    print(f"total variance : {np.var(return_dict.values())}")
    #print(return_dict.values())