diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index cf59a300d4..a32aef6561 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -15,7 +15,8 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): .. code-block:: python from langchain.embeddings import HuggingFaceEmbeddings - huggingface = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") + model_name = "sentence-transformers/all-mpnet-base-v2" + huggingface = HuggingFaceEmbeddings(model_name=model_name) """ client: Any #: :meta private: @@ -23,6 +24,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): """Model name to use.""" def __init__(self, **kwargs: Any): + """Initialize the sentence_transformer.""" super().__init__(**kwargs) try: import sentence_transformers @@ -40,7 +42,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): extra = Extra.forbid def embed_documents(self, texts: List[str]) -> List[List[float]]: - """Computes doc embeddings using a HuggingFace transformer model + """Compute doc embeddings using a HuggingFace transformer model. Args: texts: The list of texts to embed. @@ -53,7 +55,7 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): return embeddings def embed_query(self, text: str) -> List[float]: - """Computes query embeddings using a HuggingFace transformer model + """Compute query embeddings using a HuggingFace transformer model. Args: text: The text to embed.