2025. 5. 27. 23:13ㆍDevelopers 공간 [Shorts]/Software Basic
<분류>
A. 수단
- OS/Platform/Tool : Linux, Kubernetes(k8s), Docker, AWS
- Package Manager : node.js, yarn, brew,
- Compiler/Transpillar : React, Nvcc, gcc/g++, Babel, Flutter
- Module Bundler : React, Webpack, Parcel
B. 언어
- C/C++, python, Javacsript, Typescript, Go-Lang, CUDA, Dart, HTML/CSS
C. 라이브러리 및 프레임워크 및 SDK
- OpenCV, OpenCL, FastAPI, PyTorch, Tensorflow, Nsight
1. What? (현상)
이번 글에서는 PyTorch에서 모델이 활용하는 일부 혹은 전체의 Layer를 freeze 하는 방법을 살펴보겠습니다.
Layer를 Freeze하는 방법은 아래와 같이 다섯 가지가 있습니다.
1. detach()
해당 기준 앞의 그래프를 뒤로부터 끊는 방법입니다. 따라서 해당 그래프 앞의 그래프는 아예 autograd에서 분리되기 때문에 gradient가 계산되지 않습니다.
아래와 같이 layer에 적용하는 경우 layer의 output을 기준으로 그래프를 끊지만, Tensor에 적용하면 해당 Tensor를 기준으로 그래프를 끊습니다.
x2 = self.b2(x).detach()
2. requires_grad
이 속성이 False이면 .backward() 시에 gradient가 계산되지 않도록 하는 방법입니다. 모델 학습 파라미터 중 일부만 freeze할 때 주로 사용합니다.
이를 적용하는 방법은 아래와 같이 여러가지가 존재합니다.
# Method1. 선언시
x = torch.randn(5, requires_grad=True)
# Method2. requires_grad 할당
model.b1.linear.weight.requires_grad = True
# Mthod3. requires_grad_ 메소드
model.b2.requires_grad_(True) # 특정 레이어만 활성화
- torch.Tensor : Activation등을 의미하는 텐서로, 위 세가지 모두 가능합니다.
- torch.nn.Parameter : model.layer.weight와 같은 파라미터 값으로, 위 Method2와 Method3 가능합니다.
- torch.nn.Module : model.layer와 같은 레이어 모듈로, 위 Method2와 Method3 가능합니다.
3. torch.no_grad()
해당 블록 내에서 모든 연산이 autograd에 등록되지 않도록 하는 방법입니다. 이때는 위 requires_grad==True인 텐서도 추적하지 않습니다.
with torch.no_grad():
x2 = self.b2(x)
4. torch.inference_mode()
PyTorch 1.10 이후 부터 추가된 기능으로, no_grad() 보다 더 빠르고 메모리 사용이 적습니다.
with torch.inference_mode():
x = self.b3(x)
5. model.eval()
오직 Dropout, BatchNorm 등의 내부 동작을 training 모드와 다르게 기능이 비활성화하는 기능입니다.
model.b3.eval()
그럼 이들이 어떻게 동작하는지 실험해보겠습니다.
2. Why? (원인)
- X
3. How? (해결책)
먼저 실험을 진행한 방법을 살펴보겠습니다. 모델은 아래와 같이 선언해주었습니다.
import torch
import torch.nn as nn
class B(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(10, 10)
def forward(self, x):
return self.linear(x)
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.b1 = B()
self.b2 = B()
self.b3 = B()
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = self.b1(x)
x = self.b2(x)
x = self.b3(x)
x = self.dropout(x)
return x
이제 아래와 같은 코드를 통해 아래와 같은 다양한 정보를 확인해보겠습니다.
- requires_grad 파라미터가 어떻게 셋팅되었는지
- grad weight가 backward() 이후에 어떻게 계산되었는지
class A(torch.nn.Module):
def __init__(self):
super().__init__()
self.b1 = B()
self.b2 = B()
self.b3 = B()
self.dropout = nn.Dropout(p=0.5)
def forward(self, x):
x = self.b1(x)
print(f"After b1: requires_grad={x.requires_grad}, grad_fn={x.grad_fn}")
print(f"After b1: requires_grad_weight={self.b1.linear.weight.requires_grad}")
x = self.b2(x)
print(f"After b2: requires_grad={x.requires_grad}, grad_fn={x.grad_fn}")
print(f"After b2: requires_grad_weight={self.b2.linear.weight.requires_grad}")
x = self.b3(x)
print(f"After b3: requires_grad={x.requires_grad}, grad_fn={x.grad_fn}")
print(f"After b3: requires_grad_weight={self.b3.linear.weight.requires_grad}")
x = self.dropout(x)
print(f"After b4: requires_grad={x.requires_grad}, grad_fn={x.grad_fn}")
print(f"After b4: requires_grad_weight={x}")
return x
model = A()
#model.b1.linear.weight.requires_grad = False
#model.b2.linear.weight.requires_grad = False
#model.b3.linear.weight.requires_grad= False
#model.b1.requires_grad_(False)
#model.b2.requires_grad_(False)
#model.b3.requires_grad_(False)
x = torch.randn(1, 10)
output = model(x)
print(f"Output: requires_grad={output.requires_grad}, grad_fn={output.grad_fn}")
loss = output.sum()
loss.backward()
for name, param in model.named_parameters():
print(f"{name}: grad={param.grad}")
1. detach() in forward()
아래 실험에서는 주로 레이어를 기준으로 적용해서 실험해보겠습니다.
결과적으로 해당 layer에 적용하는 경우 해당 layer의 output을 기준으로 autograd를 끊는 것을 확인할 수 있습니다.
** 비어있는 곳은 모두 정상 동작했음을 의미합니다.
Apply to | b1 | b2 | b3 | dropout | |
b1(x).detach() |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | ||||
b2(x).detach() |
Grad (Weight) |
Weight : None Bias : None |
Weight : None Bias : None |
||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | ||||
b3(x).detach() |
Grad (Weight) |
ERROR | ERROR | ERROR | ERROR |
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | False | |||
x3= b3(x).detach() x = dropout(x)+x3 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
|||||
x2= b2(x).detach() x = b3(x)+x2 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
** ERROR : (.backward() 호출 시) RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
2. requires_grad
nn.Parameter와 nn.Module 두가지에 적용하는 경우를 각각 살펴보겠습니다.
결과적으로 두가지 모두 requires_grad==False로 만들어주면 "그 weight는 미분하지 않겠다"는 의미일 뿐, forward 연산을 통해 나오는 출력 텐서의 추적 여부를 결정하지는 않습니다.
따라서 그래프는 그대로 생성되고 해당 파라미터만 gradient를 저장하지 않습니다.
다만 nn.Module을 한 경우에는 내부에 있는 모든 모듈의 파라미터들을 Disable해주었네요.
a. nn.Parameter 에만 requires_grad 적용
Apply to | b1 | b2 | b3 | dropout | |
b1.weight.requires_grad=False |
Grad (Weight) |
Weight : None |
|||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
|||||
b2.weight.requires_grad=False |
Grad (Weight) |
Weight : None | |||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
|||||
b3.weight.requires_grad=False |
Grad (Weight) |
Weight : None | |||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
False | ||||
x3= b3(x) x = dropout(x)+x3 b3.weight.requires_grad=False |
Grad (Weight) |
Weight : None | |||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
|||||
x2= b2(x) x = b3(x)+x2 b2.weight.requires_grad=False |
Grad (Weight) |
Weight : None | |||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
b. nn.Module 전체에 requires_grad 적용
Apply to | b1 | b2 | b3 | dropout | |
b1.requires_grad=False |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
False | ||||
b2.requires_grad=False |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
|||||
b3.requires_grad=False |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
False | ||||
3= b3(x).requires_grad=False x = dropout(x)+x3 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
|||||
x2= b2(x).requires_grad=False x = b3(x)+x2 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
False | 없음 | |||
Requires_grad (Output) |
3. torch.no_grad() in forward()
결과적으로 requires_grad가 해당 파라미터만 Gradient를 저장하지 않도록 했던 것과 다르게 그래프 자체를 아예 끊어버립니다.
즉, 출력 텐서의 추적 여부도 결정함으로써 graph를 아예 끊어버리게 됩니다.
Apply to | b1 | b2 | b3 | dropout | |
[no_grad()] b1(x) |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | ||||
[no_grad()] b2(x) |
Grad (Weight) |
Weight : None Bias : None |
Weight : None Bias : None |
||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | ||||
[no_grad()] b3(x) |
Grad (Weight) |
ERROR | ERROR | ERROR | ERROR |
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | False | |||
[no_grad()] x3= b3(x) x = dropout(x)+x3 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
|||||
[no_grad()] x2= b2(x) x = b3(x)+x2 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
** ERROR : (.backward() 호출 시) RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
4. torch.inference_mode() in forward()
이는 no_grad의 개선버전이므로 no_grad와 비슷하게 gradient의 연산 추적 자체를 끊어버립니다.
근데 결과를 보면 조금 더 엄격한 것 같습니다.
Apply to | b1 | b2 | b3 | dropout | |
[inference()] b1(x) |
Grad (Weight) |
||||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | ERROR2 | ERROR2 | ERROR2 | |
[inference()] b2(x) |
Grad (Weight) |
||||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | ERROR2 | ERROR2 | ||
[inference()] b3(x) |
Grad (Weight) |
ERROR | ERROR | ERROR | ERROR |
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
False | False | |||
[inference()] x3= b3(x) x = dropout(x)+x3 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
|||||
[inference()] x2= b2(x) x = b3(x)+x2 |
Grad (Weight) |
Weight : None Bias : None |
|||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
** ERROR : (.backward() 호출 시) RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
** ERROR2 : (.forward() 호출 시) RuntimeError: Inference tensors cannot be saved for backward. To work around you can make a clone to get a normal tensor and use it in autograd
5. eval()
해당 레이어를 eval()처리를 해주면 Dropout, BatchNorm 등의 레이어를 비활성화해주는 것을 확인할 수 있습니다.
따라서 다른 레이어는 영향을 안받지만 dropout는 적용하는 경우 비활성화되는 것을 확인할 수 있습니다.
** 비어있는 곳은 모두 정상 동작했음을 의미합니다.
Apply to | b1 | b2 | b3 | dropout | |
b1.eval() |
Dropout | ||||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
|||||
b2.eval() |
Dropout | ||||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
|||||
b3.eval() |
Dropout | ||||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
|||||
dropout.eval() |
Dropout | 적용되지 않음 | |||
Requires_grad (Weight) |
없음 | ||||
Requires_grad (Output) |
https://docs.pytorch.org/docs/stable/notes/autograd.html
https://docs.pytorch.org/docs/1.9.0/notes/autograd.html#inference-mode
'Developers 공간 [Shorts] > Software Basic' 카테고리의 다른 글
[Python] parameter & argument 몇가지 특징 기록 (0) | 2025.05.01 |
---|---|
[Git] git에서 add/commit/push를 했는데 다시 돌리고 싶다 (0) | 2025.04.29 |
[PyTorch] Tensor와의 기본연산시 Broadcasting 문제 (0) | 2025.04.29 |
[Bash] 여러개의 Disk를 모아서 mount하기 (0) | 2025.04.28 |
[PyTorch] scaled dot product에서 Attention Map 디스플레이 하기 (0) | 2025.04.10 |