mirror of https://github.com/hwchase17/langchain
Encapsulate alicloud pai-eas access method for chatmodels and llms (#11852)
### Description: To provide an eas llm service access methods in this pull request by impletementing `PaiEasEndpoint` and `PaiEasChatEndpoint` classes in `langchain.llms` and `langchain.chat_models` modules. Base on this pr, langchain users can build up a chain to call remote eas llm service and get the llm inference results. ### About EAS Service EAS is a Alicloud product on Alibaba Cloud Machine Learning Platform for AI which is short for AliCloud PAI. EAS provides model inference deployment services for the users. We build up a llm inference services on EAS with a general llm docker images. Therefore, end users can quickly setup their llm remote instances to load majority of the hugginface llm models, and serve as a backend for most of the llm apps. ### Dependencies This pr does't involve any new dependencies. --------- Co-authored-by: 子洪 <gaoyihong.gyh@alibaba-inc.com>pull/11833/head
parent
1da6d92369
commit
f818ec49b8
@ -0,0 +1,121 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AliCloud PAI EAS\n",
|
||||
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Setup Eas Service\n",
|
||||
"\n",
|
||||
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information. Try to set environment variables to init eas service url and token:\n",
|
||||
"\n",
|
||||
"```base\n",
|
||||
"export EAS_SERVICE_URL=XXX\n",
|
||||
"export EAS_SERVICE_TOKEN=XXX\n",
|
||||
"```\n",
|
||||
"or run as follow codes:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from langchain.chat_models.base import HumanMessage\n",
|
||||
"from langchain.chat_models import PaiEasChatEndpoint\n",
|
||||
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
|
||||
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
|
||||
"chat = PaiEasChatEndpoint(\n",
|
||||
" eas_service_url=os.environ[\"EAS_SERVICE_URL\"], \n",
|
||||
" eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"]\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Run Chat Model\n",
|
||||
"You can use the default settings to call eas service as follows:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"output = chat([HumanMessage(content=\"write a funny joke\")])\n",
|
||||
"print(\"output:\", output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or, call eas service with new inference params:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"kwargs = {\"temperature\": 0.8, \"top_p\": 0.8, \"top_k\": 5}\n",
|
||||
"output = chat([HumanMessage(content=\"write a funny joke\")], **kwargs)\n",
|
||||
"print(\"output:\", output)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Or, run a stream call to get a stream response:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"outputs = chat.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
|
||||
"for output in outputs:\n",
|
||||
" print(\"stream output:\", output)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,93 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# AliCloud PAI EAS\n",
|
||||
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms.pai_eas_endpoint import PaiEasEndpoint\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"\n",
|
||||
"template = \"\"\"Question: {question}\n",
|
||||
"\n",
|
||||
"Answer: Let's think step by step.\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information,"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
|
||||
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
|
||||
"llm = PaiEasEndpoint(eas_service_url=os.environ[\"EAS_SERVICE_URL\"], eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"' Thank you for asking! However, I must respectfully point out that the question contains an error. Justin Bieber was born in 1994, and the Super Bowl was first played in 1967. Therefore, it is not possible for any NFL team to have won the Super Bowl in the year Justin Bieber was born.\\n\\nI hope this clarifies things! If you have any other questions, please feel free to ask.'"
|
||||
]
|
||||
},
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"\n",
|
||||
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
},
|
||||
"orig_nbformat": 4
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -0,0 +1,324 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaiEasChatEndpoint(BaseChatModel):
|
||||
"""Eas LLM Service chat model API.
|
||||
|
||||
To use, must have a deployed eas chat llm service on AliCloud. One can set the
|
||||
environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas
|
||||
service url and service token.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import PaiEasChatEndpoint
|
||||
eas_chat_endpoint = PaiEasChatEndpoint(
|
||||
eas_service_url="your_service_url",
|
||||
eas_service_token="your_service_token"
|
||||
)
|
||||
"""
|
||||
|
||||
"""PAI-EAS Service URL"""
|
||||
eas_service_url: str
|
||||
|
||||
"""PAI-EAS Service TOKEN"""
|
||||
eas_service_token: str
|
||||
|
||||
"""PAI-EAS Service Infer Params"""
|
||||
max_new_tokens: Optional[int] = 512
|
||||
temperature: Optional[float] = 0.8
|
||||
top_p: Optional[float] = 0.1
|
||||
top_k: Optional[int] = 10
|
||||
do_sample: Optional[bool] = False
|
||||
use_cache: Optional[bool] = True
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
"""Enable stream chat mode."""
|
||||
streaming: bool = False
|
||||
|
||||
"""Key/value arguments to pass to the model. Reserved for future use"""
|
||||
model_kwargs: Optional[dict] = None
|
||||
|
||||
version: Optional[str] = "2.0"
|
||||
|
||||
timeout: Optional[int] = 5000
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["eas_service_url"] = get_from_dict_or_env(
|
||||
values, "eas_service_url", "EAS_SERVICE_URL"
|
||||
)
|
||||
values["eas_service_token"] = get_from_dict_or_env(
|
||||
values, "eas_service_token", "EAS_SERVICE_TOKEN"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
"eas_service_url": self.eas_service_url,
|
||||
"eas_service_token": self.eas_service_token,
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "pai_eas_chat_endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"stop_sequences": [],
|
||||
"do_sample": self.do_sample,
|
||||
"use_cache": self.use_cache,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
self, stop_sequences: Optional[List[str]], **kwargs: Any
|
||||
) -> dict:
|
||||
params = self._default_params
|
||||
if self.model_kwargs:
|
||||
params.update(self.model_kwargs)
|
||||
if self.stop_sequences is not None and stop_sequences is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop_sequences is not None:
|
||||
params["stop"] = self.stop_sequences
|
||||
else:
|
||||
params["stop"] = stop_sequences
|
||||
return {**params, **kwargs}
|
||||
|
||||
def format_request_payload(
|
||||
self, messages: List[BaseMessage], **model_kwargs: Any
|
||||
) -> dict:
|
||||
prompt: Dict[str, Any] = {}
|
||||
user_content: List[str] = []
|
||||
assistant_content: List[str] = []
|
||||
|
||||
for message in messages:
|
||||
"""Converts message to a dict according to role"""
|
||||
if isinstance(message, HumanMessage):
|
||||
user_content = user_content + [message.content]
|
||||
elif isinstance(message, AIMessage):
|
||||
assistant_content = assistant_content + [message.content]
|
||||
elif isinstance(message, SystemMessage):
|
||||
prompt["system_prompt"] = message.content
|
||||
elif isinstance(message, ChatMessage) and message.role in [
|
||||
"user",
|
||||
"assistant",
|
||||
"system",
|
||||
]:
|
||||
if message.role == "system":
|
||||
prompt["system_prompt"] = message.content
|
||||
elif message.role == "user":
|
||||
user_content = user_content + [message.content]
|
||||
elif message.role == "assistant":
|
||||
assistant_content = assistant_content + [message.content]
|
||||
else:
|
||||
supported = ",".join([role for role in ["user", "assistant", "system"]])
|
||||
raise ValueError(
|
||||
f"""Received unsupported role.
|
||||
Supported roles for the LLaMa Foundation Model: {supported}"""
|
||||
)
|
||||
prompt["prompt"] = user_content[len(user_content) - 1]
|
||||
history = [
|
||||
history_item
|
||||
for _, history_item in enumerate(zip(user_content[:-1], assistant_content))
|
||||
]
|
||||
|
||||
prompt["history"] = history
|
||||
|
||||
return {**prompt, **model_kwargs}
|
||||
|
||||
def _format_response_payload(
|
||||
self, output: bytes, stop_sequences: Optional[List[str]]
|
||||
) -> str:
|
||||
"""Formats response"""
|
||||
try:
|
||||
text = json.loads(output)["response"]
|
||||
if stop_sequences:
|
||||
text = enforce_stop_tokens(text, stop_sequences)
|
||||
return text
|
||||
except Exception as e:
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
return output.decode("utf-8")
|
||||
raise e
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
|
||||
message = AIMessage(content=output_str)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _call(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
request_payload = self.format_request_payload(messages, **params)
|
||||
response_payload = self._call_eas(request_payload)
|
||||
generated_text = self._format_response_payload(response_payload, params["stop"])
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(generated_text)
|
||||
|
||||
return generated_text
|
||||
|
||||
def _call_eas(self, query_body: dict) -> Any:
|
||||
"""Generate text from the eas service."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
|
||||
# make request
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
return response.text
|
||||
|
||||
def _call_eas_stream(self, query_body: dict) -> Any:
|
||||
"""Generate text from the eas service."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
|
||||
# make request
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _convert_chunk_to_message_message(
|
||||
self,
|
||||
chunk: str,
|
||||
) -> AIMessageChunk:
|
||||
data = json.loads(chunk.encode("utf-8"))
|
||||
return AIMessageChunk(content=data.get("response", ""))
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
request_payload = self.format_request_payload(messages, **params)
|
||||
request_payload["use_stream_chat"] = True
|
||||
|
||||
response = self._call_eas_stream(request_payload)
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if chunk:
|
||||
content = self._convert_chunk_to_message_message(chunk)
|
||||
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in params["stop"]:
|
||||
if stop_seq in content.content:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
content.content = content.content[
|
||||
: content.content.index(stop_seq_found)
|
||||
]
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(content.content)
|
||||
yield ChatGenerationChunk(message=content)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
||||
|
||||
async def _agenerate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
stream: Optional[bool] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if stream if stream is not None else self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
async for chunk in self._astream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
generation = chunk
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
func = partial(
|
||||
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
return await asyncio.get_event_loop().run_in_executor(None, func)
|
@ -0,0 +1,240 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Iterator, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.pydantic_v1 import root_validator
|
||||
from langchain.schema.output import GenerationChunk
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PaiEasEndpoint(LLM):
|
||||
"""Langchain LLM class to help to access eass llm service.
|
||||
|
||||
To use this endpoint, must have a deployed eas chat llm service on PAI AliCloud.
|
||||
One can set the environment variable ``eas_service_url`` and ``eas_service_token``.
|
||||
The environment variables can set with your eas service url and service token.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
|
||||
eas_chat_endpoint = PaiEasChatEndpoint(
|
||||
eas_service_url="your_service_url",
|
||||
eas_service_token="your_service_token"
|
||||
)
|
||||
"""
|
||||
|
||||
"""PAI-EAS Service URL"""
|
||||
eas_service_url: str
|
||||
|
||||
"""PAI-EAS Service TOKEN"""
|
||||
eas_service_token: str
|
||||
|
||||
"""PAI-EAS Service Infer Params"""
|
||||
max_new_tokens: Optional[int] = 512
|
||||
temperature: Optional[float] = 0.95
|
||||
top_p: Optional[float] = 0.1
|
||||
top_k: Optional[int] = 0
|
||||
stop_sequences: Optional[List[str]] = None
|
||||
|
||||
"""Enable stream chat mode."""
|
||||
streaming: bool = False
|
||||
|
||||
"""Key/value arguments to pass to the model. Reserved for future use"""
|
||||
model_kwargs: Optional[dict] = None
|
||||
|
||||
version: Optional[str] = "2.0"
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["eas_service_url"] = get_from_dict_or_env(
|
||||
values, "eas_service_url", "EAS_SERVICE_URL"
|
||||
)
|
||||
values["eas_service_token"] = get_from_dict_or_env(
|
||||
values, "eas_service_token", "EAS_SERVICE_TOKEN"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "pai_eas_endpoint"
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling Cohere API."""
|
||||
return {
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_k": self.top_k,
|
||||
"top_p": self.top_p,
|
||||
"stop_sequences": [],
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
"eas_service_url": self.eas_service_url,
|
||||
"eas_service_token": self.eas_service_token,
|
||||
**_model_kwargs,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
self, stop_sequences: Optional[List[str]], **kwargs: Any
|
||||
) -> dict:
|
||||
params = self._default_params
|
||||
if self.stop_sequences is not None and stop_sequences is not None:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
elif self.stop_sequences is not None:
|
||||
params["stop"] = self.stop_sequences
|
||||
else:
|
||||
params["stop"] = stop_sequences
|
||||
if self.model_kwargs:
|
||||
params.update(self.model_kwargs)
|
||||
return {**params, **kwargs}
|
||||
|
||||
@staticmethod
|
||||
def _process_response(
|
||||
response: Any, stop: Optional[List[str]], version: Optional[str]
|
||||
) -> str:
|
||||
if version == "1.0":
|
||||
text = response
|
||||
else:
|
||||
text = response["response"]
|
||||
|
||||
if stop:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return "".join(text)
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
params = self._invocation_params(stop, **kwargs)
|
||||
prompt = prompt.strip()
|
||||
response = None
|
||||
try:
|
||||
if self.streaming:
|
||||
completion = ""
|
||||
for chunk in self._stream(prompt, stop, run_manager, **params):
|
||||
completion += chunk.text
|
||||
return completion
|
||||
else:
|
||||
response = self._call_eas(prompt, params)
|
||||
_stop = params.get("stop")
|
||||
return self._process_response(response, _stop, self.version)
|
||||
except Exception as error:
|
||||
raise ValueError(f"Error raised by the service: {error}")
|
||||
|
||||
def _call_eas(self, prompt: str = "", params: Dict = {}) -> Any:
|
||||
"""Generate text from the eas service."""
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
if self.version == "1.0":
|
||||
body = {
|
||||
"input_ids": f"{prompt}",
|
||||
}
|
||||
else:
|
||||
body = {
|
||||
"prompt": f"{prompt}",
|
||||
}
|
||||
|
||||
# add params to body
|
||||
for key, value in params.items():
|
||||
body[key] = value
|
||||
|
||||
# make request
|
||||
response = requests.post(self.eas_service_url, headers=headers, json=body)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(
|
||||
f"Request failed with status code {response.status_code}"
|
||||
f" and message {response.text}"
|
||||
)
|
||||
|
||||
try:
|
||||
return json.loads(response.text)
|
||||
except Exception as e:
|
||||
if isinstance(e, json.decoder.JSONDecodeError):
|
||||
return response.text
|
||||
raise e
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[GenerationChunk]:
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
|
||||
headers = {
|
||||
"User-Agent": "Test Client",
|
||||
"Authorization": f"{self.eas_service_token}",
|
||||
}
|
||||
|
||||
if self.version == "1.0":
|
||||
pload = {"input_ids": prompt, **invocation_params}
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=pload, stream=True
|
||||
)
|
||||
|
||||
res = GenerationChunk(text=response.text)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(res.text)
|
||||
|
||||
# yield text, if any
|
||||
yield res
|
||||
else:
|
||||
pload = {"prompt": prompt, "use_stream_chat": "True", **invocation_params}
|
||||
|
||||
response = requests.post(
|
||||
self.eas_service_url, headers=headers, json=pload, stream=True
|
||||
)
|
||||
|
||||
for chunk in response.iter_lines(
|
||||
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
|
||||
):
|
||||
if chunk:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
output = data["response"]
|
||||
# identify stop sequence in generated text, if any
|
||||
stop_seq_found: Optional[str] = None
|
||||
for stop_seq in invocation_params["stop"]:
|
||||
if stop_seq in output:
|
||||
stop_seq_found = stop_seq
|
||||
|
||||
# identify text to yield
|
||||
text: Optional[str] = None
|
||||
if stop_seq_found:
|
||||
text = output[: output.index(stop_seq_found)]
|
||||
else:
|
||||
text = output
|
||||
|
||||
# yield text, if any
|
||||
if text:
|
||||
res = GenerationChunk(text=text)
|
||||
yield res
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(res.text)
|
||||
|
||||
# break if stop sequence found
|
||||
if stop_seq_found:
|
||||
break
|
@ -0,0 +1,82 @@
|
||||
"""Test AliCloud Pai Eas Chat Model."""
|
||||
import os
|
||||
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
)
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
def test_pai_eas_call() -> None:
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
)
|
||||
response = chat(messages=[HumanMessage(content="Say foo:")])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_history() -> None:
|
||||
"""Tests multiple history works."""
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
)
|
||||
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="How are you doing?"),
|
||||
]
|
||||
)
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test that stream works."""
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
streaming=True,
|
||||
)
|
||||
callback_handler = FakeCallbackHandler()
|
||||
callback_manager = CallbackManager([callback_handler])
|
||||
response = chat(
|
||||
messages=[
|
||||
HumanMessage(content="Hello."),
|
||||
AIMessage(content="Hello!"),
|
||||
HumanMessage(content="Who are you?"),
|
||||
],
|
||||
stream=True,
|
||||
callbacks=callback_manager,
|
||||
)
|
||||
assert callback_handler.llm_streams > 0
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_multiple_messages() -> None:
|
||||
"""Tests multiple messages works."""
|
||||
chat = PaiEasChatEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
)
|
||||
message = HumanMessage(content="Hi, how are you.")
|
||||
response = chat.generate([[message], [message]])
|
||||
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
assert len(generations) == 1
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@ -0,0 +1,59 @@
|
||||
"""Test PaiEasEndpoint API wrapper."""
|
||||
import os
|
||||
from typing import Generator
|
||||
|
||||
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
|
||||
|
||||
|
||||
def test_pai_eas_v1_call() -> None:
|
||||
"""Test valid call to PAI-EAS Service."""
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="1.0",
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_pai_eas_v2_call() -> None:
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="2.0",
|
||||
)
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_pai_eas_v1_streaming() -> None:
|
||||
"""Test streaming call to PAI-EAS Service."""
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="1.0",
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
stream_results_string = ""
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for chunk in generator:
|
||||
assert isinstance(chunk, str)
|
||||
stream_results_string = chunk
|
||||
assert len(stream_results_string.strip()) > 1
|
||||
|
||||
|
||||
def test_pai_eas_v2_streaming() -> None:
|
||||
llm = PaiEasEndpoint(
|
||||
eas_service_url=os.getenv("EAS_SERVICE_URL"),
|
||||
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
|
||||
version="2.0",
|
||||
)
|
||||
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
|
||||
stream_results_string = ""
|
||||
assert isinstance(generator, Generator)
|
||||
|
||||
for chunk in generator:
|
||||
assert isinstance(chunk, str)
|
||||
stream_results_string = chunk
|
||||
assert len(stream_results_string.strip()) > 1
|
Loading…
Reference in New Issue