본문 바로가기
[GPUaaS]/GPUmgt

[분산 학습] NaN 반드시 알아야 하는 개념 !!

by METAVERSE STORY 2026. 2. 21.
반응형

 

 

 


1️⃣ NaN 이란?

NaN = Not a Number
👉 “숫자가 아니다” 라는 뜻입니다.

컴퓨터에서 계산 결과가 정상적인 숫자로 표현될 수 없을 때 나오는 값입니다.

예시:

 
 
0 / 0
sqrt(-1)   # 실수 범위에서
log(-5)
 

이런 연산을 하면 결과가 NaN이 됩니다.

 

 


2️⃣ 딥러닝 학습에서 NaN이 의미하는 것

학습 중 NaN이 발생한다는 건:

🔥 모델 계산이 망가졌다는 의미

주로 이런 상황입니다:

  • Loss가 갑자기 nan
  • gradient가 nan
  • weight 값이 nan
  • 전체 학습이 멈춤 (watchdog 발생 가능)

 

 


3️⃣ 대규모 노드 분산 학습에서 왜 더 위험한가?

노드가 많을수록 위험도가 커집니다.

이유

  1. 하나의 GPU에서 NaN 발생
  2. AllReduce 통신으로 전체 노드에 전파
  3. 64노드 전체가 오염
  4. 학습 즉시 중단

특히:

  • NVIDIA NCCL AllReduce
  • PyTorch DDP
  • TensorFlow MirroredStrategy

같은 분산 통신 구조에서는 NaN이 전염됩니다.

 

 


4️⃣ NaN이 발생하는 주요 원인

초보자 기준 가장 많이 발생하는 순서입니다.


① Learning Rate 너무 큼 (가장 흔함)

학습률이 너무 크면:

  • weight 값이 폭발
  • gradient explosion
  • 결국 NaN

특히 64노드에서 batch size 증가 시 LR scaling 잘못하면 발생합니다.


② Gradient Explosion

역전파 시 gradient 값이 너무 커지는 경우

 
 
1e3 → 1e6 → 1e12 → inf → nan
 

 


③ FP16 / BF16 혼합정밀 문제

AMP 사용 시:

  • underflow
  • overflow

loss scaling 실패 시 NaN 발생


④ 데이터 문제

  • 입력값에 이미 NaN 포함
  • 너무 큰 값 (예: 1e30)
  • 0으로 나누는 연산 발생

⑤ 로그 / 제곱근 수식 문제

예:

 
 
log(0)
sqrt(negative)
 

 

 


5️⃣ NaN이 발생하면 어떤 증상이 보이나?

학습 로그 예시

 
 
loss: 2.31
loss: 1.89
loss: 0.94
loss: nan
loss: nan
loss: nan
 

또는

 
 
RuntimeError: Function 'XXXBackward' returned nan values
 

또는

 
 
CUDA error: device-side assert triggered
 

 

 


6️⃣ 왜 Watchdog이 걸릴 수도 있나?

64노드에서:

  1. 한 노드에서 NaN 발생
  2. 통신 sync 대기
  3. 다른 노드는 정상값 계산
  4. AllReduce mismatch
  5. 일부 프로세스 hang
  6. watchdog timeout

그래서 NaN → 통신 스턱 → 워치독 흐름이 자주 나옵니다.

 

 


7️⃣ 상황 해석

"대규모 노드로 학습 중이며 아직 NaN 안 나옴"

이건 매우 좋은 신호입니다 ✅

즉:

  • LR 스케일링 정상
  • AMP 안정적
  • 통신 문제 없음
  • 데이터 정합성 OK

 

 


8️⃣ NaN 사전 체크 방법 (실전용)

PyTorch에서 체크

 
 
 
if torch.isnan(loss):
    print("NaN detected")

 

 

Gradient 체크

 
 
for p in model.parameters():
    if torch.isnan(p.grad).any():
        print("NaN in gradient")
 

 

 


9️⃣ 대규모 노드 환경에서 권장 안정화 방법

✔ Learning rate 재검증
✔ Grad clipping 적용
✔ AMP 사용 시 dynamic loss scaling
✔ input normalize 확인
✔ 처음 몇 step 모니터링 강화

 

 


🔟 정리 (초보자용 한줄 요약)

NaN은 “계산이 터졌다”는 신호이며
분산 학습에서는 한 GPU의 NaN이 전체 노드를 멈출 수 있다.

 

 

 

반응형

댓글