langchain/libs/community/langchain_community/llms/baidu_qianfan_endpoint.py
Liu Jun b0c48dc983
community[patch]: make ak and sk optional in qianfan endpoint (#14835)
- **Description:** The Qianfan SDK offers multiple authentication
methods, but in the `QianfanEndpoint` of Langchain, it currently only
supports authentication through AK and SK. In order to accommodate users
who wish to use alternative authentication methods, this pull request
makes AK and SK optional. This change should not impact existing users,
while allowing users to configure other authentication methods as per
the Qianfan SDK documentation.
  - **Issue:** /
  - **Dependencies:** No
  - **Tag maintainer:** No
  - **Twitter handle:**
2023-12-20 00:49:33 -05:00

229 lines
7.1 KiB
Python

from __future__ import annotations
import logging
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import LLM
from langchain_core.outputs import GenerationChunk
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
logger = logging.getLogger(__name__)
class QianfanLLMEndpoint(LLM):
"""Baidu Qianfan hosted open source or customized models.
To use, you should have the ``qianfan`` python package installed, and
the environment variable ``qianfan_ak`` and ``qianfan_sk`` set with
your API key and Secret Key.
ak, sk are required parameters which you could get from
https://cloud.baidu.com/product/wenxinworkshop
Example:
.. code-block:: python
from langchain_community.llms import QianfanLLMEndpoint
qianfan_model = QianfanLLMEndpoint(model="ERNIE-Bot",
endpoint="your_endpoint", qianfan_ak="your_ak", qianfan_sk="your_sk")
"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
client: Any
qianfan_ak: Optional[str] = None
qianfan_sk: Optional[str] = None
streaming: Optional[bool] = False
"""Whether to stream the results or not."""
model: str = "ERNIE-Bot-turbo"
"""Model name.
you could get from https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Nlks5zkzu
preset models are mapping to an endpoint.
`model` will be ignored if `endpoint` is set
"""
endpoint: Optional[str] = None
"""Endpoint of the Qianfan LLM, required if custom model used."""
request_timeout: Optional[int] = 60
"""request timeout for chat http requests"""
top_p: Optional[float] = 0.8
temperature: Optional[float] = 0.95
penalty_score: Optional[float] = 1
"""Model params, only supported in ERNIE-Bot and ERNIE-Bot-turbo.
In the case of other model, passing these params will not affect the result.
"""
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
values["qianfan_ak"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_ak",
"QIANFAN_AK",
default="",
)
)
values["qianfan_sk"] = convert_to_secret_str(
get_from_dict_or_env(
values,
"qianfan_sk",
"QIANFAN_SK",
default="",
)
)
params = {
"model": values["model"],
}
if values["qianfan_ak"].get_secret_value() != "":
params["ak"] = values["qianfan_ak"].get_secret_value()
if values["qianfan_sk"].get_secret_value() != "":
params["sk"] = values["qianfan_sk"].get_secret_value()
if values["endpoint"] is not None and values["endpoint"] != "":
params["endpoint"] = values["endpoint"]
try:
import qianfan
values["client"] = qianfan.Completion(**params)
except ImportError:
raise ImportError(
"qianfan package not found, please install it with "
"`pip install qianfan`"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
**{"endpoint": self.endpoint, "model": self.model},
**super()._identifying_params,
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "baidu-qianfan-endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Qianfan API."""
normal_params = {
"model": self.model,
"endpoint": self.endpoint,
"stream": self.streaming,
"request_timeout": self.request_timeout,
"top_p": self.top_p,
"temperature": self.temperature,
"penalty_score": self.penalty_score,
}
return {**normal_params, **self.model_kwargs}
def _convert_prompt_msg_params(
self,
prompt: str,
**kwargs: Any,
) -> dict:
if "streaming" in kwargs:
kwargs["stream"] = kwargs.pop("streaming")
return {
**{"prompt": prompt, "model": self.model},
**self._default_params,
**kwargs,
}
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to an qianfan models endpoint for each generation with a prompt.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = qianfan_model("Tell me a joke.")
"""
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response_payload = self.client.do(**params)
return response_payload["result"]
async def _acall(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.streaming:
completion = ""
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
completion += chunk.text
return completion
params = self._convert_prompt_msg_params(prompt, **kwargs)
response_payload = await self.client.ado(**params)
return response_payload["result"]
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
for res in self.client.do(**params):
if res:
chunk = GenerationChunk(text=res["result"])
yield chunk
if run_manager:
run_manager.on_llm_new_token(chunk.text)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = self._convert_prompt_msg_params(prompt, **{**kwargs, "stream": True})
async for res in await self.client.ado(**params):
if res:
chunk = GenerationChunk(text=res["result"])
yield chunk
if run_manager:
await run_manager.on_llm_new_token(chunk.text)