[Generative] Class-Balancing Diffusion Models (CVPR’24)

2024. 10. 14. 22:43Developers 공간 [SOTA]

728x90
반응형
<구성>
0. Before Start...
   a. Long-tail distribution
   b. Long-tail for GAN
   c. Data Optimization
   d. F-Beta Score for Generative Models
1. Problem
 
2. Approach 
   a. Concept
   b. Training
3. Results
   a. Setting
   b. Results
   c. Ablation Study

글효과 분류1 : 논문 내 참조 및 인용

글효과 분류2 : 폴더/파일

글효과 분류3 : 용어설명

글효과 분류4 : 글 내 참조

글효과 분류5 : 글 내 참조2

글효과 분류6 : 글 내 참조3


0. Before Start...

 

최근 다양한 모델이 등장함에도 불구하고 지속적으로 데이터로 인해 생기는 문제는 다양합니다.

  • Imbalanced Class 데이터셋 : class 간의 skewed distribution을 가지는 데이터셋으로 인해 생기는 문제
  • Large-Scale 데이터셋 : 데이터가 너무 많아 생기는 문제
  • Non-Labeled 데이터셋 : 데이터 라벨링의 어려움으로 인해 생기는 문제

이 논문을 소개하기에 앞서 이런 배경에 대해 간단히 소개하려고 합니다.


a. Long-tail Distribution

 

Long-tail distributionLong-tailed recognition과 같은 Discriminative 모델에서 잘 알려진 dillemma입니다.

[Long-tailed Recognition]

 

참고로 보통 discriminative 모델에서는 데이터가 (High 차원→Low 차원)으로 매핑되어 사용되지만, generative 모델에서는 반대로 (Low 차원→High 차원)으로 매핑되어 사용되기 때문에 아래와 같은 방법들이 직접적으로 적용되지는 않습니다.

 

이 챕터에서는 Discriminative 모델을 해결하기 위한 세가지의 paradigm을 소개하겠습니다.


1. Class Re-balancing

 

가장 흔하게 직관적으로 사용하는 방법입니다.

 

기존에 데이터셋은 지프의 법칙을 따르는 분포(Zipfian Distribution)의 라벨 분포를 가지는 경향이 있다는 것을 확인했습니다.

** Zipf's Law(지프의 법칙) : 수학적 통계를 바탕으로 밝혀진 경험적 법칙으로, 많은 종류의 정보들이 아래식과 그림과 같은 Zipfian Distribution(혹은 Yule Distribution)에 가까운 경향을 보인다는 것입니다.

일반적으로 "모든 단어의 사용 빈도가 해당 단어의 순위에 반비례"하며, 아래 식의 N은 전체 인덱스의 개수, k는 인덱스, s는 분포특성값 입니다.

$$f(k;s,N)=\frac{1/k^s}{\sum^N_{n=1}(1/n^s)}$$

[zipfian distribution, x축 k, y축 probability]

 

따라서 하기 논문은 아래와 같은 replication factor $r(h)$에 따라 샘플링을 하는 Re-Sampling을 진행합니다.

** [RS-SQRT] Exploring the limits of weakly supervised pretraining (ECCV’18)

** h는 해당 라벨을 의미하며, f(h)는 해당 라벨이 등장할 확률, t는 Threshold입니다.

$$\begin{aligned}
r(h)&=max(1,\phi(\frac{t}{f(h)}))&\\
&\phi(x)=x(uniform\ sampling)\\
&\phi(x)=\sqrt{x}(sqaure\ root\ sampling)
\end{aligned}$$

 

이렇게 데이터를 Re-sampling하는 방법외에도 objective function에 class별 frequency를 반영하는 경우도 있습니다.

** Long-tail learning via logit adjustment (ICLR’21)

 

2. Information Augmentation

** Deep representation learning on long-tailed data: A learnable embedding augmentation perspective (CVPR’20)

** A simple but effective module for learning imbalanced datasets (CVPR’21)

