이 글은 필자가 주재걸 교수님의 2018-1 Image Generation and Translation 강의를 듣고 추가 블로그 참고1, 참고2를 읽으면서 정리차 작성한 글입니다.
강의를 듣게 된 경위: 이용해야 할 개념에서 Image와 같이 2D 개념에 적용한 예시가 없어서 이를 2D에 적용하면서 나타낼 수 있는 문제점들이나 인사이트들을 다시 확용하고자 image generation task에 연결하고자 듣게 되었다.
🎈 Conditional Variational AutoEncoder (CVAE)
- VAE + Condition: VAE 구조에서 Label 정보를 추가해서 더 높은 정확도를 제공하는 VAE를 제공한다.
- Model structure
- VAE구조에 시작부분, Representation 부분에 정보를 넣는 구조를 말한다.
🎈 Conditional Variational AutoEncoder (CVAE) 코드 구현 Github Notebook
Conditional CVAE는 CVAE 구조에 CNN 구조를 덧입힌 구조를 말한다. 본 코드는 MNIST 이미지 데이터셋에 CVAE 구조를 적용한 예시를 보겠습니다.
우선 MNIST 데이터셋을 불러옵니다.
# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)
# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)
CVAE 클래스를 다음과 같이 작성합니다.
- MNIST 의 28x28 픽셀이 하나하나 Linear 하게 들어가서 x_dim이 784인 점
- Label이 시작구조에 추가되므로
- self.fc1 = nn.Linear(x_dim + c_dim, h_dim)
- encoder 함수에서 torch.cat([x,c],1)이 적용됩니다.
- Label이 representation 구조에 추가되므로
- self.fc4 = nn.Linear(z_dim + c_dim, h_dim2)
- decoder 함수에서 torch.cat([z,c],1)이 적용됩니다.
class CVAE(nn.Module):
def __init__(self, x_dim, h_dim1, h_dim2, z_dim, c_dim):
super(CVAE, self).__init__()
# encoder part
self.fc1 = nn.Linear(x_dim + c_dim, h_dim1)
self.fc2 = nn.Linear(h_dim1, h_dim2)
self.fc31 = nn.Linear(h_dim2, z_dim)
self.fc32 = nn.Linear(h_dim2, z_dim)
# decoder part
self.fc4 = nn.Linear(z_dim + c_dim, h_dim2)
self.fc5 = nn.Linear(h_dim2, h_dim1)
self.fc6 = nn.Linear(h_dim1, x_dim)
def encoder(self, x, c):
concat_input = torch.cat([x, c], 1)
h = F.relu(self.fc1(concat_input))
h = F.relu(self.fc2(h))
return self.fc31(h), self.fc32(h)
def sampling(self, mu, log_var):
std = torch.exp(0.5*log_var)
eps = torch.randn_like(std)
return eps.mul(std).add(mu) # return z sample
def decoder(self, z, c):
concat_input = torch.cat([z, c], 1)
h = F.relu(self.fc4(concat_input))
h = F.relu(self.fc5(h))
return F.sigmoid(self.fc6(h))
def forward(self, x, c):
mu, log_var = self.encoder(x.view(-1, 784), c)
z = self.sampling(mu, log_var)
return self.decoder(z, c), mu, log_var
# build model
cond_dim = train_loader.dataset.train_labels.unique().size(0)
cvae = CVAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2, c_dim=cond_dim)
#if torch.cuda.is_available():
# cvae.cuda()
다음과 같이 모델이 적용디 되면서 다음과 같이 CVAE 구조는 적용됩니다.
CVAE(
(fc1): Linear(in_features=794, out_features=512, bias=True)
(fc2): Linear(in_features=512, out_features=256, bias=True)
(fc31): Linear(in_features=256, out_features=2, bias=True)
(fc32): Linear(in_features=256, out_features=2, bias=True)
(fc4): Linear(in_features=12, out_features=256, bias=True)
(fc5): Linear(in_features=256, out_features=512, bias=True)
(fc6): Linear(in_features=512, out_features=784, bias=True)
)
'AI, Deep Learning Basics > Computer Vision' 카테고리의 다른 글
[기초] 이미지 classification 기본 모델: VGG, GoogLeNet, ResNet (0) | 2022.01.16 |
---|---|
[Generative Model] Variational AutoEncoder 3. Variational Inference (0) | 2022.01.08 |
[Generative Model] Variational AutoEncoder 1. Basic: AE, DAE, VAE (0) | 2021.12.06 |
[Computer Vision] Image, Video 분야 subtask 및 데이터 종류 정리 (0) | 2021.12.01 |
[Basic] 3x3 Conv, 1x1 Conv 하는 이유(FCN vs. FC Layer vs. FPN) (0) | 2021.11.20 |