forked from Archives/langchain
factor out mock python repl (#43)
This commit is contained in:
parent
7b0d02ac51
commit
fba30e07d1
@ -9,6 +9,7 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
|
||||
class PythonChain(Chain, BaseModel):
|
||||
@ -41,9 +42,10 @@ class PythonChain(Chain, BaseModel):
|
||||
return [self.output_key]
|
||||
|
||||
def _run(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
python_repl = PythonREPL()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
exec(inputs[self.input_key])
|
||||
python_repl.run(inputs[self.input_key])
|
||||
sys.stdout = old_stdout
|
||||
output = mystdout.getvalue()
|
||||
return {self.output_key: output}
|
||||
|
15
langchain/python.py
Normal file
15
langchain/python.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Mock Python REPL."""
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
class PythonREPL:
|
||||
"""Simulates a standalone Python REPL."""
|
||||
|
||||
def __init__(self, _globals: Optional[Dict] = None, _locals: Optional[Dict] = None):
|
||||
"""Initialize with optional globals and locals."""
|
||||
self._globals = _globals if _globals is not None else {}
|
||||
self._locals = _locals if _locals is not None else {}
|
||||
|
||||
def run(self, command: str) -> None:
|
||||
"""Run command with own globals/locals."""
|
||||
exec(command, self._globals, self._locals)
|
34
tests/unit_tests/test_python.py
Normal file
34
tests/unit_tests/test_python.py
Normal file
@ -0,0 +1,34 @@
|
||||
"""Test functionality of Python REPL."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
|
||||
def test_python_repl() -> None:
|
||||
"""Test functionality when globals/locals are not provided."""
|
||||
repl = PythonREPL()
|
||||
|
||||
# Run a simple initial command.
|
||||
repl.run("foo = 1")
|
||||
assert repl._locals["foo"] == 1
|
||||
|
||||
# Now run a command that accesses `foo` to make sure it still has it.
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl._locals["bar"] == 2
|
||||
|
||||
|
||||
def test_python_repl_no_previous_variables() -> None:
|
||||
"""Test that it does not have access to variables created outside the scope."""
|
||||
foo = 3 # noqa: F841
|
||||
repl = PythonREPL()
|
||||
with pytest.raises(NameError):
|
||||
repl.run("print(foo)")
|
||||
|
||||
|
||||
def test_python_repl_pass_in_locals() -> None:
|
||||
"""Test functionality when passing in locals."""
|
||||
_locals = {"foo": 4}
|
||||
repl = PythonREPL(_locals=_locals)
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl._locals["bar"] == 8
|
Loading…
Reference in New Issue
Block a user