From 32ba8cfab0694870996cb4c8c45d7a1adfb681ba Mon Sep 17 00:00:00 2001 From: JonZeolla Date: Wed, 12 Jun 2024 13:30:56 -0400 Subject: [PATCH] community[minor]: implement huggingface show_progress consistently (#22682) - **Description:** This implements `show_progress` more consistently (i.e. it is also added to the `HuggingFaceBgeEmbeddings` object). - **Issue:** This implements `show_progress` more consistently in the embeddings huggingface classes. Previously this could have been set via `encode_kwargs`. - **Dependencies:** None - **Twitter handle:** @jonzeolla --- .../embeddings/huggingface.py | 60 +++++++++++++++++-- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/libs/community/langchain_community/embeddings/huggingface.py b/libs/community/langchain_community/embeddings/huggingface.py index 8562e0a5df..f5aedbebb6 100644 --- a/libs/community/langchain_community/embeddings/huggingface.py +++ b/libs/community/langchain_community/embeddings/huggingface.py @@ -1,7 +1,8 @@ +import warnings from typing import Any, Dict, List, Optional import requests -from langchain_core._api import deprecated +from langchain_core._api import deprecated, warn_deprecated from langchain_core.embeddings import Embeddings from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr @@ -154,6 +155,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): """Instruction to use for embedding documents.""" query_instruction: str = DEFAULT_QUERY_INSTRUCTION """Instruction to use for embedding query.""" + show_progress: bool = False + """Whether to show a progress bar.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -167,6 +170,20 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): except ImportError as e: raise ImportError("Dependencies for InstructorEmbedding not found.") from e + if "show_progress_bar" in self.encode_kwargs: + warn_deprecated( + since="0.2.5", + removal="0.4.0", + name="encode_kwargs['show_progress_bar']", + alternative=f"the show_progress method on {self.__class__.__name__}", + ) + if self.show_progress: + warnings.warn( + "Both encode_kwargs['show_progress_bar'] and show_progress are set;" + "encode_kwargs['show_progress_bar'] takes precedence" + ) + self.show_progress = self.encode_kwargs.pop("show_progress_bar") + class Config: """Configuration for this pydantic object.""" @@ -182,7 +199,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ instruction_pairs = [[self.embed_instruction, text] for text in texts] - embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs) + embeddings = self.client.encode( + instruction_pairs, + show_progress_bar=self.show_progress, + **self.encode_kwargs, + ) return embeddings.tolist() def embed_query(self, text: str) -> List[float]: @@ -195,7 +216,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings): Embeddings for the text. """ instruction_pair = [self.query_instruction, text] - embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0] + embedding = self.client.encode( + [instruction_pair], + show_progress_bar=self.show_progress, + **self.encode_kwargs, + )[0] return embedding.tolist() @@ -252,6 +277,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): """Instruction to use for embedding query.""" embed_instruction: str = "" """Instruction to use for embedding document.""" + show_progress: bool = False + """Whether to show a progress bar.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -268,9 +295,24 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): self.client = sentence_transformers.SentenceTransformer( self.model_name, cache_folder=self.cache_folder, **self.model_kwargs ) + if "-zh" in self.model_name: self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH + if "show_progress_bar" in self.encode_kwargs: + warn_deprecated( + since="0.2.5", + removal="0.4.0", + name="encode_kwargs['show_progress_bar']", + alternative=f"the show_progress method on {self.__class__.__name__}", + ) + if self.show_progress: + warnings.warn( + "Both encode_kwargs['show_progress_bar'] and show_progress are set;" + "encode_kwargs['show_progress_bar'] takes precedence" + ) + self.show_progress = self.encode_kwargs.pop("show_progress_bar") + class Config: """Configuration for this pydantic object.""" @@ -286,7 +328,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): List of embeddings, one for each text. """ texts = [self.embed_instruction + t.replace("\n", " ") for t in texts] - embeddings = self.client.encode(texts, **self.encode_kwargs) + embeddings = self.client.encode( + texts, show_progress_bar=self.show_progress, **self.encode_kwargs + ) return embeddings.tolist() def embed_query(self, text: str) -> List[float]: @@ -300,7 +344,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): """ text = text.replace("\n", " ") embedding = self.client.encode( - self.query_instruction + text, **self.encode_kwargs + self.query_instruction + text, + show_progress_bar=self.show_progress, + **self.encode_kwargs, ) return embedding.tolist() @@ -353,7 +399,9 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings): Example: .. code-block:: python - from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings + from langchain_community.embeddings import ( + HuggingFaceInferenceAPIEmbeddings, + ) hf_embeddings = HuggingFaceInferenceAPIEmbeddings( api_key="your_api_key",