Encapsulate alicloud pai-eas access method for chatmodels and llms (#11852)

### Description: 
To provide an eas llm service access methods in this pull request by
impletementing `PaiEasEndpoint` and `PaiEasChatEndpoint` classes in
`langchain.llms` and `langchain.chat_models` modules. Base on this pr,
langchain users can build up a chain to call remote eas llm service and
get the llm inference results.

### About EAS Service
EAS is a Alicloud product on Alibaba Cloud Machine Learning Platform for
AI which is short for AliCloud PAI. EAS provides model inference
deployment services for the users. We build up a llm inference services
on EAS with a general llm docker images. Therefore, end users can
quickly setup their llm remote instances to load majority of the
hugginface llm models, and serve as a backend for most of the llm apps.

### Dependencies
This pr does't involve any new dependencies.

---------

Co-authored-by: 子洪 <gaoyihong.gyh@alibaba-inc.com>
pull/11833/head
hiigao 12 months ago committed by GitHub
parent 1da6d92369
commit f818ec49b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,121 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AliCloud PAI EAS\n",
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup Eas Service\n",
"\n",
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information. Try to set environment variables to init eas service url and token:\n",
"\n",
"```base\n",
"export EAS_SERVICE_URL=XXX\n",
"export EAS_SERVICE_TOKEN=XXX\n",
"```\n",
"or run as follow codes:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from langchain.chat_models.base import HumanMessage\n",
"from langchain.chat_models import PaiEasChatEndpoint\n",
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
"chat = PaiEasChatEndpoint(\n",
" eas_service_url=os.environ[\"EAS_SERVICE_URL\"], \n",
" eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"]\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run Chat Model\n",
"You can use the default settings to call eas service as follows:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"output = chat([HumanMessage(content=\"write a funny joke\")])\n",
"print(\"output:\", output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or, call eas service with new inference params:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"kwargs = {\"temperature\": 0.8, \"top_p\": 0.8, \"top_k\": 5}\n",
"output = chat([HumanMessage(content=\"write a funny joke\")], **kwargs)\n",
"print(\"output:\", output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Or, run a stream call to get a stream response:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"outputs = chat.stream([HumanMessage(content=\"hi\")], streaming=True)\n",
"for output in outputs:\n",
" print(\"stream output:\", output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,93 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# AliCloud PAI EAS\n",
"Machine Learning Platform for AI of Alibaba Cloud is a machine learning or deep learning engineering platform intended for enterprises and developers. It provides easy-to-use, cost-effective, high-performance, and easy-to-scale plug-ins that can be applied to various industry scenarios. With over 140 built-in optimization algorithms, Machine Learning Platform for AI provides whole-process AI engineering capabilities including data labeling (PAI-iTAG), model building (PAI-Designer and PAI-DSW), model training (PAI-DLC), compilation optimization, and inference deployment (PAI-EAS). PAI-EAS supports different types of hardware resources, including CPUs and GPUs, and features high throughput and low latency. It allows you to deploy large-scale complex models with a few clicks and perform elastic scale-ins and scale-outs in real time. It also provides a comprehensive O&M and monitoring system."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.pai_eas_endpoint import PaiEasEndpoint\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"\n",
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"One who want to use eas llms must set up eas service first. When the eas service is launched, eas_service_rul and eas_service token can be got. Users can refer to https://www.alibabacloud.com/help/en/pai/user-guide/service-deployment/ for more information,"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"os.environ[\"EAS_SERVICE_URL\"] = \"Your_EAS_Service_URL\"\n",
"os.environ[\"EAS_SERVICE_TOKEN\"] = \"Your_EAS_Service_Token\"\n",
"llm = PaiEasEndpoint(eas_service_url=os.environ[\"EAS_SERVICE_URL\"], eas_service_token=os.environ[\"EAS_SERVICE_TOKEN\"])"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"' Thank you for asking! However, I must respectfully point out that the question contains an error. Justin Bieber was born in 1994, and the Super Bowl was first played in 1967. Therefore, it is not possible for any NFL team to have won the Super Bowl in the year Justin Bieber was born.\\n\\nI hope this clarifies things! If you have any other questions, please feel free to ask.'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -38,6 +38,7 @@ from langchain.chat_models.minimax import MiniMaxChat
from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway
from langchain.chat_models.ollama import ChatOllama
from langchain.chat_models.openai import ChatOpenAI
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI
from langchain.chat_models.vertexai import ChatVertexAI
from langchain.chat_models.yandex import ChatYandexGPT
@ -63,6 +64,7 @@ __all__ = [
"ErnieBotChat",
"ChatJavelinAIGateway",
"ChatKonko",
"PaiEasChatEndpoint",
"QianfanChatEndpoint",
"ChatFireworks",
"ChatYandexGPT",

@ -0,0 +1,324 @@
import asyncio
import json
import logging
from functools import partial
from typing import Any, AsyncIterator, Dict, List, Optional
import requests
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain.chat_models.base import BaseChatModel
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import root_validator
from langchain.schema import ChatGeneration, ChatResult
from langchain.schema.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain.schema.output import ChatGenerationChunk
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PaiEasChatEndpoint(BaseChatModel):
"""Eas LLM Service chat model API.
To use, must have a deployed eas chat llm service on AliCloud. One can set the
environment variable ``eas_service_url`` and ``eas_service_token`` set with your eas
service url and service token.
Example:
.. code-block:: python
from langchain.chat_models import PaiEasChatEndpoint
eas_chat_endpoint = PaiEasChatEndpoint(
eas_service_url="your_service_url",
eas_service_token="your_service_token"
)
"""
"""PAI-EAS Service URL"""
eas_service_url: str
"""PAI-EAS Service TOKEN"""
eas_service_token: str
"""PAI-EAS Service Infer Params"""
max_new_tokens: Optional[int] = 512
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.1
top_k: Optional[int] = 10
do_sample: Optional[bool] = False
use_cache: Optional[bool] = True
stop_sequences: Optional[List[str]] = None
"""Enable stream chat mode."""
streaming: bool = False
"""Key/value arguments to pass to the model. Reserved for future use"""
model_kwargs: Optional[dict] = None
version: Optional[str] = "2.0"
timeout: Optional[int] = 5000
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["eas_service_url"] = get_from_dict_or_env(
values, "eas_service_url", "EAS_SERVICE_URL"
)
values["eas_service_token"] = get_from_dict_or_env(
values, "eas_service_token", "EAS_SERVICE_TOKEN"
)
return values
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
"eas_service_url": self.eas_service_url,
"eas_service_token": self.eas_service_token,
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "pai_eas_chat_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": [],
"do_sample": self.do_sample,
"use_cache": self.use_cache,
}
def _invocation_params(
self, stop_sequences: Optional[List[str]], **kwargs: Any
) -> dict:
params = self._default_params
if self.model_kwargs:
params.update(self.model_kwargs)
if self.stop_sequences is not None and stop_sequences is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop_sequences is not None:
params["stop"] = self.stop_sequences
else:
params["stop"] = stop_sequences
return {**params, **kwargs}
def format_request_payload(
self, messages: List[BaseMessage], **model_kwargs: Any
) -> dict:
prompt: Dict[str, Any] = {}
user_content: List[str] = []
assistant_content: List[str] = []
for message in messages:
"""Converts message to a dict according to role"""
if isinstance(message, HumanMessage):
user_content = user_content + [message.content]
elif isinstance(message, AIMessage):
assistant_content = assistant_content + [message.content]
elif isinstance(message, SystemMessage):
prompt["system_prompt"] = message.content
elif isinstance(message, ChatMessage) and message.role in [
"user",
"assistant",
"system",
]:
if message.role == "system":
prompt["system_prompt"] = message.content
elif message.role == "user":
user_content = user_content + [message.content]
elif message.role == "assistant":
assistant_content = assistant_content + [message.content]
else:
supported = ",".join([role for role in ["user", "assistant", "system"]])
raise ValueError(
f"""Received unsupported role.
Supported roles for the LLaMa Foundation Model: {supported}"""
)
prompt["prompt"] = user_content[len(user_content) - 1]
history = [
history_item
for _, history_item in enumerate(zip(user_content[:-1], assistant_content))
]
prompt["history"] = history
return {**prompt, **model_kwargs}
def _format_response_payload(
self, output: bytes, stop_sequences: Optional[List[str]]
) -> str:
"""Formats response"""
try:
text = json.loads(output)["response"]
if stop_sequences:
text = enforce_stop_tokens(text, stop_sequences)
return text
except Exception as e:
if isinstance(e, json.decoder.JSONDecodeError):
return output.decode("utf-8")
raise e
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
params = self._invocation_params(stop, **kwargs)
request_payload = self.format_request_payload(messages, **params)
response_payload = self._call_eas(request_payload)
generated_text = self._format_response_payload(response_payload, params["stop"])
if run_manager:
run_manager.on_llm_new_token(generated_text)
return generated_text
def _call_eas(self, query_body: dict) -> Any:
"""Generate text from the eas service."""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"{self.eas_service_token}",
}
# make request
response = requests.post(
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}"
f" and message {response.text}"
)
return response.text
def _call_eas_stream(self, query_body: dict) -> Any:
"""Generate text from the eas service."""
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"{self.eas_service_token}",
}
# make request
response = requests.post(
self.eas_service_url, headers=headers, json=query_body, timeout=self.timeout
)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}"
f" and message {response.text}"
)
return response
def _convert_chunk_to_message_message(
self,
chunk: str,
) -> AIMessageChunk:
data = json.loads(chunk.encode("utf-8"))
return AIMessageChunk(content=data.get("response", ""))
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params = self._invocation_params(stop, **kwargs)
request_payload = self.format_request_payload(messages, **params)
request_payload["use_stream_chat"] = True
response = self._call_eas_stream(request_payload)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
):
if chunk:
content = self._convert_chunk_to_message_message(chunk)
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in params["stop"]:
if stop_seq in content.content:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if stop_seq_found:
content.content = content.content[
: content.content.index(stop_seq_found)
]
# yield text, if any
if text:
if run_manager:
await run_manager.on_llm_new_token(content.content)
yield ChatGenerationChunk(message=content)
# break if stop sequence found
if stop_seq_found:
break
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any,
) -> ChatResult:
if stream if stream is not None else self.streaming:
generation: Optional[ChatGenerationChunk] = None
async for chunk in self._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
generation = chunk
assert generation is not None
return ChatResult(generations=[generation])
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -342,6 +342,12 @@ def _import_openlm() -> Any:
return OpenLM
def _import_pai_eas_endpoint() -> Any:
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
return PaiEasEndpoint
def _import_petals() -> Any:
from langchain.llms.petals import Petals
@ -593,6 +599,8 @@ def __getattr__(name: str) -> Any:
return _import_openllm()
elif name == "OpenLM":
return _import_openlm()
elif name == "PaiEasEndpoint":
return _import_pai_eas_endpoint()
elif name == "Petals":
return _import_petals()
elif name == "PipelineAI":
@ -703,6 +711,7 @@ __all__ = [
"OpenAIChat",
"OpenLLM",
"OpenLM",
"PaiEasEndpoint",
"Petals",
"PipelineAI",
"Predibase",
@ -780,6 +789,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"ollama": _import_ollama,
"openai": _import_openai,
"openlm": _import_openlm,
"pai_eas_endpoint": _import_pai_eas_endpoint,
"petals": _import_petals,
"pipelineai": _import_pipelineai,
"predibase": _import_predibase,

@ -0,0 +1,240 @@
import json
import logging
from typing import Any, Dict, Iterator, List, Mapping, Optional
import requests
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.pydantic_v1 import root_validator
from langchain.schema.output import GenerationChunk
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class PaiEasEndpoint(LLM):
"""Langchain LLM class to help to access eass llm service.
To use this endpoint, must have a deployed eas chat llm service on PAI AliCloud.
One can set the environment variable ``eas_service_url`` and ``eas_service_token``.
The environment variables can set with your eas service url and service token.
Example:
.. code-block:: python
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
eas_chat_endpoint = PaiEasChatEndpoint(
eas_service_url="your_service_url",
eas_service_token="your_service_token"
)
"""
"""PAI-EAS Service URL"""
eas_service_url: str
"""PAI-EAS Service TOKEN"""
eas_service_token: str
"""PAI-EAS Service Infer Params"""
max_new_tokens: Optional[int] = 512
temperature: Optional[float] = 0.95
top_p: Optional[float] = 0.1
top_k: Optional[int] = 0
stop_sequences: Optional[List[str]] = None
"""Enable stream chat mode."""
streaming: bool = False
"""Key/value arguments to pass to the model. Reserved for future use"""
model_kwargs: Optional[dict] = None
version: Optional[str] = "2.0"
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["eas_service_url"] = get_from_dict_or_env(
values, "eas_service_url", "EAS_SERVICE_URL"
)
values["eas_service_token"] = get_from_dict_or_env(
values, "eas_service_token", "EAS_SERVICE_TOKEN"
)
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "pai_eas_endpoint"
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling Cohere API."""
return {
"max_new_tokens": self.max_new_tokens,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"stop_sequences": [],
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
"eas_service_url": self.eas_service_url,
"eas_service_token": self.eas_service_token,
**_model_kwargs,
}
def _invocation_params(
self, stop_sequences: Optional[List[str]], **kwargs: Any
) -> dict:
params = self._default_params
if self.stop_sequences is not None and stop_sequences is not None:
raise ValueError("`stop` found in both the input and default params.")
elif self.stop_sequences is not None:
params["stop"] = self.stop_sequences
else:
params["stop"] = stop_sequences
if self.model_kwargs:
params.update(self.model_kwargs)
return {**params, **kwargs}
@staticmethod
def _process_response(
response: Any, stop: Optional[List[str]], version: Optional[str]
) -> str:
if version == "1.0":
text = response
else:
text = response["response"]
if stop:
text = enforce_stop_tokens(text, stop)
return "".join(text)
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
params = self._invocation_params(stop, **kwargs)
prompt = prompt.strip()
response = None
try:
if self.streaming:
completion = ""
for chunk in self._stream(prompt, stop, run_manager, **params):
completion += chunk.text
return completion
else:
response = self._call_eas(prompt, params)
_stop = params.get("stop")
return self._process_response(response, _stop, self.version)
except Exception as error:
raise ValueError(f"Error raised by the service: {error}")
def _call_eas(self, prompt: str = "", params: Dict = {}) -> Any:
"""Generate text from the eas service."""
headers = {
"Content-Type": "application/json",
"Authorization": f"{self.eas_service_token}",
}
if self.version == "1.0":
body = {
"input_ids": f"{prompt}",
}
else:
body = {
"prompt": f"{prompt}",
}
# add params to body
for key, value in params.items():
body[key] = value
# make request
response = requests.post(self.eas_service_url, headers=headers, json=body)
if response.status_code != 200:
raise Exception(
f"Request failed with status code {response.status_code}"
f" and message {response.text}"
)
try:
return json.loads(response.text)
except Exception as e:
if isinstance(e, json.decoder.JSONDecodeError):
return response.text
raise e
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
invocation_params = self._invocation_params(stop, **kwargs)
headers = {
"User-Agent": "Test Client",
"Authorization": f"{self.eas_service_token}",
}
if self.version == "1.0":
pload = {"input_ids": prompt, **invocation_params}
response = requests.post(
self.eas_service_url, headers=headers, json=pload, stream=True
)
res = GenerationChunk(text=response.text)
if run_manager:
run_manager.on_llm_new_token(res.text)
# yield text, if any
yield res
else:
pload = {"prompt": prompt, "use_stream_chat": "True", **invocation_params}
response = requests.post(
self.eas_service_url, headers=headers, json=pload, stream=True
)
for chunk in response.iter_lines(
chunk_size=8192, decode_unicode=False, delimiter=b"\0"
):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["response"]
# identify stop sequence in generated text, if any
stop_seq_found: Optional[str] = None
for stop_seq in invocation_params["stop"]:
if stop_seq in output:
stop_seq_found = stop_seq
# identify text to yield
text: Optional[str] = None
if stop_seq_found:
text = output[: output.index(stop_seq_found)]
else:
text = output
# yield text, if any
if text:
res = GenerationChunk(text=text)
yield res
if run_manager:
run_manager.on_llm_new_token(res.text)
# break if stop sequence found
if stop_seq_found:
break

@ -0,0 +1,82 @@
"""Test AliCloud Pai Eas Chat Model."""
import os
from langchain.callbacks.manager import CallbackManager
from langchain.chat_models.pai_eas_endpoint import PaiEasChatEndpoint
from langchain.schema import (
AIMessage,
BaseMessage,
ChatGeneration,
HumanMessage,
LLMResult,
)
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
def test_pai_eas_call() -> None:
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
)
response = chat(messages=[HumanMessage(content="Say foo:")])
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_multiple_history() -> None:
"""Tests multiple history works."""
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
)
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="How are you doing?"),
]
)
assert isinstance(response, BaseMessage)
assert isinstance(response.content, str)
def test_stream() -> None:
"""Test that stream works."""
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
streaming=True,
)
callback_handler = FakeCallbackHandler()
callback_manager = CallbackManager([callback_handler])
response = chat(
messages=[
HumanMessage(content="Hello."),
AIMessage(content="Hello!"),
HumanMessage(content="Who are you?"),
],
stream=True,
callbacks=callback_manager,
)
assert callback_handler.llm_streams > 0
assert isinstance(response.content, str)
def test_multiple_messages() -> None:
"""Tests multiple messages works."""
chat = PaiEasChatEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
)
message = HumanMessage(content="Hi, how are you.")
response = chat.generate([[message], [message]])
assert isinstance(response, LLMResult)
assert len(response.generations) == 2
for generations in response.generations:
assert len(generations) == 1
for generation in generations:
assert isinstance(generation, ChatGeneration)
assert isinstance(generation.text, str)
assert generation.text == generation.message.content

@ -0,0 +1,59 @@
"""Test PaiEasEndpoint API wrapper."""
import os
from typing import Generator
from langchain.llms.pai_eas_endpoint import PaiEasEndpoint
def test_pai_eas_v1_call() -> None:
"""Test valid call to PAI-EAS Service."""
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="1.0",
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_pai_eas_v2_call() -> None:
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="2.0",
)
output = llm("Say foo:")
assert isinstance(output, str)
def test_pai_eas_v1_streaming() -> None:
"""Test streaming call to PAI-EAS Service."""
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="1.0",
)
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
assert isinstance(chunk, str)
stream_results_string = chunk
assert len(stream_results_string.strip()) > 1
def test_pai_eas_v2_streaming() -> None:
llm = PaiEasEndpoint(
eas_service_url=os.getenv("EAS_SERVICE_URL"),
eas_service_token=os.getenv("EAS_SERVICE_TOKEN"),
version="2.0",
)
generator = llm.stream("Q: How do you say 'hello' in German? A:'", stop=["."])
stream_results_string = ""
assert isinstance(generator, Generator)
for chunk in generator:
assert isinstance(chunk, str)
stream_results_string = chunk
assert len(stream_results_string.strip()) > 1
Loading…
Cancel
Save