2024-03-09 21:28:55 +00:00
|
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit, create_sql_agent
|
2024-01-03 07:18:15 +00:00
|
|
|
from langchain_community.utilities.sql_database import SQLDatabase
|
2023-04-14 05:07:58 +00:00
|
|
|
from tests.unit_tests.llms.fake_llm import FakeLLM
|
|
|
|
|
|
|
|
|
|
|
|
def test_create_sql_agent() -> None:
|
|
|
|
db = SQLDatabase.from_uri("sqlite:///:memory:")
|
|
|
|
queries = {"foo": "Final Answer: baz"}
|
|
|
|
llm = FakeLLM(queries=queries, sequential_responses=True)
|
|
|
|
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
|
|
|
|
|
|
|
agent_executor = create_sql_agent(
|
|
|
|
llm=llm,
|
|
|
|
toolkit=toolkit,
|
|
|
|
)
|
|
|
|
|
|
|
|
assert agent_executor.run("hello") == "baz"
|