mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
03178ee74f
### Description Add tools implementation to `ChatEdenAI`: - `bind_tools()` - `with_structured_output()` ### Documentation Updated `docs/docs/integrations/chat/edenai.ipynb` ### Notes We don´t support stream with tools as of yet. If stream is called with tools we directly yield the whole message from `generate` (implemented the same way as Anthropic did).
627 lines
22 KiB
Python
627 lines
22 KiB
Python
import json
|
|
import warnings
|
|
from operator import itemgetter
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Callable,
|
|
Dict,
|
|
Iterator,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Sequence,
|
|
Tuple,
|
|
Type,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from aiohttp import ClientSession
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models import LanguageModelInput
|
|
from langchain_core.language_models.chat_models import (
|
|
BaseChatModel,
|
|
agenerate_from_stream,
|
|
generate_from_stream,
|
|
)
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
HumanMessage,
|
|
InvalidToolCall,
|
|
SystemMessage,
|
|
ToolCall,
|
|
ToolCallChunk,
|
|
ToolMessage,
|
|
)
|
|
from langchain_core.output_parsers.base import OutputParserLike
|
|
from langchain_core.output_parsers.openai_tools import (
|
|
JsonOutputKeyToolsParser,
|
|
PydanticToolsParser,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from langchain_core.pydantic_v1 import (
|
|
BaseModel,
|
|
Extra,
|
|
Field,
|
|
SecretStr,
|
|
root_validator,
|
|
)
|
|
from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough
|
|
from langchain_core.tools import BaseTool
|
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
from langchain_core.utils.function_calling import convert_to_openai_tool
|
|
|
|
from langchain_community.utilities.requests import Requests
|
|
|
|
|
|
def _result_to_chunked_message(generated_result: ChatResult) -> ChatGenerationChunk:
|
|
message = generated_result.generations[0].message
|
|
if isinstance(message, AIMessage) and message.tool_calls is not None:
|
|
tool_call_chunks = [
|
|
ToolCallChunk(
|
|
name=tool_call["name"],
|
|
args=json.dumps(tool_call["args"]),
|
|
id=tool_call["id"],
|
|
index=idx,
|
|
)
|
|
for idx, tool_call in enumerate(message.tool_calls)
|
|
]
|
|
message_chunk = AIMessageChunk(
|
|
content=message.content,
|
|
tool_call_chunks=tool_call_chunks,
|
|
)
|
|
return ChatGenerationChunk(message=message_chunk)
|
|
else:
|
|
return cast(ChatGenerationChunk, generated_result.generations[0])
|
|
|
|
|
|
def _message_role(type: str) -> str:
|
|
role_mapping = {
|
|
"ai": "assistant",
|
|
"human": "user",
|
|
"chat": "user",
|
|
"AIMessageChunk": "assistant",
|
|
}
|
|
|
|
if type in role_mapping:
|
|
return role_mapping[type]
|
|
else:
|
|
raise ValueError(f"Unknown type: {type}")
|
|
|
|
|
|
def _extract_edenai_tool_results_from_messages(
|
|
messages: List[BaseMessage],
|
|
) -> Tuple[List[Dict[str, Any]], List[BaseMessage]]:
|
|
"""
|
|
Get the last langchain tools messages to transform them into edenai tool_results
|
|
Returns tool_results and messages without the extracted tool messages
|
|
"""
|
|
tool_results: List[Dict[str, Any]] = []
|
|
other_messages = messages[:]
|
|
for msg in reversed(messages):
|
|
if isinstance(msg, ToolMessage):
|
|
tool_results = [
|
|
{"id": msg.tool_call_id, "result": msg.content},
|
|
*tool_results,
|
|
]
|
|
other_messages.pop()
|
|
else:
|
|
break
|
|
return tool_results, other_messages
|
|
|
|
|
|
def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
|
system = None
|
|
formatted_messages = []
|
|
|
|
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages)
|
|
last_human_message = list(human_messages)[-1] if human_messages else ""
|
|
|
|
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
|
for i, message in enumerate(other_messages):
|
|
if isinstance(message, SystemMessage):
|
|
if i != 0:
|
|
raise ValueError("System message must be at beginning of message list.")
|
|
system = message.content
|
|
elif isinstance(message, ToolMessage):
|
|
formatted_messages.append({"role": "tool", "message": message.content})
|
|
elif message != last_human_message:
|
|
formatted_messages.append(
|
|
{
|
|
"role": _message_role(message.type),
|
|
"message": message.content,
|
|
"tool_calls": _format_tool_calls_to_edenai_tool_calls(message),
|
|
}
|
|
)
|
|
|
|
return {
|
|
"text": getattr(last_human_message, "content", ""),
|
|
"previous_history": formatted_messages,
|
|
"chatbot_global_action": system,
|
|
"tool_results": tool_results,
|
|
}
|
|
|
|
|
|
def _format_tool_calls_to_edenai_tool_calls(message: BaseMessage) -> List:
|
|
tool_calls = getattr(message, "tool_calls", [])
|
|
invalid_tool_calls = getattr(message, "invalid_tool_calls", [])
|
|
edenai_tool_calls = []
|
|
|
|
for invalid_tool_call in invalid_tool_calls:
|
|
edenai_tool_calls.append(
|
|
{
|
|
"arguments": invalid_tool_call.get("args"),
|
|
"id": invalid_tool_call.get("id"),
|
|
"name": invalid_tool_call.get("name"),
|
|
}
|
|
)
|
|
|
|
for tool_call in tool_calls:
|
|
tool_args = tool_call.get("args", {})
|
|
try:
|
|
arguments = json.dumps(tool_args)
|
|
except TypeError:
|
|
arguments = str(tool_args)
|
|
edenai_tool_calls.append(
|
|
{
|
|
"arguments": arguments,
|
|
"id": tool_call["id"],
|
|
"name": tool_call["name"],
|
|
}
|
|
)
|
|
return edenai_tool_calls
|
|
|
|
|
|
def _extract_tool_calls_from_edenai_response(
|
|
provider_response: Dict[str, Any],
|
|
) -> Tuple[List[ToolCall], List[InvalidToolCall]]:
|
|
tool_calls = []
|
|
invalid_tool_calls = []
|
|
|
|
message = provider_response.get("message", {})[1]
|
|
|
|
if raw_tool_calls := message.get("tool_calls"):
|
|
for raw_tool_call in raw_tool_calls:
|
|
try:
|
|
tool_calls.append(
|
|
ToolCall(
|
|
name=raw_tool_call["name"],
|
|
args=json.loads(raw_tool_call["arguments"]),
|
|
id=raw_tool_call["id"],
|
|
)
|
|
)
|
|
except json.JSONDecodeError as exc:
|
|
invalid_tool_calls.append(
|
|
InvalidToolCall(
|
|
name=raw_tool_call.get("name"),
|
|
args=raw_tool_call.get("arguments"),
|
|
id=raw_tool_call.get("id"),
|
|
error=f"Received JSONDecodeError {exc}",
|
|
)
|
|
)
|
|
|
|
return tool_calls, invalid_tool_calls
|
|
|
|
|
|
class ChatEdenAI(BaseChatModel):
|
|
"""`EdenAI` chat large language models.
|
|
|
|
`EdenAI` is a versatile platform that allows you to access various language models
|
|
from different providers such as Google, OpenAI, Cohere, Mistral and more.
|
|
|
|
To get started, make sure you have the environment variable ``EDENAI_API_KEY``
|
|
set with your API key, or pass it as a named parameter to the constructor.
|
|
|
|
Additionally, `EdenAI` provides the flexibility to choose from a variety of models,
|
|
including the ones like "gpt-4".
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.chat_models import ChatEdenAI
|
|
from langchain_core.messages import HumanMessage
|
|
|
|
# Initialize `ChatEdenAI` with the desired configuration
|
|
chat = ChatEdenAI(
|
|
provider="openai",
|
|
model="gpt-4",
|
|
max_tokens=256,
|
|
temperature=0.75)
|
|
|
|
# Create a list of messages to interact with the model
|
|
messages = [HumanMessage(content="hello")]
|
|
|
|
# Invoke the model with the provided messages
|
|
chat.invoke(messages)
|
|
|
|
`EdenAI` goes beyond mere model invocation. It empowers you with advanced features :
|
|
|
|
- **Multiple Providers**: access to a diverse range of llms offered by various
|
|
providers giving you the freedom to choose the best-suited model for your use case.
|
|
|
|
- **Fallback Mechanism**: Set a fallback mechanism to ensure seamless operations
|
|
even if the primary provider is unavailable, you can easily switches to an
|
|
alternative provider.
|
|
|
|
- **Usage Statistics**: Track usage statistics on a per-project
|
|
and per-API key basis.
|
|
This feature allows you to monitor and manage resource consumption effectively.
|
|
|
|
- **Monitoring and Observability**: `EdenAI` provides comprehensive monitoring
|
|
and observability tools on the platform.
|
|
|
|
Example of setting up a fallback mechanism:
|
|
.. code-block:: python
|
|
|
|
# Initialize `ChatEdenAI` with a fallback provider
|
|
chat_with_fallback = ChatEdenAI(
|
|
provider="openai",
|
|
model="gpt-4",
|
|
max_tokens=256,
|
|
temperature=0.75,
|
|
fallback_provider="google")
|
|
|
|
you can find more details here : https://docs.edenai.co/reference/text_chat_create
|
|
"""
|
|
|
|
provider: str = "openai"
|
|
"""chat provider to use (eg: openai,google etc.)"""
|
|
|
|
model: Optional[str] = None
|
|
"""
|
|
model name for above provider (eg: 'gpt-4' for openai)
|
|
available models are shown on https://docs.edenai.co/ under 'available providers'
|
|
"""
|
|
|
|
max_tokens: int = 256
|
|
"""Denotes the number of tokens to predict per generation."""
|
|
|
|
temperature: Optional[float] = 0
|
|
"""A non-negative float that tunes the degree of randomness in generation."""
|
|
|
|
streaming: bool = False
|
|
"""Whether to stream the results."""
|
|
|
|
fallback_providers: Optional[str] = None
|
|
"""Providers in this will be used as fallback if the call to provider fails."""
|
|
|
|
edenai_api_url: str = "https://api.edenai.run/v2"
|
|
|
|
edenai_api_key: Optional[SecretStr] = Field(None, description="EdenAI API Token")
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key exists in environment."""
|
|
values["edenai_api_key"] = convert_to_secret_str(
|
|
get_from_dict_or_env(values, "edenai_api_key", "EDENAI_API_KEY")
|
|
)
|
|
return values
|
|
|
|
@staticmethod
|
|
def get_user_agent() -> str:
|
|
from langchain_community import __version__
|
|
|
|
return f"langchain/{__version__}"
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of chat model."""
|
|
return "edenai-chat"
|
|
|
|
@property
|
|
def _api_key(self) -> str:
|
|
if self.edenai_api_key:
|
|
return self.edenai_api_key.get_secret_value()
|
|
return ""
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
"""Call out to EdenAI's chat endpoint."""
|
|
if "available_tools" in kwargs:
|
|
yield self._stream_with_tools_as_generate(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return
|
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
|
headers = {
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"User-Agent": self.get_user_agent(),
|
|
}
|
|
formatted_data = _format_edenai_messages(messages=messages)
|
|
payload: Dict[str, Any] = {
|
|
"providers": self.provider,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"fallback_providers": self.fallback_providers,
|
|
**formatted_data,
|
|
**kwargs,
|
|
}
|
|
|
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if self.model is not None:
|
|
payload["settings"] = {self.provider: self.model}
|
|
|
|
request = Requests(headers=headers)
|
|
response = request.post(url=url, data=payload, stream=True)
|
|
response.raise_for_status()
|
|
|
|
for chunk_response in response.iter_lines():
|
|
chunk = json.loads(chunk_response.decode())
|
|
token = chunk["text"]
|
|
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(token, chunk=cg_chunk)
|
|
yield cg_chunk
|
|
|
|
async def _astream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[ChatGenerationChunk]:
|
|
if "available_tools" in kwargs:
|
|
yield await self._astream_with_tools_as_agenerate(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return
|
|
url = f"{self.edenai_api_url}/text/chat/stream"
|
|
headers = {
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"User-Agent": self.get_user_agent(),
|
|
}
|
|
formatted_data = _format_edenai_messages(messages=messages)
|
|
payload: Dict[str, Any] = {
|
|
"providers": self.provider,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"fallback_providers": self.fallback_providers,
|
|
**formatted_data,
|
|
**kwargs,
|
|
}
|
|
|
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if self.model is not None:
|
|
payload["settings"] = {self.provider: self.model}
|
|
|
|
async with ClientSession() as session:
|
|
async with session.post(url, json=payload, headers=headers) as response:
|
|
response.raise_for_status()
|
|
async for chunk_response in response.content:
|
|
chunk = json.loads(chunk_response.decode())
|
|
token = chunk["text"]
|
|
cg_chunk = ChatGenerationChunk(
|
|
message=AIMessageChunk(content=token)
|
|
)
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(
|
|
token=chunk["text"], chunk=cg_chunk
|
|
)
|
|
yield cg_chunk
|
|
|
|
def bind_tools(
|
|
self,
|
|
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
|
|
*,
|
|
tool_choice: Optional[
|
|
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
|
|
] = None,
|
|
**kwargs: Any,
|
|
) -> Runnable[LanguageModelInput, BaseMessage]:
|
|
formatted_tools = [convert_to_openai_tool(tool)["function"] for tool in tools]
|
|
formatted_tool_choice = "required" if tool_choice == "any" else tool_choice
|
|
return super().bind(
|
|
available_tools=formatted_tools, tool_choice=formatted_tool_choice, **kwargs
|
|
)
|
|
|
|
def with_structured_output(
|
|
self,
|
|
schema: Union[Dict, Type[BaseModel]],
|
|
*,
|
|
include_raw: bool = False,
|
|
**kwargs: Any,
|
|
) -> Runnable[LanguageModelInput, Union[Dict, BaseModel]]:
|
|
if kwargs:
|
|
raise ValueError(f"Received unsupported arguments {kwargs}")
|
|
llm = self.bind_tools([schema], tool_choice="required")
|
|
if isinstance(schema, type) and issubclass(schema, BaseModel):
|
|
output_parser: OutputParserLike = PydanticToolsParser(
|
|
tools=[schema], first_tool_only=True
|
|
)
|
|
else:
|
|
key_name = convert_to_openai_tool(schema)["function"]["name"]
|
|
output_parser = JsonOutputKeyToolsParser(
|
|
key_name=key_name, first_tool_only=True
|
|
)
|
|
|
|
if include_raw:
|
|
parser_assign = RunnablePassthrough.assign(
|
|
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
|
|
)
|
|
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
|
|
parser_with_fallback = parser_assign.with_fallbacks(
|
|
[parser_none], exception_key="parsing_error"
|
|
)
|
|
return RunnableMap(raw=llm) | parser_with_fallback
|
|
else:
|
|
return llm | output_parser
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Call out to EdenAI's chat endpoint."""
|
|
if self.streaming:
|
|
if "available_tools" in kwargs:
|
|
warnings.warn(
|
|
"stream: Tool use is not yet supported in streaming mode."
|
|
)
|
|
else:
|
|
stream_iter = self._stream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return generate_from_stream(stream_iter)
|
|
|
|
url = f"{self.edenai_api_url}/text/chat"
|
|
headers = {
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"User-Agent": self.get_user_agent(),
|
|
}
|
|
formatted_data = _format_edenai_messages(messages=messages)
|
|
|
|
payload: Dict[str, Any] = {
|
|
"providers": self.provider,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"fallback_providers": self.fallback_providers,
|
|
**formatted_data,
|
|
**kwargs,
|
|
}
|
|
|
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if self.model is not None:
|
|
payload["settings"] = {self.provider: self.model}
|
|
|
|
request = Requests(headers=headers)
|
|
response = request.post(url=url, data=payload)
|
|
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
provider_response = data[self.provider]
|
|
|
|
if self.fallback_providers:
|
|
fallback_response = data.get(self.fallback_providers)
|
|
if fallback_response:
|
|
provider_response = fallback_response
|
|
|
|
if provider_response.get("status") == "fail":
|
|
err_msg = provider_response.get("error", {}).get("message")
|
|
raise Exception(err_msg)
|
|
|
|
tool_calls, invalid_tool_calls = _extract_tool_calls_from_edenai_response(
|
|
provider_response
|
|
)
|
|
|
|
return ChatResult(
|
|
generations=[
|
|
ChatGeneration(
|
|
message=AIMessage(
|
|
content=provider_response["generated_text"] or "",
|
|
tool_calls=tool_calls,
|
|
invalid_tool_calls=invalid_tool_calls,
|
|
)
|
|
)
|
|
],
|
|
llm_output=data,
|
|
)
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
if self.streaming:
|
|
if "available_tools" in kwargs:
|
|
warnings.warn(
|
|
"stream: Tool use is not yet supported in streaming mode."
|
|
)
|
|
else:
|
|
stream_iter = self._astream(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return await agenerate_from_stream(stream_iter)
|
|
|
|
url = f"{self.edenai_api_url}/text/chat"
|
|
headers = {
|
|
"Authorization": f"Bearer {self._api_key}",
|
|
"User-Agent": self.get_user_agent(),
|
|
}
|
|
formatted_data = _format_edenai_messages(messages=messages)
|
|
payload: Dict[str, Any] = {
|
|
"providers": self.provider,
|
|
"max_tokens": self.max_tokens,
|
|
"temperature": self.temperature,
|
|
"fallback_providers": self.fallback_providers,
|
|
**formatted_data,
|
|
**kwargs,
|
|
}
|
|
|
|
payload = {k: v for k, v in payload.items() if v is not None}
|
|
|
|
if self.model is not None:
|
|
payload["settings"] = {self.provider: self.model}
|
|
|
|
async with ClientSession() as session:
|
|
async with session.post(url, json=payload, headers=headers) as response:
|
|
response.raise_for_status()
|
|
data = await response.json()
|
|
provider_response = data[self.provider]
|
|
|
|
if self.fallback_providers:
|
|
fallback_response = data.get(self.fallback_providers)
|
|
if fallback_response:
|
|
provider_response = fallback_response
|
|
|
|
if provider_response.get("status") == "fail":
|
|
err_msg = provider_response.get("error", {}).get("message")
|
|
raise Exception(err_msg)
|
|
|
|
return ChatResult(
|
|
generations=[
|
|
ChatGeneration(
|
|
message=AIMessage(
|
|
content=provider_response["generated_text"]
|
|
)
|
|
)
|
|
],
|
|
llm_output=data,
|
|
)
|
|
|
|
def _stream_with_tools_as_generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]],
|
|
run_manager: Optional[CallbackManagerForLLMRun],
|
|
**kwargs: Any,
|
|
) -> ChatGenerationChunk:
|
|
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
|
result = self._generate(messages, stop=stop, run_manager=run_manager, **kwargs)
|
|
return _result_to_chunked_message(result)
|
|
|
|
async def _astream_with_tools_as_agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]],
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun],
|
|
**kwargs: Any,
|
|
) -> ChatGenerationChunk:
|
|
warnings.warn("stream: Tool use is not yet supported in streaming mode.")
|
|
result = await self._agenerate(
|
|
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return _result_to_chunked_message(result)
|