네이버 부스트캠프 🔗/⭐주간 학습 정리

[네이버 부스트 캠프 AI Tech]Conditional Generative Model

Dobby98 2023. 4. 3. 15:53

본 글은 네이버 부스트 캠프 AI Tech 기간동안

개인적으로 배운 내용들을 주단위로 정리한 글입니다

 

본 글의 내용은 새롭게 알게 된 내용을 중심으로 정리하였고

복습 중요도를 선정해서 정리하였습니다

 

✅ Week 5

  1. Conditional generative model
  2. Image translation GANs
    1. Pix2Pix
    2. Cycle GAN
    3. Perceptual loss

1. Conditional generative model

우리가 이전 시간에 배웠던 Gnerative Model을 한번 떠올려보자

간단히 Gnerative model의 output 또는 predict를 표현하자면

random sample 된 값이라고 할 수 있다

 

하지만 오늘 살펴볼 Conditional generative model은 조금 다르다

바로 '조건'이 주어지기 때문이다

 

쉽게 말해서 일반적인 Genrative model에게 가방을 그려달라고 하면 랜덤적인 가방을 그려준다

하지만  Conditional generative model의 경우 우리가 원하는 특정한 조건을 주어서 -예를 들면 색상이나 모양등

우리가 원하는 가방을 생성하는 모델이라고 할 수 있다

 

즉, 생성에 조건을 달아주는 것이다

 

이러한 방법은 다양한 task에서 활용된다

 

 

특히, GPT와 같이 우리가 원하는 제목을 주고 

이 제목에 대한 목차나 짧은 글을 쓰라고 하는 것도  Conditional generative model중 하나이다

 

간단하게 두 모델의 차이점을 비교하면 아래 사진과 같다

즉, 기존의 GAN에서 C라는 특정한 조건을 부여하는 것이다

 

이러한 방법이 대표적으로 활용되는 곳에는  Super resolution이있다

 


Super Resolution

간단히 말해서 이 task는 이미지의 해상도를 개선하는 것이다

즉, Upsampling과 비슷한 개념이다

 

물론 이러한 방법에는 GAN방식이 아니더라도 다른 방식이 활용될 수 있다

 

대표적인 방법이 기존 회귀 모델인데

이 방법은 단점이 명확하다

 

바로 결과값이 blur하게 나온다는 것이다

왜냐하면 이러한 방법은 평균값을 활용하기 때문에  - 평균값이 가장  loss가 적기 때문에

좋은 결과 값을 만들어 내지 못한다

 

따라서 여기에 GAN 모델, 즉, GAN loss를 사용하는 모델을 활용하면

다음과 같이 더 좋은 성능을 낼 수 있다

만약 출력이 1, 2 라면 위의 회귀 모델의 경우 L1 loss를 사용하기 때문에 애매한 결과물이 나오지만

GAN의 경우 1, 2 에 가까운 결과를 만들 수 있기 때문이다

 

실제로 output을 비교하면 다음과 같다

SRResNet- > 기존 회귀

SRGAN -> GAN 모델

 


2.Image Translation GAN

이러한 GAN모델은 다양한 Image Translation 영역에서 많이 활용되고 있다

지금부터 대표적인 모델들을 한번 살펴보자

 

2.1 Pix2Pix

pxi2pix는 앞에서 살펴본 L1 loss와 GAN loss를 모두 활용한다

실제로 GAN의 경우 학습을 시키기기 매우 힘들다

 

따라서 L1 loss를 추가적으로 사용해서 학습을 좀더 수월하게 만들어 주는 것이다

하지만 이러한 pix2pix는 지도학습이다

따라서 조건에 대한 결과값이 무조건 적으로 필요하다

결국 데이터셋이 pair해야하는 것이다

 

이러한 한계를 어느 정도 극복한 모델이 바로 Cycle GAN이다


2.2 CycleGAN

Clycle GAN은 Unpaired 한 Data에서도 학습이 가능하다

이름에서 알 수 있듯이 X -> Y -> X의 과정을 통해서 회전하면서 학습한다

 

이를 위해서 Clycle GAN의 loss는 다음과 같은 loss를 활용한다

GAN loss + Cycle-consistency loss

쉽게 말해서 GAN loss와 Cycle - consistency loss를 합친것으로 

GAN loss의 경우 X -> Y로 갈 때 만들어진 Dx가 실제 Y와 같은지를 체크하고

Y-> X로 갈 때 만들어진 Dy도 체크한다

 

하지만 이러한 방법은 Generator가 하나의 결과에만 최적화가 되기 때문에 여러 input이 하나의 결과로 수렴하게 된다

 Mode Collapse가 발생하게 되는 것이다

 

이를 해결하기 위해서 Cycle-consistency loss을 추가적으로 활용한다

X -> Y로 갔을 때 만들어지는 이미지를 다시 X로 되돌려 얼마나 X가 보존되는지를 체크하고

Y -> X로 갔을 때 만들어지는 이미지도 같은 방법으로 체크한다

 

즉, 이런방법을 활용해서 원본의 내용을 유지하는 것이다

 

물론 Cycle GAN은 지도 학습이 불가능할 때 많이 활용되는 방법이다

 

만약 지도학습이 충분이 가능할 만큼 많은 데이터의 양을 보유하고 있다면 

Cycle GAN 보다는 GAN 모델을 많이 사용한다

 


2.3 Perceptual Loss

하지만 이러한 GAN 모델은 매우매우 학습이 어렵다는 치명적인 단점이 존재한다

Generator 또는 Discriminator 둘중 하나의 모델만 망가진다면

전체 모델의 학습이 어려워 지기 때문이다

 

이를해결하기 위해서 preceptual loss를 활용해서  모델을 학습시키기도한다

하지만 이를 위해서는

pretrained된 모델이 필요하다

 

하지만 기존 network처럼 forward, backward를 활용하기 때문에 학습이 쉽다

 

모델의 학습과정은 다음과 같다

 

우선 backbone 모델을 이용해서 이미지를 생성하고 - 이 backbone은 backward시 가중치 업데이트 ok

이렇게 만들어진 Y'을 pretrained된 backbone - 이 backbone은 가중치 업데이트 no 모델에 넣는다

이때 Y' 뿐만아니라 Y' style target, Y' content target도 병렬적으로 넣어서

Feature Reconstruction Loss와 Style Reconstruction Loss를 구해준다

 

 

Feature Reconstruction Loss

 Y' 와 Y' content가 pre trained network를 지나서 나온 값들을 활용해서 L2 loss를 구한것이

Feature Reconstruction loss이다

 

이 loss는 원본 이미지의 모양을 유지하게 해준다

 

 

Style Reconstruction Loss

Style Reconstruction Loss의 경우에는 Y'과 Y' Style target에   Gram matrices를 적용하는데

이 Gram matrices의 경우 이미지 전체의 통계적 특성을 나타낸다

결국 적용된 Y'과 Y' Style target을 활용하여서 Style Reconstruction Loss를 만든다

 

Style Reconstruction Loss 의 경우 원하는 스타일에 얼마나 맞는지를 계산해준다

 

 

코드 참고 자료 :

 

Perceptual loss

카메라 이미지 품질 향상 AI 경진대회

dacon.io