
v1 pairwise distances 로 (N, M) 거리 행렬을 얻었다면, 보통 다음 단계는 각 query 에 대해 k 최근접 을 뽑는 것. 전체 argsort 는 per row; argpartition 은 per row — 정확도는 같지만 더 빠름.
D = pairwise Euclidean distances shape (N, M).idx_unsorted = np.argpartition(D, k, axis=1)[:, :k].함수 knn_search(X, Y, k) 를 완성하세요.
X shape (N, d) query, Y shape (M, d) database, k 정수.(idx, dists) 튜플.
idx shape (N, k): 각 row 는 query i 의 k 개 nearest Y 인덱스 (거리 오름차순).dists shape (N, k): 해당 거리 (오름차순).D = pairwise_euclidean(X, Y) # (N, M)
idx_unsorted = np.argpartition(D, k, axis=1)[:, :k]
dist_unsorted = np.take_along_axis(D, idx_unsorted, axis=1)
order = np.argsort(dist_unsorted, axis=1) # 거리 기준 재정렬
idx = np.take_along_axis(idx_unsorted, order, axis=1)
dists = np.take_along_axis(dist_unsorted, order, axis=1)
return idx, dists
| # | 이름 | 검증 |
|---|---|---|
| 1 | shape (N, k) × 2 | |
| 2 | 거리 오름차순 | |
| 3 | k=1 → 최근접 1개 | |
| 4 | X ⊂ Y 일 때 자기 자신이 1등 (거리 0) | |
| 5 | idx ∈ [0, M) | |
| 6 | 손계산 toy |
코드를 작성하고 Run 을 눌러보세요.