** Feature transfer learning for face recognition with under-represented data (CVPR’19)

 

class frequency가 높은 head class의 정보를 이용해, class frequency가 낮은 tail classaugmentation하는 방법입니다.

[augmentation 예시]

 

3. Module Improvement

** Learning deep representation for imbalanced classification (CVPR’16)

** Factors in finetuning deep model for object detection with long-tail distribution (CVPR’16)

 

새로운 네트워크의 구조나 방법을 제안해 해당 문제를 해결하는 방법입니다.

[새로운 모델 구조를 활용해 Long-tailed dataset문제를 해결하는 방법]


b. Long-tail for GAN

 

GAN은 적은 데이터셋에서 학습할 때 discriminator의 overfitting으로 인한 문제가 있어 다양한 방법이 시도되었고, 그 중 두가지를 아래 소개합니다.

 

1. Regularization 방법

** Self-supervised dense consistency regularization for image-to-image translation (CVPR’22)

** Stabilizing training of generative adversarial networks through regularization (NIPS’17)

** Regularizing generative adversarial networks under limited data (CVPR’21)

 

기존 Loss에 다양한 방법으로 auxiliary loss를 추가해 성능을 개선하는 방법입니다.

[GAN에서 Regularization으로 해결하는 방법]

 

2. Data Augmentation 방법

** [DiffAug] Differentiable augmentation for data-efficient gan training (NIPS’20)

** [ADA] Training generative adversarial networks with limited data (NIPS’20)

 

GAN에서 augmentation하는 경우 정보가 generator로 많이 leak하는 현상이 발생해 정밀한 방법이 필요해 DiffAug(Differential Augmentation)ADA(Adaptive Augmentation)가 제안되기도 했습니다.

 

단순한 augmentation T를 실제 데이터에 적용하는 “Augment Reals Only” 방법은 실제 이미지의 distribution을 바꿔버릴 수 있는 문제가 있었습니다.

 

따라서 DiffAug는 real sample과 fake sample 모두를 augment하는 방법입니다.

[DiffAugment 기법]

 


이외에도 학습데이터에 augmentation이 진행되면 생성된 이미지에도 이것들이 inherited되는 경우가 많아, 기존에 bCR(balanced Consistency Regularization)이란 방법은 이것을 해결하기 위해 regularization term을 discriminator loss에 추가해 생성된 이미지와 실제 이미지들의 consistency를 추구했습니다.

 

하지만 여전히 generator는 자유롭게 생성이 가능하기 때문에 augmentation된 결과를 생성하는 leaking augmentation이 발생했습니다.

 

따라서 ADA는 아래 그림과 같이, 따로 loss term을 만들지 않고 augmented된 이미지 자체를 활용해 loss term을 만들되, Generator와 Discriminator 모두에 대해 loss evaluation을 진행합니다. 

[ADA 기법]


c. Data Optimization

 

앞서 언급한 Long-tail Distribution 데이터셋을 학습하기 위한 문제 말고도, Large-scale의 데이터셋을 학습하기 위해 굉장히 큰 training cost를 해결하기 위한 노력도 있습니다.

 

