forked from Archives/langchain
combine python files (#256)
This commit is contained in:
parent
98fb19b535
commit
f5c665a544
@ -11,7 +11,6 @@ from langchain.chains import (
|
||||
LLMChain,
|
||||
LLMMathChain,
|
||||
PALChain,
|
||||
PythonChain,
|
||||
QAWithSourcesChain,
|
||||
SQLDatabaseChain,
|
||||
VectorDBQA,
|
||||
@ -32,7 +31,6 @@ from langchain.vectorstores import FAISS, ElasticVectorSearch
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SelfAskWithSearchChain",
|
||||
"SerpAPIWrapper",
|
||||
"SerpAPIChain",
|
||||
|
@ -4,7 +4,6 @@ from langchain.chains.conversation.base import ConversationChain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.base import LLMMathChain
|
||||
from langchain.chains.pal.base import PALChain
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.chains.qa_with_sources.base import QAWithSourcesChain
|
||||
from langchain.chains.qa_with_sources.vector_db import VectorDBQAWithSourcesChain
|
||||
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
|
||||
@ -14,7 +13,6 @@ from langchain.chains.vector_db_qa.base import VectorDBQA
|
||||
__all__ = [
|
||||
"LLMChain",
|
||||
"LLMMathChain",
|
||||
"PythonChain",
|
||||
"SQLDatabaseChain",
|
||||
"VectorDBQA",
|
||||
"SequentialChain",
|
||||
|
@ -6,9 +6,9 @@ from pydantic import BaseModel, Extra
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
|
||||
class LLMMathChain(Chain, BaseModel):
|
||||
@ -50,7 +50,7 @@ class LLMMathChain(Chain, BaseModel):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
python_executor = PythonChain()
|
||||
python_executor = PythonREPL()
|
||||
if self.verbose:
|
||||
print_text(inputs[self.input_key])
|
||||
t = llm_executor.predict(question=inputs[self.input_key], stop=["```output"])
|
||||
|
@ -12,10 +12,10 @@ from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chains.pal.colored_object_prompt import COLORED_OBJECT_PROMPT
|
||||
from langchain.chains.pal.math_prompt import MATH_PROMPT
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.input import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
|
||||
class PALChain(Chain, BaseModel):
|
||||
@ -54,7 +54,7 @@ class PALChain(Chain, BaseModel):
|
||||
code = llm_chain.predict(stop=[self.stop], **inputs)
|
||||
if self.verbose:
|
||||
print_text(code, color="green", end="\n")
|
||||
repl = PythonChain()
|
||||
repl = PythonREPL()
|
||||
res = repl.run(code + f"\n{self.get_answer_expr}")
|
||||
return {self.output_key: res.strip()}
|
||||
|
||||
|
@ -1,51 +0,0 @@
|
||||
"""Chain that runs python code.
|
||||
|
||||
Heavily borrowed from https://replit.com/@amasad/gptpy?v=1#main.py
|
||||
"""
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Dict, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.python import PythonREPL
|
||||
|
||||
|
||||
class PythonChain(Chain, BaseModel):
|
||||
"""Chain to run python code.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain import PythonChain
|
||||
python_chain = PythonChain()
|
||||
"""
|
||||
|
||||
input_key: str = "code" #: :meta private:
|
||||
output_key: str = "output" #: :meta private:
|
||||
|
||||
@property
|
||||
def input_keys(self) -> List[str]:
|
||||
"""Expect input key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.input_key]
|
||||
|
||||
@property
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Return output key.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
python_repl = PythonREPL()
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
python_repl.run(inputs[self.input_key])
|
||||
sys.stdout = old_stdout
|
||||
output = mystdout.getvalue()
|
||||
return {self.output_key: output}
|
@ -1,4 +1,6 @@
|
||||
"""Mock Python REPL."""
|
||||
import sys
|
||||
from io import StringIO
|
||||
from typing import Dict, Optional
|
||||
|
||||
|
||||
@ -10,6 +12,11 @@ class PythonREPL:
|
||||
self._globals = _globals if _globals is not None else {}
|
||||
self._locals = _locals if _locals is not None else {}
|
||||
|
||||
def run(self, command: str) -> None:
|
||||
"""Run command with own globals/locals."""
|
||||
def run(self, command: str) -> str:
|
||||
"""Run command with own globals/locals and returns anything printed."""
|
||||
old_stdout = sys.stdout
|
||||
sys.stdout = mystdout = StringIO()
|
||||
exec(command, self._globals, self._locals)
|
||||
sys.stdout = old_stdout
|
||||
output = mystdout.getvalue()
|
||||
return output
|
||||
|
@ -1,15 +0,0 @@
|
||||
"""Test python chain."""
|
||||
|
||||
from langchain.chains.python import PythonChain
|
||||
|
||||
|
||||
def test_functionality() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonChain(input_key="code1", output_key="output1")
|
||||
code = "print(1 + 1)"
|
||||
output = chain({"code1": code})
|
||||
assert output == {"code1": code, "output1": "2\n"}
|
||||
|
||||
# Test with the more user-friendly interface.
|
||||
simple_output = chain.run(code)
|
||||
assert simple_output == "2\n"
|
@ -32,3 +32,11 @@ def test_python_repl_pass_in_locals() -> None:
|
||||
repl = PythonREPL(_locals=_locals)
|
||||
repl.run("bar = foo * 2")
|
||||
assert repl._locals["bar"] == 8
|
||||
|
||||
|
||||
def test_functionality() -> None:
|
||||
"""Test correct functionality."""
|
||||
chain = PythonREPL()
|
||||
code = "print(1 + 1)"
|
||||
output = chain.run(code)
|
||||
assert output == "2\n"
|
||||
|
Loading…
Reference in New Issue
Block a user