mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
community: add model_name param valid for GPT4AllEmbeddings (#23867)
Description: add model_name param valid for GPT4AllEmbeddings Issue: #23863 #22819 --------- Co-authored-by: gongwn1 <gongwn1@lenovo.com>
This commit is contained in:
parent
a4eb6d0fb1
commit
b1e90b3075
@ -22,21 +22,20 @@ class GPT4AllEmbeddings(BaseModel, Embeddings):
|
||||
)
|
||||
"""
|
||||
|
||||
model_name: str
|
||||
model_name: Optional[str] = None
|
||||
n_threads: Optional[int] = None
|
||||
device: Optional[str] = "cpu"
|
||||
gpt4all_kwargs: Optional[dict] = {}
|
||||
client: Any #: :meta private:
|
||||
|
||||
@root_validator()
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that GPT4All library is installed."""
|
||||
|
||||
try:
|
||||
from gpt4all import Embed4All
|
||||
|
||||
values["client"] = Embed4All(
|
||||
model_name=values["model_name"],
|
||||
model_name=values.get("model_name"),
|
||||
n_threads=values.get("n_threads"),
|
||||
device=values.get("device"),
|
||||
**values.get("gpt4all_kwargs"),
|
||||
|
Loading…
Reference in New Issue
Block a user