Support multi gpu inference for HuggingFaceEmbeddings (#4732)

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/9020/head^2
胡亮 1 year ago committed by GitHub
parent 8aab39e3ce
commit 7edf4ca396
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -47,6 +47,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
"""Key word arguments to pass to the model."""
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Key word arguments to pass when calling the `encode` method of the model."""
multi_process: bool = False
"""Run encode() on multiple GPUs."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
@ -78,8 +80,16 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
import sentence_transformers
texts = list(map(lambda x: x.replace("\n", " "), texts))
embeddings = self.client.encode(texts, **self.encode_kwargs)
if self.multi_process:
pool = self.client.start_multi_process_pool()
embeddings = self.client.encode_multi_process(texts, pool)
sentence_transformers.SentenceTransformer.stop_multi_process_pool(pool)
else:
embeddings = self.client.encode(texts, **self.encode_kwargs)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
@ -91,9 +101,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text.
"""
text = text.replace("\n", " ")
embedding = self.client.encode(text, **self.encode_kwargs)
return embedding.tolist()
return self.embed_documents([text])[0]
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):

Loading…
Cancel
Save