mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
82baecc892
This PR adds * `ZeroShotAgent.as_sql_agent`, which returns an agent for interacting with a sql database. This builds off of `SQLDatabaseChain`. The main advantages are 1) answering general questions about the db, 2) access to a tool for double checking queries, and 3) recovering from errors * `ZeroShotAgent.as_json_agent` which returns an agent for interacting with json blobs. * Several examples in notebooks --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
"""Test functionality of Python REPL."""
|
|
|
|
from langchain.python import PythonREPL
|
|
|
|
|
|
def test_python_repl() -> None:
|
|
"""Test functionality when globals/locals are not provided."""
|
|
repl = PythonREPL()
|
|
|
|
# Run a simple initial command.
|
|
repl.run("foo = 1")
|
|
assert repl.locals is not None
|
|
assert repl.locals["foo"] == 1
|
|
|
|
# Now run a command that accesses `foo` to make sure it still has it.
|
|
repl.run("bar = foo * 2")
|
|
assert repl.locals is not None
|
|
assert repl.locals["bar"] == 2
|
|
|
|
|
|
def test_python_repl_no_previous_variables() -> None:
|
|
"""Test that it does not have access to variables created outside the scope."""
|
|
foo = 3 # noqa: F841
|
|
repl = PythonREPL()
|
|
output = repl.run("print(foo)")
|
|
assert output == "name 'foo' is not defined"
|
|
|
|
|
|
def test_python_repl_pass_in_locals() -> None:
|
|
"""Test functionality when passing in locals."""
|
|
_locals = {"foo": 4}
|
|
repl = PythonREPL(_locals=_locals)
|
|
repl.run("bar = foo * 2")
|
|
assert repl.locals is not None
|
|
assert repl.locals["bar"] == 8
|
|
|
|
|
|
def test_functionality() -> None:
|
|
"""Test correct functionality."""
|
|
chain = PythonREPL()
|
|
code = "print(1 + 1)"
|
|
output = chain.run(code)
|
|
assert output == "2\n"
|
|
|
|
|
|
def test_function() -> None:
|
|
"""Test correct functionality."""
|
|
chain = PythonREPL()
|
|
code = "def add(a, b): " " return a + b"
|
|
output = chain.run(code)
|
|
assert output == ""
|
|
|
|
code = "print(add(1, 2))"
|
|
output = chain.run(code)
|
|
assert output == "3\n"
|