fix bug 
The ZhipuAIEmbeddings class is not working.

Co-authored-by: xu yandong <shaonian@acsx1.onexmail.com>
This commit is contained in:
xyd 2024-06-20 21:04:50 +08:00 committed by GitHub
parent ad7f2ec67d
commit 9b3a025f9c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -26,7 +26,7 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
# query_result = embeddings.embed_query(texts) # query_result = embeddings.embed_query(texts)
""" """
_client: Any = Field(default=None, exclude=True) #: :meta private: client: Any = Field(default=None, exclude=True) #: :meta private:
model: str = Field(default="embedding-2") model: str = Field(default="embedding-2")
"""Model name""" """Model name"""
api_key: str api_key: str
@ -39,7 +39,7 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
try: try:
from zhipuai import ZhipuAI from zhipuai import ZhipuAI
values["_client"] = ZhipuAI(api_key=values["api_key"]) values["client"] = ZhipuAI(api_key=values["api_key"])
except ImportError: except ImportError:
raise ImportError( raise ImportError(
"Could not import zhipuai python package." "Could not import zhipuai python package."
@ -71,6 +71,6 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
A list of embeddings for each document in the input list. A list of embeddings for each document in the input list.
Each embedding is represented as a list of float values. Each embedding is represented as a list of float values.
""" """
resp = self._client.embeddings.create(model=self.model, input=texts) resp = self.client.embeddings.create(model=self.model, input=texts)
embeddings = [r.embedding for r in resp.data] embeddings = [r.embedding for r in resp.data]
return embeddings return embeddings