← 문제 목록/Weighted Mean Pool (attention pool 기본)
문제 해설

Weighted Mean Pool (attention pool 기본)

NLP · easy

preview

Weighted Mean Pool

98번 Masked Mean Pool 은 mask {0, 1} 로 유효 위치만 평균. Weighted 버전은 임의의 실수 weight (학습 가능, 예: attention 가중치) 로 평균:

pool(x,w)b=lwb,lxb,l,:lwb,l\text{pool}(x, w)_b = \frac{\sum_{l} w_{b,l} \cdot x_{b,l,:}}{\sum_l w_{b,l}}

  • x shape (B, L, d), w shape (B, L) 양의 실수.
  • 마스크는 w{0,1}w \in \{0, 1\} 의 특수 경우.
  • Attention pool: w=softmax-v1(s)w = \text{softmax-v1}(s) 로 만들어 쓰면 그대로 attention pooling.

어디에 쓰이나

  • Sentence BERT pooling 변형.
  • Graph Neural Network 의 weighted readout.
  • Set Transformer 의 attention pool.

수치 안정

  • ww 가 전부 0 이면 0 나누기 발생 → clip(min=1e-9).
  • ww 가 negative 도 이론적으로 가능하지만 이 문제에선 w0w \ge 0 가정.

과제

함수 weighted_mean_pool(x, w) 를 완성하세요.

  • x (B, L, d), w (B, L), 반환 (B, d).
  • 힌트: num = (x * w[..., None]).sum(axis=1); den = w.sum(axis=1, keepdims=True).clip(min=1e-9).

테스트 케이스

#이름검증
1shape (B, d)
2w 전부 1 → 평범한 mean
3one-hot w → 해당 위치 그대로w[i]=1, 나머지 0x[:,i,:]
4softmax-v1 w → attention pool
5w 스케일 불변w*2 → 동일 결과
코드 작성
Loading...
실행 결과

코드를 작성하고 Run 을 눌러보세요.