community:qianfan endpoint support init params & remove useless params definietion (#15381)

- **Description:**
- support custom kwargs in object initialization. For instantance, QPS
differs from multiple object(chat/completion/embedding with diverse
models), for which global env is not a good choice for configuration.
  - **Issue:** no
  - **Dependencies:** no
  - **Twitter handle:** no

@baskaryan PTAL
pull/14852/head
NuODaniel 6 months ago committed by GitHub
parent 26f84b74d0
commit 7773943a51
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -83,7 +83,12 @@ class QianfanChatEndpoint(BaseChatModel):
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""
client: Any
@ -134,6 +139,7 @@ class QianfanChatEndpoint(BaseChatModel):
)
)
params = {
**values.get("init_kwargs", {}),
"model": values["model"],
"stream": values["streaming"],
}

@ -4,7 +4,7 @@ import logging
from typing import Any, Dict, List, Optional
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
@ -41,8 +41,12 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
client: Any
"""Qianfan client"""
max_retries: int = 5
"""Max reties times"""
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -88,6 +92,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
import qianfan
params = {
**values.get("init_kwargs", {}),
"model": values["model"],
}
if values["qianfan_ak"].get_secret_value() != "":
@ -125,7 +130,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
]
lst = []
for chunk in text_in_chunks:
resp = self.client.do(texts=chunk)
resp = self.client.do(texts=chunk, **self.model_kwargs)
lst.extend([res["embedding"] for res in resp["data"]])
return lst
@ -140,7 +145,7 @@ class QianfanEmbeddingsEndpoint(BaseModel, Embeddings):
]
lst = []
for chunk in text_in_chunks:
resp = await self.client.ado(texts=chunk)
resp = await self.client.ado(texts=chunk, **self.model_kwargs)
for res in resp["data"]:
lst.extend([res["embedding"]])
return lst

@ -40,7 +40,12 @@ class QianfanLLMEndpoint(LLM):
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
init_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""init kwargs for qianfan client init, such as `query_per_second` which is
associated with qianfan resource object to limit QPS"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""extra params for model invoke using with `do`."""
client: Any
@ -91,6 +96,7 @@ class QianfanLLMEndpoint(LLM):
)
params = {
**values.get("init_kwargs", {}),
"model": values["model"],
}
if values["qianfan_ak"].get_secret_value() != "":

@ -217,3 +217,18 @@ def test_functions_call() -> None:
chain = prompt | chat.bind(functions=_FUNCTIONS)
resp = chain.invoke({})
assert isinstance(resp, AIMessage)
def test_rate_limit() -> None:
chat = QianfanChatEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
assert chat.client._client._rate_limiter._sync_limiter._query_per_second == 2
responses = chat.batch(
[
[HumanMessage(content="Hello")],
[HumanMessage(content="who are you")],
[HumanMessage(content="what is baidu")],
]
)
for res in responses:
assert isinstance(res, BaseMessage)
assert isinstance(res.content, str)

@ -25,3 +25,15 @@ def test_model() -> None:
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
output = embedding.embed_documents(documents)
assert len(output) == 2
def test_rate_limit() -> None:
llm = QianfanEmbeddingsEndpoint(
model="Embedding-V1", init_kwargs={"query_per_second": 2}
)
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
documents = ["foo", "bar"]
output = llm.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 384
assert len(output[1]) == 384

@ -33,3 +33,11 @@ async def test_qianfan_aio() -> None:
async for token in llm.astream("hi qianfan."):
assert isinstance(token, str)
def test_rate_limit() -> None:
llm = QianfanLLMEndpoint(model="ERNIE-Bot", init_kwargs={"query_per_second": 2})
assert llm.client._client._rate_limiter._sync_limiter._query_per_second == 2
output = llm.generate(["write a joke"])
assert isinstance(output, LLMResult)
assert isinstance(output.generations, list)

Loading…
Cancel
Save