mirror of https://github.com/hwchase17/langchain
Add Baidu Qianfan endpoint for LLM (#10496)
- Description: * Baidu AI Cloud's [Qianfan Platform](https://cloud.baidu.com/doc/WENXINWORKSHOP/index.html) is an all-in-one platform for large model development and service deployment, catering to enterprise developers in China. Qianfan Platform offers a wide range of resources, including the Wenxin Yiyan model (ERNIE-Bot) and various third-party open-source models. - Issue: none - Dependencies: * qianfan - Tag maintainer: @baskaryan - Twitter handle: --------- Co-authored-by: Bagatur <baskaryan@gmail.com>pull/10535/head
parent
0a0276bcdb
commit
adabdfdfc7
@ -0,0 +1,293 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_resp_to_message_chunk(resp: Mapping[str, Any]) -> BaseMessageChunk:
|
||||
return AIMessageChunk(
|
||||
content=resp["result"],
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["functions"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
class QianfanChatEndpoint(BaseChatModel):
|
||||
"""Baidu Qianfan chat models.
|
||||
|
||||
To use, you should have the ``qianfan`` python package installed, and
|
||||
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with your
|
||||
API key and Secret Key.
|
||||
|
||||
ak, sk are required parameters
|
||||
which you could get from https://cloud.baidu.com/product/wenxinworkshop
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import QianfanChatEndpoint
|
||||
qianfan_chat = QianfanChatEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
client: Any
|
||||
|
||||
qianfan_ak: Optional[str] = None
|
||||
qianfan_sk: Optional[str] = None
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
penalty_score: Optional[float] = 1
|
||||
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
|
||||
In the case of other model, passing these params will not affect the result.
|
||||
"""
|
||||
|
||||
model: str = "ERNIE-Bot-turbo"
|
||||
"""Model name.
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
@root_validator()
|
||||
def validate_enviroment(cls, values: Dict) -> Dict:
|
||||
values["qianfan_ak"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_ak",
|
||||
"QIANFAN_AK",
|
||||
)
|
||||
values["qianfan_sk"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_sk",
|
||||
"QIANFAN_SK",
|
||||
)
|
||||
params = {
|
||||
"ak": values["qianfan_ak"],
|
||||
"sk": values["qianfan_sk"],
|
||||
"model": values["model"],
|
||||
"stream": values["streaming"],
|
||||
}
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
params["endpoint"] = values["endpoint"]
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
values["client"] = qianfan.ChatCompletion(**params)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"qianfan package not found, please install it with "
|
||||
"`pip install qianfan`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat_model."""
|
||||
return "baidu-qianfan-chat"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
return {
|
||||
**{"messages": [convert_message_to_dict(m) for m in messages]},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""Call out to an qianfan models endpoint for each generation with a prompt.
|
||||
Args:
|
||||
messages: The messages to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = qianfan_model("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=[gen], llm_output=llm_output)
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(messages, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
lc_msg = AIMessage(content=completion, additional_kwargs={})
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
return ChatResult(
|
||||
generations=[gen],
|
||||
llm_output={"token_usage": {}, "model_name": self.model},
|
||||
)
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
response_payload = await self.client.ado(**params)
|
||||
lc_msg = AIMessage(content=response_payload["result"], additional_kwargs={})
|
||||
generations = []
|
||||
gen = ChatGeneration(
|
||||
message=lc_msg,
|
||||
generation_info=dict(finish_reason="finished"),
|
||||
)
|
||||
generations.append(gen)
|
||||
token_usage = response_payload.get("usage", {})
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"],
|
||||
message=_convert_resp_to_message_chunk(res),
|
||||
generation_info={"finish_reason": "finished"},
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(messages, **kwargs)
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
chunk = ChatGenerationChunk(
|
||||
text=res["result"], message=_convert_resp_to_message_chunk(res)
|
||||
)
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
@ -0,0 +1,217 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class QianfanLLMEndpoint(LLM):
|
||||
"""Baidu Qianfan hosted open source or customized models.
|
||||
|
||||
To use, you should have the ``qianfan`` python package installed, and
|
||||
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with
|
||||
your API key and Secret Key.
|
||||
|
||||
ak, sk are required parameters which you could get from
|
||||
https://cloud.baidu.com/product/wenxinworkshop
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import QianfanLLMEndpoint
|
||||
qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
|
||||
endpoint="your_endpoint", ak="your_ak", sk="your_sk")
|
||||
"""
|
||||
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
client: Any
|
||||
|
||||
qianfan_ak: Optional[str] = None
|
||||
qianfan_sk: Optional[str] = None
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the results or not."""
|
||||
|
||||
model: str = "ERNIE-Bot-turbo"
|
||||
"""Model name.
|
||||
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
|
||||
|
||||
preset models are mapping to an endpoint.
|
||||
`model` will be ignored if `endpoint` is set
|
||||
"""
|
||||
|
||||
endpoint: Optional[str] = None
|
||||
"""Endpoint of the Qianfan LLM, required if custom model used."""
|
||||
|
||||
request_timeout: Optional[int] = 60
|
||||
"""request timeout for chat http requests"""
|
||||
|
||||
top_p: Optional[float] = 0.8
|
||||
temperature: Optional[float] = 0.95
|
||||
penalty_score: Optional[float] = 1
|
||||
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
|
||||
In the case of other model, passing these params will not affect the result.
|
||||
"""
|
||||
|
||||
@root_validator()
|
||||
def validate_enviroment(cls, values: Dict) -> Dict:
|
||||
values["qianfan_ak"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_ak",
|
||||
"QIANFAN_AK",
|
||||
)
|
||||
values["qianfan_sk"] = get_from_dict_or_env(
|
||||
values,
|
||||
"qianfan_sk",
|
||||
"QIANFAN_SK",
|
||||
)
|
||||
|
||||
params = {
|
||||
"ak": values["qianfan_ak"],
|
||||
"sk": values["qianfan_sk"],
|
||||
"model": values["model"],
|
||||
}
|
||||
if values["endpoint"] is not None and values["endpoint"] != "":
|
||||
params["endpoint"] = values["endpoint"]
|
||||
try:
|
||||
import qianfan
|
||||
|
||||
values["client"] = qianfan.Completion(**params)
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"qianfan package not found, please install it with "
|
||||
"`pip install qianfan`"
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
return {
|
||||
**{"endpoint": self.endpoint, "model": self.model},
|
||||
**super()._identifying_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "baidu-qianfan-endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
normal_params = {
|
||||
"stream": self.streaming,
|
||||
"request_timeout": self.request_timeout,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_score": self.penalty_score,
|
||||
}
|
||||
|
||||
return {**normal_params, **self.model_kwargs}
|
||||
|
||||
def _convert_prompt_msg_params(
|
||||
self,
|
||||
prompt: str,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
return {
|
||||
**{"prompt": prompt, "model": self.model},
|
||||
**self._default_params,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to an qianfan models endpoint for each generation with a prompt.
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
response = qianfan_model("Tell me a joke.")
|
||||
"""
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
response_payload = self.client.do(**params)
|
||||
|
||||
return response_payload["result"]
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
response_payload = await self.client.ado(**params)
|
||||
|
||||
return response_payload["result"]
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
|
||||
for res in self.client.do(**params):
|
||||
if res:
|
||||
chunk = GenerationChunk(text=res["result"])
|
||||
yield chunk
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[GenerationChunk]:
|
||||
params = self._convert_prompt_msg_params(prompt, **kwargs)
|
||||
async for res in await self.client.ado(**params):
|
||||
if res:
|
||||
chunk = GenerationChunk(text=res["result"])
|
||||
|
||||
yield chunk
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(chunk.text)
|
@ -0,0 +1,85 @@
|
||||
"""Test Baidu Qianfan Chat Endpoint."""
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.baidu_qianfan_endpoint import QianfanChatEndpoint
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_default_call() -> None:
|
||||
"""Test default model(`ERNIE-Bot`) call."""
|
||||
chat = QianfanChatEndpoint()
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_model() -> None:
|
||||
"""Test model kwarg works."""
|
||||
chat = QianfanChatEndpoint(model="BLOOMZ-7B")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_endpoint() -> None:
|
||||
"""Test user custom model deployments like some open source models."""
|
||||
chat = QianfanChatEndpoint(endpoint="qianfan_bloomz_7b_compressed")
|
||||
response = chat(messages=[HumanMessage(content="Hello")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = QianfanChatEndpoint(streaming=True)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = QianfanChatEndpoint()
|
||||
message = HumanMessage(content="Hi, how are you.")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@ -0,0 +1,25 @@
|
||||
"""Test Baidu Qianfan Embedding Endpoint."""
|
||||
from langchain.embeddings.baidu_qianfan_endpoint import QianfanEmbeddingsEndpoint
|
||||
|
||||
|
||||
def test_embedding_multiple_documents() -> None:
|
||||
documents = ["foo", "bar"]
|
||||
embedding = QianfanEmbeddingsEndpoint()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) == 384
|
||||
assert len(output[1]) == 384
|
||||
|
||||
|
||||
def test_embedding_query() -> None:
|
||||
query = "foo"
|
||||
embedding = QianfanEmbeddingsEndpoint()
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) == 384
|
||||
|
||||
|
||||
def test_model() -> None:
|
||||
documents = ["hi", "qianfan"]
|
||||
embedding = QianfanEmbeddingsEndpoint(model="Embedding-V1")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
@ -0,0 +1,37 @@
|
||||
"""Test Baidu Qianfan LLM Endpoint."""
|
||||
from typing import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def test_call() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
output = llm("write a joke")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_generate() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
output = llm.generate(["write a joke"])
|
||||
assert isinstance(output, LLMResult)
|
||||
assert isinstance(output.generations, list)
|
||||
|
||||
|
||||
def test_generate_stream() -> None:
|
||||
"""Test valid call to qianfan."""
|
||||
llm = QianfanLLMEndpoint()
|
||||
output = llm.stream("write a joke")
|
||||
assert isinstance(output, Generator)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_qianfan_aio() -> None:
|
||||
llm = QianfanLLMEndpoint(streaming=True)
|
||||
|
||||
async for token in llm.astream("hi qianfan."):
|
||||
assert isinstance(token, str)
|
Loading…
Reference in New Issue