@ -4,7 +4,7 @@ import logging
from typing import Any , Dict , List , Optional
from langchain_core . embeddings import Embeddings
from langchain_core . pydantic_v1 import BaseModel , root_validator
from langchain_core . pydantic_v1 import BaseModel , Field, root_validator
from langchain_core . utils import convert_to_secret_str , get_from_dict_or_env
logger = logging . getLogger ( __name__ )
@ -41,8 +41,12 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
client : Any
""" Qianfan client """
max_retries : int = 5
""" Max reties times """
init_kwargs : Dict [ str , Any ] = Field ( default_factory = dict )
""" init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS """
model_kwargs : Dict [ str , Any ] = Field ( default_factory = dict )
""" extra params for model invoke using with `do`. """
@root_validator ( )
def validate_environment ( cls , values : Dict ) - > Dict :
@ -88,6 +92,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
import qianfan
params = {
* * values . get ( " init_kwargs " , { } ) ,
" model " : values [ " model " ] ,
}
if values [ " qianfan_ak " ] . get_secret_value ( ) != " " :
@ -125,7 +130,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
]
lst = [ ]
for chunk in text_in_chunks :
resp = self . client . do ( texts = chunk )
resp = self . client . do ( texts = chunk , * * self . model_kwargs )
lst . extend ( [ res [ " embedding " ] for res in resp [ " data " ] ] )
return lst
@ -140,7 +145,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
]
lst = [ ]
for chunk in text_in_chunks :
resp = await self . client . ado ( texts = chunk )
resp = await self . client . ado ( texts = chunk , * * self . model_kwargs )
for res in resp [ " data " ] :
lst . extend ( [ res [ " embedding " ] ] )
return lst