mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
factor out mock python repl (#43)
This commit is contained in:
parent
7b0d02ac51
commit
fba30e07d1
@ -9,6 +9,7 @@ from typing import Dict, List
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
|
from langchain.python import PythonREPL
|
||||||
|
|
||||||
|
|
||||||
class PythonChain(Chain, BaseModel):
|
class PythonChain(Chain, BaseModel):
|
||||||
@ -41,9 +42,10 @@ class PythonChain(Chain, BaseModel):
|
|||||||
return [self.output_key]
|
return [self.output_key]
|
||||||
|
|
||||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||||
|
python_repl = PythonREPL()
|
||||||
old_stdout = sys.stdout
|
old_stdout = sys.stdout
|
||||||
sys.stdout = mystdout = StringIO()
|
sys.stdout = mystdout = StringIO()
|
||||||
exec(inputs[self.input_key])
|
python_repl.run(inputs[self.input_key])
|
||||||
sys.stdout = old_stdout
|
sys.stdout = old_stdout
|
||||||
output = mystdout.getvalue()
|
output = mystdout.getvalue()
|
||||||
return {self.output_key: output}
|
return {self.output_key: output}
|
||||||
|
15
langchain/python.py
Normal file
15
langchain/python.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
"""Mock Python REPL."""
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class PythonREPL:
|
||||||
|
"""Simulates a standalone Python REPL."""
|
||||||
|
|
||||||
|
def __init__(self, _globals: Optional[Dict] = None, _locals: Optional[Dict] = None):
|
||||||
|
"""Initialize with optional globals and locals."""
|
||||||
|
self._globals = _globals if _globals is not None else {}
|
||||||
|
self._locals = _locals if _locals is not None else {}
|
||||||
|
|
||||||
|
def run(self, command: str) -> None:
|
||||||
|
"""Run command with own globals/locals."""
|
||||||
|
exec(command, self._globals, self._locals)
|
34
tests/unit_tests/test_python.py
Normal file
34
tests/unit_tests/test_python.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
"""Test functionality of Python REPL."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from langchain.python import PythonREPL
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl() -> None:
|
||||||
|
"""Test functionality when globals/locals are not provided."""
|
||||||
|
repl = PythonREPL()
|
||||||
|
|
||||||
|
# Run a simple initial command.
|
||||||
|
repl.run("foo = 1")
|
||||||
|
assert repl._locals["foo"] == 1
|
||||||
|
|
||||||
|
# Now run a command that accesses `foo` to make sure it still has it.
|
||||||
|
repl.run("bar = foo * 2")
|
||||||
|
assert repl._locals["bar"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_no_previous_variables() -> None:
|
||||||
|
"""Test that it does not have access to variables created outside the scope."""
|
||||||
|
foo = 3 # noqa: F841
|
||||||
|
repl = PythonREPL()
|
||||||
|
with pytest.raises(NameError):
|
||||||
|
repl.run("print(foo)")
|
||||||
|
|
||||||
|
|
||||||
|
def test_python_repl_pass_in_locals() -> None:
|
||||||
|
"""Test functionality when passing in locals."""
|
||||||
|
_locals = {"foo": 4}
|
||||||
|
repl = PythonREPL(_locals=_locals)
|
||||||
|
repl.run("bar = foo * 2")
|
||||||
|
assert repl._locals["bar"] == 8
|
Loading…
Reference in New Issue
Block a user