네이버 부스트캠프 🔗/⭐주간 학습 정리

[네이버 부스트 캠프 AI Tech] VAE 직접 구현하기 by Pytorch

Dobby98 2023. 3. 24. 17:50

본 글은 네이버 부스트 캠프 AI Tech 기간동안

개인적으로 배운 내용들을 주단위로 정리한 글입니다

 

본 글의 내용은 새롭게 알게 된 내용을 중심으로 정리하였고

복습 중요도를 선정해서 정리하였습니다

 

이번주에 정해진 진도를 이미 목요일날 다 나갔기 때문에

오늘은 특별히 이번주 내용중 하나였던 VAE - Variational Auto-Encoder를 pytroch를 활용해서 구현하는 

글을 작성했습니다


✅ Week 3

목차

  1. 간략한 VAE 소개
  2. pytorch를 활용해 VAE 구현하기
    1. model
    2. train

✅ 1. 간략한 VAE 소개하기

이번주 Generative Model 시간에 간단하게 VAE, Variational Auto-Encoder를 알아보았다

 

 

[네이버 부스트 캠프 AI Tech] Generative Model

본 글은 네이버 부스트 캠프 AI Tech 기간동안 개인적으로 배운 내용들을 주단위로 정리한 글입니다 본 글의 내용은 새롭게 알게 된 내용을 중심으로 정리하였고 복습 중요도를 선정해서 정리하

eumgill98.tistory.com

전체의 흐름으로 설명을 하면

Input Image X의 특징을 Encoder로 추출하여 와 시그마을 추출하여서 Latent Vector Z에 담고 

이 Latent Vector Z를 통해 Decoder에서 비슷한 분포를 가진 새로운 Image를 생성하는 흐름으로 흘러간다

여기서 뮤와 시그마는 평균과 분산값이다

 

수학적인 부분은 논문리뷰나 개념을 다루는 글에서 자세히 다루어 보겠다

참고 :

 

(2) VAE (Variational AutoEncoder)

참고. VAE는 이해하기가 수학적으로 조금 까다롭습니다. (정말 어렵습니다 ㅠ) 수학적인 내용을 최소화하면서 설명하나, 기본적인 통계지식은 조금 필요합니다. (여기서는 KL Divergence를 사용하여

itrepo.tistory.com

🔥VAE의 구조

이제 위의 과정을 코드로 구현해보자


✅ 2. pytorch를 활용해 VAE 구현하기

전체 코드는 아래 Github 주소를 참고해주세요!!

 

GitHub - Eumgill98/DL: 딥러닝 모델 직접구현

딥러닝 모델 직접구현. Contribute to Eumgill98/DL development by creating an account on GitHub.

github.com

 


⭐Model

model.py 전체코드

import torch
from torch import nn
import torch.nn.functional as F

#과정 : Input -> Hidden dim -> mean, std -> Parametrization trick -> Decoder -> Output
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        
        #encdoer
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        #decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)


        self.relu = nn.ReLU()
    
    def encoder(self, x):
        #q_phi(z|x)
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)

        return mu, sigma

    def decoder(self, z):
        h = self.relu(self.z_2hid(z))

        return torch.sigmoid(self.hid_2img(h)) # mnist date -> binary
    

    def forward(self, x):
        mu, sigma = self.encoder(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma*epsilon
        x_reconstructed = self.decoder(z_reparametrized)

        return x_reconstructed, mu, sigma

 

 

나눠서 살펴보자

 

class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, h_dim=200, z_dim=20):
        super().__init__()
        
        #encdoer
        self.img_2hid = nn.Linear(input_dim, h_dim)
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        #decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)


        self.relu = nn.ReLU()

우선 VariationalAutoEncoder 라는 클래스를 선언해준다

그리고 init 부분에 우리가 사용할 변수들을 선언해 준다

 

encoder 부분

img_2hid : Encoder 첫 번째 hidden layer

hid_2mu : img_2hid를 input으로 받아서 mu를 output하는 레이어

hid_2sigma : img_2hid를 input으로 받아서 sigma를 output하는 레이어

 

decoder 부분

z_2hid : z 벡터에서 input을 받는 decoder  첫 번째 hidden layer

hid_2img : hidden layer의 값을 이미지로 출력하는 layer

 

relu : 사용할 relu 함수


 

Encoder part

    def encoder(self, x):
        #q_phi(z|x)
        h = self.relu(self.img_2hid(x))
        mu, sigma = self.hid_2mu(h), self.hid_2sigma(h)

        return mu, sigma

encoder 부분은 input X값을 먼저 img_2hid에 입력해주고 이를 통과한 output에 relu를 덮어서 h를 지정해준다

그리고 hid_2mu()와 hid_2sigma()에 h를 input으로 넣어서 mu와 sigma를 생성하고 이를 리턴하게 한다


 

Decoder part

    def decoder(self, z):
        h = self.relu(self.z_2hid(z))

        return torch.sigmoid(self.hid_2img(h)) # mnist date -> binary

decoder 부분은 더 간단하다

z 벡터를 input으로 받아서 z_2hid레이어와 relu를 통과해서 h를 생성하고 

이렇게 생성된 h를 hid_2img를 지나서 -

MINST의 경우 2진 분류기이기 때문에 sigmoid 함수를 지나게 해서 이미지를 생성해준다


 

Forward part

    def forward(self, x):
        mu, sigma = self.encoder(x)
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma*epsilon
        x_reconstructed = self.decoder(z_reparametrized)

        return x_reconstructed, mu, sigma

