[Generative Model] Variational AutoEncoder 3. Variational Inference
AI, Deep Learning Basics/Computer Vision

[Generative Model] Variational AutoEncoder 3. Variational Inference

 

 

이 글은 VAE 모델을 학습하는 데 있어 이미지를 생성하는 Decoder 쪽에서 일어나는 Variational Inference 부분을 수식을 통해 자세히 이해하고자 만든 글입니다. 참고자료는 블로그, 블로그2, 블로그3 입니다. 
-220319. z에 관한 설명, 수식 전체적으로 latex 처리

🔦 시작하기 전 단어 정리

  1. Variational Inference/Varational Bayesain Method: 변분 추론
  2. KL-Divergence: 두 확률분포의 차이를 나타내는 지표
  3. ELBO(Evidence LowerBOund): Loss function을 추론하며 나타나는, 함수가 학습되면서 학습할 방향을 지정한다.

 

🧨 VAE Concept Recap

image from Lil'log

  1. Probabilistic Encoder $p(z|x) \approx q_\phi (z|x) $
  2. Latent Variable z를 추출, Sampling된 Latent vector z를 제공한다. $$z \sim N(\mu, \sigma)$$ $$z = \mu(x) + \sigma(x) \times \epsilon, \epsilon \sim \mathcal{N}(0, I)$$
  3. Probabilistic Decoder(Goal: get optimal $\theta$!) $$p(x|z) = \mathcal{N}(x|f_\mu(z), f_\theta(z)^2 \times I)$$ $$ \approx p_\theta(x|z)$$

 

🧨 ELBO Loss function

🔥 과정 1. 

$D_{KL}(q_\phi (z|x) || p(z|x))$ 관점에서 계산을 해본다면 다음이 유도된다.

https://en.wikipedia.org/wiki/Variational_autoencoder

 즉 , 

$$log(p_\theta(x)) -D_{KL}(q_\phi (z|x) || p_\theta(z|x)))= E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z))) -D_{KL}(q_\phi (z|x) || p_\theta(z))$$

우리가 결국에 구하고자 하는 값은 MLE 관점에 따라  $\arg \max_{\theta} \log(p_\theta)$ 이므로 이를 바탕으로 전체 식을 정리해 본다면... 

$$log(p_\theta(x)) = D_{KL}(q_\phi (z|x) || p_\theta(z|x)) + E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z))) -D_{KL}(q_\phi (z|x) || p_\theta(z)) $$

🔥 과정 2. 

이 때 각각을 본다면,

  • 높을수록 좋은 1.$\log(p_\theta(x))$
  • 높을수록 좋은(높아야 하는 값) 2. 추론된 $z$로 부터 $x를 추론하는 과정 $ E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z)))$
  • 낮을수록 좋은 $ D_{KL}(q_\phi (z|x) || p_\theta(z) ) $
  • 구할 수 없는 값이지만 KL divergence의 특성에 의해 0보다 크거나 같다. $ D_{KL}(q_\phi (z|x) || p_\theta(z|x)) \geq 0$

결국에 

$$log(p_\theta(x)) \geq E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z))) -D_{KL}(q_\phi (z|x) || p_\theta(z)) $$ 이므로

$$\uparrow [ E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z))) -D_{KL}(q_\phi (z|x) || p_\theta(z)) ]$$

$$\downarrow [ -E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z))) +D_{KL}(q_\phi (z|x) || p_\theta(z)) ]$$

🔥 과정 3. 

결국에 Loss function은, 

$$L_{\theta, \phi} = -E_{z \sim q_\phi(z|x)} (log(p_\theta (x|z))) +D_{KL}(q_\phi (z|x) || p_\theta(z)) = -p + \textrm{KL  Divergence} $$

이를 바탕으로 하는 $$\theta^*, \Phi^* = argmin_{\theta, \Phi} L_{\theta, \Phi} $$

을 구하는 것이 VAE의 최종 목표이다.

🧨 Methodologies of Variational Inference

  1. Variational inference with MC Sampling(MCMC)
  2. Stochastic Variational Inference
  3. Variational EM algorithm

