combine python files (#256)

This commit is contained in:
Harrison Chase 2022-12-04 15:57:36 -08:00 committed by GitHub
parent 98fb19b535
commit f5c665a544
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 21 additions and 76 deletions

View File

@ -11,7 +11,6 @@ from langchain.chains import (
LLMChain,
LLMMathChain,
PALChain,
PythonChain,
QAWithSourcesChain,
SQLDatabaseChain,
VectorDBQA,
@ -32,7 +31,6 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
__all__ = [
"LLMChain",
"LLMMathChain",
"PythonChain",
"SelfAskWithSearchChain",
"SerpAPIWrapper",
"SerpAPIChain",

View File

@ -4,7 +4,6 @@ from langchain.chains.conversation.base import ConversationChain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.base import LLMMathChain
from langchain.chains.pal.base import PALChain
from langchain.chains.python import PythonChain
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
@ -14,7 +13,6 @@ from langchain.chains.vector_db_qa.base import VectorDBQA
__all__ = [
"LLMChain",
"LLMMathChain",
"PythonChain",
"SQLDatabaseChain",
"VectorDBQA",
"SequentialChain",

View File

@ -6,9 +6,9 @@ from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.llm_math.prompt import PROMPT
from langchain.chains.python import PythonChain
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.python import PythonREPL
class LLMMathChain(Chain, BaseModel):
@ -50,7 +50,7 @@ class LLMMathChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonChain()
python_executor = PythonREPL()
if self.verbose:
print_text(inputs[self.input_key])
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])

View File

@ -12,10 +12,10 @@ from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
from langchain.chains.pal.math_prompt import MATH_PROMPT
from langchain.chains.python import PythonChain
from langchain.input import print_text
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.python import PythonREPL
class PALChain(Chain, BaseModel):
@ -54,7 +54,7 @@ class PALChain(Chain, BaseModel):
code = llm_chain.predict(stop=[self.stop], **inputs)
if self.verbose:
print_text(code, color="green", end="\n")
repl = PythonChain()
repl = PythonREPL()
res = repl.run(code + f"\n{self.get_answer_expr}")
return {self.output_key: res.strip()}

View File

@ -1,51 +0,0 @@
"""Chain that runs python code.
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
"""
import sys
from io import StringIO
from typing import Dict, List
from pydantic import BaseModel
from langchain.chains.base import Chain
from langchain.python import PythonREPL
class PythonChain(Chain, BaseModel):
"""Chain to run python code.
Example:
.. code-block:: python
from langchain import PythonChain
python_chain = PythonChain()
"""
input_key: str = "code" #: :meta private:
output_key: str = "output" #: :meta private:
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
python_repl = PythonREPL()
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
python_repl.run(inputs[self.input_key])
sys.stdout = old_stdout
output = mystdout.getvalue()
return {self.output_key: output}

View File

@ -1,4 +1,6 @@
"""Mock Python REPL."""
import sys
from io import StringIO
from typing import Dict, Optional
@ -10,6 +12,11 @@ class PythonREPL:
self._globals = _globals if _globals is not None else {}
self._locals = _locals if _locals is not None else {}
def run(self, command: str) -> None:
"""Run command with own globals/locals."""
def run(self, command: str) -> str:
"""Run command with own globals/locals and returns anything printed."""
old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
exec(command, self._globals, self._locals)
sys.stdout = old_stdout
output = mystdout.getvalue()
return output

View File

@ -1,15 +0,0 @@
"""Test python chain."""
from langchain.chains.python import PythonChain
def test_functionality() -> None:
"""Test correct functionality."""
chain = PythonChain(input_key="code1", output_key="output1")
code = "print(1 + 1)"
output = chain({"code1": code})
assert output == {"code1": code, "output1": "2\n"}
# Test with the more user-friendly interface.
simple_output = chain.run(code)
assert simple_output == "2\n"

View File

@ -32,3 +32,11 @@ def test_python_repl_pass_in_locals() -> None:
repl = PythonREPL(_locals=_locals)
repl.run("bar = foo * 2")
assert repl._locals["bar"] == 8
def test_functionality() -> None:
"""Test correct functionality."""
chain = PythonREPL()
code = "print(1 + 1)"
output = chain.run(code)
assert output == "2\n"