← 문제 목록/Orthogonal Initialization (직교 초기화) [medium]
문제 해설

Orthogonal Initialization (직교 초기화) [medium]

신경망 · medium

preview

Orthogonal Initialization [medium]

v1 He init 은 iid Gaussian. 이번엔 직교(orthogonal) 초기화:

WTW=g2I(if fan_infan_out)W^T W = g^2 I \quad (\text{if fan\_in} \geq \text{fan\_out})

Saxe et al. (2013) 이 deep linear nets 분석에서 도입. RNN/Transformer에서 hidden state transition 에 자주 사용.

왜 직교인가

  • 직교 행렬은 singular values 가 모두 1 → 곱셈 시 norm 보존.
  • Deep RNN 에서 gradient explode/vanish 완화 (spectral radius = 1).
  • 활성화 분산이 layer 간 그대로 전파됨.

알고리즘 (QR decomposition)

  1. A = rng.normal(shape=(max(fi,fo), max(fi,fo))) 또는 필요한 직사각형.
  2. Q, R = np.linalg.qr(A).
  3. Sign 보정: Q *= np.sign(np.diag(R)) — QR 유일성 확보.
  4. W = Q[:fan_in, :fan_out] 를 잘라내고 gain 곱.

더 일반적으로:

  • fan_in >= fan_out: A shape (fan_in, fan_out), W = Q
  • fan_in < fan_out: A shape (fan_out, fan_in), W = Q.T (rows orthonormal)

과제

함수 orthogonal_init(fan_in, fan_out, seed, gain=1.0) 를 완성하세요.

  • 반환 shape: (fan_in, fan_out).
  • 직교성: fan_in >= fan_out 이면 W.T @ W = gain^2 · I; 반대면 W @ W.T = gain^2 · I.

테스트 케이스

#이름검증
1shape (fan_in, fan_out)
2직교성 (tall: W^T W = g^2 I)
3직교성 (wide: W W^T = g^2 I)
4gain 스케일링gain=2 → 직교 2배
5시드 재현성
6다른 시드 → 다른 W
7Norm 보존: ‖Wx‖ = g·‖x‖tall 케이스
코드 작성
Loading...
실행 결과

코드를 작성하고 Run 을 눌러보세요.