|
|
@ -4,6 +4,8 @@ from typing import List
|
|
|
|
from pydantic import Field
|
|
|
|
from pydantic import Field
|
|
|
|
|
|
|
|
|
|
|
|
from langchain.agents.agent_toolkits.base import BaseToolkit
|
|
|
|
from langchain.agents.agent_toolkits.base import BaseToolkit
|
|
|
|
|
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
|
|
|
|
from langchain.llms.openai import OpenAI
|
|
|
|
from langchain.sql_database import SQLDatabase
|
|
|
|
from langchain.sql_database import SQLDatabase
|
|
|
|
from langchain.tools import BaseTool
|
|
|
|
from langchain.tools import BaseTool
|
|
|
|
from langchain.tools.sql_database.tool import (
|
|
|
|
from langchain.tools.sql_database.tool import (
|
|
|
@ -18,6 +20,7 @@ class SQLDatabaseToolkit(BaseToolkit):
|
|
|
|
"""Toolkit for interacting with SQL databases."""
|
|
|
|
"""Toolkit for interacting with SQL databases."""
|
|
|
|
|
|
|
|
|
|
|
|
db: SQLDatabase = Field(exclude=True)
|
|
|
|
db: SQLDatabase = Field(exclude=True)
|
|
|
|
|
|
|
|
llm: BaseLLM = Field(default_factory=lambda: OpenAI(temperature=0))
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
@property
|
|
|
|
def dialect(self) -> str:
|
|
|
|
def dialect(self) -> str:
|
|
|
@ -35,5 +38,5 @@ class SQLDatabaseToolkit(BaseToolkit):
|
|
|
|
QuerySQLDataBaseTool(db=self.db),
|
|
|
|
QuerySQLDataBaseTool(db=self.db),
|
|
|
|
InfoSQLDatabaseTool(db=self.db),
|
|
|
|
InfoSQLDatabaseTool(db=self.db),
|
|
|
|
ListSQLDatabaseTool(db=self.db),
|
|
|
|
ListSQLDatabaseTool(db=self.db),
|
|
|
|
QueryCheckerTool(db=self.db),
|
|
|
|
QueryCheckerTool(db=self.db, llm=self.llm),
|
|
|
|
]
|
|
|
|
]
|
|
|
|