[PyTorch] scaled dot product에서 Attention Map 디스플레이 하기

2025. 4. 10. 00:07Developers 공간 [Shorts]/Software Basic

728x90
반응형
<분류>
A. 수단
- OS/Platform/Tool : Linux, Kubernetes(k8s), Docker, AWS
- Package Manager : node.js, yarn, brew, 
- Compiler/Transpillar : React, Nvcc, gcc/g++, Babel, Flutter

- Module Bundler  : React, Webpack, Parcel

B. 언어
- C/C++, python, Javacsript, Typescript, Go-Lang, CUDA, Dart, HTML/CSS

C. 라이브러리 및 프레임워크 및 SDK
- OpenCV, OpenCL, FastAPI, PyTorch, Tensorflow, Nsight

 


1. What? (현상)

 

Text-To-Image Diffusion 논문을 보다보면 feature를 visualize해서 네트워크 내 어떤 feature에 집중해서 처리하고 있는지를 확인하는 과정이 있기도 합니다.

 

예를 들어 아래 같은 경우는 각 decoder layer에서 나온 spatial feature들을 Top-3 leading components에 대해 PCA를 적용해 실제 이미지와 비교한 결과입니다. 이는 각 layer가 어떤 feature에 집중하고 있는지를 살펴볼 수 있겠죠.

** Plug-and-play diffusion features for text-driven image-to-image translation(CVPR'23)

[Decoder Layer feature 분석]

 

아래는 Self-Attention Map을 display해서 layer 별로 어떤 feature에 집중하고 있는지를 살펴보는 과정입니다.

** Visual Style Prompting with Swapping Self-Attention (arxiv'24)

[Self-Attention Map 분석]

 

위에서는 각 layer에서 나온 feature 자체를 살펴보고 어떤 것에 집중하는지를 살펴본다면, 아래 같은 경우 Text Condition에 따른 Cross-attention map을 비교하는 방법을 통해, Condition이 어디에 집중해서 생성하게 되는지를 살펴보기도 합니다.

** Prompt-to-prompt image editing with cross attention control(arxiv'22)

[Cross-Attention Map 분석]

 

이번 글에서는 위와 같은 Cross Attention map을 display하기 위해 scaled_dot_product_attention()함수를 직접 customize하는 방법을 보이려고 합니다.


2. Why? (원인)

  • X

3. How? (해결책)

 

아시다 시피 Attention은 아래와 같습니다.

$$Attention(Q,K,V)={\color{red}softmax(\frac{QK^\top}{\sqrt{d_k}})}V$$

이중 attention map은 $softmax(\frac{QK^\top}{\sqrt{d_k}})$입니다. 즉, Key와 Query간의 관계에 관련된 Weight Map이죠.

 

일단 기존에 아래와 같이 Key, Query, Value들에 대한 위 Attention결과는 아래와 같이 구성될 것입니다.

import torch.nn.functional as F
out = F.scaled_dot_product_attention(
    q, k, v,
    attn_mask=mask,
    is_causal=causal
)

 

이를 직접 만든 custom_scaled_dot_product_attention()으로 대체할 계획입니다.

out = custom_scaled_dot_product_attention(
     q, k, v, 
     attn_mask=mask,
     is_causal=causal
 )

 

그럼 이번엔 custom_scaled_dot_product_attention()함수를 기존과 똑같이 기능하도록 구현해보겠습니다.

import torch
import torch.nn.functional as F
import typing as tp

def custom_scaled_dot_product_attention(q, k, v, 
                    attn_mask=None, 
                    dropout_p=0.0, 
                    is_causal=False):
    """
    q : (batch, num_heads, seq_lenA, head_dim)
    k, v : (batch, num_heads, seq_lenB, head_dim)
    attn_mask: (batch, seq_len, seq_len)
    """

    d_k = q.size(-1)
    # 1. QK^T / sqrt(d_k)
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=q.dtype, device=q.device))

    # 2. causal mask 
    if is_causal:
        seq_len = q.size(-2)
        causal_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=q.device))
        causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)  # (..., 1, seq_len, seq_len)
        attn_scores = attn_scores.masked_fill(~causal_mask, float('-inf'))

    # 3. attention mask 
    if attn_mask is not None:
        attn_scores += attn_mask

    # 4. softmax (attention weights)
    attn_weights = F.softmax(attn_scores, dim=-1)

    # 5. dropout
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)

    # 6. attention output
    output = torch.matmul(attn_weights, v)

    return output

 

이제 특정상황에 앞서 구했던 attn_weights를 저장하도록 위 함수 내에 함수를 만들어보겠습니다.

아래는 save_ix라는 layer의 index를 활용해서 주어지면 해당 layer의 weight map을 저장하도록 하려고 합니다.

이 때 사이즈는 아래와 같을 것입니다.

  • query의 사이즈 : (batch, num_heads, seq_lenA, head_dim)
  • key와 value의 사이즈 : (batch, num_heads, seq_lenB, head_dim)
  • attn_weights의 사이즈 : (batch, num_heads, seq_lenA, seq_lenB)
import torch
import torch.nn.functional as F
import typing as tp

def custom_scaled_dot_product_attention(q, k, v, 
                    attn_mask=None, 
                    dropout_p=0.0, 
                    is_causal=False, 
                    save_ix=-1, 
                    time_step=-1,
                    mode:tp.Literal["tensor", "numpy", "heatmap", "pca"]="pca" ):

    # ...

    if save_ix !=-1:
        save_tensors(attn_weights, save_ix, time_step, mode=mode)

    return output

 

