getAugmented_saveimg.py 1.55 KB
import os
import fire
import json
from pprint import pprint
import pickle
import random

import torch
import torch.nn as nn
from torchvision.utils import save_image
from torch.utils.tensorboard import SummaryWriter

from utils import *

# command
# python getAugmented_saveimg.py --model_path='logs/April_26_00:55:16__resnet50__None/'

def eval(model_path):
    print('\n[+] Parse arguments')
    kwargs_path = os.path.join(model_path, 'kwargs.json')
    kwargs = json.loads(open(kwargs_path).read())
    args, kwargs = parse_args(kwargs)
    pprint(args)
    device = torch.device('cuda' if args.use_cuda else 'cpu')

    cp_path = os.path.join(model_path, 'augmentation.cp')
    
    writer = SummaryWriter(log_dir=model_path)


    print('\n[+] Load transform')
    # list to tensor
    with open(cp_path, 'rb') as f:
        aug_transform_list = pickle.load(f)

    transform = transforms.RandomChoice(aug_transform_list)


    print('\n[+] Load dataset')

    dataset = get_dataset(args, transform, 'train')
    loader = iter(get_aug_dataloader(args, dataset))


    print('\n[+] Save 1 random policy')   
    os.makedirs(os.path.join(model_path, 'augmented_imgs'))
    save_dir = os.path.join(model_path, 'augmented_imgs')

    for i, (image, target) in enumerate(loader):
        image = image.view(240, 240)
        # save img
        save_image(image, os.path.join(save_dir, 'aug_'+ str(i) + '.png'))

        if(i % 100 == 0):
            print("\n saved images: ", i)
                
    print('\n[+] Finished to save')

if __name__ == '__main__':
    fire.Fire(eval)