community: add model_name param valid for GPT4AllEmbeddings (#23867)

Description: add model_name param valid for GPT4AllEmbeddings

Issue: #23863 #22819

---------

Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
wenngong 2024-07-05 22:46:34 +08:00 committed by GitHub
parent a4eb6d0fb1
commit b1e90b3075
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,21 +22,20 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
)
"""
model_name: str
model_name: Optional[str] = None
n_threads: Optional[int] = None
device: Optional[str] = "cpu"
gpt4all_kwargs: Optional[dict] = {}
client: Any #: :meta private:
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that GPT4All library is installed."""
try:
from gpt4all import Embed4All
values["client"] = Embed4All(
model_name=values["model_name"],
model_name=values.get("model_name"),
n_threads=values.get("n_threads"),
device=values.get("device"),
**values.get("gpt4all_kwargs"),