|
|
|
@ -47,6 +47,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
"""Key word arguments to pass to the model."""
|
|
|
|
|
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
|
"""Key word arguments to pass when calling the `encode` method of the model."""
|
|
|
|
|
multi_process: bool = False
|
|
|
|
|
"""Run encode() on multiple GPUs."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, **kwargs: Any):
|
|
|
|
|
"""Initialize the sentence_transformer."""
|
|
|
|
@ -78,8 +80,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
Returns:
|
|
|
|
|
List of embeddings, one for each text.
|
|
|
|
|
"""
|
|
|
|
|
import sentence_transformers
|
|
|
|
|
|
|
|
|
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
|
|
|
|
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
|
|
|
|
if self.multi_process:
|
|
|
|
|
pool = self.client.start_multi_process_pool()
|
|
|
|
|
embeddings = self.client.encode_multi_process(texts, pool)
|
|
|
|
|
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
|
|
|
|
|
else:
|
|
|
|
|
embeddings = self.client.encode(texts, **self.encode_kwargs)
|
|
|
|
|
|
|
|
|
|
return embeddings.tolist()
|
|
|
|
|
|
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
|
|
@ -91,9 +101,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
Returns:
|
|
|
|
|
Embeddings for the text.
|
|
|
|
|
"""
|
|
|
|
|
text = text.replace("\n", " ")
|
|
|
|
|
embedding = self.client.encode(text, **self.encode_kwargs)
|
|
|
|
|
return embedding.tolist()
|
|
|
|
|
return self.embed_documents([text])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
|
|
|
|
|