mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
community: fix ChatEdenAI + EdenAI Tools (#23715)
Fixes for Eden AI Custom tools and ChatEdenAI: - add missing import in __init__ of chat_models - add `args_schema` to custom tools. otherwise '__arg1' would sometimes be passed to the `run` method - fix IndexError when no human msg is added in ChatEdenAI
This commit is contained in:
parent
871bf5a841
commit
0fdbaf4a8d
@ -51,6 +51,7 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.deepinfra import (
|
||||
ChatDeepInfra,
|
||||
)
|
||||
from langchain_community.chat_models.edenai import ChatEdenAI
|
||||
from langchain_community.chat_models.ernie import (
|
||||
ErnieBotChat,
|
||||
)
|
||||
@ -182,6 +183,7 @@ __all__ = [
|
||||
"ChatOctoAI",
|
||||
"ChatDatabricks",
|
||||
"ChatDeepInfra",
|
||||
"ChatEdenAI",
|
||||
"ChatEverlyAI",
|
||||
"ChatFireworks",
|
||||
"ChatFriendli",
|
||||
@ -237,6 +239,7 @@ _module_lookup = {
|
||||
"ChatDatabricks": "langchain_community.chat_models.databricks",
|
||||
"ChatDeepInfra": "langchain_community.chat_models.deepinfra",
|
||||
"ChatEverlyAI": "langchain_community.chat_models.everlyai",
|
||||
"ChatEdenAI": "langchain_community.chat_models.edenai",
|
||||
"ChatFireworks": "langchain_community.chat_models.fireworks",
|
||||
"ChatFriendli": "langchain_community.chat_models.friendli",
|
||||
"ChatGooglePalm": "langchain_community.chat_models.google_palm",
|
||||
|
@ -122,8 +122,8 @@ def _format_edenai_messages(messages: List[BaseMessage]) -> Dict[str, Any]:
|
||||
system = None
|
||||
formatted_messages = []
|
||||
|
||||
human_messages = filter(lambda msg: isinstance(msg, HumanMessage), messages)
|
||||
last_human_message = list(human_messages)[-1] if human_messages else ""
|
||||
human_messages = list(filter(lambda msg: isinstance(msg, HumanMessage), messages))
|
||||
last_human_message = human_messages[-1] if human_messages else ""
|
||||
|
||||
tool_results, other_messages = _extract_edenai_tool_results_from_messages(messages)
|
||||
for i, message in enumerate(other_messages):
|
||||
|
@ -3,17 +3,21 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl, validator
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SpeechToTextInput(BaseModel):
|
||||
query: HttpUrl = Field(description="url of the audio to analyze")
|
||||
|
||||
|
||||
class EdenAiSpeechToTextTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Speech To Text API.
|
||||
|
||||
@ -23,7 +27,6 @@ class EdenAiSpeechToTextTool(EdenaiTool):
|
||||
To use, you should have
|
||||
the environment variable ``EDENAI_API_KEY`` set with your API token.
|
||||
You can find your token here: https://app.edenai.run/admin/account/settings
|
||||
|
||||
"""
|
||||
|
||||
edenai_api_key: Optional[str] = None
|
||||
@ -34,6 +37,7 @@ class EdenAiSpeechToTextTool(EdenaiTool):
|
||||
"Useful for when you have to convert audio to text."
|
||||
"Input should be a url to an audio file."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SpeechToTextInput
|
||||
is_async: bool = True
|
||||
|
||||
language: Optional[str] = "en"
|
||||
|
@ -1,17 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Dict, List, Literal, Optional
|
||||
from typing import Dict, List, Literal, Optional, Type
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import Field, root_validator, validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator, validator
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextToSpeechInput(BaseModel):
|
||||
query: str = Field(description="text to generate audio from")
|
||||
|
||||
|
||||
class EdenAiTextToSpeechTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Text to speech API.
|
||||
for api reference check edenai documentation:
|
||||
@ -30,6 +34,7 @@ class EdenAiTextToSpeechTool(EdenaiTool):
|
||||
"""the output is a string representing the URL of the audio file,
|
||||
or the path to the downloaded wav file """
|
||||
)
|
||||
args_schema: Type[BaseModel] = TextToSpeechInput
|
||||
|
||||
language: Optional[str] = "en"
|
||||
"""
|
||||
|
@ -1,15 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExplicitImageInput(BaseModel):
|
||||
query: HttpUrl = Field(description="url of the image to analyze")
|
||||
|
||||
|
||||
class EdenAiExplicitImageTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Explicit image detection.
|
||||
|
||||
@ -33,6 +38,7 @@ class EdenAiExplicitImageTool(EdenaiTool):
|
||||
pornography, violence, gore content, etc."""
|
||||
"Input should be the string url of the image ."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExplicitImageInput
|
||||
|
||||
combine_available: bool = True
|
||||
feature: str = "image"
|
||||
|
@ -1,15 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ObjectDetectionInput(BaseModel):
|
||||
query: HttpUrl = Field(description="url of the image to analyze")
|
||||
|
||||
|
||||
class EdenAiObjectDetectionTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Object detection API.
|
||||
|
||||
@ -30,6 +35,7 @@ class EdenAiObjectDetectionTool(EdenaiTool):
|
||||
(with bounding boxes) objects in an image """
|
||||
"Input should be the string url of the image to identify."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ObjectDetectionInput
|
||||
|
||||
show_positions: bool = False
|
||||
|
||||
|
@ -1,15 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IDParsingInput(BaseModel):
|
||||
query: HttpUrl = Field(description="url of the document to parse")
|
||||
|
||||
|
||||
class EdenAiParsingIDTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Identity parsing API.
|
||||
|
||||
@ -29,6 +34,7 @@ class EdenAiParsingIDTool(EdenaiTool):
|
||||
"Useful for when you have to extract information from an ID Document "
|
||||
"Input should be the string url of the document to parse."
|
||||
)
|
||||
args_schema: Type[BaseModel] = IDParsingInput
|
||||
|
||||
feature: str = "ocr"
|
||||
subfeature: str = "identity_parser"
|
||||
|
@ -1,15 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, HttpUrl
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class InvoiceParsingInput(BaseModel):
|
||||
query: HttpUrl = Field(description="url of the document to parse")
|
||||
|
||||
|
||||
class EdenAiParsingInvoiceTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Invoice parsing API.
|
||||
|
||||
@ -23,7 +28,6 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
|
||||
"""
|
||||
|
||||
name: str = "edenai_invoice_parsing"
|
||||
|
||||
description: str = (
|
||||
"A wrapper around edenai Services invoice parsing. "
|
||||
"""Useful for when you have to extract information from
|
||||
@ -33,6 +37,7 @@ class EdenAiParsingInvoiceTool(EdenaiTool):
|
||||
in a structured format to automate the invoice processing """
|
||||
"Input should be the string url of the document to parse."
|
||||
)
|
||||
args_schema: Type[BaseModel] = InvoiceParsingInput
|
||||
|
||||
language: Optional[str] = None
|
||||
"""
|
||||
|
@ -1,15 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Type
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForToolRun
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
from langchain_community.tools.edenai.edenai_base_tool import EdenaiTool
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TextModerationInput(BaseModel):
|
||||
query: str = Field(description="Text to moderate")
|
||||
|
||||
|
||||
class EdenAiTextModerationTool(EdenaiTool):
|
||||
"""Tool that queries the Eden AI Explicit text detection.
|
||||
|
||||
@ -23,7 +28,6 @@ class EdenAiTextModerationTool(EdenaiTool):
|
||||
"""
|
||||
|
||||
name: str = "edenai_explicit_content_detection_text"
|
||||
|
||||
description: str = (
|
||||
"A wrapper around edenai Services explicit content detection for text. "
|
||||
"""Useful for when you have to scan text for offensive,
|
||||
@ -44,6 +48,7 @@ class EdenAiTextModerationTool(EdenaiTool):
|
||||
"""
|
||||
"Input should be a string."
|
||||
)
|
||||
args_schema: Type[BaseModel] = TextModerationInput
|
||||
|
||||
language: str
|
||||
|
||||
|
@ -11,6 +11,7 @@ EXPECTED_ALL = [
|
||||
"ChatDatabricks",
|
||||
"ChatDeepInfra",
|
||||
"ChatEverlyAI",
|
||||
"ChatEdenAI",
|
||||
"ChatFireworks",
|
||||
"ChatFriendli",
|
||||
"ChatGooglePalm",
|
||||
|
Loading…
Reference in New Issue
Block a user