community[minor]: implement huggingface show_progress consistently (#22682)

- **Description:** This implements `show_progress` more consistently
(i.e. it is also added to the `HuggingFaceBgeEmbeddings` object).
- **Issue:** This implements `show_progress` more consistently in the
embeddings huggingface classes. Previously this could have been set via
`encode_kwargs`.
 - **Dependencies:** None
 - **Twitter handle:** @jonzeolla
pull/22831/head
JonZeolla 3 weeks ago committed by GitHub
parent 74e705250f
commit 32ba8cfab0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,8 @@
import warnings
from typing import Any, Dict, List, Optional
import requests
from langchain_core._api import deprecated
from langchain_core._api import deprecated, warn_deprecated
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, SecretStr
@ -154,6 +155,8 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
"""Instruction to use for embedding documents."""
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
"""Instruction to use for embedding query."""
show_progress: bool = False
"""Whether to show a progress bar."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
@ -167,6 +170,20 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
except ImportError as e:
raise ImportError("Dependencies for InstructorEmbedding not found.") from e
if "show_progress_bar" in self.encode_kwargs:
warn_deprecated(
since="0.2.5",
removal="0.4.0",
name="encode_kwargs['show_progress_bar']",
alternative=f"the show_progress method on {self.__class__.__name__}",
)
if self.show_progress:
warnings.warn(
"Both encode_kwargs['show_progress_bar'] and show_progress are set;"
"encode_kwargs['show_progress_bar'] takes precedence"
)
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
class Config:
"""Configuration for this pydantic object."""
@ -182,7 +199,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
instruction_pairs = [[self.embed_instruction, text] for text in texts]
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
embeddings = self.client.encode(
instruction_pairs,
show_progress_bar=self.show_progress,
**self.encode_kwargs,
)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
@ -195,7 +216,11 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
Embeddings for the text.
"""
instruction_pair = [self.query_instruction, text]
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
embedding = self.client.encode(
[instruction_pair],
show_progress_bar=self.show_progress,
**self.encode_kwargs,
)[0]
return embedding.tolist()
@ -252,6 +277,8 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
"""Instruction to use for embedding query."""
embed_instruction: str = ""
"""Instruction to use for embedding document."""
show_progress: bool = False
"""Whether to show a progress bar."""
def __init__(self, **kwargs: Any):
"""Initialize the sentence_transformer."""
@ -268,9 +295,24 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
self.client = sentence_transformers.SentenceTransformer(
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
)
if "-zh" in self.model_name:
self.query_instruction = DEFAULT_QUERY_BGE_INSTRUCTION_ZH
if "show_progress_bar" in self.encode_kwargs:
warn_deprecated(
since="0.2.5",
removal="0.4.0",
name="encode_kwargs['show_progress_bar']",
alternative=f"the show_progress method on {self.__class__.__name__}",
)
if self.show_progress:
warnings.warn(
"Both encode_kwargs['show_progress_bar'] and show_progress are set;"
"encode_kwargs['show_progress_bar'] takes precedence"
)
self.show_progress = self.encode_kwargs.pop("show_progress_bar")
class Config:
"""Configuration for this pydantic object."""
@ -286,7 +328,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
List of embeddings, one for each text.
"""
texts = [self.embed_instruction + t.replace("\n", " ") for t in texts]
embeddings = self.client.encode(texts, **self.encode_kwargs)
embeddings = self.client.encode(
texts, show_progress_bar=self.show_progress, **self.encode_kwargs
)
return embeddings.tolist()
def embed_query(self, text: str) -> List[float]:
@ -300,7 +344,9 @@ class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
"""
text = text.replace("\n", " ")
embedding = self.client.encode(
self.query_instruction + text, **self.encode_kwargs
self.query_instruction + text,
show_progress_bar=self.show_progress,
**self.encode_kwargs,
)
return embedding.tolist()
@ -353,7 +399,9 @@ class HuggingFaceInferenceAPIEmbeddings(BaseModel, Embeddings):
Example:
.. code-block:: python
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
from langchain_community.embeddings import (
HuggingFaceInferenceAPIEmbeddings,
)
hf_embeddings = HuggingFaceInferenceAPIEmbeddings(
api_key="your_api_key",

Loading…
Cancel
Save