save_tensors()를 한번 보겠습니다. 첫번째 batch에 대해서만 적용할 것이고, 아래와 같은 4가지 버전의 상황이 있습니다.

  • tensor 모드 : tensor객체를 저장
  • numpy 모드 : numpy객체를 저장
  • heatmap 모드 : [query, key] weight를 query를 기준으로 key값으로 그대로 나타내보려고 합니다.
    ** head간에는 평균값을 사용하려고 합니다.
  • pca 모드 : [query, key] weight를 query를 기준으로 key값들의 pca결과를 나타내보려고 합니다.
    ** head간에는 평균값을 사용하려고 합니다.
from pathlib import Path
import numpy as np

def save_tensors(attn_weights, save_ix, time_step, mode:tp.Literal["tensor", "numpy", "heatmap", "pca"]="pca"):
    time_step_pos = time_step[0]
    time_step_pos = int(time_step_pos * 10000.0)

    b_size, h_num, q_len, k_len = attn_weights.shape

    # [B, #heads, query_len, key_len] > [#heads, query_len, key_len] > [query_len, key_len]
    attn_weights_pos_avg = attn_weights.mean(dim=1)[0].detach().cpu()

    Path("./output_test").mkdir(parents=True, exist_ok=True)
    if mode=="tensor":
        filename = f'output_test/crossattn_tensor_{time_step_pos:04d}_layer{save_ix:03d}.pt'
        torch.save(attn_weights_pos_avg, filename)
        return 
    attn_weights_pos_avg = attn_weights_pos_avg.numpy() 
    if mode=="numpy":
        filename = f'output_test/crossattn_numpy_{time_step_pos:04d}_layer{save_ix:03d}.npy'
        np.save(filename, attn_weights_pos_avg)
    elif mode=="heatmap":
        filename = f'output_test/crossattn_heatmap_{time_step_pos:04d}_layer{save_ix:03d}.png'
        title = f'Mean AttentionMap : Heatmap (step 0.{time_step_pos:04d}, layer {save_ix:02d})'
        save_attention_heatmap(attn_weights_pos_avg, save_path=filename, title=title)
    elif mode=="pca":
        filename = f'output_test/crossattn_pca_{time_step_pos:04d}_layer{save_ix:03d}.png'
        title = f'Mean AttentionMap : Text PCA-Reduced (step 0.{time_step_pos:04d}, layer {save_ix:02d})'
        save_attention_pca(attn_weights_pos_avg, save_path=filename, title=title)
더보기

-------------------------------------------------------------

<다른 PyTorch 객체를 저장하기>

 

위에서 구현한 save_tensors()를 기존 latent에 적용해서 저장할 수도 있습니다.

from einops import rearrange
from .common.tk_attention import save_tensors

# latents : [Batch, latents, query]
latents_save = rearrange(latents, 'b l q -> b 1 q l')
save_tensors(latents_save, 100, torch.tensor([0.0]))

------------------------------------------------------------- 

 

save_attention_heatmap()save_attention_pca()는 각각 heatmap과 pca를 저장하는 함수로 matplotlib를 활용해 구현하면 됩니다.

아래는 예시로 필자가 구현한 함수들을 보입니다. 필자는 오디오와 텍스트에 대한 Attention Map에 적용해보았습니다.

import matplotlib.pyplot as plt

def save_attention_heatmap(attn_map, save_path="attn_heatmap.png", title="Mean AttentionMap : Heatmap"):
    """
    attn_map: (q_len, k_len)
    Saves mean attention heatmap as PNG.
    """
    q_len, k_len = attn_map.shape

    attn_map = attn_map.T # [k, q]

    plt.figure(figsize=(10, 4))
    plt.title(title)

    plt.imshow(attn_map, cmap='magma', aspect="auto")
    plt.colorbar(label="Attention Weight")

    plt.xlabel("Query(Audio Token) Index", fontsize=10)
    plt.ylabel("Key Text Token Index", fontsize=10)
    plt.ylim((0, k_len))
    plt.xlim((0, q_len))

    plt.tight_layout()
    plt.savefig(save_path, dpi=150)
    plt.close()

[heatmap 예시 결과]

 

import matplotlib.pyplot as plt
from sklearn.decomposition import PCA

def save_attention_pca(attn_map, save_path="attn_pca_heatmap.png", title="Mean AttentionMap : Text PCA-Reduced"):
    """
    attn_map: (q_len, k_len)
    Saves PCA-reduced attention map as PNG.
    """

    q_len, k_len = attn_map.shape

    # PCA over keys for each query
    pca = PCA(n_components=1)
    reduced = pca.fit_transform(attn_map)[:,0] # (q, 1)

    plt.figure(figsize=(10, 4))
    plt.plot(range(q_len), reduced, color='tomato', linewidth=2)
    plt.title(title)

    plt.axhline(0, color='black', linestyle='--', linewidth=1)
    plt.grid(alpha=0.3)

    plt.xlabel("Query(Audio Token) Index")
    plt.ylabel("Key Text 1st Principal Component")
    plt.ylim((-0.5, 0.5))
    plt.xlim((0, q_len))

    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

[PCA 예시 결과]

 

 


 

728x90
반응형