mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
python repl improvement for csv agent (#9618)
This commit is contained in:
parent
632a83c48e
commit
02545a54b3
@ -6,13 +6,13 @@ import re
|
||||
import sys
|
||||
from contextlib import redirect_stdout
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForToolRun,
|
||||
CallbackManagerForToolRun,
|
||||
)
|
||||
from langchain.pydantic_v1 import Field, root_validator
|
||||
from langchain.pydantic_v1 import BaseModel, Field, root_validator
|
||||
from langchain.tools.base import BaseTool
|
||||
from langchain.utilities import PythonREPL
|
||||
|
||||
@ -77,6 +77,10 @@ class PythonREPLTool(BaseTool):
|
||||
return result
|
||||
|
||||
|
||||
class PythonInputs(BaseModel):
|
||||
query: str = Field(description="code snippet to run")
|
||||
|
||||
|
||||
class PythonAstREPLTool(BaseTool):
|
||||
"""A tool for running python code in a REPL."""
|
||||
|
||||
@ -90,6 +94,7 @@ class PythonAstREPLTool(BaseTool):
|
||||
globals: Optional[Dict] = Field(default_factory=dict)
|
||||
locals: Optional[Dict] = Field(default_factory=dict)
|
||||
sanitize_input: bool = True
|
||||
args_schema: Type[BaseModel] = PythonInputs
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_python_version(cls, values: Dict) -> Dict:
|
||||
|
Loading…
Reference in New Issue
Block a user