[PyTorch] 모델의 state_dict 다루기

2024. 7. 14. 01:51Developers 공간 [Shorts]/Software Basic

728x90
반응형
<분류>
A. 수단
- OS/Platform/Tool : Linux, Kubernetes(k8s), Docker, AWS
- Package Manager : node.js, yarn, brew, 
- Compiler/Transpillar : React, Nvcc, gcc/g++, Babel, Flutter

- Module Bundler  : React, Webpack, Parcel

B. 언어
- C/C++, python, Javacsript, Typescript, Go-Lang, CUDA, Dart, HTML/CSS

C. 라이브러리 및 프레임워크 및 SDK
- OpenCV, OpenCL, FastAPI, PyTorch, Tensorflow, Nsight

 


1. What? (현상)

PyTorch 모델을 활용하다보면 모델의 구조와 관련된 state_dict를 접하게 됩니다.

 

이는 torch.nn.Module 형태의 모델의 각 layer마다 Weight, Gradient 같은 Tensor를 dictionary형태로 저장해둔 객체입니다.

 

이를 활용하면 Weight의 일부만 사용할수도 있고, 학습된 모델의 정보를 추출하는데 사용할 수도 있습니다.

 

이번 글에서는 기본적으로 state_dict를 다루는 방법을 살펴보고 이를 통해 Weight가 얼마나 달라졌는지를 비교하는 방법과 학습에 활용하기 위해 불러오는 방법을 살펴보겠습니다.


2. Why? (원인)

  • X

3. How? (해결책)

 

0. 기본적인 state_dict의 형태

 

기본적으로 state_dict는 dictionary인데 "diffusion.pretransform.model.layer1"과 같은 이름으로 되어있습니다.

 

따라서 key-value형태로 state_dictB[key]와 같이 사용합니다.

 

1. state_dict 간 비교하기

 

state_dict가 두개 주어졌을 때, mse를 더한 뒤 평균을 구하는 함수를 소개하겠습니다.

import torch
from torch.nn.functional import mse_loss

def compare_state_dict(state_dictA, state_dictB):
    compared_result = {}
    compared_count = {}
    for key in state_dictA:
        if key in state_dictB :
            if state_dictA[key].shape == state_dictB[key].shape:
                if isinstance(state_dictA[key], torch.nn.Parameter):
                    state_dictA[key] = state_dictA[key].data
                if isinstance(state_dictB[key], torch.nn.Parameter):
                    state_dictB[key] = state_dictB[key].data

                compare_key = '.'.join(key.split('.')[:3])
                if compare_key in compared_result:
                    compared_result[compare_key] =+ mse_loss(state_dictA[key], state_dictB[key])
                    compared_count[compare_key] +=1 
                else:
                    compared_result[compare_key] = mse_loss(state_dictA[key], state_dictB[key])
                    compared_count[compare_key] =1
            else:
                print("[TK][Warning] weight Not match(Size) : {}".format(state_dictA[key].shape)) # Nothing
        else:
            print("[TK][Warning] weight Not match(Key) : {}".format(key))

    for key in compared_result :
        loss = compared_result[key]
        count = compared_count[key]
        loss_mean = loss/count
        # TODO : Summation of MSE but change metric
        print("[TK][Info] Compared : {} ({})".format(loss_mean, key))
        if loss_mean==0.0:
			print("[TK][Warning] Critical!! : {}".format(key))

 

그럼 이 함수를 실행해보겠습니다. state_dict를 비교하기로 했으므로, 아래 모델을 불러오는 get_model_from_ckpt() 함수는 이후에 설명하겠습니다.

targetA="./A/epoch=622-step=160000.ckpt"
targetB = "./B/model.safetensors"
target_config = "./C/config.json"

A, _=get_model_from_ckpt(target_config,targetA)
B, _=get_model_from_ckpt(target_config,targetB)

compare_state_dict(A.state_dict(),B.state_dict())

 

2. state_dict를 활용해 weight를 불러오는 방법

 

PyTorch에는 기존의 체크포인트의 weight를 불러오는 방법이 아래와 같이 이미 있습니다.

  • modelA = torch.load("abc.pt") : .pt 또는 .pth 또는 .ckpt 와 같은 체크포인트를 불러와 모델을 만들어냅니다.
  • modelB.load_state_dict(modelA, strict=False) : 이미 만들어진 모델의 state_dict를 다른 모델에 넣어줍니다.

