diff --git a/langchain/embeddings/huggingface.py b/langchain/embeddings/huggingface.py index bc41fcf3bd..217b6a8448 100644 --- a/langchain/embeddings/huggingface.py +++ b/langchain/embeddings/huggingface.py @@ -1,5 +1,5 @@ """Wrapper around HuggingFace embedding models.""" -from typing import Any, List +from typing import Any, List, Optional from pydantic import BaseModel, Extra @@ -29,6 +29,9 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): client: Any #: :meta private: model_name: str = DEFAULT_MODEL_NAME """Model name to use.""" + cache_folder: Optional[str] = None + """Path to store models. + Can be also set by SENTENCE_TRANSFORMERS_HOME enviroment variable.""" def __init__(self, **kwargs: Any): """Initialize the sentence_transformer.""" @@ -36,7 +39,9 @@ class HuggingFaceEmbeddings(BaseModel, Embeddings): try: import sentence_transformers - self.client = sentence_transformers.SentenceTransformer(self.model_name) + self.client = sentence_transformers.SentenceTransformer( + self.model_name, self.cache_folder + ) except ImportError: raise ValueError( "Could not import sentence_transformers python package. "