이 또한 데이터 사용법을 개선한 학습방법이므로 이를 해결하기 위한 아래 4가지 방법들을 소개하겠습니다.

 

  1. Data Distillation & Coreset Selection
    • informative한 작은 데이터셋을 선택하거나(Coreset Selection) informative한 작은 데이터셋을 다른 모델로 합성(Data Distillation)함으로써 샘플 자체의 수를 줄여 학습속도를 개선하는 방법입니다.
    • 하지만 샘플 수를 줄이기 위해 오히려 추가적인 cost가 발생하며, lossless 성능을 도달하기 어렵습니다.
  2. Weighted Sampling Method
    • 샘플간의 사용 빈도수(frequency)를 조절해 convergence speed를 개선하는 방법입니다. 
    • 하지만 모델과 데이터셋에 굉장히 민감합니다.
  3. LARS & LAMB
    ** [LARS] Large batch training of convolutional networks (arxiv'17)
    ** [LAMB] Large batch optimization for deep learning : Training bert in 76 minutes (arxiv'19)
    • 병렬화를 최대한 활용해 엄청큰 batch size를 활용해 학습하는 방법입니다.
    • 하지만 이를 위해 추가적인 training cost가 발생하기도 합니다.
  4. Data Pruning
    • Static Pruning : Loss를 기반으로 샘플을 정렬해 덜 중요한 샘플을 제거하며 학습 속도를 향상하는 방법입니다. 데이터셋이 클수록 이를 측정하기 위한 overhead가 발생해 오히려 더 느려지는 경향이 있으며, biased gradient expectation이라는 실제 최적화 방향이 아닌 다른 방향으로 학습되는 경향이 있습니다.
    • Dynamic Pruning : Static Pruning과 다르게 실제 score를 업데이트해가며 학습하는 방법입니다.
      ** InfoBatch: Lossless Training Speed Up by Unbiased Dynamic Data Pruning

[Static Pruning과 Dynamic Pruning]


1. Problem 

 

기존의 DM은 데이터가 uniformly distributed하다는 가정하에 학습을 진행하지만, 실제로 학습데이터는 굉장히 skewed되어 있어 Long-tailed Dataset이라고 불립니다.

[Long-tailed Recognition]

 

이런 데이터셋으로 학습을 하면 Unconditional DMlow-quality 이미지를 생성하기도 하며, Conditional DMtail class에서보다 head class에서 만족스러운 성능을 보이는 경향이 있습니다.

 

기존에 이런 imbalanced class 분포의 데이터를 활용해 GAN을 학습하는 방법은 제안되었으나, 본 논문에서는 DM을 학습할 방법으로 CBDM(Class-Balancing Diffusion Model)으로 제안합니다.


2. Approach 

 

해결방법을 소개하기 전에, 용어는 아래와 같습니다.

  • $\boldsymbol{x}$ : 데이터
  • $y$ : 데이터의 Class Label
  • $q(\boldsymbol{x},y)$ : 학습할 데이터의 분포
  • $p_\theta(\boldsymbol{x},y)$ : 모델이 예측할 joint 분포 
  • $r=\frac{q(\boldsymbol{x},y)}{p_\theta(\boldsymbol{x},y)}=\frac{q(\boldsymbol{x}|y)}{p_\theta(\boldsymbol{x}|y)}\cdot{}\frac{q(y)}{p_\theta(y)}$: Density Ratio로, 위 두 분포의 차이를 의미합니다.
    • $q(y)$ : Class Label의 분포
    • $p_\theta(y)$ : Class Label에 대한 모델의 prior로, 일반적인 DM은 uniform하게 학습되었다고 가정됩니다.
    • skewed된 데이터셋을 활용하면 $q(y)$와 $p_\theta(y)$가 같지 않습니다.

a. Concept

 

먼저 class 빈도에 따른 성능을 아래 그림에서 살펴보면, class의 빈도가 낮은 곳에서는 FID가 높은(안 좋은)을 볼 수 있습니다. 

[Class 빈도에 따른 DM의 FID]


즉, 아래 그림에서 각 sampling process 과정의 왼쪽 분포를 보면 실제 tail class로 생성한 분포는 (head class에 비해) 실제 GT분포의 모형을 잘 못따라가며, mode가 잘 커버되지 않는 것을 볼 수 있습니다.

** mode : 가장 빈번하게 등장하는 값을 의미합니다.

[Sampling Process에서 step별 Adjust 전후의 분포]


결과적으로 아래 그림을 보면 본논문에서 제안한 CBDM이 더 다양한 결과를 도출해내는 것을 볼 수 있습니다.

  • 위 두 줄 : class 70에서 noisy image를 DDPM을 활용해 recover
  • 아래 두 줄 : class 86에서 noisy image를 CBDM을 활용해 recover

[DDPM과 CBDM의 결과]

 

보통 이런 문제를 해결하기 위해 가장 직관적인 접근은 Class Label 분포 prior를 adjust하기 위한 Class-balanced Re-Sampling입니다.

 

하지만 이 방법은 오히려 성능이 안좋아지는 경우가 있고, DM은 step-by-step sampling과 같은 특징을 갖고 있기 때문에 이런 distribution을 조금더 soft하게 적용시킬 수 있습니다.

 

먼저 아래와 같은 용어를 가정해봅시다

  • $p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y)$ : Class-Imbalanced 데이터로 학습된 conditional transfer probability
  • $p^*_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y)$ : Optimum하게 학습된 conditional transfer probability 

 

