[Generative Model] Variational AutoEncoder 2. Application: Conditional VAE, Convolutional CVAE 코드 구현
AI, Deep Learning Basics/Computer Vision

[Generative Model] Variational AutoEncoder 2. Application: Conditional VAE, Convolutional CVAE 코드 구현

 

이 글은 필자가 주재걸 교수님의 2018-1 Image Generation and Translation 강의를 듣고 추가 블로그 참고1, 참고2를 읽으면서 정리차 작성한 글입니다. 
강의를 듣게 된 경위: 이용해야 할 개념에서 Image와 같이 2D 개념에 적용한 예시가 없어서 이를 2D에 적용하면서 나타낼 수 있는 문제점들이나 인사이트들을 다시 확용하고자 image generation task에 연결하고자 듣게 되었다.

 

🎈 Conditional Variational AutoEncoder (CVAE)

  1. VAE + Condition: VAE 구조에서 Label 정보를 추가해서 더 높은 정확도를 제공하는 VAE를 제공한다.
  2. Model structure
    • VAE구조에 시작부분, Representation 부분에 정보를 넣는 구조를 말한다.

🎈 Conditional Variational AutoEncoder (CVAE) 코드 구현 Github Notebook

Conditional CVAE는 CVAE 구조에 CNN 구조를 덧입힌 구조를 말한다. 본 코드는 MNIST 이미지 데이터셋에 CVAE 구조를 적용한 예시를 보겠습니다.

우선 MNIST 데이터셋을 불러옵니다.

# prerequisites
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

bs = 100
# MNIST Dataset
train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=False)

# Data Loader (Input Pipeline)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=bs, shuffle=False)

CVAE 클래스를 다음과 같이 작성합니다.

  •  MNIST 의 28x28 픽셀이 하나하나 Linear 하게 들어가서 x_dim이 784인 점
  • Label이 시작구조에 추가되므로
    • self.fc1 = nn.Linear(x_dim + c_dim, h_dim)
    • encoder 함수에서 torch.cat([x,c],1)이 적용됩니다.
  • Label이 representation 구조에 추가되므로
    • self.fc4 = nn.Linear(z_dim + c_dim, h_dim2)
    • decoder 함수에서 torch.cat([z,c],1)이 적용됩니다.
class CVAE(nn.Module):
    def __init__(self, x_dim, h_dim1, h_dim2, z_dim, c_dim):
        super(CVAE, self).__init__()
        
        # encoder part
        self.fc1 = nn.Linear(x_dim + c_dim, h_dim1)
        self.fc2 = nn.Linear(h_dim1, h_dim2)
        self.fc31 = nn.Linear(h_dim2, z_dim)
        self.fc32 = nn.Linear(h_dim2, z_dim)
        # decoder part
        self.fc4 = nn.Linear(z_dim + c_dim, h_dim2)
        self.fc5 = nn.Linear(h_dim2, h_dim1)
        self.fc6 = nn.Linear(h_dim1, x_dim)
    
    def encoder(self, x, c):
        concat_input = torch.cat([x, c], 1)
        h = F.relu(self.fc1(concat_input))
        h = F.relu(self.fc2(h))
        return self.fc31(h), self.fc32(h)
    
    def sampling(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps.mul(std).add(mu) # return z sample
    
    def decoder(self, z, c):
        concat_input = torch.cat([z, c], 1)
        h = F.relu(self.fc4(concat_input))
        h = F.relu(self.fc5(h))
        return F.sigmoid(self.fc6(h))
    
    def forward(self, x, c):
        mu, log_var = self.encoder(x.view(-1, 784), c)
        z = self.sampling(mu, log_var)
        return self.decoder(z, c), mu, log_var

# build model
cond_dim = train_loader.dataset.train_labels.unique().size(0)
cvae = CVAE(x_dim=784, h_dim1=512, h_dim2=256, z_dim=2, c_dim=cond_dim)
#if torch.cuda.is_available():
#    cvae.cuda()

 

다음과 같이 모델이 적용디 되면서 다음과 같이 CVAE 구조는 적용됩니다.

CVAE(
  (fc1): Linear(in_features=794, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc31): Linear(in_features=256, out_features=2, bias=True)
  (fc32): Linear(in_features=256, out_features=2, bias=True)
  (fc4): Linear(in_features=12, out_features=256, bias=True)
  (fc5): Linear(in_features=256, out_features=512, bias=True)
  (fc6): Linear(in_features=512, out_features=784, bias=True)
)