community: fix ChatEdenAI + EdenAI Tools (#23715)

Fixes for Eden AI Custom tools and ChatEdenAI:
- add missing import in __init__ of chat_models
- add `args_schema` to custom tools. otherwise '__arg1' would sometimes
be passed to the `run` method
- fix IndexError when no human msg is added in ChatEdenAI
This commit is contained in:
KyrianC 2024-07-25 21:19:14 +02:00 committed by GitHub
parent 871bf5a841
commit 0fdbaf4a8d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 55 additions and 14 deletions

View File

@ -51,6 +51,7 @@ if TYPE_CHECKING:
from langchain_community.chat_models.deepinfra import (
ChatDeepInfra,
)
from langchain_community.chat_models.edenai import ChatEdenAI
from langchain_community.chat_models.ernie import (
ErnieBotChat,
)
@ -182,6 +183,7 @@ __all__ = [
"ChatOctoAI",
"ChatDatabricks",
"ChatDeepInfra",
"ChatEdenAI",
"ChatEverlyAI",
"ChatFireworks",
"ChatFriendli",
@ -237,6 +239,7 @@ _module_lookup = {
"ChatDatabricks": "langchain_community.chat_models.databricks",
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
"ChatEverlyAI": "langchain_community.chat_models.everlyai",
"ChatEdenAI": "langchain_community.chat_models.edenai",
"ChatFireworks": "langchain_community.chat_models.fireworks",
"ChatFriendli": "langchain_community.chat_models.friendli",
"ChatGooglePalm": "langchain_community.chat_models.google_palm",

View File

@ -122,8 +122,8 @@ def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
system = None
formatted_messages = []
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages)
last_human_message = list(human_messages)[-1] if human_messages else ""
human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages))
last_human_message = human_messages[-1] if human_messages else ""
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
for i, message in enumerate(other_messages):

View File

@ -3,17 +3,21 @@ from __future__ import annotations
import json
import logging
import time
from typing import List, Optional
from typing import List, Optional, Type
import requests
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import validator
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl, validator
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class SpeechToTextInput(BaseModel):
query: HttpUrl = Field(description="url of the audio to analyze")
class EdenAiSpeechToTextTool(EdenaiTool):
"""Tool that queries the Eden AI Speech To Text API.
@ -23,7 +27,6 @@ class EdenAiSpeechToTextTool(EdenaiTool):
To use, you should have
the environment variable ``EDENAI_API_KEY`` set with your API token.
You can find your token here: https://app.edenai.run/admin/account/settings
"""
edenai_api_key: Optional[str] = None
@ -34,6 +37,7 @@ class EdenAiSpeechToTextTool(EdenaiTool):
"Useful for when you have to convert audio to text."
"Input should be a url to an audio file."
)
args_schema: Type[BaseModel] = SpeechToTextInput
is_async: bool = True
language: Optional[str] = "en"

View File

@ -1,17 +1,21 @@
from __future__ import annotations
import logging
from typing import Dict, List, Literal, Optional
from typing import Dict, List, Literal, Optional, Type
import requests
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import Field, root_validator, validator
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class TextToSpeechInput(BaseModel):
query: str = Field(description="text to generate audio from")
class EdenAiTextToSpeechTool(EdenaiTool):
"""Tool that queries the Eden AI Text to speech API.
for api reference check edenai documentation:
@ -30,6 +34,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
"""the output is a string representing the URL of the audio file,
or the path to the downloaded wav file """
)
args_schema: Type[BaseModel] = TextToSpeechInput
language: Optional[str] = "en"
"""

View File

@ -1,15 +1,20 @@
from __future__ import annotations
import logging
from typing import Optional
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class ExplicitImageInput(BaseModel):
query: HttpUrl = Field(description="url of the image to analyze")
class EdenAiExplicitImageTool(EdenaiTool):
"""Tool that queries the Eden AI Explicit image detection.
@ -33,6 +38,7 @@ class EdenAiExplicitImageTool(EdenaiTool):
pornography, violence, gore content, etc."""
"Input should be the string url of the image ."
)
args_schema: Type[BaseModel] = ExplicitImageInput
combine_available: bool = True
feature: str = "image"

View File

@ -1,15 +1,20 @@
from __future__ import annotations
import logging
from typing import Optional
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class ObjectDetectionInput(BaseModel):
query: HttpUrl = Field(description="url of the image to analyze")
class EdenAiObjectDetectionTool(EdenaiTool):
"""Tool that queries the Eden AI Object detection API.
@ -30,6 +35,7 @@ class EdenAiObjectDetectionTool(EdenaiTool):
(with bounding boxes) objects in an image """
"Input should be the string url of the image to identify."
)
args_schema: Type[BaseModel] = ObjectDetectionInput
show_positions: bool = False

View File

@ -1,15 +1,20 @@
from __future__ import annotations
import logging
from typing import Optional
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class IDParsingInput(BaseModel):
query: HttpUrl = Field(description="url of the document to parse")
class EdenAiParsingIDTool(EdenaiTool):
"""Tool that queries the Eden AI Identity parsing API.
@ -29,6 +34,7 @@ class EdenAiParsingIDTool(EdenaiTool):
"Useful for when you have to extract information from an ID Document "
"Input should be the string url of the document to parse."
)
args_schema: Type[BaseModel] = IDParsingInput
feature: str = "ocr"
subfeature: str = "identity_parser"

View File

@ -1,15 +1,20 @@
from __future__ import annotations
import logging
from typing import Optional
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class InvoiceParsingInput(BaseModel):
query: HttpUrl = Field(description="url of the document to parse")
class EdenAiParsingInvoiceTool(EdenaiTool):
"""Tool that queries the Eden AI Invoice parsing API.
@ -23,7 +28,6 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
"""
name: str = "edenai_invoice_parsing"
description: str = (
"A wrapper around edenai Services invoice parsing. "
"""Useful for when you have to extract information from
@ -33,6 +37,7 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
in a structured format to automate the invoice processing """
"Input should be the string url of the document to parse."
)
args_schema: Type[BaseModel] = InvoiceParsingInput
language: Optional[str] = None
"""

View File

@ -1,15 +1,20 @@
from __future__ import annotations
import logging
from typing import Optional
from typing import Optional, Type
from langchain_core.callbacks import CallbackManagerForToolRun
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
logger = logging.getLogger(__name__)
class TextModerationInput(BaseModel):
query: str = Field(description="Text to moderate")
class EdenAiTextModerationTool(EdenaiTool):
"""Tool that queries the Eden AI Explicit text detection.
@ -23,7 +28,6 @@ class EdenAiTextModerationTool(EdenaiTool):
"""
name: str = "edenai_explicit_content_detection_text"
description: str = (
"A wrapper around edenai Services explicit content detection for text. "
"""Useful for when you have to scan text for offensive,
@ -44,6 +48,7 @@ class EdenAiTextModerationTool(EdenaiTool):
"""
"Input should be a string."
)
args_schema: Type[BaseModel] = TextModerationInput
language: str

View File

@ -11,6 +11,7 @@ EXPECTED_ALL = [
"ChatDatabricks",
"ChatDeepInfra",
"ChatEverlyAI",
"ChatEdenAI",
"ChatFireworks",
"ChatFriendli",
"ChatGooglePalm",