Improved Techniques for Training GANs

1. Introduction

Generative adversarial networks(GANs)은 game 이론에 근거한 생성 모델이다.

(게임이론의 핵심이라 할 수 있는 내쉬 균형은 각자가 상대방 대응에 따라, 최선의 선택을 하고 자신의 선택을 바꾸지 않는 균형 상태를 말한다.

factorio thumbnail

출처: https://itwiki.kr/w/GAN

하지만, GANs 은 game 이론의 내쉬 균형을 찾기 보다는 비용함수의 low value를 착기 위해 gradient descent 방식을 사용한다.

즉, 이는 non-convex 비용함수와 고차원의 파라미터 공간에서 수렴을 보장하지 못한다.

본 논문에서는 GAN이 수렴이기 위한 5가지의 기술을 소개한다.

2. Toward Convergent GAN Training

5가지 기술은 다음과 같다.

  • Feature Matching
  • Minibatch discrimination
  • Historical averaging
  • One-sided label smoothing
  • Virtual batch normalization

2.1 Feature Matching

기존 GAN의 generator의 목적함수를 아래의 것으로 사용한다.

factorio thumbnail

여기서 f(x)는 discriminator의 중간 층 activations 이다.

이것으로 Generator에 새로운 목적함수 지정하여 오버트레이닝을 방지하고, GAN의 불안정성(insatbility)을 해결할 수 있다.

식을 좀더 이해해보면 Discriminator 중간층의 output이 생성에 필요한 하나의 특징(feature)이며, 이것이 random sampling된 z에 대해 분포가 비슷한지(matching) 살펴보는 것이다. 즉, Generator에서 생성한 분포가 실제 데이터의 분포를 matching 시키려고 한다. 단순하게 진짜/가짜를 나누는 방식이 아닌, 진짜와 같은 feature를 가지고 있도록 훈련을 진행하는 것이다. 저자는 G가 목표하는 통계치에 도달하는지는 확신할 수 없지만, 경험적으로 불안정한 GAN에 대해 효과적이라고 이야기 하고 있다.

2.2 Minibatch discriminator

다음으로는 minibatch discriminator이다.

여기서 minibatch는 다음과 같은 연산을 추가로 수행한다.

factorio thumbnail

GAN이 실패하는 경우 중 하나는 Generator가 동일한(유사한) 출력을 하게 parameter가 세팅되는 경우다. Generator는 Discriminator를 속이기만 하면 되기 때문에 이런 일이 발생할 수 있다. 논문에서는 이 문제는 Discriminator가 각 example을 개별로 처리하기 때문에 출력간의 관계를 고려하지 않기 때문이라고 이야기하고 있다. 그래서 배치(batch) 안에서 다른 데이터간의 관계를 고려하도록 설계하는 방법이다.

개별 샘플이 minibatch내의 다른 샘플들과의 유사도(L1 norm)를 계산하여 합치고, 이를 판별에서 추가적인 정보(side information)로 사용하게 됩니다.

Minibatch discriminator 적용 전/후의 이미지다

factorio thumbnail

factorio thumbnail

출처 : https://wikidocs.net/152501

기존에 discriminator와는 다르게 minibatch discriminator에서는 추가 정보( o(x) )를 사용하는 것을 볼 수 있다.

2.3. Historical averaging

이 기술은 dicrimnator와 generator의 손실함수 모두에 다음을 추가한다. θ 는 모델 파라미터를 나타내고, θ[i]는 i번째 학습 과정에서의 파라미터를 의미한다.

factorio thumbnail

기존 loss함수에 과거 Parameter의 Average값을 패널티 텀으로 추가하면 수렴할때 나타나는 파라미터의 급격한 변동을 막아 학습이 거듭될수록 수렴할 수 있다고 한다. 아래 그림(왼쪽)과 같이 Loss가 수렴되지 않을 경우, 적용하면 아래(오른쪽) 그림과 같이, 수렴되는 효과를 볼 수 있다고 한다.

factorio thumbnail

출처 : https://wikidocs.net/152501

2.4. One-sided label smoothing

factorio thumbnail

이 기술은 Label smoothing은 0과 1 타겟 대신 0.9, 0.1 등의 smoothed value로 classifier을 훈련하는 방법이다. 신경망은 정수보다는 실수를 학습하는데 유리하다고 하며, 구체적으로는 positive target을 α, negative target을 β 로 두는데, 여기서 negative data가 더 좋은 방향으로 생성되는 것을 위해 β는 0으로 둔다. (One-sided)

2.5. Virtual batch normalization

Batch normalization은 신경망 최적화에 큰 도움을 주었지만, 아래와 같이 Strong Intra batch correlation이 발생할 수 있다. 즉, 같은 batch내의 다른 입력값에 영향을 많이 받아 유사한 샘플을 만들어낸다. 이를 방지하기 위해 Virtual Batch Normalization (VBN) 을 제안했다.

Mini batch의 다른 값들의 영향을 많이 받는 것을 방지하기 위해 고정된 배치(reference batch)를 이용한다. reference batch는 학습 초기에 한번 선별되어 학습이 진행되는 동안 변하지 않고, 이 값을 이용하여 normalize를 수행한다. 하지만 2개의 minibatch를 연산량이 많기 때문에 generator에서만 적용된다.

3. Assessment of image quality

기존의 생성모델으 성능을 평가하는데 있어 주로 pit/dim이나 NLL이 사용되었다. 하지만 앞선 지표들의 값이 우수하다고 하더라도 실제 생성되는 이미지의 질은 좋지 않은 경우도 매우 많았다. 이러한 한계를 극복하기 위해 본 논문에서는 생성된 이미지가 실제 같은지에 대해 인간의 판단이 필요했고 Amazon Mechanical Turk와 같은 Crowd Sourcing을 통해 생성 이미지를 평가했다. 하지만 사람마다 기준이 달라 객관적인 기준이 없다는 점과 판단 결과에 대해 피드백을 주었을 떄 그 피드백에 의해 평가가 급격하게 변경되는 경우가 많이 있었다. 따라서 본 논문에서는 모든 생성모델에 객관적으로 사용될 수 있는 방법인 Inception Score를 제안한다.

3.1 Inception Model

Inception Model은 https://arxiv.org/pdf/1512.00567.pdf 에서 제안된 모델로 ImageNet 챌린지에서 우수한 성적을 거둔 모델이다. Inception model은 ImageNet 데이터에 pre-trained되었고 이미지 샘플 $(x)$ , label $(y)$ 가 있을 때 $p(y \vert x)$ 를 계산하기 위해 사용된다.

Inception Model Inception Model

3.2 Inception Score

Inception Score는 아래 식과 같이 구할 수 있다.

\[Inception Score(IS) = exp⁡(𝔼_𝑥 (𝐾𝐿(𝑝(𝑦│𝑥) | | 𝑝(𝑦))))\]

IS를 활용하여 모델의 2가지 성능을 평가할 수 있는데 첫번째는 생성된 이미지의 품질이고 두번째는 샘플 생성의 다양성이다. $p(y \vert x)$를 최대한 작은 entropy(아래 왼쪽 그림)를 가지도록 학습하여 이미지의 label을 정확히 판단할 수 있도록 하여 생성되는 이미지 샘플들의 성능을 높이게 된다. 그리고 $p(y)$ 가 최대한 큰 entropy(아래 오른쪽 그림)를 가지도록 모델을 학습하여 다양한 label의 이미지 샘플을 생성할 수 있게 한다. 즉, $p(y \vert x)$와 $p(y)$의 두 값의 entropy 차이가 커질수록 IS 값이 커지게 되어 잘 생성된 이미지라 할 수 있다. IS는 인간의 성능 평가만큼 정확한 metric인데 이에 대해서는 Experiment에서 자세히 다루겠다.

Entropy image (출처 : https://gaussian37.github.io/ml-concept-basic_information_theory/)

4. Semi-supervised learning

대부분의 classifier들은 데이터 $x$에 대해 K개의 class가 있을 때 input $x$를 어느 모델에 통과시켜 나온 K 차원의 벡터에 SOFTMAX를 취해 확률값으로 변환시킨 후 가장 높은 확률을 가지는 class로 분류하는 방식이다. 그리고 suprvised learning에서 분류모델들은 label과 모델의 예측값의 cross-entropy 값을 최소화하면서 학습하게 된다.

본 논문에서는 기존의 classifier의 학습에 기반하여 semi-supervised learning을 제안한다. 데이터셋에 “generated” label을 K+1번째 class로 추가해서 classifier를 학습합니다. $p_{model}(y = K + 1 \vert x)$의 값은 $x$가 fake일 확률을 의미하게 되고 이는 vanillaGAN에서의 $1 - D(x)$와 같게 된다. 그리고 이와 같은 사실들을 이용해 GAN의 objective function을 supervised learning과 같은 형태로 볼 수 있게 된다. 본 논문에서는 아래 두 loss (supervised loss, unsupervised loss)를 이용해 semi-supervised learning을 수행한다.

\[L = L_{supervised} + L_{unsupervised}\]

생성 이미지 $(p_{model})$ 와 관련이 없이 데이터셋에서 real 데이터들에 대한 classifier 학습을 진행한다.

\[L_{supervised} = -E_{x, y \sim p_{data}(x, y)}log(p_{model}(y \vert x, y < K + 1))//\]

$D(x) = 1 - p_{model}(y = K + 1 \vert x)$ 라 정의하면 일반적인 GAN의 loss와 동일한 것을 확인할 수 있다.

\[L_{unsupervised} = -[E_{x \sim p_{data}(x)}log[1 - p_{model}(y = K + 1 \vert x)] + E_{x \sim G}log[p_{model}(y = K+1 \vert x)]]\]

4.1 Importance of labels for image quality

위에서 제안한 semi-supervised learning은 SOTA를 달서했을 뿐만 아니라 생성되는 이미지 샘플의 질에 대한 인간의 평가도 향상시켰다. 그 이유로는 인간의 시각은 물체를 분류할 때 특징이 있는 일부분에 강하게 반응을 하는데 이와 같은 특성을 discriminator가 이미지 class classification task를 수행함으로써 구현했기 때문이다. 그리고 discriminator가 학습한 특징들을 활용해 generator가 좋은 이미지를 생성할 수 있게 된다.

4.2 Code Implementation

아래는 Inception Score를 구하는 코드로 생성된 이미지의 품질에 대해 테스트해볼 수 있다. get_inception_score 함수의 첫번째 인자로 이미지 리스트를 전달하면 Inception score를 반환한다.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os.path
import sys
import tarfile

import numpy as np
from six.moves import urllib
import tensorflow as tf
import glob
import scipy.misc
import math
import sys

MODEL_DIR = 'your_model_dir' # inception_v3
DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
softmax = None

def get_inception_score(images, splits=10):
  assert(type(images) == list)
  assert(type(images[0]) == np.ndarray)
  assert(len(images[0].shape) == 3)
  assert(np.max(images[0]) > 10)
  assert(np.min(images[0]) >= 0.0)
  inps = []
  for img in images:
    img = img.astype(np.float32)
    inps.append(np.expand_dims(img, 0))
  bs = 1
  with tf.Session() as sess:
    preds = []
    n_batches = int(math.ceil(float(len(inps)) / float(bs)))
    for i in range(n_batches):
        sys.stdout.write(".")
        sys.stdout.flush()
        inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
        inp = np.concatenate(inp, 0)
        pred = sess.run(softmax, {'ExpandDims:0': inp})
        preds.append(pred)
    preds = np.concatenate(preds, 0)
    scores = []
    for i in range(splits):
      part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
      kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
      kl = np.mean(np.sum(kl, 1))
      scores.append(np.exp(kl))
    return np.mean(scores), np.std(scores)

6. Experiments

Semi-supervised learning을 MNIST, CIFAR-10, SVHN 3개의 데이터셋에 대해 학습했고 3개의 데이터셋에 ImageNet 데이터셋을 추가하여 샘플 generation을 수행했다.

6.1 MNIST

아래 그림은 본 논문에서 제안한 feature matching과 minibatch discrimination 방버을 각각 적용하여 MNIST 데이터셋으로 모델을 학습하고 이미지 샘플을 생성한 결과이다. 생성된 이미지 샘플들을 볼 떄 feature matching보다 minibatch discrimination 방법을 적용한 경우 이미지 샘플들의 질이 더 좋은 것을 확인할 수 있었다. 실제로 MTurk에서 2000명의 annotator들을 대상으로 fake detection 실험을 진행했을 때 단 52.4%의 샘플들만을 구별할 수 있었다고 한다.

factorio thumbnail

하지만 classification task에서는 minibatch discrimination이 아닌 feature matching 방법이 더 좋은 성능을 보여줬다.

factorio thumbnail

6.2 CIFAR-10

MNIST와 마찬가지로 실험을 진행했다. 간단한 숫자 데이터인 MNIST와 다르게 CIFAR-10 데이터셋은 실제 물체들의 이미지이기 떄문에 MNIST에서 보다 human annotator의 fake detection 정확도는 78.7%로 높지만 여기서 주목할 점은 IS의 metric으로서의 타당성이다. 모든 fake 이미지가 아닌 IS가 상위 1%인 생성 샘플들에 대해서 fake detection을 진행했을 때는 약 7%p 떨어진 71.4%의 정확도를 보였다. 즉, IS가 높은 이미지들은 IS가 낮은 이미지들 보다 더 실제 같은 이미지라는 것을 의미한다.

factorio thumbnail

factorio thumbnail

추가로 본 논문에서 제안된 방법들이 실제로 IS를 높일 수 있다는 것을 보이기 위해 ablation study를 진행했다.

factorio thumbnail

  • VBN : Virtual Batch Normalization
  • BN : Batch Normalization
  • -L : 학습 과정에서 label 제거한 경우
  • +HA : historical averaging 추가
  • -LS : label smoothing 제거
  • -MBF : minibatch features

6.3 SVHN

CIFAR-10과 동일하게 실험을 진행했고 결과는 다음과 같다.

factorio thumbnail

6.4 ImageNet

ImageNet 데이터셋은 128x128의 고해상도이면서 1,000개의 카테고리를 가진 큰 데이터셋이다. DCGAN과 같은 일반적인 GAN 모델들은 ImageNet처럼 카테고리가 매우 많은 데이터셋에 대해 분포의 entropy를 훨씬 낮게 평가하기 떄문에 문제가 있었다. 하지만 본 논문에서 제안한 method들을 활용한 경우 아래 오른쪽 이미지와 같이 object들의 특징들을(털, 눈, 코 등) 잘 학습하는 것을 볼 수 있다.

factorio thumbnail

7. Conclusion

기존의 GAN 관련 연구에서 안정적인 학습의 어려움과 객관적인 평가지표가 부족하다는 매우 큰 단점들이 있었다. 하지만 본 논문에서 제안한 학습 방법들로 GAN 모델의 학습을 안정화시켜 기존에는 학습이 불가능했던 모델들을 학습시켰으며 Inception Score를 정의하여 모델간의 객관적인 비교를 가능하게 했다. 또한 추후에 Fréchet Inception Distance (FID) 와 같이 더 일관된 평가 지표가 나올 수 있는 기반이 된 점이 중요한 contribution이다.

8. Opinion

본 논문에 대한 몇가지 장단점은 다음과 같다.

8.1 Advantage

1) 현재 이미지 생성 평가지표로 많이 활용되는 IS score를 정의 2) Object의 feature를 정확하게 capture하여 이미지로 생성

8.2 Disdvantage

1) 이론적 배경 없이 실험을 통한 휴리스틱한 접근 2) Feature matching과 minibatch의 성능차이에 대한 근거 부재 3) Object의 feature는 잘 capture하지만 feature들 간의 관계를 이미지 생성에 반영하지 못함