본문 바로가기
Deep Learning/GAN

[GAN-①] GAN 기본 컨셉

by 룰루셩 2021. 11. 25.

ETRI AI 아카데미의 GAN 수업을 들을 기회가 생겼다.  

지난 4일간 ETRI의 조영주 연구원님의 GAN 수업을 수강했고, 잊어버리기 전에 배운 내용을 복습하며 정리해두고자 한다.

수업을 들을 때는 어느정도 이해가 되고 알겠다고 생각했었는데 다시 집에와서 공부하니 모르겠는 부분도 있다..ㅎㅎㅎ

수업에서 배운 이론 내용을 블로그에 정리해두고 STARGANv2 코드를 깃허브에 설명과 함께 적어두는 것이 목표이다.

복습하면서 이해가 잘 안가는 부분이나 개념들에 대해서는 구글링하며 정리해 둘 예정이다.


 

비지도학습 (Unsupervised Learning)

- GAN은 비지도학습 중 하나이다.

- 비지도학습: latent vector로 임베딩 하는 과정.. 

- latent vector로 바꿔서 이 vector끼리 차원에 뿌리면 그걸 SVM이든 뭐든 사용해서 그룹핑하는 것을 말한다. (군집화 기반 비지도학습)

  ※ latent vector: 차원이 줄어든 채로 데이터를 잘 설명할 수 있는 전체 공간에서의 벡터

- Unsupervised vs GAN

   Usupervised: image가 들어왔을때 latent code를 나오게 하는 것

   GAN: latent code가 들어왔을때 image를 만들어내는 것

 

 

GAN이 하고자하는 것은? 실제 데이터의 분포를 잘 근사하는 모델을 만드는 것!

- GAN은 원 데이터가 가지고 있는 확률 분포를 추정하도록 하고, 인공신경망이 그 분포를 만들어내는 것을 말한다.

- 즉, GAN이 하고자 하는 것은 실제 이 데이터의 분포를 잘 근사하는 모델을 만드는 것이다!!

 

GAN에서 다루고자하는 모든 데이터는 확률분포를 가지고 있는 랜덤변수(random variable)이다.

  ※ random variable: 측정할때마다 다른 변수가 나온다. (2차방정식의 x와 다른 개념, 여기서 x는 특정한 수임)

                                 다른 변수가 나오지만, 특정한 확률분포를 따르는 숫자를 생성한다.

                                 랜덤 변수에 대한 확률분포를 안다는 것은 데이터에 대한 전부를 안다는 것과 같은 이야기!!                                         확률분포를 알면 그 데이터의 예측 기대값, 데이터의 분산을 알 수 있어서 데이터의 통계적 특성을 바로 분석할 수 있다.

주어진 확률분포를 따르도록 데이터를 임의생성하면 그 데이터는 확률분포를 구할 때 사용한 원데이터와 유사한 값을 가진다.

→ GAN과 같은 비지도학습이 가능한 머신러닝 알고리즘으로 데이터에 대한 확률분포를 모델링할 수 있게 되면, 원데이터와 확률분포를 공유하는 무한히 많은 새로운 데이터를 생성할 수 있음을 의미한다.

 

G(생성자)와 D(판별자)가 경쟁적으로 학습 → 진짜/가짜 구별 못하도록 하는게 최종 목표 → G는 원데이터의 확률분포를 알아내려고 노력하여 학습이 종료된 이후에는 원데이터의 확률 분포를 따르는 새로운 데이터를 만들어낸다.

→ D는 더이상 분류해도 의미가 없는 0.5라는 확률값을 뱉어내게 된다. (진짜/가짜 맞출 확률 0.5이면 D가 의미가 없어진다.)

 

 

GAN 학습 구조

분류모델(D)을 먼저 학습시키고 생성모델(G)을 학습시키는 과정을 서로 주고받으면서 반복

생성 모델과 분류 모델의 학습 속도를 맞추기 위해 둘을 번갈아 가면서 최적화한다. 이론적으로는 분류모델(D)가 훈련데이터를 완전히 학습하고 난 뒤에 생성모델(G)을 훈련해도 될 것 같지만, 그러면 분류모델은 생성모델이 만들어낸 데이터를 진짜 데이터라고 속지 않기 때문에 생성모델은 계속해서 큰 손실을 만들어내고 학습되지 않는다. 따라서 분류모델과 생성 모델의 훈련은 균형된 속도로 훈련하되, 분류모델이 조금 앞서가도록 학습해야한다.

 