본 논문은 Imbalanced 데이터로 학습된 모델은 아래와 같이 매 reverse step t마다의 correction을 통해 나타낼 수 있다고 주장합니다. (Proposition 1)

$$p^*_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y)=p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y){\color{red}\underbrace{\frac{p_\theta(\boldsymbol{x}_{t-1})}{p^*_\theta(\boldsymbol{x}_{t-1})}\frac{q^*(\boldsymbol{x}_t)}{q(\boldsymbol{x}_t)}}_{\text{Adjustment Schema}}}$$

 

근데, 위 식의 Adjustment는 모든 sampling에서 실제로 구하기는 어렵습니다.

** 논문에 나와 있지는 않지만, $p_\theta^*$이나 $q^*$과 같은 optimum distribution을 가지고 있지 않기 때문일 것입니다.

 

그래서 먼저 위 식의 모델과 관련없는 부분인 $\frac{q^*(\boldsymbol{x}_t)}{q(\boldsymbol{x}_t)}$를 없애고, $\frac{p_\theta(\boldsymbol{x}_{t-1})}{p^*_\theta(\boldsymbol{x}_{t-1})}$ 부분에 대해서는 아래와 같이 time step t에 대한 $\mathcal{L}_r$라는 upper bound를 통해 구현해줍니다. (Proposition2)

$$\begin{aligned}
\mathcal{L}^*_{DM}&=\sum^{T}_{t=1}\mathcal{L}^*_{t-1}\\
\sum_{t\geq 1}\mathcal{L}^*_{t-1}&=\sum_{t\geq 1}{\color{blue}[}D_{KL}[q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)||p^*_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y)]{\color{blue}]}\\
&\leq \sum_{t\geq 1}{\color{blue}[}D_{KL}[\underbrace{q(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,\boldsymbol{x}_0)||p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,y)}_{DM\ loss\  \mathcal{L}_{DM}}]\\
&+\underbrace{{\color{red}t\mathbb{E}_{y'}[}D_{KL}[{\color{red}p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t)}||p_\theta(\boldsymbol{x}_{t-1}|\boldsymbol{x}_t,{\color{red}y'})]{\color{red}]}}_{Distribution\ Adjustment\ loss\ \mathcal{L}_r}{\color{blue}]}
\end{aligned}
$$

$\mathcal{L}_{DM}$은 일반적인 DDPM Loss이며 $\mathcal{L}_r$은 regularization term입니다.

이 때, $\mathcal{L}_r$은 랜덤 Class $y’$에 대한 결과와 모델의 unconditional output의미 동일성을 유지해주는 역할을 합니다.

 

즉, $y'\sim q^*_y$는 랜덤하게 모든 클래스에서 class를 샘플하므로 이전에 잘 선택되지 않던 tail sample에 대한 확률을 보완해주고, 이렇게 함으로써 head class에 대해 overfitting되는 결과를 줄임과 동시에 tail class에 대한 다양성을 크게해줄 수 있습니다.


b. Training

 

CBDM은 Adjusted Transfer Probability로 구현되는데, MSE형태의 추가적인 regularizer를 통해 구현됩니다. 이를 자세히 살펴보겠습니다.

 

먼저, 앞서 설명한 adjustment loss $\mathcal{L}_r$는 아래와 같이 구해냅니다.

