diff --git a/libs/langchain/langchain/utils/math.py b/libs/langchain/langchain/utils/math.py index 76c6ed3d8f..ae3c291934 100644 --- a/libs/langchain/langchain/utils/math.py +++ b/libs/langchain/langchain/utils/math.py @@ -46,11 +46,11 @@ def cosine_similarity_top_k( if len(X) == 0 or len(Y) == 0: return [], [] score_array = cosine_similarity(X, Y) - sorted_idxs = score_array.flatten().argsort()[::-1] - top_k = top_k or len(sorted_idxs) - top_idxs = sorted_idxs[:top_k] score_threshold = score_threshold or -1.0 - top_idxs = top_idxs[score_array.flatten()[top_idxs] > score_threshold] - ret_idxs = [(x // score_array.shape[1], x % score_array.shape[1]) for x in top_idxs] - scores = score_array.flatten()[top_idxs].tolist() - return ret_idxs, scores + score_array[score_array < score_threshold] = 0 + top_k = min(top_k or len(score_array), np.count_nonzero(score_array)) + top_k_idxs = np.argpartition(score_array, -top_k, axis=None)[-top_k:] + top_k_idxs = top_k_idxs[np.argsort(score_array.ravel()[top_k_idxs])][::-1] + ret_idxs = np.unravel_index(top_k_idxs, score_array.shape) + scores = score_array.ravel()[top_k_idxs].tolist() + return list(zip(*ret_idxs)), scores # type: ignore