1. 분류모델 학습

  ① 진짜 데이터 입력 → 네트워크가 진짜로 분류하도록 학습

  ② 가짜 데이터 입력 → 네트워크가 가짜로 분류하도록 학습

  이 과정에서 분류 모델은 진짜를 진짜로, 가짜를 가짜로 분류할 수 있게 된다.

 

2. 생성모델 학습 (학습된 분류 모델을 속이는 방향으로)

  생성모델에서 만든 가짜 데이터를 판별 모델에 입력 → 가짜 데이터를 진짜로 분류할만큼 유사한 데이터 만들어내도록 생성모델 학습

 

경쟁적으로 발전시키는 구조를 이루고 있다.

 

 

BCE loss (Binary Cross Entropy) 또는 Adversarial loss

$x\sim p_{data}(x)$: 실제 데이터에 대한 확률분포에서 샘플링한 데이터

$z\sim p_{z}(z)$: 가우시안 분포를 사용하는 임의의 노이즈에서 샘플링한 데이터 (latent vector)

 

: real data x를 분류모델(D)에 넣었을 때 나오는 결과를 log 취했을 때 얻는 기댓값

 

 

: fake data z를 생성모델(G)에 넣었을때 나오는 결과를 분류모델(D)에 넣었을 때

 그 결과를 log(1-결과) 취했을때 얻는 기댓값

 

 

분류 모델(D) 입장에서 V(D, G)의 이상적인 결과

첫번째 항: D(X)가 진짜를 진짜라고 판별하여 1이어서 첫번째 항은 0이 되어 사라진다.(D(x)=0과 1 사이의 값이니까 log를 취했을때의  최대값이 0이 되는 것)

두번째 항: G(z)가 생성해내는 가짜 이미지를 구별해낼 수 있으므로 D(G(z))는 0이 된다. → $log(1-0)=log1=0$

전체식 V(D,G)는 0이 된다.

D의 입장에서 얻을 수 있는 이상적인 결과, 최댓값은 '0'이다. 

 

생성 모델(G) 입장에서 V(D, G)의 이상적인 결과

G가 D가 구별해내지 못할 정도 잘 생성하는 게 G 입장에서는 Best!

첫번째 항: D가 구별해내는 것에 대한 항, G의 성능에 의해 결정 X → 무시

두번째 항: G가 생성해낸 데이터 → D를 속일 수 있는 데이터라고 가정 → D(G(z)) = 1

              $log(1-1) = log(0) = - \infty$

G 입장에서 얻을 수 있는 이상적인 결과, 최솟값은 '$-\infty$' 이다.

 

그래서 저 맨 왼쪽에 있는 min max가 의미하는 것은 D는 최대가 되려고 하고 G는 최소가 되려고 한다는 의미이다!

 

- D는 training data의 샘플과 G의 샘플이 진짜인지, 가짜인지 올바른 라벨을 분류할 확률을 최대화하기 위해 학습!

   D는 진짜/가짜 판단하므로 Binary cross entropy loss를 사용한다.

- G는 log(1-(D(G(z)))를 최소화 (D(G(z))를 최대화)하기 위해 학습하는 것이 V(D, G)를 최소화 하는 것 

 

 

 

 

 

 

 

[참고]

[1] ETRI AI 아카데미 조영주 연구원님 E6001 GAN 강의

[2] https://www.samsungsds.com/kr/insights/Generative-adversarial-network-AI.html

[3] https://www.samsungsds.com/kr/insights/Generative-adversarial-network-AI-2.html

[4] https://m.blog.naver.com/PostView.naver?blogId=euleekwon&logNo=221558014002&targetKeyword=&targetRecommendationCode=1 

[5] Do it! 딥러닝 교과서 

'Deep Learning > GAN' 카테고리의 다른 글

[GAN-②] DCGAN  (0) 2021.12.16

댓글