From 7edf4ca396739588d9053dc15a8d2f649bad6060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E4=BA=AE?= <1271446412@qq.com> Date: Sat, 12 Aug 2023 06:55:44 +0800 Subject: [PATCH] Support multi gpu inference for HuggingFaceEmbeddings (#4732) Co-authored-by: Bagatur --- .../langchain/embeddings/huggingface.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/embeddings/huggingface.py b/libs/langchain/langchain/embeddings/huggingface.py index b2d4a18333..72243705d2 100644 --- a/libs/langchain/langchain/embeddings/huggingface.py +++ b/libs/langchain/langchain/embeddings/huggingface.py @@ -47,6 +47,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): """Key word arguments to pass to the model.""" encode_kwargs: Dict[str, Any] = Field(default_factory=dict) """Key word arguments to pass when calling the `encode` method of the model.""" + multi_process: bool = False + """Run encode() on multiple GPUs.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -78,8 +80,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Returns: List of embeddings, one for each text. """ + import sentence_transformers + texts = list(map(lambda x: x.replace("\n", " "), texts)) - embeddings = self.client.encode(texts, **self.encode_kwargs) + if self.multi_process: + pool = self.client.start_multi_process_pool() + embeddings = self.client.encode_multi_process(texts, pool) + sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool) + else: + embeddings = self.client.encode(texts, **self.encode_kwargs) + return embeddings.tolist() def embed_query(self, text: str) -> List[float]: @@ -91,9 +101,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text. """ - text = text.replace("\n", " ") - embedding = self.client.encode(text, **self.encode_kwargs) - return embedding.tolist() + return self.embed_documents([text])[0] class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):