← 문제 목록/Attention Pooling (learnable query) [medium]
문제 해설

Attention Pooling (learnable query) [medium]

신경망 · medium

preview

Attention Pooling [medium]

v1 masked meanuniform 가중치. 실무 센텐스 인코더 (USE, Sentence-T5, Gemma embedding) 는 학습된 query vector 로 토큰을 가중 평균:

outb=lwb,lxb,l,wb=softmax ⁣(xbqTd)mask\text{out}_b = \sum_l w_{b,l} \cdot x_{b,l}, \quad w_{b} = \text{softmax}\!\left( \frac{x_b q^T}{\sqrt{d}} \right)_{\text{mask}}

scaled dot-product attention 에서 query 가 토큰당 하나가 아닌 글로벌 학습 벡터 로 고정된 버전.

구현 포인트

  1. scores = x @ q / sqrt(d) shape (B, L).
  2. Pad 위치에 -inf 대입 → softmax 후 0.
  3. weights = softmax(scores, axis=-1) shape (B, L).
  4. out = weights[..., None] · x 합 → (B, d).

수치 안정성: softmax 전 max 빼기.

왜 이게 강한가

  • 중요한 토큰 (감정어, 핵심 명사) 이 자동으로 더 큰 weight 를 받음.
  • Uniform mean 은 stop-word 가 희석 → 문장 벡터 품질 저하.
  • 단 하나의 벡터 q 만 추가 파라미터 (d 개).

과제

함수 attention_pool(x, mask, q) 를 완성하세요.

  • x shape (B, L, d), mask (B, L) bool, q (d,).
  • 반환 shape (B, d).
  • 전부 pad 인 batch 는 없다고 가정.

테스트 케이스

#이름검증
1shape (B, d)
2pad 제외: pad 값 변경해도 출력 불변
3softmax 가중치 합 = 1 (암묵)직접 검증
4q=0 → masked mean (v1) 과 동일모든 score 0
5다른 q → 다른 출력
6한 토큰이 지배 (큰 score)그 토큰 ≈ 출력
7알려진 toy 값
코드 작성
Loading...
실행 결과

코드를 작성하고 Run 을 눌러보세요.