make_noisy.py 7.15 KB
import torch
import torch.nn as nn
from model import mobilenetv3
import argparse
import torchvision
from torchvision.transforms import transforms
import torchvision.datasets as datasets
from augmentations import RandAugment
from get_mean_std import get_params
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np
import os
import cv2
from utils import MyImageFolder

class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i %len(d)] for d in self.datasets)

    def __len__(self):
        return max(len(d) for d in self.datasets)


def make_dir():
    if not os.path.exists('../data/Fourth_data/teacher_data/Double'):
        os.mkdir('../data/Fourth_data/teacher_data/Double')

    if not os.path.exists('../data/Fourth_data/teacher_data/Flip'):
        os.mkdir('../data/Fourth_data/teacher_data/Flip')

    if not os.path.exists('../data/Fourth_data/teacher_data/Scratch'):
        os.mkdir('../data/Fourth_data/teacher_data/Scratch')

    if not os.path.exists('../data/Fourth_data/teacher_data/Leave'):
        os.mkdir('../data/Fourth_data/teacher_data/Leave')

    if not os.path.exists('../data/Fourth_data/teacher_data/Normal'):
        os.mkdir('../data/Fourth_data/teacher_data/Normal')

    if not os.path.exists('../data/Fourth_data/teacher_data/Empty'):
        os.mkdir('../data/Fourth_data/teacher_data/Empty')


parser = argparse.ArgumentParser(description='Process make noisy student model')
parser.add_argument('--checkpoint_path', type=str, help='checkpoint path')
parser.add_argument('--size', type=int, help='resize integer of input')
parser.add_argument('--batch_size', type=int, default=256,help='set batch size')
parser.add_argument('--teacher_checkpoint_path', type=str, help='teacher first checkpoint path')
parser.add_argument('--Labeled_dataset_path', default='../data/Fourth_data/noisy_data/Labeled', type=str, help='path of dataset')
parser.add_argument('--Unlabeled_dataset_path', default='../data/Fourth_data/noisy_data/Unlabeled', type=str, help='path of unlabeled dataset')
parser.add_argument('--num_workers', default=8, type=int, help="number of gpu worker")
parser.add_argument('--epochs', default=350, type=int, help='epoch')
parser.add_argument('--finetune_epochs', default=2, type=int, help='finetuning epochs')
parser.add_argument('--data_save_path', default='../data/Fourth_data/teacher_data', type=str, help='teacher save unlabeled data in this path')
args = parser.parse_args()

print(args)

# by paper of https://arxiv.org/pdf/1911.04252.pdf
Aug_number = 2
Aug_magnitude = 27

#my customize network
blocks = [4,5,6,7,8]

# data loader parameters
kwargs = {'num_workers': args.num_workers, 'pin_memory': True}

Labeled_mean, Labeled_std = get_params(args.Labeled_dataset_path, args.size)
Unlabeled_mean, Unlabeled_std = get_params(args.Unlabeled_dataset_path, args.size)

transform_labeled = transforms.Compose([
    transforms.Resize((args.size, args.size)),
    transforms.RandomCrop(args.size, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(Labeled_mean[0].item(), Labeled_std[0].item())
])

#이건 Teacher가 raw data를 받아서 판단하는거기 때문에 따로 Augmentation할 필요 x
transform_unlabeled = transforms.Compose([
    transforms.Resize((args.size, args.size)),
    transforms.RandomCrop(args.size, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(Unlabeled_mean[0].item(), Unlabeled_std[0].item())
])

# Add RandAugment with N, M(hyperparameter)
transform_labeled.transforms.insert(0, RandAugment(Aug_number, Aug_magnitude))

# set dataset
Labeled_dataset = datasets.ImageFolder(args.Labeled_dataset_path, transform_labeled)
Unlabeled_dataset = MyImageFolder(args.Unlabeled_dataset_path, transform_unlabeled)

labeled_data_loader = torch.utils.data.DataLoader(
  Labeled_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

unlabeled_data_loader = torch.utils.data.DataLoader(
  Unlabeled_dataset, batch_size=args.batch_size, shuffle=True, **kwargs)

# noisy teacher은 student보다 더 작게 설정하며, dropout을 0으로 설정.
noisy_teacher_model = mobilenetv3(n_class=2, dropout=0.0, blocknum=4)
checkpoint = torch.load(args.teacher_checkpoint_path)
noisy_teacher_model.load_state_dict(checkpoint['state_dict'])

# make loss function
criterion = nn.CrossEntropyLoss()

# make class directory
make_dir()

classes = os.listdir(args.data_save_path)
classes.sort()

for block in blocks:
  #noisy student는 더 크게 설정하고 dropout은 논문에 나와있는대로 0.5로 설정.
  noisy_student_model = mobilenetv3(n_class=2, dropout=0.5, blocknum=block, stochastic=True)

  noisy_student_model.cuda()
  noisy_teacher_model.cuda()
  criterion.cuda()

  # make optimizer same as official code lr = 0.128 and decays by 0.97 every 2.4epochs
  optimizer = torch.optim.RMSprop(noisy_student_model.parameters(), lr=0.128, weight_decay=0.9, momentum=0.9)

  # exp scheduler like tf offical code
  scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer,0.7)

  for epoch in range(args.epochs):
    # unlabeled data를 labeling하는 과정.
    for idx, data in enumerate(unlabeled_data_loader):
      (unlabeled_input, _), (path, _) = data

      unlabeled_input = unlabeled_input.cuda()

      output=noisy_teacher_model(unlabeled_input)

      prob = F.softmax(output, dim=1)

      for idx, p in enumerate(prob):
        indices = torch.topk(p,1).indices.tolist()
        
        img = cv2.imread(path[idx])

        cv2.imwrite(f"{args.data_save_path}/{classes[indices[0]]}/{path[idx].split('/')[-1]}", img)

    # teacher 모델이 구성한 data에 대해서 다시 loader 구성.
    transform_teacher_data  = transforms.Compose([
      transforms.Resize((args.size, args.size)),
      transforms.RandomCrop(args.size, padding=4),
      transforms.RandomHorizontalFlip(),
      transforms.ToTensor(),
      transforms.Normalize(Unlabeled_mean[0].item(), Unlabeled_std[0].item())
    ])
    transform_teacher_data.transforms.insert(0, RandAugment(Aug_number, Aug_magnitude))

    teacher_data = datasets.ImageFolder(args.data_save_path, transform_teacher_data)

    teacher_data_loader = torch.utils.data.DataLoader(
      teacher_data, batch_size=args.batch_size, shuffle=True, **kwargs)

    merged_dataset = ConcatDataset(teacher_data_loader, labeled_data_loader)      #앞은 teacher가 예측한거 뒤는 실제 데이터

    merged_data_loader = torch.utils.data.DataLoader(
      merged_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, pin_memory=True)

    #일단 코드상으로는 unlabeled된 data에 대해서 hard하게 구성. todo: soft labeling.
    for i, (input, target) in enumerate(merged_data_loader):
      input = input.cuda()
      target = target.cuda()

      output = noisy_student_model(input)

      loss = criterion(target, output)

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
      
      #논문에서는 2.4epoch마다라고 하였지만 현재는 2에폭마다로 설정.
      if epoch % 2 == 0:
        scheduler.step()

    # iterative learning.
    noisy_teacher_model = noisy_student_model