mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
Langchain-community : EdenAI chat integration. (#16377)
- **Description:** This PR adds [EdenAI](https://edenai.co/) for the chat model (already available in LLM & Embeddings). It supports all [ChatModel] functionality: generate, async generate, stream, astream and batch. A detailed notebook was added. - **Dependencies**: No dependencies are added as we call a rest API. --------- Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
This commit is contained in:
parent
08d3fd7f2e
commit
e30c6662df
272
docs/docs/integrations/chat/edenai.ipynb
Normal file
272
docs/docs/integrations/chat/edenai.ipynb
Normal file
@ -0,0 +1,272 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Eden AI"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Eden AI is revolutionizing the AI landscape by uniting the best AI providers, empowering users to unlock limitless possibilities and tap into the true potential of artificial intelligence. With an all-in-one comprehensive and hassle-free platform, it allows users to deploy AI features to production lightning fast, enabling effortless access to the full breadth of AI capabilities via a single API. (website: https://edenai.co/)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This example goes over how to use LangChain to interact with Eden AI models\n",
|
||||||
|
"\n",
|
||||||
|
"-----------------------------------------------------------------------------------"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"`EdenAI` goes beyond mere model invocation. It empowers you with advanced features, including:\n",
|
||||||
|
"\n",
|
||||||
|
"- **Multiple Providers**: Gain access to a diverse range of language models offered by various providers, giving you the freedom to choose the best-suited model for your use case.\n",
|
||||||
|
"\n",
|
||||||
|
"- **Fallback Mechanism**: Set a fallback mechanism to ensure seamless operations even if the primary provider is unavailable, you can easily switches to an alternative provider.\n",
|
||||||
|
"\n",
|
||||||
|
"- **Usage Tracking**: Track usage statistics on a per-project and per-API key basis. This feature allows you to monitor and manage resource consumption effectively.\n",
|
||||||
|
"\n",
|
||||||
|
"- **Monitoring and Observability**: `EdenAI` provides comprehensive monitoring and observability tools on the platform. Monitor the performance of your language models, analyze usage patterns, and gain valuable insights to optimize your applications.\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Accessing the EDENAI's API requires an API key, \n",
|
||||||
|
"\n",
|
||||||
|
"which you can get by creating an account https://app.edenai.run/user/register and heading here https://app.edenai.run/admin/iam/api-keys\n",
|
||||||
|
"\n",
|
||||||
|
"Once we have a key we'll want to set it as an environment variable by running:\n",
|
||||||
|
"\n",
|
||||||
|
"```bash\n",
|
||||||
|
"export EDENAI_API_KEY=\"...\"\n",
|
||||||
|
"```\n",
|
||||||
|
"\n",
|
||||||
|
"You can find more details on the API reference : https://docs.edenai.co/reference"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"If you'd prefer not to set an environment variable you can pass the key in directly via the edenai_api_key named parameter\n",
|
||||||
|
"\n",
|
||||||
|
" when initiating the EdenAI Chat Model class."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_community.chat_models.edenai import ChatEdenAI\n",
|
||||||
|
"from langchain_core.messages import HumanMessage"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 2,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatEdenAI(\n",
|
||||||
|
" edenai_api_key=\"...\", provider=\"openai\", temperature=0.2, max_tokens=250\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='Hello! How can I assist you today?')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 3,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"messages = [HumanMessage(content=\"Hello !\")]\n",
|
||||||
|
"chat.invoke(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='Hello! How can I assist you today?')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 4,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"await chat.ainvoke(messages)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Streaming and Batching\n",
|
||||||
|
"\n",
|
||||||
|
"`ChatEdenAI` supports streaming and batching. Below is an example."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Hello! How can I assist you today?"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"for chunk in chat.stream(messages):\n",
|
||||||
|
" print(chunk.content, end=\"\", flush=True)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"[AIMessage(content='Hello! How can I assist you today?')]"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chat.batch([messages])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Fallback mecanism\n",
|
||||||
|
"\n",
|
||||||
|
"With Eden AI you can set a fallback mechanism to ensure seamless operations even if the primary provider is unavailable, you can easily switches to an alternative provider."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"chat = ChatEdenAI(\n",
|
||||||
|
" edenai_api_key=\"...\",\n",
|
||||||
|
" provider=\"openai\",\n",
|
||||||
|
" temperature=0.2,\n",
|
||||||
|
" max_tokens=250,\n",
|
||||||
|
" fallback_providers=\"google\",\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"In this example, you can use Google as a backup provider if OpenAI encounters any issues.\n",
|
||||||
|
"\n",
|
||||||
|
"For more information and details about Eden AI, check out this link: : https://docs.edenai.co/docs/additional-parameters"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Chaining Calls\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||||
|
"\n",
|
||||||
|
"prompt = ChatPromptTemplate.from_template(\n",
|
||||||
|
" \"What is a good name for a company that makes {product}?\"\n",
|
||||||
|
")\n",
|
||||||
|
"chain = prompt | chat"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": [
|
||||||
|
"AIMessage(content='VitalBites')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"execution_count": 9,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"chain.invoke({\"product\": \"healthy snacks\"})"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "langchain-pr",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.10.12"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
368
libs/community/langchain_community/chat_models/edenai.py
Normal file
368
libs/community/langchain_community/chat_models/edenai.py
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional
|
||||||
|
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from langchain_core.callbacks import (
|
||||||
|
AsyncCallbackManagerForLLMRun,
|
||||||
|
CallbackManagerForLLMRun,
|
||||||
|
)
|
||||||
|
from langchain_core.language_models.chat_models import (
|
||||||
|
BaseChatModel,
|
||||||
|
agenerate_from_stream,
|
||||||
|
generate_from_stream,
|
||||||
|
)
|
||||||
|
from langchain_core.messages import (
|
||||||
|
AIMessage,
|
||||||
|
AIMessageChunk,
|
||||||
|
BaseMessage,
|
||||||
|
)
|
||||||
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
|
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
|
||||||
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
|
|
||||||
|
from langchain_community.utilities.requests import Requests
|
||||||
|
|
||||||
|
|
||||||
|
def _message_role(type: str) -> str:
|
||||||
|
role_mapping = {"ai": "assistant", "human": "user", "chat": "user"}
|
||||||
|
|
||||||
|
if type in role_mapping:
|
||||||
|
return role_mapping[type]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown type: {type}")
|
||||||
|
|
||||||
|
|
||||||
|
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||||
|
system = None
|
||||||
|
formatted_messages = []
|
||||||
|
text = messages[-1].content
|
||||||
|
for i, message in enumerate(messages[:-1]):
|
||||||
|
if message.type == "system":
|
||||||
|
if i != 0:
|
||||||
|
raise ValueError("System message must be at beginning of message list.")
|
||||||
|
system = message.content
|
||||||
|
else:
|
||||||
|
formatted_messages.append(
|
||||||
|
{
|
||||||
|
"role": _message_role(message.type),
|
||||||
|
"message": message.content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"text": text,
|
||||||
|
"previous_history": formatted_messages,
|
||||||
|
"chatbot_global_action": system,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatEdenAI(BaseChatModel):
|
||||||
|
"""`EdenAI` chat large language models.
|
||||||
|
|
||||||
|
`EdenAI` is a versatile platform that allows you to access various language models
|
||||||
|
from different providers such as Google, OpenAI, Cohere, Mistral and more.
|
||||||
|
|
||||||
|
To get started, make sure you have the environment variable ``EDENAI_API_KEY``
|
||||||
|
set with your API key, or pass it as a named parameter to the constructor.
|
||||||
|
|
||||||
|
Additionally, `EdenAI` provides the flexibility to choose from a variety of models,
|
||||||
|
including the ones like "gpt-4".
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.chat_models import ChatEdenAI
|
||||||
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
|
# Initialize `ChatEdenAI` with the desired configuration
|
||||||
|
chat = ChatEdenAI(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4",
|
||||||
|
max_tokens=256,
|
||||||
|
temperature=0.75)
|
||||||
|
|
||||||
|
# Create a list of messages to interact with the model
|
||||||
|
messages = [HumanMessage(content="hello")]
|
||||||
|
|
||||||
|
# Invoke the model with the provided messages
|
||||||
|
chat.invoke(messages)
|
||||||
|
|
||||||
|
`EdenAI` goes beyond mere model invocation. It empowers you with advanced features :
|
||||||
|
|
||||||
|
- **Multiple Providers**: access to a diverse range of llms offered by various
|
||||||
|
providers giving you the freedom to choose the best-suited model for your use case.
|
||||||
|
|
||||||
|
- **Fallback Mechanism**: Set a fallback mechanism to ensure seamless operations
|
||||||
|
even if the primary provider is unavailable, you can easily switches to an
|
||||||
|
alternative provider.
|
||||||
|
|
||||||
|
- **Usage Statistics**: Track usage statistics on a per-project
|
||||||
|
and per-API key basis.
|
||||||
|
This feature allows you to monitor and manage resource consumption effectively.
|
||||||
|
|
||||||
|
- **Monitoring and Observability**: `EdenAI` provides comprehensive monitoring
|
||||||
|
and observability tools on the platform.
|
||||||
|
|
||||||
|
Example of setting up a fallback mechanism:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
# Initialize `ChatEdenAI` with a fallback provider
|
||||||
|
chat_with_fallback = ChatEdenAI(
|
||||||
|
provider="openai",
|
||||||
|
model="gpt-4",
|
||||||
|
max_tokens=256,
|
||||||
|
temperature=0.75,
|
||||||
|
fallback_provider="google")
|
||||||
|
|
||||||
|
you can find more details here : https://docs.edenai.co/reference/text_chat_create
|
||||||
|
"""
|
||||||
|
|
||||||
|
provider: str = "openai"
|
||||||
|
"""chat provider to use (eg: openai,google etc.)"""
|
||||||
|
|
||||||
|
model: Optional[str] = None
|
||||||
|
"""
|
||||||
|
model name for above provider (eg: 'gpt-4' for openai)
|
||||||
|
available models are shown on https://docs.edenai.co/ under 'available providers'
|
||||||
|
"""
|
||||||
|
|
||||||
|
max_tokens: int = 256
|
||||||
|
"""Denotes the number of tokens to predict per generation."""
|
||||||
|
|
||||||
|
temperature: Optional[float] = 0
|
||||||
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||||
|
|
||||||
|
streaming: bool = False
|
||||||
|
"""Whether to stream the results."""
|
||||||
|
|
||||||
|
fallback_providers: Optional[str] = None
|
||||||
|
"""Providers in this will be used as fallback if the call to provider fails."""
|
||||||
|
|
||||||
|
edenai_api_url: str = "https://api.edenai.run/v2"
|
||||||
|
|
||||||
|
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
|
"""Validate that api key exists in environment."""
|
||||||
|
values["edenai_api_key"] = convert_to_secret_str(
|
||||||
|
get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY")
|
||||||
|
)
|
||||||
|
return values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_user_agent() -> str:
|
||||||
|
from langchain_community import __version__
|
||||||
|
|
||||||
|
return f"langchain/{__version__}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _llm_type(self) -> str:
|
||||||
|
"""Return type of chat model."""
|
||||||
|
return "edenai-chat"
|
||||||
|
|
||||||
|
def _stream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
|
"""Call out to EdenAI's chat endpoint."""
|
||||||
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
||||||
|
"User-Agent": self.get_user_agent(),
|
||||||
|
}
|
||||||
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"providers": self.provider,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"fallback_providers": self.fallback_providers,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
payload["settings"] = {self.provider: self.model}
|
||||||
|
|
||||||
|
request = Requests(headers=headers)
|
||||||
|
response = request.post(url=url, data=payload, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
for chunk_response in response.iter_lines():
|
||||||
|
chunk = json.loads(chunk_response.decode())
|
||||||
|
token = chunk["text"]
|
||||||
|
chat_generatio_chunk = ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=token)
|
||||||
|
)
|
||||||
|
yield chat_generatio_chunk
|
||||||
|
if run_manager:
|
||||||
|
run_manager.on_llm_new_token(token, chunk=chat_generatio_chunk)
|
||||||
|
|
||||||
|
async def _astream(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
||||||
|
"User-Agent": self.get_user_agent(),
|
||||||
|
}
|
||||||
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"providers": self.provider,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"fallback_providers": self.fallback_providers,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
payload["settings"] = {self.provider: self.model}
|
||||||
|
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(url, json=payload, headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
async for chunk_response in response.content:
|
||||||
|
chunk = json.loads(chunk_response.decode())
|
||||||
|
token = chunk["text"]
|
||||||
|
chat_generation_chunk = ChatGenerationChunk(
|
||||||
|
message=AIMessageChunk(content=token)
|
||||||
|
)
|
||||||
|
yield chat_generation_chunk
|
||||||
|
if run_manager:
|
||||||
|
await run_manager.on_llm_new_token(
|
||||||
|
token=chunk["text"], chunk=chat_generation_chunk
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
"""Call out to EdenAI's chat endpoint."""
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._stream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return generate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
url = f"{self.edenai_api_url}/text/chat"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
||||||
|
"User-Agent": self.get_user_agent(),
|
||||||
|
}
|
||||||
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"providers": self.provider,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"fallback_providers": self.fallback_providers,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
payload["settings"] = {self.provider: self.model}
|
||||||
|
|
||||||
|
request = Requests(headers=headers)
|
||||||
|
response = request.post(url=url, data=payload)
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
provider_response = data[self.provider]
|
||||||
|
|
||||||
|
if self.fallback_providers:
|
||||||
|
fallback_response = data.get(self.fallback_providers)
|
||||||
|
if fallback_response:
|
||||||
|
provider_response = fallback_response
|
||||||
|
|
||||||
|
if provider_response.get("status") == "fail":
|
||||||
|
err_msg = provider_response.get("error", {}).get("message")
|
||||||
|
raise Exception(err_msg)
|
||||||
|
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=AIMessage(content=provider_response["generated_text"])
|
||||||
|
)
|
||||||
|
],
|
||||||
|
llm_output=data,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self,
|
||||||
|
messages: List[BaseMessage],
|
||||||
|
stop: Optional[List[str]] = None,
|
||||||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> ChatResult:
|
||||||
|
if self.streaming:
|
||||||
|
stream_iter = self._astream(
|
||||||
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||||
|
)
|
||||||
|
return await agenerate_from_stream(stream_iter)
|
||||||
|
|
||||||
|
url = f"{self.edenai_api_url}/text/chat"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
|
||||||
|
"User-Agent": self.get_user_agent(),
|
||||||
|
}
|
||||||
|
formatted_data = _format_edenai_messages(messages=messages)
|
||||||
|
payload: Dict[str, Any] = {
|
||||||
|
"providers": self.provider,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"fallback_providers": self.fallback_providers,
|
||||||
|
**formatted_data,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {k: v for k, v in payload.items() if v is not None}
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
payload["settings"] = {self.provider: self.model}
|
||||||
|
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.post(url, json=payload, headers=headers) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
data = await response.json()
|
||||||
|
provider_response = data[self.provider]
|
||||||
|
|
||||||
|
if self.fallback_providers:
|
||||||
|
fallback_response = data.get(self.fallback_providers)
|
||||||
|
if fallback_response:
|
||||||
|
provider_response = fallback_response
|
||||||
|
|
||||||
|
if provider_response.get("status") == "fail":
|
||||||
|
err_msg = provider_response.get("error", {}).get("message")
|
||||||
|
raise Exception(err_msg)
|
||||||
|
|
||||||
|
return ChatResult(
|
||||||
|
generations=[
|
||||||
|
ChatGeneration(
|
||||||
|
message=AIMessage(
|
||||||
|
content=provider_response["generated_text"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
],
|
||||||
|
llm_output=data,
|
||||||
|
)
|
@ -0,0 +1,70 @@
|
|||||||
|
"""Test EdenAI API wrapper."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
||||||
|
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||||
|
|
||||||
|
from langchain_community.chat_models.edenai import (
|
||||||
|
ChatEdenAI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_chat_edenai() -> None:
|
||||||
|
"""Test ChatEdenAI wrapper."""
|
||||||
|
chat = ChatEdenAI(
|
||||||
|
provider="openai", model="gpt-3.5-turbo", temperature=0, max_tokens=1000
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Who are you ?")
|
||||||
|
response = chat([message])
|
||||||
|
assert isinstance(response, AIMessage)
|
||||||
|
assert isinstance(response.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_edenai_generate() -> None:
|
||||||
|
"""Test generate method of edenai."""
|
||||||
|
chat = ChatEdenAI(provider="google")
|
||||||
|
chat_messages: List[List[BaseMessage]] = [
|
||||||
|
[HumanMessage(content="What is the meaning of life?")]
|
||||||
|
]
|
||||||
|
messages_copy = [messages.copy() for messages in chat_messages]
|
||||||
|
result: LLMResult = chat.generate(chat_messages)
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
for response in result.generations[0]:
|
||||||
|
assert isinstance(response, ChatGeneration)
|
||||||
|
assert isinstance(response.text, str)
|
||||||
|
assert response.text == response.message.content
|
||||||
|
assert chat_messages == messages_copy
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
async def test_edenai_async_generate() -> None:
|
||||||
|
"""Test async generation."""
|
||||||
|
chat = ChatEdenAI(provider="google", max_tokens=50)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
result: LLMResult = await chat.agenerate([[message], [message]])
|
||||||
|
assert isinstance(result, LLMResult)
|
||||||
|
for response in result.generations[0]:
|
||||||
|
assert isinstance(response, ChatGeneration)
|
||||||
|
assert isinstance(response.text, str)
|
||||||
|
assert response.text == response.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
def test_edenai_streaming() -> None:
|
||||||
|
"""Test streaming EdenAI chat."""
|
||||||
|
llm = ChatEdenAI(provider="openai", max_tokens=50)
|
||||||
|
|
||||||
|
for chunk in llm.stream("Generate a high fantasy story."):
|
||||||
|
assert isinstance(chunk.content, str)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.scheduled
|
||||||
|
async def test_edenai_astream() -> None:
|
||||||
|
"""Test streaming from EdenAI."""
|
||||||
|
llm = ChatEdenAI(provider="openai", max_tokens=50)
|
||||||
|
|
||||||
|
async for token in llm.astream("Generate a high fantasy story."):
|
||||||
|
assert isinstance(token.content, str)
|
40
libs/community/tests/unit_tests/chat_models/test_edenai.py
Normal file
40
libs/community/tests/unit_tests/chat_models/test_edenai.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
"""Test EdenAI Chat API wrapper."""
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from langchain_community.chat_models.edenai import (
|
||||||
|
_format_edenai_messages,
|
||||||
|
_message_role,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("messages", "expected"),
|
||||||
|
[
|
||||||
|
(
|
||||||
|
[
|
||||||
|
SystemMessage(content="Translate the text from English to French"),
|
||||||
|
HumanMessage(content="Hello how are you today?"),
|
||||||
|
],
|
||||||
|
{
|
||||||
|
"text": "Hello how are you today?",
|
||||||
|
"previous_history": [],
|
||||||
|
"chatbot_global_action": "Translate the text from English to French",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str) -> None:
|
||||||
|
result = _format_edenai_messages(messages)
|
||||||
|
assert result == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("role", "role_response"),
|
||||||
|
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
|
||||||
|
)
|
||||||
|
def test_edenai_message_role(role: str, role_response) -> None:
|
||||||
|
role = _message_role(role)
|
||||||
|
assert role == role_response
|
Loading…
Reference in New Issue
Block a user