하지만 아래와 같이 state_dict를 하나하나 불러오는 것으로 구현하면, 일부 weight를 선택해서 불러오게 구현할 수도 있겠죠?

 

create_model_from_config()라는 직접만든 함수를 활용해 config_path를 활용해 어떤 torch.nn.Module로 정의된 모델을 만들 수 있다고 가정하겠습니다.

 

이 때, 위에서 보인 체크포인트들과 .safetensor형태의 weight 두가지의 체크포인트를 활용해 불러오는 두가지를 분리해 구현했습니다.

import torch
import json
from safetensors.torch import load_file

def get_model_from_ckpt(config_path, ckpt_path):        

    with open(config_path) as f:
        model_config = json.load(f)
    model = create_model_from_config(model_config)

    if ckpt_path.endswith(".safetensors"):      
        state_dict = load_file(ckpt_path)
        copy_state_dict(model, state_dict, first_remove=False)
    else:
        state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
        copy_state_dict(model, state_dict, first_remove=True)
		
    return model, model_config

 

위 두 모델에 대해서 first_remove를 다르게 주는 이유는, 필자의 경우 아래와 같이 두가지가 달랐기 때문입니다. 

  • safetensor : safetensor는 저장된 모델이 firstmodule.model.layer...라고 저장되어있다면
  • 체크포인트 : 직접 저장된 체크포인트를 이름을 포함해 mymodel.firstmodule.model.layer...라고 저장되어있어

각각의 경우에 대해 다르게 옵션을 주기 위해 구현했습니다.

 

그럼 state_dict를 복사하는 함수는 아래와 같습니다. 

def copy_state_dict(model, state_dict, first_remove=False, print_remain=False):
    if print_remain:
        remove_test=[]
    model_state_dict = model.state_dict()
    for key in state_dict:
        if first_remove:
            model_key = '.'.join(key.split('.')[1:])
        else:
            model_key=key
        if model_key in model_state_dict :
            if state_dict[key].shape == model_state_dict[model_key].shape:
                if isinstance(state_dict[key], torch.nn.Parameter):
                    state_dict[key] = state_dict[key].data
                model_state_dict[model_key] = state_dict[key]
                if print_remain:
                    remove_test.append(model_key) #temp
            else:
                print("[TK][Warning] weight Not match(Size) : {}".format(state_dict[key].shape)) # Nothing
        else:
            print("[TK][Warning] weight Not match(Key) : {}".format(key))
    if print_remain:
        count = 0
        for key in model_state_dict:
            if key not in remove_test:
                count +=1
                print("[TK][Warning] Remained : {}".format(key))
        if count==0:
            print("[TK][Info] Successfully Loaded")
    model.load_state_dict(model_state_dict, strict=False)

 

그럼 이제 실행해보겠습니다. 위는 local에 모델의 config와 checkpoint path가 주어질 때에 대해 구현했는데, huggingface hub의 이름이 pretrained_name이라는 이름으로 주어졌을 때의 경우도 구현했습니다.

import json
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import torch

if pretrained_name is not None: # Given Huggingface Name
	model_config_path = hf_hub_download(pretrained_name, filename="model_config.json", repo_type='model')
    with open(model_config_path) as f:
        model_config = json.load(f)
    model = create_model_from_config(model_config)
        
    try:
        model_ckpt_path = hf_hub_download(pretrained_name, filename="model.safetensors", repo_type='model')
    except Exception as e:
        model_ckpt_path = hf_hub_download(pretrained_name, filename="model.ckpt", repo_type='model')
        
    if model_ckpt_path.endswith(".safetensors"):
        state_dict = load_file(model_ckpt_path)
    else:
        state_dict = torch.load(model_ckpt_path, map_location="cpu")["state_dict"]
        
    model.load_state_dict(state_dict)

elif model_config_path is not None and ckpt_path is not None: # Given Local path
    model, model_config = get_model_from_ckpt(model_config_path, ckpt_path)

 

여기까지 state_dict를 다루는 것에 대해 설명했는데, 이를 활용해 각각의 state_dict.keys()의 값에 따라 불러오고 싶은 weight만 불러오거나, weight에서 정보를 얻어내는 방법을 다양하게 구현할 수 있습니다.


 

728x90
반응형