Showing
18 changed files
with
1420 additions
and
0 deletions
source code/adda_mixup/.gitignore
0 → 100644
| 1 | +# # Created by .ignore support plugin (hsz.mobi) | ||
| 2 | +# ### Example user template template | ||
| 3 | +# ### Example user template | ||
| 4 | + | ||
| 5 | +# # IntelliJ project files | ||
| 6 | +# .idea | ||
| 7 | +# *.iml | ||
| 8 | +# out | ||
| 9 | +# gen | ||
| 10 | +# ### Python template | ||
| 11 | +# # Byte-compiled / optimized / DLL files | ||
| 12 | +__pycache__/ | ||
| 13 | +# *.py[cod] | ||
| 14 | +# *$py.class | ||
| 15 | + | ||
| 16 | +# # C extensions | ||
| 17 | +# *.so | ||
| 18 | + | ||
| 19 | +# # Distribution / packaging | ||
| 20 | +# .Python | ||
| 21 | +# build/ | ||
| 22 | +# develop-eggs/ | ||
| 23 | +# dist/ | ||
| 24 | +# downloads/ | ||
| 25 | +# eggs/ | ||
| 26 | +# .eggs/ | ||
| 27 | +# lib/ | ||
| 28 | +# lib64/ | ||
| 29 | +# parts/ | ||
| 30 | +# sdist/ | ||
| 31 | +# var/ | ||
| 32 | +# wheels/ | ||
| 33 | +# pip-wheel-metadata/ | ||
| 34 | +# share/python-wheels/ | ||
| 35 | +# *.egg-info/ | ||
| 36 | +# .installed.cfg | ||
| 37 | +# *.egg | ||
| 38 | +# MANIFEST | ||
| 39 | + | ||
| 40 | +# # PyInstaller | ||
| 41 | +# # Usually these files are written by a python script from a template | ||
| 42 | +# # before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
| 43 | +# *.manifest | ||
| 44 | +# *.spec | ||
| 45 | + | ||
| 46 | +# # Installer logs | ||
| 47 | +# pip-log.txt | ||
| 48 | +# pip-delete-this-directory.txt | ||
| 49 | + | ||
| 50 | +# # Unit test / coverage reports | ||
| 51 | +# htmlcov/ | ||
| 52 | +# .tox/ | ||
| 53 | +# .nox/ | ||
| 54 | +# .coverage | ||
| 55 | +# .coverage.* | ||
| 56 | +# .cache | ||
| 57 | +# nosetests.xml | ||
| 58 | +# coverage.xml | ||
| 59 | +# *.cover | ||
| 60 | +# .hypothesis/ | ||
| 61 | +# .pytest_cache/ | ||
| 62 | + | ||
| 63 | +# # # Translations | ||
| 64 | +# *.mo | ||
| 65 | +# *.pot | ||
| 66 | + | ||
| 67 | +# # Django stuff: | ||
| 68 | +# *.log | ||
| 69 | +# local_settings.py | ||
| 70 | +# db.sqlite3 | ||
| 71 | + | ||
| 72 | +# # Flask stuff: | ||
| 73 | +# instance/ | ||
| 74 | +# .webassets-cache | ||
| 75 | + | ||
| 76 | +# # Scrapy stuff: | ||
| 77 | +# .scrapy | ||
| 78 | + | ||
| 79 | +# # Sphinx documentation | ||
| 80 | +# docs/_build/ | ||
| 81 | + | ||
| 82 | +# # PyBuilder | ||
| 83 | +# target/ | ||
| 84 | + | ||
| 85 | +# # Jupyter Notebook | ||
| 86 | +# .ipynb_checkpoints | ||
| 87 | + | ||
| 88 | +# # IPython | ||
| 89 | +# profile_default/ | ||
| 90 | +# ipython_config.py | ||
| 91 | + | ||
| 92 | +# # pyenv | ||
| 93 | +# .python-version | ||
| 94 | + | ||
| 95 | +# # pipenv | ||
| 96 | +# # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. | ||
| 97 | +# # However, in case of collaboration, if having platform-specific dependencies or dependencies | ||
| 98 | +# # having no cross-platform support, pipenv may install dependencies that don't work, or not | ||
| 99 | +# # install all needed dependencies. | ||
| 100 | +# #Pipfile.lock | ||
| 101 | + | ||
| 102 | +# # celery beat schedule file | ||
| 103 | +# celerybeat-schedule | ||
| 104 | + | ||
| 105 | +# # SageMath parsed files | ||
| 106 | +# *.sage.py | ||
| 107 | + | ||
| 108 | +# # Environments | ||
| 109 | +# .env | ||
| 110 | +# .venv | ||
| 111 | +# env/ | ||
| 112 | +# venv/ | ||
| 113 | +# ENV/ | ||
| 114 | +# env.bak/ | ||
| 115 | +# venv.bak/ | ||
| 116 | + | ||
| 117 | +# # Spyder project settings | ||
| 118 | +# .spyderproject | ||
| 119 | +# .spyproject | ||
| 120 | + | ||
| 121 | +# # Rope project settings | ||
| 122 | +# .ropeproject | ||
| 123 | + | ||
| 124 | +# # mkdocs documentation | ||
| 125 | +# /site | ||
| 126 | + | ||
| 127 | +# # mypy | ||
| 128 | +# .mypy_cache/ | ||
| 129 | +# .dmypy.json | ||
| 130 | +# dmypy.json | ||
| 131 | + | ||
| 132 | +# # Pyre type checker | ||
| 133 | +# .pyre/ | ||
| 134 | + | ||
| 135 | +# ### JetBrains template | ||
| 136 | +# # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm | ||
| 137 | +# # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 | ||
| 138 | + | ||
| 139 | +# # User-specific stuff | ||
| 140 | +# .idea/**/workspace.xml | ||
| 141 | +# .idea/**/tasks.xml | ||
| 142 | +# .idea/**/usage.statistics.xml | ||
| 143 | +# .idea/**/dictionaries | ||
| 144 | +# .idea/**/shelf | ||
| 145 | + | ||
| 146 | +# # Generated files | ||
| 147 | +# .idea/**/contentModel.xml | ||
| 148 | + | ||
| 149 | +# # Sensitive or high-churn files | ||
| 150 | +# .idea/**/dataSources/ | ||
| 151 | +# .idea/**/dataSources.ids | ||
| 152 | +# .idea/**/dataSources.local.xml | ||
| 153 | +# .idea/**/sqlDataSources.xml | ||
| 154 | +# .idea/**/dynamic.xml | ||
| 155 | +# .idea/**/uiDesigner.xml | ||
| 156 | +# .idea/**/dbnavigator.xml | ||
| 157 | + | ||
| 158 | +# # Gradle | ||
| 159 | +# .idea/**/gradle.xml | ||
| 160 | +# .idea/**/libraries | ||
| 161 | + | ||
| 162 | +# # Gradle and Maven with auto-import | ||
| 163 | +# # When using Gradle or Maven with auto-import, you should exclude module files, | ||
| 164 | +# # since they will be recreated, and may cause churn. Uncomment if using | ||
| 165 | +# # auto-import. | ||
| 166 | +# # .idea/modules.xml | ||
| 167 | +# # .idea/*.iml | ||
| 168 | +# # .idea/modules | ||
| 169 | +# # *.iml | ||
| 170 | +# # *.ipr | ||
| 171 | + | ||
| 172 | +# # CMake | ||
| 173 | +# cmake-build-*/ | ||
| 174 | + | ||
| 175 | +# # Mongo Explorer plugin | ||
| 176 | +# .idea/**/mongoSettings.xml | ||
| 177 | + | ||
| 178 | +# # File-based project format | ||
| 179 | +# *.iws | ||
| 180 | + | ||
| 181 | +# # IntelliJ | ||
| 182 | +# out/ | ||
| 183 | + | ||
| 184 | +# # mpeltonen/sbt-idea plugin | ||
| 185 | +# .idea_modules/ | ||
| 186 | + | ||
| 187 | +# # JIRA plugin | ||
| 188 | +# atlassian-ide-plugin.xml | ||
| 189 | + | ||
| 190 | +# # Cursive Clojure plugin | ||
| 191 | +# .idea/replstate.xml | ||
| 192 | + | ||
| 193 | +# # Crashlytics plugin (for Android Studio and IntelliJ) | ||
| 194 | +# com_crashlytics_export_strings.xml | ||
| 195 | +# crashlytics.properties | ||
| 196 | +# crashlytics-build.properties | ||
| 197 | +# fabric.properties | ||
| 198 | + | ||
| 199 | +# # Editor-based Rest Client | ||
| 200 | +# .idea/httpRequests | ||
| 201 | + | ||
| 202 | +# # Android studio 3.1+ serialized cache file | ||
| 203 | +# .idea/caches/build_file_checksums.ser | ||
| 204 | + | ||
| 205 | +./data/ | ||
| 206 | +data/ | ||
| 207 | +# .idea/ | ||
| 208 | +generated/ |
source code/adda_mixup/README.md
0 → 100644
| 1 | +# PyTorch-ADDA-mixup | ||
| 2 | +A PyTorch implementation added MIXUO for Adversarial Discriminative Domain Adaptation. | ||
| 3 | + | ||
| 4 | +Confirmed improved performance by mixing up target domian and source domain | ||
| 5 | + | ||
| 6 | +# Usage | ||
| 7 | +It works on MNIST -> USPS , SVHN -> MNIST , USPS -> MNIST, MNIST -> MNIST-M | ||
| 8 | +Only 10,000 of the total data were used.(usps excluded) | ||
| 9 | + | ||
| 10 | +<pre> | ||
| 11 | +<code> | ||
| 12 | +python main.py | ||
| 13 | +</code> | ||
| 14 | +</pre> | ||
| 15 | + | ||
| 16 | +## adda | ||
| 17 | +This repo is based on https://github.com/corenel/pytorch-adda , https://github.com/Fujiki-Nakamura/ADDA.PyTorch | ||
| 18 | + | ||
| 19 | + | ||
| 20 | + | ||
| 21 | +Reference | ||
| 22 | +https://arxiv.org/abs/1702.05464 | ||
| 23 | + | ||
| 24 | + | ||
| 25 | +# |
source code/adda_mixup/core/__init__.py
0 → 100644
source code/adda_mixup/core/adapt.py
0 → 100644
| 1 | +import os | ||
| 2 | + | ||
| 3 | +import torch | ||
| 4 | +import torch.optim as optim | ||
| 5 | +from torch import nn | ||
| 6 | +from core import test | ||
| 7 | +import params | ||
| 8 | +from utils import make_cuda, mixup_data | ||
| 9 | + | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +def train_tgt(source_cnn, target_cnn, critic, | ||
| 13 | + src_data_loader, tgt_data_loader): | ||
| 14 | + """Train encoder for target domain.""" | ||
| 15 | + #################### | ||
| 16 | + # 1. setup network # | ||
| 17 | + #################### | ||
| 18 | + | ||
| 19 | + source_cnn.eval() | ||
| 20 | + target_cnn.encoder.train() | ||
| 21 | + critic.train() | ||
| 22 | + isbest = 0 | ||
| 23 | + # setup criterion and optimizer | ||
| 24 | + criterion = nn.CrossEntropyLoss() | ||
| 25 | + #target encoder | ||
| 26 | + optimizer_tgt = optim.Adam(target_cnn.parameters(), | ||
| 27 | + lr=params.adp_c_learning_rate, | ||
| 28 | + betas=(params.beta1, params.beta2), | ||
| 29 | + weight_decay=params.weight_decay | ||
| 30 | + ) | ||
| 31 | + #Discriminator | ||
| 32 | + optimizer_critic = optim.Adam(critic.parameters(), | ||
| 33 | + lr=params.d_learning_rate, | ||
| 34 | + betas=(params.beta1, params.beta2), | ||
| 35 | + weight_decay=params.weight_decay | ||
| 36 | + | ||
| 37 | + | ||
| 38 | + ) | ||
| 39 | + | ||
| 40 | + #################### | ||
| 41 | + # 2. train network # | ||
| 42 | + #################### | ||
| 43 | + len_data_loader = min(len(src_data_loader), len(tgt_data_loader)) | ||
| 44 | + | ||
| 45 | + for epoch in range(params.num_epochs): | ||
| 46 | + # zip source and target data pair | ||
| 47 | + data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) | ||
| 48 | + for step, ((images_src, _), (images_tgt, _)) in data_zip: | ||
| 49 | + | ||
| 50 | + # make images variable | ||
| 51 | + images_src = make_cuda(images_src) | ||
| 52 | + images_tgt = make_cuda(images_tgt) | ||
| 53 | + | ||
| 54 | + | ||
| 55 | + | ||
| 56 | + | ||
| 57 | + ########################### | ||
| 58 | + # 2.1 train discriminator # | ||
| 59 | + ########################### | ||
| 60 | + | ||
| 61 | + # zero gradients for optimizer | ||
| 62 | + optimizer_critic.zero_grad() | ||
| 63 | + | ||
| 64 | + # extract and concat features | ||
| 65 | + feat_src = source_cnn.encoder(images_src) | ||
| 66 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
| 67 | + feat_concat = torch.cat((feat_src, feat_tgt), 0) | ||
| 68 | + | ||
| 69 | + # predict on discriminator | ||
| 70 | + pred_concat = critic(feat_concat.detach()) | ||
| 71 | + | ||
| 72 | + # prepare real and fake label | ||
| 73 | + label_src = make_cuda(torch.zeros(feat_src.size(0)).long()) | ||
| 74 | + label_tgt = make_cuda(torch.ones(feat_tgt.size(0)).long()) | ||
| 75 | + label_concat = torch.cat((label_src, label_tgt), 0) | ||
| 76 | + | ||
| 77 | + # compute loss for critic | ||
| 78 | + loss_critic = criterion(pred_concat, label_concat) | ||
| 79 | + loss_critic.backward() | ||
| 80 | + | ||
| 81 | + # optimize critic | ||
| 82 | + optimizer_critic.step() | ||
| 83 | + | ||
| 84 | + pred_cls = torch.squeeze(pred_concat.max(1)[1]) | ||
| 85 | + acc = (pred_cls == label_concat).float().mean() | ||
| 86 | + | ||
| 87 | + | ||
| 88 | + ############################ | ||
| 89 | + # 2.2 train target encoder # | ||
| 90 | + ############################ | ||
| 91 | + | ||
| 92 | + # zero gradients for optimizer | ||
| 93 | + optimizer_critic.zero_grad() | ||
| 94 | + optimizer_tgt.zero_grad() | ||
| 95 | + | ||
| 96 | + # extract and target features | ||
| 97 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
| 98 | + | ||
| 99 | + # predict on discriminator | ||
| 100 | + pred_tgt = critic(feat_tgt) | ||
| 101 | + | ||
| 102 | + # prepare fake labels | ||
| 103 | + label_tgt = make_cuda(torch.zeros(feat_tgt.size(0)).long()) | ||
| 104 | + | ||
| 105 | + # compute loss for target encoder | ||
| 106 | + loss_tgt = criterion(pred_tgt, label_tgt) | ||
| 107 | + loss_tgt.backward() | ||
| 108 | + | ||
| 109 | + # optimize target encoder | ||
| 110 | + optimizer_tgt.step() | ||
| 111 | + ####################### | ||
| 112 | + # 2.3 print step info # | ||
| 113 | + ####################### | ||
| 114 | + if ((epoch % 10 ==0 )&((step + 1) % len_data_loader== 0)): | ||
| 115 | + print("Epoch [{}/{}] Step [{}/{}]:" | ||
| 116 | + "d_loss={:.5f} g_loss={:.5f} acc={:.5f}" | ||
| 117 | + .format(epoch, | ||
| 118 | + params.num_epochs, | ||
| 119 | + step + 1, | ||
| 120 | + len_data_loader, | ||
| 121 | + loss_critic.item(), | ||
| 122 | + loss_tgt.item(), | ||
| 123 | + acc.item())) | ||
| 124 | + | ||
| 125 | + | ||
| 126 | + torch.save(critic.state_dict(), os.path.join( | ||
| 127 | + params.model_root, | ||
| 128 | + "ADDA-critic-final.pt")) | ||
| 129 | + torch.save(target_cnn.state_dict(), os.path.join( | ||
| 130 | + params.model_root, | ||
| 131 | + "ADDA-target_cnn-final.pt")) | ||
| 132 | + return target_cnn | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/core/mixup.py
0 → 100644
| 1 | +import torch.nn as nn | ||
| 2 | +import torch | ||
| 3 | +import torch.optim as optim | ||
| 4 | + | ||
| 5 | +import params | ||
| 6 | +from utils import make_cuda, save_model, LabelSmoothingCrossEntropy,mixup_data,mixup_criterion | ||
| 7 | +from random import * | ||
| 8 | +import sys | ||
| 9 | + | ||
| 10 | +from torch.utils.data import Dataset,DataLoader | ||
| 11 | +import os | ||
| 12 | +from core import test | ||
| 13 | +from utils import make_cuda, mixup_data | ||
| 14 | + | ||
| 15 | + | ||
| 16 | + | ||
| 17 | +class CustomDataset(Dataset): | ||
| 18 | + def __init__(self,img,label): | ||
| 19 | + self.x_data = img | ||
| 20 | + self.y_data = label | ||
| 21 | + def __len__(self): | ||
| 22 | + return len(self.x_data) | ||
| 23 | + | ||
| 24 | + def __getitem__(self, idx): | ||
| 25 | + x = self.x_data[idx] | ||
| 26 | + y = self.y_data[idx] | ||
| 27 | + return x, y | ||
| 28 | + | ||
| 29 | + | ||
| 30 | +def train_src(model, source_data_loader,target_data_loader,valid_loader): | ||
| 31 | + """Train classifier for source domain.""" | ||
| 32 | + #################### | ||
| 33 | + # 1. setup network # | ||
| 34 | + #################### | ||
| 35 | + | ||
| 36 | + model.train() | ||
| 37 | + | ||
| 38 | + | ||
| 39 | + target_data_loader = list(target_data_loader) | ||
| 40 | + | ||
| 41 | + # setup criterion and optimizer | ||
| 42 | + optimizer = optim.Adam( | ||
| 43 | + model.parameters(), | ||
| 44 | + lr=params.pre_c_learning_rate, | ||
| 45 | + betas=(params.beta1, params.beta2), | ||
| 46 | + weight_decay=params.weight_decay | ||
| 47 | + ) | ||
| 48 | + | ||
| 49 | + | ||
| 50 | + | ||
| 51 | + if params.labelsmoothing: | ||
| 52 | + criterion = LabelSmoothingCrossEntropy(smoothing= params.smoothing) | ||
| 53 | + else: | ||
| 54 | + criterion = nn.CrossEntropyLoss() | ||
| 55 | + | ||
| 56 | + | ||
| 57 | + #################### | ||
| 58 | + # 2. train network # | ||
| 59 | + #################### | ||
| 60 | + len_data_loader = min(len(source_data_loader), len(target_data_loader)) | ||
| 61 | + | ||
| 62 | + for epoch in range(params.num_epochs_pre+1): | ||
| 63 | + data_zip = enumerate(zip(source_data_loader, target_data_loader)) | ||
| 64 | + for step, ((images, labels), (images_tgt, _)) in data_zip: | ||
| 65 | + # make images and labels variable | ||
| 66 | + images = make_cuda(images) | ||
| 67 | + labels = make_cuda(labels.squeeze_()) | ||
| 68 | + # zero gradients for optimizer | ||
| 69 | + optimizer.zero_grad() | ||
| 70 | + target=target_data_loader[randint(0, len(target_data_loader)-1)] | ||
| 71 | + images, lam = mixup_data(images,target[0]) | ||
| 72 | + | ||
| 73 | + # compute loss for critic | ||
| 74 | + preds = model(images) | ||
| 75 | + # loss = mixup_criterion(criterion, preds, labels, labels_tgt, lam) | ||
| 76 | + loss = criterion(preds, labels) | ||
| 77 | + | ||
| 78 | + # optimize source classifier | ||
| 79 | + loss.backward() | ||
| 80 | + optimizer.step() | ||
| 81 | + | ||
| 82 | + | ||
| 83 | + | ||
| 84 | + # # eval model on test set | ||
| 85 | + if ((epoch) % params.eval_step_pre == 0): | ||
| 86 | + print(f"Epoch [{epoch}/{params.num_epochs_pre}]",end='') | ||
| 87 | + if valid_loader is not None: | ||
| 88 | + test.eval_tgt(model, valid_loader) | ||
| 89 | + else: | ||
| 90 | + test.eval_tgt(model, source_data_loader) | ||
| 91 | + | ||
| 92 | + # save model parameters | ||
| 93 | + if ((epoch + 1) % params.save_step_pre == 0): | ||
| 94 | + save_model(model, "our-source_cnn-{}.pt".format(epoch + 1)) | ||
| 95 | + | ||
| 96 | + # # save final model | ||
| 97 | + save_model(model, "our-source_cnn-final.pt") | ||
| 98 | + | ||
| 99 | + return model | ||
| 100 | + | ||
| 101 | + | ||
| 102 | + | ||
| 103 | + | ||
| 104 | +def train_tgt(source_cnn, target_cnn, critic, | ||
| 105 | + src_data_loader, tgt_data_loader,valid_loader): | ||
| 106 | + """Train encoder for target domain.""" | ||
| 107 | + #################### | ||
| 108 | + # 1. setup network # | ||
| 109 | + #################### | ||
| 110 | + | ||
| 111 | + source_cnn.eval() | ||
| 112 | + target_cnn.encoder.train() | ||
| 113 | + critic.train() | ||
| 114 | + isbest = 0 | ||
| 115 | + # setup criterion and optimizer | ||
| 116 | + criterion = nn.CrossEntropyLoss() | ||
| 117 | + #target encoder | ||
| 118 | + optimizer_tgt = optim.Adam(target_cnn.parameters(), | ||
| 119 | + lr=params.adp_c_learning_rate, | ||
| 120 | + betas=(params.beta1, params.beta2), | ||
| 121 | + weight_decay=params.weight_decay | ||
| 122 | + ) | ||
| 123 | + #Discriminator | ||
| 124 | + optimizer_critic = optim.Adam(critic.parameters(), | ||
| 125 | + lr=params.d_learning_rate, | ||
| 126 | + betas=(params.beta1, params.beta2), | ||
| 127 | + weight_decay=params.weight_decay | ||
| 128 | + | ||
| 129 | + | ||
| 130 | + ) | ||
| 131 | + | ||
| 132 | + #################### | ||
| 133 | + # 2. train network # | ||
| 134 | + #################### | ||
| 135 | + data_len = min(len(src_data_loader), len(tgt_data_loader)) | ||
| 136 | + | ||
| 137 | + for epoch in range(params.num_epochs+1): | ||
| 138 | + # zip source and target data pair | ||
| 139 | + data_zip = enumerate(zip(src_data_loader, tgt_data_loader)) | ||
| 140 | + for step, ((images_src, _), (images_tgt, _)) in data_zip: | ||
| 141 | + | ||
| 142 | + # make images variable | ||
| 143 | + images_src = make_cuda(images_src) | ||
| 144 | + images_tgt = make_cuda(images_tgt) | ||
| 145 | + | ||
| 146 | + ########################### | ||
| 147 | + # 2.1 train discriminator # | ||
| 148 | + ########################### | ||
| 149 | + | ||
| 150 | + # mixup data | ||
| 151 | + images_src, _ = mixup_data(images_src,images_tgt) | ||
| 152 | + | ||
| 153 | + # zero gradients for optimizer | ||
| 154 | + optimizer_critic.zero_grad() | ||
| 155 | + | ||
| 156 | + # extract and concat features | ||
| 157 | + feat_src = source_cnn.encoder(images_src) | ||
| 158 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
| 159 | + feat_concat = torch.cat((feat_src, feat_tgt), 0) | ||
| 160 | + | ||
| 161 | + # predict on discriminator | ||
| 162 | + pred_concat = critic(feat_concat.detach()) | ||
| 163 | + | ||
| 164 | + # prepare real and fake label | ||
| 165 | + label_src = make_cuda(torch.zeros(feat_src.size(0)).long()) | ||
| 166 | + label_tgt = make_cuda(torch.ones(feat_tgt.size(0)).long()) | ||
| 167 | + label_concat = torch.cat((label_src, label_tgt), 0) | ||
| 168 | + | ||
| 169 | + # compute loss for critic | ||
| 170 | + loss_critic = criterion(pred_concat, label_concat) | ||
| 171 | + loss_critic.backward() | ||
| 172 | + | ||
| 173 | + # optimize critic | ||
| 174 | + optimizer_critic.step() | ||
| 175 | + | ||
| 176 | + pred_cls = torch.squeeze(pred_concat.max(1)[1]) | ||
| 177 | + acc = (pred_cls == label_concat).float().mean() | ||
| 178 | + | ||
| 179 | + | ||
| 180 | + ############################ | ||
| 181 | + # 2.2 train target encoder # | ||
| 182 | + ############################ | ||
| 183 | + | ||
| 184 | + # zero gradients for optimizer | ||
| 185 | + optimizer_critic.zero_grad() | ||
| 186 | + optimizer_tgt.zero_grad() | ||
| 187 | + | ||
| 188 | + # extract and target features | ||
| 189 | + feat_tgt = target_cnn.encoder(images_tgt) | ||
| 190 | + | ||
| 191 | + # predict on discriminator | ||
| 192 | + pred_tgt = critic(feat_tgt) | ||
| 193 | + | ||
| 194 | + # prepare fake labels | ||
| 195 | + label_tgt = make_cuda(torch.zeros(feat_tgt.size(0)).long()) | ||
| 196 | + | ||
| 197 | + # compute loss for target encoder | ||
| 198 | + loss_tgt = criterion(pred_tgt, label_tgt) | ||
| 199 | + loss_tgt.backward() | ||
| 200 | + | ||
| 201 | + # optimize target encoder | ||
| 202 | + optimizer_tgt.step() | ||
| 203 | + ####################### | ||
| 204 | + # 2.3 print step info # | ||
| 205 | + ####################### | ||
| 206 | + if ((epoch % 10 ==0 )& ((step + 1) % data_len == 0)): | ||
| 207 | + print("Epoch [{}/{}] Step [{}/{}]:" | ||
| 208 | + "d_loss={:.5f} g_loss={:.5f} acc={:.5f}" | ||
| 209 | + .format(epoch , | ||
| 210 | + params.num_epochs, | ||
| 211 | + step + 1, | ||
| 212 | + data_len, | ||
| 213 | + loss_critic.item(), | ||
| 214 | + loss_tgt.item(), | ||
| 215 | + acc.item())) | ||
| 216 | + if valid_loader is not None: | ||
| 217 | + test.eval_tgt(target_cnn,valid_loader) | ||
| 218 | + | ||
| 219 | + torch.save(critic.state_dict(), os.path.join( | ||
| 220 | + params.model_root, | ||
| 221 | + "our-critic-final.pt")) | ||
| 222 | + torch.save(target_cnn.state_dict(), os.path.join( | ||
| 223 | + params.model_root, | ||
| 224 | + "our-target_cnn-final.pt")) | ||
| 225 | + return target_cnn | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/core/pretrain.py
0 → 100644
| 1 | +import torch.nn as nn | ||
| 2 | +import torch.optim as optim | ||
| 3 | +import torch | ||
| 4 | +import params | ||
| 5 | +from utils import make_cuda, save_model, LabelSmoothingCrossEntropy,mixup_data | ||
| 6 | +from random import * | ||
| 7 | +import sys | ||
| 8 | + | ||
| 9 | +def train_src(model, source_data_loader): | ||
| 10 | + """Train classifier for source domain.""" | ||
| 11 | + #################### | ||
| 12 | + # 1. setup network # | ||
| 13 | + #################### | ||
| 14 | + | ||
| 15 | + model.train() | ||
| 16 | + | ||
| 17 | + | ||
| 18 | + | ||
| 19 | + # setup criterion and optimizer | ||
| 20 | + optimizer = optim.Adam( | ||
| 21 | + model.parameters(), | ||
| 22 | + lr=params.pre_c_learning_rate, | ||
| 23 | + betas=(params.beta1, params.beta2), | ||
| 24 | + weight_decay=params.weight_decay | ||
| 25 | + ) | ||
| 26 | + | ||
| 27 | + | ||
| 28 | + if params.labelsmoothing: | ||
| 29 | + criterion = LabelSmoothingCrossEntropy(smoothing= params.smoothing) | ||
| 30 | + else: | ||
| 31 | + criterion = nn.CrossEntropyLoss() | ||
| 32 | + | ||
| 33 | + | ||
| 34 | + #################### | ||
| 35 | + # 2. train network # | ||
| 36 | + #################### | ||
| 37 | + | ||
| 38 | + for epoch in range(params.num_epochs_pre): | ||
| 39 | + for step, (images, labels) in enumerate(source_data_loader): | ||
| 40 | + # make images and labels variable | ||
| 41 | + images = make_cuda(images) | ||
| 42 | + labels = make_cuda(labels.squeeze_()) | ||
| 43 | + # zero gradients for optimizer | ||
| 44 | + optimizer.zero_grad() | ||
| 45 | + | ||
| 46 | + # compute loss for critic | ||
| 47 | + preds = model(images) | ||
| 48 | + loss = criterion(preds, labels) | ||
| 49 | + | ||
| 50 | + # optimize source classifier | ||
| 51 | + loss.backward() | ||
| 52 | + optimizer.step() | ||
| 53 | + | ||
| 54 | + | ||
| 55 | + | ||
| 56 | + # # eval model on test set | ||
| 57 | + if ((epoch ) % params.eval_step_pre == 0): | ||
| 58 | + print(f"Epoch [{epoch}/{params.num_epochs_pre}]",end='') | ||
| 59 | + eval_src(model, source_data_loader) | ||
| 60 | + | ||
| 61 | + # save model parameters | ||
| 62 | + if ((epoch + 1) % params.save_step_pre == 0): | ||
| 63 | + save_model(model, "ADDA-source_cnn-{}.pt".format(epoch + 1)) | ||
| 64 | + | ||
| 65 | + # # save final model | ||
| 66 | + save_model(model, "ADDA-source_cnn-final.pt") | ||
| 67 | + | ||
| 68 | + return model | ||
| 69 | + | ||
| 70 | +def eval_src(model, data_loader): | ||
| 71 | + """Evaluate classifier for source domain.""" | ||
| 72 | + # set eval state for Dropout and BN layers | ||
| 73 | + model.eval() | ||
| 74 | + with torch.no_grad(): | ||
| 75 | + # init loss and accuracy | ||
| 76 | + loss = 0 | ||
| 77 | + acc = 0 | ||
| 78 | + | ||
| 79 | + # evaluate network | ||
| 80 | + for (images, labels) in data_loader: | ||
| 81 | + | ||
| 82 | + images = make_cuda(images) | ||
| 83 | + labels = make_cuda(labels).squeeze_() | ||
| 84 | + | ||
| 85 | + preds = model(images) | ||
| 86 | + | ||
| 87 | + pred_cls = preds.data.max(1)[1] | ||
| 88 | + acc += pred_cls.eq(labels.data).cpu().sum().item() | ||
| 89 | + | ||
| 90 | + acc /= len(data_loader.dataset) | ||
| 91 | + | ||
| 92 | + print("Avg Accuracy = {:2%}".format( acc)) | ||
| 93 | + | ||
| 94 | + |
source code/adda_mixup/core/test.py
0 → 100644
| 1 | +#!/usr/bin/env python3 | ||
| 2 | +# -*- coding: utf-8 -*- | ||
| 3 | +""" | ||
| 4 | +Created on Wed Dec 5 15:03:50 2018 | ||
| 5 | + | ||
| 6 | +@author: gaoyi | ||
| 7 | +""" | ||
| 8 | + | ||
| 9 | +import torch | ||
| 10 | +import torch.nn as nn | ||
| 11 | + | ||
| 12 | +from utils import make_cuda | ||
| 13 | + | ||
| 14 | + | ||
| 15 | +def eval_tgt(model, data_loader): | ||
| 16 | + """Evaluation for target encoder by source classifier on target dataset.""" | ||
| 17 | + # set eval state for Dropout and BN layers | ||
| 18 | + model.eval() | ||
| 19 | + # init loss and accuracy | ||
| 20 | + loss = 0 | ||
| 21 | + acc = 0 | ||
| 22 | + with torch.no_grad(): | ||
| 23 | + # evaluate network | ||
| 24 | + for (images, labels) in data_loader: | ||
| 25 | + images = make_cuda(images) | ||
| 26 | + labels = make_cuda(labels).squeeze_() | ||
| 27 | + | ||
| 28 | + preds = model(images) | ||
| 29 | + _, preds = torch.max(preds.data, 1) | ||
| 30 | + acc += (preds == labels).float().sum()/images.shape[0] | ||
| 31 | + | ||
| 32 | + acc /= len(data_loader) | ||
| 33 | + | ||
| 34 | + print("Avg Accuracy = {:2%}".format(acc)) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/__init__.py
0 → 100644
source code/adda_mixup/dataset/customdata.py
0 → 100644
| 1 | +from torchvision import transforms, datasets | ||
| 2 | +import torch | ||
| 3 | +import params | ||
| 4 | + | ||
| 5 | +def get_custom(train,adp=False,size = 0): | ||
| 6 | + | ||
| 7 | + pre_process = transforms.Compose([transforms.Resize(params.image_size), | ||
| 8 | + transforms.ToTensor(), | ||
| 9 | + # transforms.Normalize((0.5),(0.5)), | ||
| 10 | + ]) | ||
| 11 | + custom_dataset = datasets.ImageFolder( | ||
| 12 | + root = params.custom_dataset_root , | ||
| 13 | + transform = pre_process, | ||
| 14 | + ) | ||
| 15 | + length = len(custom_dataset) | ||
| 16 | + train_set, val_set = torch.utils.data.random_split(custom_dataset, [int(length*0.9), length-int(length*0.9)]) | ||
| 17 | + | ||
| 18 | + if train: | ||
| 19 | + train_set,_ = torch.utils.data.random_split(train_set, [size,len(train_set)-size]) | ||
| 20 | + | ||
| 21 | + | ||
| 22 | + | ||
| 23 | + custom_data_loader = torch.utils.data.DataLoader( | ||
| 24 | + train_set if train else val_set, | ||
| 25 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 26 | + shuffle=True, | ||
| 27 | + drop_last=True | ||
| 28 | + | ||
| 29 | + ) | ||
| 30 | + | ||
| 31 | + return custom_data_loader | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/mnist.py
0 → 100644
| 1 | + | ||
| 2 | + | ||
| 3 | +import torch | ||
| 4 | +from torchvision import datasets, transforms | ||
| 5 | +import torch.utils.data as data_utils | ||
| 6 | + | ||
| 7 | +import params | ||
| 8 | + | ||
| 9 | + | ||
| 10 | +def get_mnist(train,adp = False,size = 0): | ||
| 11 | + """Get MNIST dataset loader.""" | ||
| 12 | + # image pre-processing | ||
| 13 | + pre_process = transforms.Compose([transforms.Resize(params.image_size), | ||
| 14 | + transforms.ToTensor(), | ||
| 15 | +# transforms.Normalize((0.5),(0.5)), | ||
| 16 | + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||
| 17 | + | ||
| 18 | + | ||
| 19 | + ]) | ||
| 20 | + | ||
| 21 | + | ||
| 22 | + | ||
| 23 | + | ||
| 24 | + # dataset and data loader | ||
| 25 | + mnist_dataset = datasets.MNIST(root=params.mnist_dataset_root, | ||
| 26 | + train=train, | ||
| 27 | + transform=pre_process, | ||
| 28 | + | ||
| 29 | + download=True) | ||
| 30 | + if train: | ||
| 31 | + # perm = torch.randperm(len(mnist_dataset)) | ||
| 32 | + # indices = perm[:10000] | ||
| 33 | + mnist_dataset,_ = data_utils.random_split(mnist_dataset, [size,len(mnist_dataset)-size]) | ||
| 34 | + # size = len(mnist_dataset) | ||
| 35 | + # train, valid = data_utils.random_split(mnist_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
| 36 | + # train_loader = torch.utils.data.DataLoader( | ||
| 37 | + # dataset=train, | ||
| 38 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 39 | + # shuffle=True, | ||
| 40 | + # drop_last=True) | ||
| 41 | + # valid_loader = torch.utils.data.DataLoader( | ||
| 42 | + # dataset=valid, | ||
| 43 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 44 | + # shuffle=True, | ||
| 45 | + # drop_last=True) | ||
| 46 | + | ||
| 47 | + # return train_loader,valid_loader | ||
| 48 | + | ||
| 49 | + mnist_data_loader = torch.utils.data.DataLoader( | ||
| 50 | + dataset=mnist_dataset, | ||
| 51 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 52 | + shuffle=True, | ||
| 53 | + drop_last=True) | ||
| 54 | + return mnist_data_loader | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/mnist_m.py
0 → 100644
| 1 | +import torch.utils.data as data | ||
| 2 | +from PIL import Image | ||
| 3 | +import os | ||
| 4 | +import params | ||
| 5 | +from torchvision import transforms | ||
| 6 | +import torch | ||
| 7 | + | ||
| 8 | +import torch.utils.data as data_utils | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +class GetLoader(data.Dataset): | ||
| 12 | + def __init__(self, data_root, data_list, transform=None): | ||
| 13 | + self.root = data_root | ||
| 14 | + self.transform = transform | ||
| 15 | + | ||
| 16 | + f = open(data_list, 'r') | ||
| 17 | + data_list = f.readlines() | ||
| 18 | + f.close() | ||
| 19 | + | ||
| 20 | + self.n_data = len(data_list) | ||
| 21 | + | ||
| 22 | + self.img_paths = [] | ||
| 23 | + self.img_labels = [] | ||
| 24 | + | ||
| 25 | + for data_ in data_list: | ||
| 26 | + self.img_paths.append(data_[:-3]) | ||
| 27 | + self.img_labels.append(data_[-2]) | ||
| 28 | + | ||
| 29 | + def __getitem__(self, item): | ||
| 30 | + img_paths, labels = self.img_paths[item], self.img_labels[item] | ||
| 31 | + imgs = Image.open(os.path.join(self.root, img_paths)).convert('RGB') | ||
| 32 | + | ||
| 33 | + if self.transform is not None: | ||
| 34 | + imgs = self.transform(imgs) | ||
| 35 | + labels = int(labels) | ||
| 36 | + | ||
| 37 | + return imgs, labels | ||
| 38 | + | ||
| 39 | + def __len__(self): | ||
| 40 | + return self.n_data | ||
| 41 | + | ||
| 42 | + | ||
| 43 | +def get_mnist_m(train,adp=False,size= 0 ): | ||
| 44 | + | ||
| 45 | + if train == True: | ||
| 46 | + mode = 'train' | ||
| 47 | + else: | ||
| 48 | + mode = 'test' | ||
| 49 | + | ||
| 50 | + train_list = os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}_labels.txt'.format(mode)) | ||
| 51 | + # image pre-processing | ||
| 52 | + pre_process = transforms.Compose([ | ||
| 53 | + transforms.Resize(params.image_size), | ||
| 54 | + # transforms.Grayscale(3), | ||
| 55 | + | ||
| 56 | + transforms.ToTensor(), | ||
| 57 | + | ||
| 58 | +# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
| 59 | +# transforms.Grayscale(1), | ||
| 60 | + ] | ||
| 61 | + ) | ||
| 62 | + | ||
| 63 | + dataset_target = GetLoader( | ||
| 64 | + data_root=os.path.join(params.mnist_m_dataset_root, 'mnist_m_{}'.format(mode)), | ||
| 65 | + data_list=train_list, | ||
| 66 | + transform=pre_process) | ||
| 67 | + | ||
| 68 | + if train: | ||
| 69 | + # perm = torch.randperm(len(dataset_target)) | ||
| 70 | + # indices = perm[:10000] | ||
| 71 | + dataset_target,_ = data_utils.random_split(dataset_target, [size,len(dataset_target)-size]) | ||
| 72 | + # size = len(dataset_target) | ||
| 73 | + # train, valid = data_utils.random_split(dataset_target,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
| 74 | + # train_loader = torch.utils.data.DataLoader( | ||
| 75 | + # dataset=train, | ||
| 76 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 77 | + # shuffle=True, | ||
| 78 | + # drop_last=True) | ||
| 79 | + # valid_loader = torch.utils.data.DataLoader( | ||
| 80 | + # dataset=valid, | ||
| 81 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 82 | + # shuffle=True, | ||
| 83 | + # drop_last=True) | ||
| 84 | + | ||
| 85 | + # return train_loader,valid_loader | ||
| 86 | + | ||
| 87 | + | ||
| 88 | + dataloader = torch.utils.data.DataLoader( | ||
| 89 | + dataset=dataset_target, | ||
| 90 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 91 | + | ||
| 92 | + shuffle=True, | ||
| 93 | + drop_last=True) | ||
| 94 | + | ||
| 95 | + return dataloader | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/svhn.py
0 → 100644
| 1 | +import torch | ||
| 2 | +from torchvision import datasets, transforms | ||
| 3 | + | ||
| 4 | +import params | ||
| 5 | + | ||
| 6 | +import torch.utils.data as data_utils | ||
| 7 | + | ||
| 8 | +def get_svhn(train,adp=False,size=0): | ||
| 9 | + """Get SVHN dataset loader.""" | ||
| 10 | + # image pre-processing | ||
| 11 | + pre_process = transforms.Compose([ | ||
| 12 | + transforms.Resize(params.image_size), | ||
| 13 | + transforms.Grayscale(3), | ||
| 14 | + transforms.ToTensor(), | ||
| 15 | + # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), | ||
| 16 | + | ||
| 17 | + ]) | ||
| 18 | + | ||
| 19 | + | ||
| 20 | + # dataset and data loader | ||
| 21 | + svhn_dataset = datasets.SVHN(root=params.svhn_dataset_root, | ||
| 22 | + split='train' if train else 'test', | ||
| 23 | + transform=pre_process, | ||
| 24 | + download=True) | ||
| 25 | + if train: | ||
| 26 | + # perm = torch.randperm(len(svhn_dataset)) | ||
| 27 | + # indices = perm[:10000] | ||
| 28 | + svhn_dataset,_ = data_utils.random_split(svhn_dataset, [size,len(svhn_dataset)-size]) | ||
| 29 | + # size = len(svhn_dataset) | ||
| 30 | + # train, valid = data_utils.random_split(svhn_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
| 31 | + | ||
| 32 | + # train_loader = torch.utils.data.DataLoader( | ||
| 33 | + # dataset=train, | ||
| 34 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 35 | + | ||
| 36 | + # shuffle=True, | ||
| 37 | + # drop_last=True) | ||
| 38 | + | ||
| 39 | + # valid_loader = torch.utils.data.DataLoader( | ||
| 40 | + # dataset=valid, | ||
| 41 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 42 | + | ||
| 43 | + # shuffle=True, | ||
| 44 | + # drop_last=True) | ||
| 45 | + # return train_loader,valid_loader | ||
| 46 | + | ||
| 47 | + svhn_data_loader = torch.utils.data.DataLoader( | ||
| 48 | + dataset=svhn_dataset, | ||
| 49 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 50 | + | ||
| 51 | + shuffle=True, | ||
| 52 | + drop_last=True) | ||
| 53 | + | ||
| 54 | + return svhn_data_loader | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/dataset/usps.py
0 → 100644
| 1 | +import torch | ||
| 2 | +from torchvision import datasets, transforms | ||
| 3 | +from torch.utils.data.dataset import random_split | ||
| 4 | + | ||
| 5 | +import params | ||
| 6 | +import torch.utils.data as data_utils | ||
| 7 | + | ||
| 8 | + | ||
| 9 | +def get_usps(train,adp=False,size=0): | ||
| 10 | + """Get usps dataset loader.""" | ||
| 11 | + # image pre-processing | ||
| 12 | + pre_process = transforms.Compose([transforms.Resize(params.image_size), | ||
| 13 | + transforms.ToTensor(), | ||
| 14 | + # transforms.Normalize((0.5),(0.5)), | ||
| 15 | + transforms.Lambda(lambda x: x.repeat(3, 1, 1)), | ||
| 16 | + # transforms.Grayscale(1), | ||
| 17 | + | ||
| 18 | + | ||
| 19 | + | ||
| 20 | + ]) | ||
| 21 | + | ||
| 22 | + | ||
| 23 | + # dataset and data loader | ||
| 24 | + usps_dataset = datasets.USPS(root=params.usps_dataset_root, | ||
| 25 | + train=train, | ||
| 26 | + transform=pre_process, | ||
| 27 | + download=True) | ||
| 28 | + | ||
| 29 | + | ||
| 30 | + if train: | ||
| 31 | + usps_dataset, _ = data_utils.random_split(usps_dataset, [size,len(usps_dataset)-size]) | ||
| 32 | + # size = len(usps_dataset) | ||
| 33 | + # train, valid = data_utils.random_split(usps_dataset,[size-int(size*params.train_val_ratio),int(size*params.train_val_ratio)]) | ||
| 34 | + # train_loader = torch.utils.data.DataLoader( | ||
| 35 | + # dataset=train, | ||
| 36 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 37 | + # shuffle=True, | ||
| 38 | + # drop_last=True) | ||
| 39 | + # valid_loader = torch.utils.data.DataLoader( | ||
| 40 | + # dataset=valid, | ||
| 41 | + # batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 42 | + # shuffle=True, | ||
| 43 | + # drop_last=True) | ||
| 44 | + # return train_loader,valid_loader | ||
| 45 | + | ||
| 46 | + usps_data_loader = torch.utils.data.DataLoader( | ||
| 47 | + dataset=usps_dataset, | ||
| 48 | + batch_size= params.adp_batch_size if adp else params.batch_size, | ||
| 49 | + | ||
| 50 | + shuffle=True, | ||
| 51 | + drop_last=True) | ||
| 52 | + return usps_data_loader | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/main.py
0 → 100644
| 1 | + | ||
| 2 | +import params | ||
| 3 | +from utils import get_data_loader, init_model, init_random_seed,mixup_data | ||
| 4 | +from core import pretrain , adapt , test,mixup | ||
| 5 | +import torch | ||
| 6 | +from models.models import * | ||
| 7 | +import numpy as np | ||
| 8 | +import sys | ||
| 9 | + | ||
| 10 | + | ||
| 11 | + | ||
| 12 | +if __name__ == '__main__': | ||
| 13 | + # init random seed | ||
| 14 | + init_random_seed(params.manual_seed) | ||
| 15 | + print(f"Is cuda availabel? {torch.cuda.is_available()}") | ||
| 16 | + | ||
| 17 | + | ||
| 18 | + | ||
| 19 | + #set loader | ||
| 20 | + print("src data loader....") | ||
| 21 | + src_data_loader = get_data_loader(params.src_dataset,adp=False,size = 10000) | ||
| 22 | + src_data_loader_eval = get_data_loader(params.src_dataset,train=False) | ||
| 23 | + print("tgt data loader....") | ||
| 24 | + tgt_data_loader = get_data_loader(params.tgt_dataset,adp=False,size = 50000) | ||
| 25 | + tgt_data_loader_eval = get_data_loader(params.tgt_dataset, train=False) | ||
| 26 | + print(f"scr data size : {len(src_data_loader.dataset)}") | ||
| 27 | + print(f"tgt data size : {len(tgt_data_loader.dataset)}") | ||
| 28 | + | ||
| 29 | + | ||
| 30 | + print("start training") | ||
| 31 | + source_cnn = CNN(in_channels=3).to("cuda") | ||
| 32 | + target_cnn = CNN(in_channels=3, target=True).to("cuda") | ||
| 33 | + discriminator = Discriminator().to("cuda") | ||
| 34 | + | ||
| 35 | + source_cnn = mixup.train_src(source_cnn, src_data_loader,tgt_data_loader,None) | ||
| 36 | + # source_cnn.load_state_dict(torch.load("./generated/models/our-source_cnn-final.pt")) | ||
| 37 | + | ||
| 38 | + test.eval_tgt(source_cnn, tgt_data_loader_eval) | ||
| 39 | + | ||
| 40 | + target_cnn.load_state_dict(source_cnn.state_dict()) | ||
| 41 | + | ||
| 42 | + tgt_encoder = mixup.train_tgt(source_cnn, target_cnn, discriminator, | ||
| 43 | + src_data_loader,tgt_data_loader,None) | ||
| 44 | + | ||
| 45 | + | ||
| 46 | + | ||
| 47 | + print("=== Evaluating classifier for encoded target domain ===") | ||
| 48 | + print(f"mixup : {params.lammax} {params.src_dataset} -> {params.tgt_dataset} ") | ||
| 49 | + print("Eval | source_cnn | src_data_loader_eval") | ||
| 50 | + test.eval_tgt(source_cnn, src_data_loader_eval) | ||
| 51 | + print(">>> Eval | source_cnn | tgt_data_loader_eval <<<") | ||
| 52 | + test.eval_tgt(source_cnn, tgt_data_loader_eval) | ||
| 53 | + print(">>> Eval | target_cnn | tgt_data_loader_eval <<<") | ||
| 54 | + test.eval_tgt(target_cnn, tgt_data_loader_eval) | ||
| 55 | + | ||
| 56 | + | ||
| 57 | + | ||
| 58 | + |
source code/adda_mixup/models/__init__.py
0 → 100644
File mode changed
source code/adda_mixup/models/models.py
0 → 100644
| 1 | +from torch import nn | ||
| 2 | +import torch.nn.functional as F | ||
| 3 | +import params | ||
| 4 | + | ||
| 5 | +class Encoder(nn.Module): | ||
| 6 | + def __init__(self, in_channels=1, h=256, dropout=0.5): | ||
| 7 | + super(Encoder, self).__init__() | ||
| 8 | + self.conv1 = nn.Conv2d(in_channels, 20, kernel_size=5, stride=1) | ||
| 9 | + self.conv2 = nn.Conv2d(20, 50, kernel_size=5, stride=1) | ||
| 10 | + self.bn1 = nn.BatchNorm2d(20) | ||
| 11 | + self.bn2 = nn.BatchNorm2d(50) | ||
| 12 | + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | ||
| 13 | + self.relu = nn.ReLU() | ||
| 14 | + self.dropout =nn.Dropout2d(p= dropout) | ||
| 15 | + # self.dropout = nn.Dropout(dropout) | ||
| 16 | + self.fc1 = nn.Linear(800, 500) | ||
| 17 | + | ||
| 18 | + # for m in self.modules(): | ||
| 19 | + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
| 20 | + # nn.init.kaiming_normal_(m.weight) | ||
| 21 | + | ||
| 22 | + def forward(self, x): | ||
| 23 | + bs = x.size(0) | ||
| 24 | + x = self.pool(self.relu(self.bn1(self.conv1(x)))) | ||
| 25 | + x = self.pool(self.relu(self.bn2(self.dropout(self.conv2(x))))) | ||
| 26 | + x = x.view(bs, -1) | ||
| 27 | + # x = self.dropout(x)W | ||
| 28 | + x = self.fc1(x) | ||
| 29 | + return x | ||
| 30 | + | ||
| 31 | + | ||
| 32 | +class Classifier(nn.Module): | ||
| 33 | + def __init__(self, n_classes, dropout=0.5): | ||
| 34 | + super(Classifier, self).__init__() | ||
| 35 | + self.l1 = nn.Linear(500, n_classes) | ||
| 36 | + | ||
| 37 | + # for m in self.modules(): | ||
| 38 | + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
| 39 | + # nn.init.kaiming_normal_(m.weight) | ||
| 40 | + | ||
| 41 | + def forward(self, x): | ||
| 42 | + x = self.l1(x) | ||
| 43 | + return x | ||
| 44 | + | ||
| 45 | + | ||
| 46 | +class CNN(nn.Module): | ||
| 47 | + def __init__(self, in_channels=1, n_classes=10, target=False): | ||
| 48 | + super(CNN, self).__init__() | ||
| 49 | + self.encoder = Encoder(in_channels=in_channels) | ||
| 50 | + self.classifier = Classifier(n_classes) | ||
| 51 | + if target: | ||
| 52 | + for param in self.classifier.parameters(): | ||
| 53 | + param.requires_grad = False | ||
| 54 | + | ||
| 55 | + def forward(self, x): | ||
| 56 | + x = self.encoder(x) | ||
| 57 | + x = self.classifier(x) | ||
| 58 | + return x | ||
| 59 | + | ||
| 60 | + | ||
| 61 | +class Discriminator(nn.Module): | ||
| 62 | + def __init__(self, h=500): | ||
| 63 | + super(Discriminator, self).__init__() | ||
| 64 | + self.l1 = nn.Linear(500, h) | ||
| 65 | + self.l2 = nn.Linear(h, h) | ||
| 66 | + self.l3 = nn.Linear(h, 2) | ||
| 67 | + # self.slope =params.slope | ||
| 68 | + | ||
| 69 | + self.relu = nn.ReLU() | ||
| 70 | + | ||
| 71 | + # for m in self.modules(): | ||
| 72 | + # if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): | ||
| 73 | + # nn.init.kaiming_normal_(m.weight) | ||
| 74 | + | ||
| 75 | + def forward(self, x): | ||
| 76 | + x = self.relu(self.l1(x)) | ||
| 77 | + x = self.relu(self.l2(x)) | ||
| 78 | + x = self.l3(x) | ||
| 79 | + return x |
source code/adda_mixup/params.py
0 → 100644
| 1 | +import torch | ||
| 2 | + | ||
| 3 | +# params for dataset and data loader | ||
| 4 | +data_root = "data" | ||
| 5 | +image_size = 28 | ||
| 6 | + | ||
| 7 | +#restore | ||
| 8 | +model_root = 'generated\\models' | ||
| 9 | + | ||
| 10 | + | ||
| 11 | +# params for target dataset | ||
| 12 | +# 'mnist_m', 'usps', 'svhn' "custom" | ||
| 13 | + | ||
| 14 | +#dataset root | ||
| 15 | +mnist_dataset_root = data_root | ||
| 16 | +mnist_m_dataset_root = data_root+'\\mnist_m' | ||
| 17 | +usps_dataset_root = data_root+'\\usps' | ||
| 18 | +svhn_dataset_root = data_root+'\\svhn' | ||
| 19 | +custom_dataset_root = data_root+'\\custom\\' | ||
| 20 | + | ||
| 21 | +# params for training network | ||
| 22 | +num_gpu = 1 | ||
| 23 | + | ||
| 24 | +log_step_pre = 10 | ||
| 25 | +log_step = 10 | ||
| 26 | +eval_step_pre = 10 | ||
| 27 | + | ||
| 28 | +##epoch | ||
| 29 | +save_step_pre = 100 | ||
| 30 | +manual_seed = 1234 | ||
| 31 | + | ||
| 32 | +d_input_dims = 500 | ||
| 33 | +d_hidden_dims = 500 | ||
| 34 | +d_output_dims = 2 | ||
| 35 | +d_model_restore = 'generated\\models\\ADDA-critic-final.pt' | ||
| 36 | + | ||
| 37 | +## sorce target | ||
| 38 | +src_dataset = 'custom' | ||
| 39 | +tgt_dataset = 'custom' | ||
| 40 | + | ||
| 41 | + | ||
| 42 | +# params for optimizing models | ||
| 43 | +# # lam 0.3 | ||
| 44 | +#mnist -> custom | ||
| 45 | +num_epochs_pre = 20 | ||
| 46 | +num_epochs = 50 | ||
| 47 | +batch_size = 128 | ||
| 48 | +adp_batch_size = 128 | ||
| 49 | +pre_c_learning_rate = 2e-4 | ||
| 50 | +adp_c_learning_rate = 1e-4 | ||
| 51 | +d_learning_rate = 1e-4 | ||
| 52 | +beta1 = 0.5 | ||
| 53 | +beta2 = 0.999 | ||
| 54 | +weight_decay = 0 | ||
| 55 | + | ||
| 56 | + | ||
| 57 | +# #usps -> custom | ||
| 58 | +# #lam 0.1 | ||
| 59 | +# num_epochs_pre = 5 | ||
| 60 | +# num_epochs = 20 | ||
| 61 | +# batch_size = 256 | ||
| 62 | +# pre_c_learning_rate = 1e-4 | ||
| 63 | +# adp_c_learning_rate = 2e-5 | ||
| 64 | +# d_learning_rate = 1e-5 | ||
| 65 | +# beta1 = 0.5 | ||
| 66 | +# beta2 = 0.999 | ||
| 67 | +# weight_decay = 2e-4 | ||
| 68 | + | ||
| 69 | +# #mnist_m -> custom | ||
| 70 | +# #lam 0.1 | ||
| 71 | +# num_epochs_pre = 30 | ||
| 72 | +# num_epochs = 50 | ||
| 73 | +# batch_size = 256 | ||
| 74 | +# adp_batch_size = 256 | ||
| 75 | +# pre_c_learning_rate = 1e-3 | ||
| 76 | +# adp_c_learning_rate = 1e-4 | ||
| 77 | +# d_learning_rate = 1e-4 | ||
| 78 | +# beta1 = 0.5 | ||
| 79 | +# beta2 = 0.999 | ||
| 80 | +# weight_decay = 2e-4 | ||
| 81 | + | ||
| 82 | +# # params for optimizing models | ||
| 83 | +#lam 0.3 | ||
| 84 | +# #mnist -> mnist_m | ||
| 85 | +# num_epochs_pre = 50 | ||
| 86 | +# num_epochs = 100 | ||
| 87 | +# batch_size = 256 | ||
| 88 | +# adp_batch_size = 256 | ||
| 89 | +# pre_c_learning_rate = 2e-4 | ||
| 90 | +# adp_c_learning_rate = 2e-4 | ||
| 91 | +# d_learning_rate = 2e-4 | ||
| 92 | +# beta1 = 0.5 | ||
| 93 | +# beta2 = 0.999 | ||
| 94 | +# weight_decay = 0 | ||
| 95 | + | ||
| 96 | +# # source 10000 target 50000 | ||
| 97 | +# # params for optimizing models | ||
| 98 | +# #svhn -> mnist | ||
| 99 | +# num_epochs_pre = 20 | ||
| 100 | +# num_epochs = 30 | ||
| 101 | +# batch_size = 128 | ||
| 102 | +# adp_batch_size = 128 | ||
| 103 | +# pre_c_learning_rate = 2e-4 | ||
| 104 | +# adp_c_learning_rate = 1e-4 | ||
| 105 | +# d_learning_rate = 1e-4 | ||
| 106 | +# beta1 = 0.5 | ||
| 107 | +# beta2 = 0.999 | ||
| 108 | +# weight_decay = 2.5e-4 | ||
| 109 | + | ||
| 110 | +# # mnist->usps | ||
| 111 | +# num_epochs_pre = 50 | ||
| 112 | +# num_epochs = 100 | ||
| 113 | +# batch_size = 256 | ||
| 114 | +# adp_batch_size = 256 | ||
| 115 | +# pre_c_learning_rate = 2e-4 | ||
| 116 | +# adp_c_learning_rate = 2e-4 | ||
| 117 | +# d_learning_rate = 2e-4 | ||
| 118 | +# beta1 = 0.5 | ||
| 119 | +# beta2 = 0.999 | ||
| 120 | +# weight_decay =0 | ||
| 121 | + | ||
| 122 | + | ||
| 123 | +# # usps->mnist | ||
| 124 | +# num_epochs_pre = 50 | ||
| 125 | +# num_epochs = 100 | ||
| 126 | +# batch_size = 256 | ||
| 127 | +# pre_c_learning_rate = 2e-4 | ||
| 128 | +# adp_c_learning_rate = 2e-4 | ||
| 129 | +# d_learning_rate =2e-4 | ||
| 130 | +# beta1 = 0.5 | ||
| 131 | +# beta2 = 0.999 | ||
| 132 | +# weight_decay =0 | ||
| 133 | + | ||
| 134 | + | ||
| 135 | + | ||
| 136 | +# | ||
| 137 | +use_load = False | ||
| 138 | +train =False | ||
| 139 | + | ||
| 140 | +#ratio mix target | ||
| 141 | +lammax = 0.0 | ||
| 142 | +lammin = 0.0 | ||
| 143 | + | ||
| 144 | + | ||
| 145 | +labelsmoothing = False | ||
| 146 | +smoothing = 0.3 | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
source code/adda_mixup/utils.py
0 → 100644
| 1 | + | ||
| 2 | + | ||
| 3 | +import os | ||
| 4 | +import random | ||
| 5 | +import torch | ||
| 6 | +import torch.backends.cudnn as cudnn | ||
| 7 | +from torch.autograd import Variable | ||
| 8 | +import params | ||
| 9 | +from dataset import get_mnist, get_mnist_m, get_usps,get_svhn,get_custom | ||
| 10 | +import numpy as np | ||
| 11 | +import itertools | ||
| 12 | +import torch.nn.functional as F | ||
| 13 | +import params | ||
| 14 | + | ||
| 15 | +def make_cuda(tensor): | ||
| 16 | + """Use CUDA if it's available.""" | ||
| 17 | + if torch.cuda.is_available(): | ||
| 18 | + tensor = tensor.cuda() | ||
| 19 | + return tensor | ||
| 20 | + | ||
| 21 | + | ||
| 22 | +def denormalize(x, std, mean): | ||
| 23 | + """Invert normalization, and then convert array into image.""" | ||
| 24 | + out = x * std + mean | ||
| 25 | + return out.clamp(0, 1) | ||
| 26 | + | ||
| 27 | + | ||
| 28 | +def init_weights(layer): | ||
| 29 | + """Init weights for layers w.r.t. the original paper.""" | ||
| 30 | + layer_name = layer.__class__.__name__ | ||
| 31 | + if layer_name.find("Conv") != -1: | ||
| 32 | + layer.weight.data.normal_(0.0, 0.02) | ||
| 33 | + elif layer_name.find("BatchNorm") != -1: | ||
| 34 | + layer.weight.data.normal_(1.0, 0.02) | ||
| 35 | + layer.bias.data.fill_(0) | ||
| 36 | + | ||
| 37 | + | ||
| 38 | +def init_random_seed(manual_seed): | ||
| 39 | + """Init random seed.""" | ||
| 40 | + seed = None | ||
| 41 | + if manual_seed is None: | ||
| 42 | + seed = random.randint(1, 10000) | ||
| 43 | + else: | ||
| 44 | + seed = manual_seed | ||
| 45 | + #for REPRODUCIBILITY | ||
| 46 | + torch.backends.cudnn.deterministic = True | ||
| 47 | + torch.backends.cudnn.benchmark = False | ||
| 48 | + print("use random seed: {}".format(seed)) | ||
| 49 | + random.seed(seed) | ||
| 50 | + torch.manual_seed(seed) | ||
| 51 | + np.random.seed(seed) | ||
| 52 | + | ||
| 53 | + if torch.cuda.is_available(): | ||
| 54 | + torch.cuda.manual_seed_all(seed) | ||
| 55 | + | ||
| 56 | + | ||
| 57 | +def get_data_loader(name,train=True,adp=False,size = 0): | ||
| 58 | + """Get data loader by name.""" | ||
| 59 | + if name == "mnist": | ||
| 60 | + return get_mnist(train,adp,size) | ||
| 61 | + elif name == "mnist_m": | ||
| 62 | + return get_mnist_m(train,adp,size) | ||
| 63 | + elif name == "usps": | ||
| 64 | + return get_usps(train,adp,size) | ||
| 65 | + elif name == "svhn": | ||
| 66 | + return get_svhn(train,adp,size) | ||
| 67 | + elif name == "custom": | ||
| 68 | + return get_custom(train,adp,size) | ||
| 69 | + | ||
| 70 | +def init_model(net, restore=None): | ||
| 71 | + """Init models with cuda and weights.""" | ||
| 72 | + # init weights of model | ||
| 73 | + # net.apply(init_weights) | ||
| 74 | + | ||
| 75 | + print(f'restore file : {restore}') | ||
| 76 | + # restore model weights | ||
| 77 | + if restore is not None and os.path.exists(restore): | ||
| 78 | + net.load_state_dict(torch.load(restore)) | ||
| 79 | + net.restored = True | ||
| 80 | + print("Restore model from: {}".format(os.path.abspath(restore))) | ||
| 81 | + | ||
| 82 | + # check if cuda is available | ||
| 83 | + if torch.cuda.is_available(): | ||
| 84 | + net.cuda() | ||
| 85 | + | ||
| 86 | + return net | ||
| 87 | + | ||
| 88 | +def save_model(net, filename): | ||
| 89 | + """Save trained model.""" | ||
| 90 | + if not os.path.exists(params.model_root): | ||
| 91 | + os.makedirs(params.model_root) | ||
| 92 | + torch.save(net.state_dict(), | ||
| 93 | + os.path.join(params.model_root, filename)) | ||
| 94 | + print("save pretrained model to: {}".format(os.path.join(params.model_root, | ||
| 95 | + filename))) | ||
| 96 | + | ||
| 97 | +class LabelSmoothingCrossEntropy(torch.nn.Module): | ||
| 98 | + def __init__(self,smoothing): | ||
| 99 | + super(LabelSmoothingCrossEntropy, self).__init__() | ||
| 100 | + self.smoothing = smoothing | ||
| 101 | + def forward(self, y, targets,smoothing=0.1): | ||
| 102 | + confidence = 1. - self.smoothing | ||
| 103 | + log_probs = F.log_softmax(y, dim=-1) # 예측 확률 계산 | ||
| 104 | + true_probs = torch.zeros_like(log_probs) | ||
| 105 | + true_probs.fill_(self.smoothing / (y.shape[1] - 1)) | ||
| 106 | + true_probs.scatter_(1, targets.data.unsqueeze(1), confidence) # 정답 인덱스의 정답 확률을 confidence로 변경 | ||
| 107 | + return torch.mean(torch.sum(true_probs * -log_probs, dim=-1)) # negative log likelihood | ||
| 108 | + | ||
| 109 | +#mixup only data, not label | ||
| 110 | +def mixup_data(source,target): | ||
| 111 | + max = params.lammax | ||
| 112 | + min = params.lammin | ||
| 113 | + lam = (max-min)*torch.rand((1))+min | ||
| 114 | + lam=lam.cuda() | ||
| 115 | + target = target.cuda() | ||
| 116 | + mixed_source = (1 - lam) * source + lam* target | ||
| 117 | + | ||
| 118 | + | ||
| 119 | + return mixed_source, lam | ||
| 120 | + | ||
| 121 | + | ||
| 122 | + | ||
| 123 | + | ||
| 124 | +def mixup_criterion(criterion, pred, y_a, y_b, lam): | ||
| 125 | + return 0.9* criterion(pred, y_a) + 0.1 * criterion(pred, y_b) | ||
| ... | \ No newline at end of file | ... | \ No newline at end of file |
-
Please register or login to post a comment