ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • Projected GANs Converge Faster
    논문리뷰 2022. 2. 25. 04:55

    Projected GANs Converge Faster (NerIPS 2021)

     

    Intro


    본 논문은, GAN의 D가 pretrained model의 deep layer들에 존재하는 feature를 온전히 활용하지 못한다는 단점을 해결코자 했다.

     

    해결책으로 해당 feature들을 채널별로, 해상도별로 mixing하는 작업을 했고 이미지 퀄리티, sample efficiency, 수렴속도면에서 이득을 볼 수 있었다.

     

    또한 22개의 벤치마크 데이타셋에서 FID SOTA를 달성했고, 고해상도 이미지도 잘만들고, 수렴속도도 엄청나게 감소시켰다.

     

     

    기존의 D는, 가지 역할을 했었다.

     

    첫째로, input으로 들어오는 real, fake image를 meaningful한 space로 projected해서 input의 representation에대해 배웠다.(conv해서 featuremap 만든거)

     

    그리고,  해당 representation을 토대로 discriminate를 진행했다.

     

    이 과정에서, 여러가지의 standard regularization(gradient penalty 등) 등이 사용되었지만, GAN학습을 안정화 시키는데에는 어려움이 있었다.

     

    최근 비전, clip과 같은 text에 대한 pretrained representation의 활용성이 높아짐에 따라, 본 논문은 GAN에 이러한 지식을 적용시켰다,

     

    나이브하게 성능 잘나오는 Vit pretrained 모델 쓰면 성능이 잘 안나오고(D가 너무 우세), G에서 vanishing gradient가 발생한다.

     

    본논문은 

     

    1. feature pyramid구조의 multiscale discriminator를 통해 다양한 scale에서 loss를 쟤고,

     

    2. pretrained 모델의 deep layer를 좀 더 잘 활용하고자 random projection을 활용했다.

     

     

    본 논문 이전에 multiple discriminator가 어떻게 사용되었냐면,

     

    RGB image를 resolution만 다르게 여러개 만든 다음 D에 집어넣었다.

     

    반면 본 논문은 feature map을 scale별로 다르게 discriminator를 주는 방법을 제안한다.

    (feature map은 scale만 다른게 아니라 얼마나 conv 통과했는지도 다를 것)

     

    2. Methods


    2-1 Projected GAN

     

    Feature projector (ResNet + random projection? or 걍 random projection?)은 freezed이고, G와 multiscale D를 학습시킨다.

     

    이 때 feature projector는 미분가능하고, 해당 이미지에대한 중요한 정보를 가지고있어야한다.

     

    (우리는 또한 random init과 같이 어려운 representation을 쉬운 representation으로 바꾸고자 하는 노력을 진행했다.)

     

    Consistency

     

    prokected GAN은 더 이상 이미지 레벨단에서 true dataset의 확률 분포를 optimize한게 아니라, true dataset의 feature space상의 distribution의 확률분포와 매칭되도록 optimize한다.

     

    optimal한 D 와 G는 각각 다음의 조건일 때 만족된다고 한다.

     

    2-2 Multi-Scale Discriminators

    본 논문의 D는, pretrained 모델에서 뽑은  4개의 scale에서 featuremap을 가져온다.(fixed efficientnet 활용)

     

     

    해당 scale에서 각 상응하는 multiscale discriminator가 conv(spectral norm 포함)을 거쳐 값을 산출한다.

     

    이 때 각 스케일 모두 최종 사이즈를 (4, 4) 가 될 때 까지 conv해주고 (4, 4)에서 logit을 뽑았을 때 결과물이 가장 좋았다고 한다.

     

    4개의 discriminator에서 각각 logit을 뽑아내고 더하여, loss로 활용한다.

     

    2-3 Random Projections

    (왜 소제목이 random projection인지 모르겠다.. random projection은 Johnson-Linderstrauss Lemma를 통해 벡터간의 거리를 보존하며 고차원 -> 저차원으로 바꾸는 개념이라고 함)

     

    Deep layer의 featuremap일수록, 해당 represenatation에 대한 정확한 이해는 네트워크 역시 어려워한다.

     

    저자들은 discriminator가 해당 deep layer featuremap(D 맨 끝단 semantic한 정보 담는 부분)에서, 일부분에만 집중한다고 가정했다. (그니까 대부분의 conv layer의 weight는 0이고 일부만 activated 되어있지 않을까 생각됨)

     

    그리고 그 가정이 맞다면, 해당 두 가지 방법을 사용하면 성능 향상을 이뤄낼 수 있지 않을까? 생각했다.

     

    두 가지 방법은 CCM, CSM인데, 신기한 점은 해당 두 방법의 파라미터들이 random init된 이후에, fixed 되고 not trained된다고 한다.

    kaiming init

     

    CCM

     

     

    경험적으로, 저자들은 지켜져야 할 두가지 특성을 찾았다고 한다.

     

    1. Efficient net을 통과한 featuremap을  CCM을 하고 나서도 information을 잘 preserving 할 수 있어야하고,

     

    2. 쉽게 invertable하면 안된다고 한다.(왜??)

     

    CCM은, Cross Channel Mixing 즉 채널간 믹싱이다.

     

    input, output channel수 같게 1by1 conv  때리는게 일반적이고, 해당 input이 가지고 있었던 정보를 잘 가지고 있다.

     

    또한 실험적으로, 저자들은 output channel 수를 늘렸을 때 information preserving을 더 잘해서 better performance를 얻을 수 있었다고 한다.

     

     

    CSM

     

     

    scale간의 feature도 섞어주기 위해, CCM을 거쳐 나온 output을 3x3 conv, bilinear interpolation을 거쳐 다음 scale의 CCM output과 섞어준다(채널방향으로 concat해주는 듯).

     

    U net 구조랑 비슷하고, 한 scale에 conv 한 번 밖에 안썼다고 한다.

     

    Pretrained Feature network

     

    EfficientNet (imagenet pretrained)

    ResNet

    R50-CLIP(clip contrastive image - text loss로 학습됨)

    ViT

    Deit

    (inception network는 사용하지 않음.. FID와 strong correlation이 있기에

     

     

    3. Ablation study


    실험은 LSUN-Church에서 진행되었다고 한다.

    (126k img, 적당한 complexity, res 256)

     

    G의 구조로는 FastGAN 구조를 사용했다.

     

    GAN에서 딥 모델을 만들 때 사용하는 skip connection (element wise하게 addition)대신 FastGAN은 channel wise multiplication을 사용한다고 한다.

     

    FastGAN(TOWARDS FASTER AND STABILIZED GAN TRAINING FOR HIGH-FIDELITY FEW-SHOT IMAGE SYNTHESIS, ICLR 2021)

    ---

    x_low를 다른 이미지의 것으로 생성하면 스타일만 바뀐 결과를 생성할 수 있도록 한다.

     

     

    ---

     

    loss는 hinge loss를, bs = 64, 1M iter 돌 때까지 training.

     

    현재 존재하는 discriminator augmentation 기법들을 쓸수록 성능이 증가했고, 쓸 수록 SOTA와 가까워졌다고 한다.

     

     

     

    3.1 Which feature network layers are most informative?

     

    FD는 특정 feature space scale에서 Ture data의 featuremap과 generated data의 Frechet Distance를 사용했다.

    (FD_i = layer i scale)

    multi 안쓴 standard RGB discriminator를 1로 잡고, 

    로 비교했다.

    (L1, L2, L3 = EfficientNet 통과한 output을 input으로 받는 D를 학습!!)

    L1 -> 위쪽 (res 크고 ch 작음)

    L4 -> 아래쪽(res 작고 ch 큼, deep)

    EfficientNet L1,L2 이렇게 모델명이 아님!!

     

     

    Perceptual discriminator도 사용했다.  각 feature map들이 똑같은 discriminator의 different layer로 들어가서 single logit을 뽑았다는데.. 논문을 읽어봐야 가능할 것 같다.

     

    그냥 해당 scale에서 perceptual loss 쟀다고 생각. 저 초록색 3개 행 끼리 비교

     

    perceptual loss

     

    이제 비교 시작

     

    결과를 보면, 뒤쪽 layer에 discriminator를 달면 , 오히려 성능이 안좋게 나오는 걸 볼 수 있다.

     

    이에따라 더욱 semantic에 관련된 feature일수록 adversarial loss에 따라 training되지는 않는다고 생각했다.

     

    shallow한 feature을 D 생략하면 성능이 떨어지는걸 보아, 해당 부분에 image의 대부분의 정보가 있다는 것을 알 수 있다.

     

    perceptual discriminator 다시 보기..

     

     

    3.2 How can we best utilize the pretrained features?

     

    CCM은 대부분의 FD를 낮춘경향, 즉 mixing channel은 효과적이었다는 것을 보여준다.

     

    CSM은 deeper layer에서 추가될때  효과가 더욱 컸던걸로 보아 CSM이 semantic feature활용을 더 잘 할 수 있게 되었다는것을 볼 수 있다.

     

    신기하게도, 4개의 scale에 사용했을때 성능이 가장 좋았다.

     

     

    3.3 Which feature network architecture is most effective?

    ​다양한 perceptual feature network architecture들을 사용했다.

     

    10만장이 각 네트워크에 대해 G, D를 학습하는데 사용되었다.

     

    놀라운 사실은. 이미지넷 acc와 해당 모델을 feature 뽑는 backbone으로 썼을 때 FID가 correlation이 없었다는 것이다.

     

    오히려, 모델이 작을수록 FID가 낮은 경향을 보였다.

     

    또 Res50, Res50CLIP 성능차이를 볼 때, FID를 낮추는데 imagenet feature acc가 엄청나게 중요하진 않다는 것을 알 수 있다.

    (Res50 : 4억 pair image-text objective를 목표로  학습) 

     

    저자들은 최종적으로 EfficientNet lite1을 백본으로 사용했다.

     

     

    4. Result


    4-1 Comparison to State-of-the-Art

     

    data의 전체 크기에 따라, resolution의 크기에 따라, complexity에 따라 등 다양한 실험을 진행했다.

     

    50k img 생성후 FID 비교로 성능 측정했고,

     

    몇 장이 있어야 converge하는지도 측정했다.

     

    각 experiment는 V100으로 100hours정도 걸림

     

     

    각 모델별 성능을 측정할 때, 각 모델에 맞는 aug strategy를 사용했다.(differentiable data-augmentation, adaptive discriminator augmentation)

     

    ProjectedGAN은 모든 실험에서 동일 lr, bs를 사용했다. (projected gan은 bs에 민감.)

     

    baseline 논문들도 최대한 성능이 잘 나오도록 하이퍼파라미터들을 sensitive하게 초이스 해주었다고 한다.

    (sg2ada는 lr, R1 penalty에 민감하다고함)

     

     

    시간으로 치면, sg2ada로 5일 걸릴게 3시간이 걸렸다고 한다.

     

    sample effiency가 상당히 높다.(알고리즘적 효율같은 느낌.number of action it takes)

     

    sg2ada로 1M장이 필요한게 0.1장이면 충분했다고 한다.

     

    본 이미지는 fixed latent로 이미지를 생성한 모습.

     

    신기하게도, fixed latent인데 생성되는 이미지가 상당히 다양한걸 볼 수 있다.

     

    이 이유는  training시 이미지들은 엄청난 perceptual change를 겪는데 이에 기인한것이라고 한다.

     

    실제 non projection case시에는 이 정도로 이미지가 변덕스럽지 않다고 한다.

     

    multiscale D의 random projection 부분이 일반적인 D에 비해 semantic적인 feedback을 상당히 많이 줄 것 이라 한다.

     

    이러한 feedback이 학습시에 stochasticity (확률적 다양성)을 주어 converge, performance적 성능 향상을 준것이라고도 볼 수 있다.

     

    4-2 Quantitative Results

    (256, 256)으로 모든 dataset 통일 

     

    large dataset에서 sg2ada는 projectedGANs 성능 낼라면(3.39) 이미지 80배 필요함

     

    별 : SOTA 깼을 때 필요헀던 이미지 장수

     

     

     

    Discussion

     

    AFHQ에서 머리만 떠다니는고, 객체 생성은 기깔나는데 배경이 blurry하거나 빈걸 종종 볼 수 있는데, 이는 classification model의 성능에서 비롯된것이라 볼 수 있다.

     

    classification에서 배경제거 해버려도 acc는 거의 안줄어드는데, 이를 통해 background에는 별로 신경을 안쓰고 이 특성이 featuremap 만들때도 나타난 것 같다.

     

    무조건 deep한 pretrained 모델 보다는 compact한게 Projected GAN training 하는데에는 더 좋다.

     

     

     

     

     

Designed by Tistory.