|
|
@ -4,6 +4,7 @@ https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb"""
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import concurrent.futures
|
|
|
|
from typing import Any, List, Optional
|
|
|
|
from typing import Any, List, Optional
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
import numpy as np
|
|
|
@ -14,7 +15,8 @@ from langchain.schema import BaseRetriever, Document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
|
|
|
def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray:
|
|
|
|
return np.array([embeddings.embed_query(split) for split in contexts])
|
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
|
|
|
|
|
|
return np.array(list(executor.map(embeddings.embed_query, contexts)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SVMRetriever(BaseRetriever, BaseModel):
|
|
|
|
class SVMRetriever(BaseRetriever, BaseModel):
|
|
|
|