** CFG를 위해 10%의 condition $y$는 None으로 주었다고 합니다.

$$\mathcal{L}_r(\boldsymbol{x}_t,y,t)=\frac{1}{|\mathcal{Y}|}\sum_{y'\in \mathcal{Y}}[t\|\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y)-\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y') \|^2]$$

 

위와 같은 Loss와 함께 기존의 conditional DM에 적용할 수 있는데, regularization weight $\tau$와 함께 적용됩니다. 이 $\tau$는 앞서 설명한 density ratio $\frac{p_\theta(\boldsymbol{x}_{t-1})}{p^*_\theta(\boldsymbol{x}_{t-1})}$의 sharpness를 결정합니다.

$$\begin{aligned}
\mathcal{L}_{CBDM}(\boldsymbol{x},y,t,\boldsymbol{\epsilon})=&\|\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y)-\boldsymbol{\epsilon} \|^2+&Ordinary\ DM\ loss\  \mathcal{L}_{DM}\\
{\color{red}\tau}\frac{t}{|\mathcal{Y}|}\sum_{y'\in \mathcal{Y}}{\color{blue}(}&\|\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y)-\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y') \|^2{\color{blue})}&Distribution\ Adjustment\ loss\ \mathcal{L}_r
\end{aligned}$$

 

하지만 위 식을 그대로 사용한다면, 모델의 output이 condition $y$와 무관하게 생성해 전체적으로 conditional generation의 성능을 떨어트리는 model collapse가 일어날 수 있습니다.

 

따라서 하기 논문을 참고해 아래 식과 같이 SG(Stop Gradient) Operation을 추가해주었습니다. 이 때 commitment weight $\gamma$는 본 논문에서 $\frac{1}{4}$로 셋팅되었습니다.

** Exploring simple siamese representation learning (CVPR’21)

** Neural discrete representation learning (NIPS’17)