🧨 Loss function 최종 정리

  • Reconstuction loss: AE의 입력 X와 출력 X'의 차이에 대한 loss, 하지만 Decoder(또는 Generator)에서 학습시 사용하는 z를 (z_mean을 기준으로 z_var 만큼) 랜덤하게 흔들어 준다는 것이 VAE의 특징이다. 데이터 x 하나하나 별로 나오는 각 z 하나하나를 학습 할때마다 (매 에폭 마다) 랜덤 하게 흔들어 준다. 그래서 z 공간을 더 많이 학습된 공간으로 채워 준다. 결국 Decoder 입장에서 z 공간이 훨씬 더 믿을 만한 (유용한) 공간이 된다. Encoder 입장에서는 특정한 하나의 값 z를 결과로 내지 않고, 확률적인 값 z를 내는 것과 같이 된다. 
  • Regularization lossx 하나하나에 대한 z 하나하나를 확률적으로 보고 학습 하겠다는 것이다. z_mean과 z_var을 합쳐서 생각하는 것은 어렵지만, 각각을 따로 생각 해서 해석 하면 그 내용은 간단하다. 
    • 결론: AE와 다른 점은 z 하나하나를 "모으고", "흔들어서", z 공간을 더 촘촘히 채워줘서 더 유용한 z 공간을 만든다는 것이다. 즉, regularization loss은 결국 우리가 다루기 편한 prior와 인코딩 네트워크로 추정하는 posterior 랑 같게 만드는 역할입니다. 학습이 다 되고 decoder만 사용하여 generation 할 때 prior로 할 수 있는건 이 때문이죠. 이는 기존 오토인코더에서는 불가능한 부분입니다. 기존 오토인코더에서 학습되는 유의미한 latent space (z space)를 컨트롤할 수 없기 때문이다.
    • z_mean 부분: z_mean^2을 낮추라는 것이다. 즉, z_mean을 원점(0)으로부터 멀리하지 말라는 뜻으로, 각 데이터의 z_mean들을 원점쪽으로 모아 주는 역할을 한다. 첨부한 그림으로 보면 각 보라색 선들을 최대한 짧게 해 보라는 것이다. reconstruction_loss를 낮추려면 z_mean들을 퍼뜨리는게 유리한데, 이 regularizaion이 있으므로 마음껏 퍼뜨릴 수는 없게 되는 것이다. 그래서 (1.에서 얘기한 것과 같이) z 공간이 더 촘촘하게 채워지게 된다.
    • z_var 부분: z_var - log_z_var 그래프를 그려보면 바로 이해할 수 있다. 다음과 같은 그래프 툴을 통해 그려볼 수 있다. x-ln(x)를 그려보면 된다. 여기서는 log가 아니라 ln으로 입력해야 한다. https://www.desmos.com/calculator
    • z_var - log_z_var는 z_var가 1일때 최소화 된다. 즉, z_var를 1에 가깝게, 1보다 너무 크지도 작지도 않게 하라는 뜻이다. 첨부한 그림으로 보면, 초록색 선들을 최대한 1에 맞추라는 것이다. 즉, 초록색 원영역을 적당한 크기로 유지하라는 것이다. 이 부분을 통해서 앞의 1. reconstruction_loss에서 "랜덤하게 흔드는" 것을 어느 정도는 적당히 흔들어 주라는 뜻이 된다. 그래야 거기서 얘기한 "공간 채우기" 효과를 낼 수 있기 때문이다
  • 추가 정리 (나중에 해석 필요)
    • Methodologies of Variational Inference 
      [유도방식2] log(P(X)) = ELBO + KL(Q(Z|X),P(Z|X)) 이고, Q(Z|X)와
      P(Z|X)를 같게 만드는게 목적이었는데 (KL최소화), training sample X에 대해서 log(P(X))는 고정값이니 KL최소화하는 것과 ELBO를 최대화하는 것은 같은 작업. log(P(X)) >= ELBO
      -ELBO = reconstruction error + KL(Q(Z|X), P(Z))
       
      1. Variational inference with MC Sampling
      2. Stochastic Variational Inference
      3. Variational EM algorithm
    • [유도방식1] log(P(X)) >= ELBO (jensen's inequality 써서 바로 유도)