|
|
|
@ -26,7 +26,7 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
# query_result = embeddings.embed_query(texts)
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
_client: Any = Field(default=None, exclude=True) #: :meta private:
|
|
|
|
|
client: Any = Field(default=None, exclude=True) #: :meta private:
|
|
|
|
|
model: str = Field(default="embedding-2")
|
|
|
|
|
"""Model name"""
|
|
|
|
|
api_key: str
|
|
|
|
@ -39,7 +39,7 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
try:
|
|
|
|
|
from zhipuai import ZhipuAI
|
|
|
|
|
|
|
|
|
|
values["_client"] = ZhipuAI(api_key=values["api_key"])
|
|
|
|
|
values["client"] = ZhipuAI(api_key=values["api_key"])
|
|
|
|
|
except ImportError:
|
|
|
|
|
raise ImportError(
|
|
|
|
|
"Could not import zhipuai python package."
|
|
|
|
@ -71,6 +71,6 @@ class ZhipuAIEmbeddings(BaseModel, Embeddings):
|
|
|
|
|
A list of embeddings for each document in the input list.
|
|
|
|
|
Each embedding is represented as a list of float values.
|
|
|
|
|
"""
|
|
|
|
|
resp = self._client.embeddings.create(model=self.model, input=texts)
|
|
|
|
|
resp = self.client.embeddings.create(model=self.model, input=texts)
|
|
|
|
|
embeddings = [r.embedding for r in resp.data]
|
|
|
|
|
return embeddings
|
|
|
|
|