$$\begin{aligned}
\mathcal{L}_{CBDM}(\boldsymbol{x},y,t,\boldsymbol{\epsilon})=&\|\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y)-\boldsymbol{\epsilon} \|^2+&Ordinary\ DM\ loss\  \mathcal{L}_{DM}\\
\tau\frac{t}{|\mathcal{Y}|}\sum_{y'\in \mathcal{Y}}{\color{blue}(}&\|\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y)-{\color{red}sg(}\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y'){\color{red})} \|^2+&Distribution\ Adjustment\ loss\ \mathcal{L}_r\\
{\color{red}\gamma} &\|{\color{red}sg(}\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y){\color{red})}-\boldsymbol{\epsilon}_\theta(\boldsymbol{x}_t,y') \|^2{\color{blue})}&+Commitment\ loss\ \mathcal{L}_{rc}
\end{aligned}$$

 

또한 샘플링 클래스 셋인 $\mathcal{Y}$를 선택하는 것 또한 중요한데, 아래와 같은 두 가지 중에 선택해 구현했으며 두가지 모두 잘 동작했다고 합니다. 뒤 실험에서 다시 다룰 예정입니다.

  1. 가장 기본적으로 Label Distribution을 Class-balanced label distribution으로 바꾸어 $\mathcal{Y}^{bal}$을 만들어내는 방법
  2. stabilized training만을 위해 기존보다 상대적으로 적게 class-imbalanced한 distribution으로 만들어주는 방법

결과적으로 알고리즘을 정리하면 아래와 같습니다.

[CBDM 학습 알고리즘]


3. Results 

 

이제 CBDM을 활용했을 때의 결과를 살펴보겠습니다.


a. Setting

 

먼저 실험을 진행하기 위해 한 학습 셋팅은 아래와 같습니다.

  • Dataset CIFAR10CIFAR100CIFAR10LT(Long-tailed 버전), CIFAR100LT(Long-tailed 버전)
    • 이 때 LT버전은 하기 논문의 방법과 같이 exponential한 class 분포를 가지며 이때 imbalance factor는 0.01입니다.
      ** Learning imbalanced datasets with labeldistribution-aware margin loss(NIPS’19)
    • class의 index가 증가할 수록 그 class의 데이터가 적습니다.
  • Model : DDPMEDMDiffusion-ViT
    • DDPM : $\beta_1=10^{-4},\beta_T=0.02, T=1,000$
  • Optimizer : Adam
  • Learning Rate : 초기 0.0002 (5,000 epochs warmup)
  • Epochs : Data 별로 다르게 적용

 

다음으로 평가를 위해 대조군인 baselinemetric은 아래와 같습니다. 

  • DM-based Baseline : DM에 기존 방법들을 적용해 baseline으로 활용합니다.
    • Classic Methods (Re-Sampling)
      • RS(Re-Sampling) : uniform class distribution으로 만들어 사용합니다.
      • RS-SQRT(Soft Resampling Method) : 
class frequency의 sqaure root의 확률로 샘플을 만들어 사용합니다.
    • Augmentation Methods
      • DiffAug : 해당 기법의 default threshold에 따라 학습 이미지에 적용됩니다.
      • ADA : 해당 기법을 이미지 뿐 아니라, augmentation pipelie 자체도 encoding해 condition으로 활용합니다.
  • GAN-based baseline : long-tail distribution을 다룬 GAN들도 baseline으로 활용합니다.
    • CBGAN 
    • gSR(group spectral regularization) GAN 
    • SNGAN 
  • Metric : 다양한 metric을 활용합니다.
    • Diversity & Fidelity : FID 
    • Diversity : Recall, $F_8$
    • Fidelity : IS, $F_{1/8}$

특히 Sampling 과정에서는 CFG가 활용되었는데, guidance strength $\omega$는 아래와 같이 데이터마다 다르게 적용했습니다. 

  •  CIFAR10, : 1.6
  • CIFAR100, : 0.8
  • CIFAR10LT(Long-tailed 버전),  : 1.0
  • CIFAR100LT(Long-tailed 버전) : 0.8

b. Results

 

먼저 정량적인 결과는 아래와 같습니다.

[정량적인 결과 비교]

전체적으로, 직접적인 baseline인 DDPM보다 모든 데이터셋에서 IS를 제외하고 성능이 좋았으며, diversity와 fidelity 모든 면에서 클래스가 많을 수록 더 성능향상이 두드러졌습니다.

 

추가적으로 Re-SamplingAugmentation을 진행했을 때, ADA를 제외하고는 DDPM과 결합했을 때 성능이 오히려 감소되는 현상이 있었습니다. ADA augmentationDDPM, CBDM과 더해졌을 때 모두 성능이 향상되는 것을 확인할 수 있었습니다. 

 

정량적 결과를 보았으니, 이번엔 정성적으로 DDPMCBDM을 아래 그림으로 살펴보겠습니다. 아래 그림은 상대적으로 mild tail class인 62와 확실한 tail class인 94에 대해 생성한 결과를 비교한 결과입니다.

[정성적인 결과 비교]

결과를 보면 CBDM이 더욱 다양한 이미지를 생성하며, 94 class에 대해서는 (b)를 보니 색감과 질감이 더욱 다양한 것을 확인할 수 있습니다. 이와 반대로 DDPM은 학습 데이터와 비슷한 결과만 생성했습니다.

 

이번엔 각 클래스에 대한 condition이 잘 되었는지를 보기 위해 FID 성능 향상치를 각 클래스에 대한 case-by-case로 살펴보겠습니다. 아래 그림을 보면 DDPM, CBDM의 class 40~의 tail class들에서 확실한 향상을 볼 수 있습니다.

[각 클래스 별 FID 감소치]

 

다음으로 label set $\mathcal{Y}$를 어떻게 정하는 것이 좋을지 에 대해 살펴보겠습니다. 앞서 아래와 같이 두개의 셋팅을 했었다는 것을 먼저 상기해봅시다.

 

  1. 가장 기본적으로 Label Distribution을 Class-balanced label distribution으로 바꾸어 $\mathcal{Y}^{bal}$을 만들어내는 방법
  2. stabilized training만을 위해 기존보다 상대적으로 적게 class-imbalanced한 distribution으로 만들어주는 방법

 

이를 참고해 label set $\mathcal{Y}$는 아래와 같이 세가지로 구성합니다.

  • $\mathcal{Y}^{train}$ : 기존 training set과 비슷한 label distribution으로 구성합니다.
  • $\mathcal{Y}^{bal}$ (위 방법1) : 완전히 uniform한 label distribution으로 구성합니다.
  • $\mathcal{Y}^{sqrt}$ (위 방법2) : 기존 class frequency의 sqaure root의 비율로 상대적으로 덜 class-imbalanced한 distribution으로 구성합니다.

또한, 학습 메커니즘 또한 아래와 같이 세가지로 나누어 구성합니다.

  • PT(LT) : 기존 소개된 CBDM방식으로 scratch부터 학습합니다.
  • PT(LT)+FT(LT) : Long-tailed 데이터로 학습된 모델을 Long-tailed 데이터로 fine-tuning하는 방식으로 진행합니다.
  • PT+FT(LT) : 학습된 모델을 적은 Long-tailed 데이터로 fine-tuning하는 방식으로 진행합니다.

위 결과는 아래와 같습니다.

** 아래에서 label set $\mathcal{Y}$에 -로 표시된 것은 DDPM의 결과입니다.

[다양한 label set $\mathcal{Y}$에 대한 학습 결과]

CBDM은 모든 configuration에 대해 성능이 좋았습니다.

 

이 때, $\mathcal{Y}^{bal}$로 label set을 구성하는 것이 diversity($F_8$, Recall)는 더 좋고 fidelity(IS)는 많이 좋지 않았으며, 이는 학습의 stability를 방해했기 때문이고 다른 데이터셋에 적용해보았을 때는 더 안 좋은 결과가 나왔다고 합니다.

 

또한 $\mathcal{Y}^{sqrt}$로 label set을 구성하는 것은 diversity는 유지하면서도 fidelity(IS)가 좋았으며, FID가 가장 좋았다고 합니다.

 

학습 메커니즘의 경우 pre-trained에 fine-tuning을 추가로 하는 것이 더 안정적으로 학습이 가능했으며, $\mathcal{Y}^{bal}$을 사용할 때가 더 큰 성능향상을 보였다고 합니다.

 

 

 

이번엔 ResNet-32이라는 Classification Model의 long-tailed 데이터 학습으로 인한 문제우리가 만든 생성 모델의 생성 데이터로 향상시킬 수 있을지를 살펴보았습니다.

 

아래 표는 DDPMCBDM으로 각 데이터를 augment해 학습한 결과입니다.

[생성 데이터로 Classifier 개선]

예상한 바와 같이 long-tailed 데이터를 그냥 쓰는 것은 굉장히 성능이 떨어졌으며, 생성 데이터로 augmentation을 진행하면 성능이 향상하는 것을 볼 수 있었습니다.

 

특히 CBDM이 더 좋은 성능향상을 보였으며, recall에서 이가 두드려졌는데 이는 생성된 이미지가 DDPM보다 더 diversity를 가짐을 의미합니다.

 

 

 

마지막으로 SOTA 생성 모델 중 대표적으로 long-tailed를 다룬 기법들과 비교해보았습니다. 주로 GAN 기반의 방법이므로, 이들은 학습 데이터가 skewed되어 있을 때 더 심각한 문제가 발생하곤 합니다.

[Long-tailed SOTA 생성 모델과의 비교]


c. Ablation Study

 

이번엔 CBDM모델의 Ablation Study를 진행해보겠습니다.

 

먼저, score-matching 방법 등의 다른 DM backbone을 사용할 때를 살펴보았습니다. 아래는 CIFAR100LT 데이터셋에서 다른 DM BackboneCBDM을 적용할 경우의 결과를 나타냅니다.

[다른 DM Backbone에 CBDM을 적용한 경우]

CBDM은 두 backbone 모두에서 성능향상을 보였으며, 이는 다른 backbone들과의 compatibility를 보임과 동시에 CBDMlong-tailed 데이터를 다룰 수 있는 효율성을 나타냅니다.

 

 

이번엔 sampling할 때 DDPM reverse sampling이나 SDE방법이 아닌, DDIM과 같은 deterministic ODE방법을 사용해도 적용이 가능한지 테스트해보았습니다.

 

아래 표는 DDPM reverse step의 1/10 개수의 DDIM step으로 적용했을 때 성능 변화를 나타냈습니다. 결과적으로, IS를 제외하고 큰 변화가 없었습니다.

[DDIM 적용시 성능 변화]

 

이번엔 regularization weight $\tau$의 영향에 대해 조사해봅니다. 실제 CBDM에서 $\tau_0$는 0.001로 셋팅 되어 모든 step에서 weight가 1.0을 넘지 않게 되어있습니다.

 

아래 그림은 weight의 서로 다른 scale에 대해 FID와 IS 결과를 나타냅니다.

[$\tau$에 따른 weight scale 별 FID와 IS]

결과적으로 weight를 곱했을 때의 scale이 1.0 이하일 때 가장 좋았으며, 이는 즉 $\tau$가 너무 크거나 작으면 안되는 것을 의미한다는 것을 의미합니다. 

 

 

 

또한 추가적으로 commitment loss $\mathcal{L}_{rc}$가 있는 것이 나을지에 대한 테스트도 진행했는데, 안 쓸때는 FID가 8.84이지만 사용하니 8.30으로 확실히 개선되었다고 합니다.

 

 

 

이번엔 sampling 시점에 guidance strength $\omega$에 따른 결과를 살펴보려고 합니다. 아래 그림은 $\omega$를 0.0~2.0으로 변화하며 FID와 IS를 살펴본 결과입니다. 

[guidance scale에 따른 FID와 IS]

 

위 그림에서 DDPM보다 CBDM을 활용할 때, FID는 guidance scale로 인해 FID가 상당히 줄어드는  것을 볼 수 있지만, IS의 경우는 DDPM보다 크게 개선이 잘 안되는 것을 볼 수 있습니다. 하지만 IS도 결국 2.0에서는 같은 값에 도달했습니다.

 

 

 

마지막으로 regularization weight $\tau$와 guidance strength $\omega$ 간의 조율을 살펴봅니다.

 

CFG 논문에 따르면 강한 guidance strength는 overfitting으로 인해 fidelity는 높아지지만 diversity가 떨어진다는 것을 보였습니다. 즉, 둘간의 trade-off가 존재한다는 것이죠.

 

이와 반대로 regularization weight $\tau$는 모델의 class간 정보 교환을 통해 diversity를 향상시키는 기능을 가지고 있습니다.

 

따라서 이를 보기 위한 아래 그림은 body class 53에 대해 다양한 guidance strengthregularization weight에 따른 결과를 정성적으로 보인 결과입니다.

[fidelity와 diversity의 tradeoff 실험 과정]

 

결과에 앞서 DDPM모델은 guidance strength가 늘어나면 realistic한 결과를 생성할 수 있지만, 학습 데이터와 거의 겹치는 이미지를 생성하는 경향이 있습니다.

 

반대로 CBDM에서는 guidance strength가 늘어날 때 image content가 급격하게 변해 guided되지 않은 결과를 내는 것이 아니라, 오히려 기존의 컨텐츠를 해당 class로 가깝게 refine하는 결과를 냈습니다.

이는 즉, CBDM이 un-guided class의 다양성을 내포하고 있기 때문에 효율적으로 guide term이 적용된다는 것을 의미한다고 합니다.


 

 

 

 

 

728x90
반응형