model의 전체 forward 흐름이다 

먼저 encoder에 x를 투입해서 mu, sigma를 리턴으로 받고 

epsilon을 sigma 사이즈 만큼 랜덤적으로 생성해준다

 

그리고 z_reparametrizer 를 구성해주는데 방법은 mu와 epsilon을 곱한 sigma를 더해주어서 z vector의 값을 만들어준다

그리고 이렇게 만든 z를 decoder에 input으로 넣어서 이미지를 x_reconstructed라는 변수 output으로 받아서

x_reconstructed, mu, sigma를 리턴해준다

 

⭐Train

train.py 전체코드

import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

#configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20

NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4 #Karpathy constant

#Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

#Model
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr = LR_RATE)
loss_fn = nn.BCELoss(reduction='sum') 


#train
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) in loop:
        #Forward pass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)

        #compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        ki_div = -torch.sum(1 + torch.log(sigma.pow(2)) -mu.pow(2) -sigma.pow(2))

        #backprop
        loss = reconstruction_loss + ki_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())



#inference 
model = model.to("cpu")
def inference(digit, num_examples=1):

    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1

        if idx == 10:
            break

    encoding_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encoder(images[d].view(1, 784))
        encoding_digit.append((mu, sigma))

    mu, sigma = encoding_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decoder(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f'generated_{digit}_ex{example}.png')


for idx in range(10):
    inference(idx, num_examples=1)

train.py는 크게 2개의 부분으로 구성되어있다

하나는 train 을 담당하는 부분이고

다른하나는 inference를 담당하는 부분이다

 

import torch
import torchvision.datasets as datasets
from tqdm import tqdm
from torch import nn, optim
from model import VariationalAutoEncoder
from torchvision import transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader

#configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_DIM = 784
H_DIM = 200
Z_DIM = 20

NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4 #Karpathy constant

#Dataset Loading
dataset = datasets.MNIST(root="dataset/", train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

#Model
model = VariationalAutoEncoder(INPUT_DIM, H_DIM, Z_DIM).to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr = LR_RATE)
loss_fn = nn.BCELoss(reduction='sum')

이 부분은 일반적인 모델 학습과 비슷한 부분이기 때문에 생략하도록 하겠습니다

 

 

train

#train
for epoch in range(NUM_EPOCHS):
    loop = tqdm(enumerate(train_loader))
    for i, (x, _) in loop:
        #Forward pass
        x = x.to(DEVICE).view(x.shape[0], INPUT_DIM)
        x_reconstructed, mu, sigma = model(x)

        #compute loss
        reconstruction_loss = loss_fn(x_reconstructed, x)
        ki_div = -torch.sum(1 + torch.log(sigma.pow(2)) -mu.pow(2) -sigma.pow(2))

        #backprop
        loss = reconstruction_loss + ki_div
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        loop.set_postfix(loss=loss.item())

설정해놓은 epoch만큼 loop가 돌아가고

tqdm의 경우 현재 돌아가고 있는 loop를 출력하는 코드이다

 

그리고 loop안에서 받는 x값만을 받아서 input의 형태로 변환해주고

이를 model에 넣어서 x_reconstructed, mu, sigma를 리턴으로 받아준다

 

그리고 loss를 계산하는 부분은 앞에서 지정한 loss_fn = Binary Entropy에 x,reconstructed - x'과 x를 입력해주어서 loss를 구하고

앞에서 리턴한 mu와 sigma를 활용해서 ki_div를 구해주는데

이부분은 논문을 참고하길 바란다

간단하게 설명하자면 backward를 가능하게 하기위해서 해주는 trick이다

 

그리고 backward 부분으로 

계산한 loss와 ki_div를 합처서 loss를 계산해주고 

이를 통해서 backward를 실행한다 

그리고 가중치를 업데이트 해주고

loop에 현재 loop의 loss를 출력한다

 

inference

model = model.to("cpu")
def inference(digit, num_examples=1):

    images = []
    idx = 0
    for x, y in dataset:
        if y == idx:
            images.append(x)
            idx += 1

        if idx == 10:
            break

    encoding_digit = []
    for d in range(10):
        with torch.no_grad():
            mu, sigma = model.encoder(images[d].view(1, 784))
        encoding_digit.append((mu, sigma))

    mu, sigma = encoding_digit[digit]
    for example in range(num_examples):
        epsilon = torch.randn_like(sigma)
        z = mu + sigma * epsilon
        out = model.decoder(z)
        out = out.view(-1, 1, 28, 28)
        save_image(out, f'generated_{digit}_ex{example}.png')


for idx in range(10):
    inference(idx, num_examples=1)

이제 inference부분이다

 

우선 가지고 있는 image 중 0~9에 해당되는 이미지를 한장씩 랜덤적으로 추출한다 - 밑에 for 문 인풋이 0~9

그리고 이를 model의 encoder에 넣어서 mu와 sigma를 구하고

 

이렇게 구한 mu와 sigma를 활용해서 z를 구하여 이를 통해서 decoder에 넣어서 새롭게 생성한 이미지를 저장한다

 


위의 코드를 활용하여 생성한 이미지는 아래와 같다

 

 

 

 

 


Referance

 

 

 

Paper summary: Variational autoencoders with PyTorch implementation

Variational autoencoders (VAEs) act as foundation building blocks in current state-of-the-art text-to-image generators such as DALL-E and…

sannaperzon.medium.com