forked from Archives/langchain
Compare commits
1 Commits
main
...
harrison/l
Author | SHA1 | Date | |
---|---|---|---|
|
844151605c |
@ -7,9 +7,11 @@ from pydantic import BaseModel
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.printing import get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class Action(NamedTuple):
|
||||
@ -116,7 +118,7 @@ class Agent(Chain, BaseModel, ABC):
|
||||
tool, tool_input = parsed_output
|
||||
return Action(tool, tool_input, full_output)
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
"""Run text through and get agent response."""
|
||||
text = inputs[self.input_key]
|
||||
# Construct a mapping of tool name to tool for easy lookup
|
||||
@ -128,7 +130,7 @@ class Agent(Chain, BaseModel, ABC):
|
||||
# prompts the LLM to take an action.
|
||||
starter_string = text + self.starter_string + self.llm_prefix
|
||||
# We use the ChainedInput class to iteratively add to the input over time.
|
||||
chained_input = ChainedInput(starter_string, verbose=self.verbose)
|
||||
chained_input = ChainedInput(starter_string, inputs[CONTEXT_KEY], logger=self.logger)
|
||||
# We construct a mapping from each tool to a color, used for logging.
|
||||
color_mapping = get_color_mapping(
|
||||
[tool.name for tool in self.tools], excluded_colors=["green"]
|
||||
|
@ -1,9 +1,12 @@
|
||||
"""Base interface that all chains should implement."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
from langchain.logger import PrintLogger
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.logger import Logger, CONTEXT_KEY
|
||||
|
||||
class Memory(BaseModel, ABC):
|
||||
"""Base interface for memory in chains."""
|
||||
@ -35,6 +38,7 @@ class Chain(BaseModel, ABC):
|
||||
|
||||
verbose: bool = False
|
||||
"""Whether to print out response text."""
|
||||
logger: Optional[Logger] = None
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -46,6 +50,19 @@ class Chain(BaseModel, ABC):
|
||||
def output_keys(self) -> List[str]:
|
||||
"""Output keys this chain expects."""
|
||||
|
||||
@root_validator()
|
||||
def add_logger(cls, values: Dict) -> Dict:
|
||||
"""Add a printing logger if verbose=True and none provided."""
|
||||
if values["verbose"] and values["logger"] is None:
|
||||
values["logger"] = PrintLogger()
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
|
||||
"""Check that all inputs are present."""
|
||||
missing_keys = set(self.input_keys).difference(inputs)
|
||||
@ -76,16 +93,22 @@ class Chain(BaseModel, ABC):
|
||||
chain will be returned. Defaults to False.
|
||||
|
||||
"""
|
||||
if CONTEXT_KEY not in inputs:
|
||||
inputs[CONTEXT_KEY] = {}
|
||||
if "id" not in inputs[CONTEXT_KEY]:
|
||||
inputs[CONTEXT_KEY]["id"] = str(uuid.uuid4())
|
||||
|
||||
if self.memory is not None:
|
||||
external_context = self.memory.load_memory_variables(inputs)
|
||||
inputs = dict(inputs, **external_context)
|
||||
self._validate_inputs(inputs)
|
||||
if self.verbose:
|
||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||
if self.logger:
|
||||
self.logger.log_start_of_chain(inputs)
|
||||
outputs = self._call(inputs)
|
||||
if self.verbose:
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
self._validate_outputs(outputs)
|
||||
outputs[CONTEXT_KEY] = inputs[CONTEXT_KEY]
|
||||
if self.logger:
|
||||
self.logger.log_end_of_chain(outputs)
|
||||
if self.memory is not None:
|
||||
self.memory.save_context(inputs, outputs)
|
||||
if return_only_outputs:
|
||||
|
@ -4,9 +4,10 @@ from typing import Any, Dict, List
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import print_text
|
||||
from langchain.printing import print_text
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.base import BasePromptTemplate
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class LLMChain(Chain, BaseModel):
|
||||
@ -54,9 +55,9 @@ class LLMChain(Chain, BaseModel):
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
|
||||
prompt = self.prompt.format(**selected_inputs)
|
||||
if self.verbose:
|
||||
print("Prompt after formatting:")
|
||||
print_text(prompt, color="green", end="\n")
|
||||
if self.logger:
|
||||
title="Prompt after formatting:"
|
||||
self.logger.log(prompt, inputs[CONTEXT_KEY],title=title, color="green", end="\n")
|
||||
kwargs = {}
|
||||
if "stop" in inputs:
|
||||
kwargs["stop"] = inputs["stop"]
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Chain that interprets a prompt and executes python code to do math."""
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
@ -9,6 +9,7 @@ from langchain.chains.llm_math.prompt import PROMPT
|
||||
from langchain.chains.python import PythonChain
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class LLMMathChain(Chain, BaseModel):
|
||||
@ -48,10 +49,10 @@ class LLMMathChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
|
||||
python_executor = PythonChain()
|
||||
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
|
||||
chained_input = ChainedInput(inputs[self.input_key], inputs[CONTEXT_KEY], logger=self.logger)
|
||||
t = llm_executor.predict(question=chained_input.input, stop=["```output"])
|
||||
chained_input.add(t, color="green")
|
||||
t = t.strip()
|
||||
|
@ -5,7 +5,6 @@ from typing import Dict, List
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
|
||||
|
||||
class SequentialChain(Chain, BaseModel):
|
||||
@ -127,11 +126,8 @@ class SimpleSequentialChain(Chain, BaseModel):
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
_input = inputs[self.input_key]
|
||||
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
|
||||
for i, chain in enumerate(self.chains):
|
||||
_input = chain.run(_input)
|
||||
if self.strip_outputs:
|
||||
_input = _input.strip()
|
||||
if self.verbose:
|
||||
print_text(_input, color=color_mapping[str(i)], end="\n")
|
||||
return {self.output_key: _input}
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Chain for interacting with SQL Database."""
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Any
|
||||
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
@ -9,6 +9,7 @@ from langchain.chains.sql_database.prompt import PROMPT
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.sql_database import SQLDatabase
|
||||
from langchain.logger import CONTEXT_KEY
|
||||
|
||||
|
||||
class SQLDatabaseChain(Chain, BaseModel):
|
||||
@ -51,10 +52,10 @@ class SQLDatabaseChain(Chain, BaseModel):
|
||||
"""
|
||||
return [self.output_key]
|
||||
|
||||
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
|
||||
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
|
||||
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
|
||||
chained_input = ChainedInput(
|
||||
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
|
||||
inputs[self.input_key] + "\nSQLQuery:", inputs[CONTEXT_KEY], logger=self.logger
|
||||
)
|
||||
llm_inputs = {
|
||||
"input": chained_input.input,
|
||||
|
@ -1,48 +1,24 @@
|
||||
"""Handle chained inputs."""
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
}
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
if color is None:
|
||||
print(text, end=end)
|
||||
else:
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
|
||||
from langchain.logger import Logger
|
||||
|
||||
|
||||
class ChainedInput:
|
||||
"""Class for working with input that is the result of chains."""
|
||||
|
||||
def __init__(self, text: str, verbose: bool = False):
|
||||
def __init__(self, text: str, context: dict, logger: Optional[Logger] = None):
|
||||
"""Initialize with verbose flag and initial text."""
|
||||
self._verbose = verbose
|
||||
if self._verbose:
|
||||
print_text(text, None)
|
||||
self._logger = logger
|
||||
if self._logger:
|
||||
self._logger.log(text, context)
|
||||
self._input = text
|
||||
self._context = context
|
||||
|
||||
def add(self, text: str, color: Optional[str] = None) -> None:
|
||||
"""Add text to input, print if in verbose mode."""
|
||||
if self._verbose:
|
||||
print_text(text, color)
|
||||
if self._logger:
|
||||
self._logger.log(text, self._context, color=color)
|
||||
self._input += text
|
||||
|
||||
@property
|
||||
|
70
langchain/logger.py
Normal file
70
langchain/logger.py
Normal file
@ -0,0 +1,70 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Any
|
||||
from langchain.printing import print_text
|
||||
from pathlib import Path
|
||||
|
||||
CONTEXT_KEY = "__context__"
|
||||
|
||||
|
||||
class Logger(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def log_start_of_chain(self, inputs):
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def log_end_of_chain(self, outputs):
|
||||
""""""
|
||||
|
||||
@abstractmethod
|
||||
def log(self, text: str, context: dict, **kwargs):
|
||||
""""""
|
||||
|
||||
|
||||
class PrintLogger(Logger):
|
||||
def log_start_of_chain(self, inputs):
|
||||
""""""
|
||||
print("\n\n\033[1m> Entering new chain...\033[0m")
|
||||
|
||||
def log_end_of_chain(self, outputs):
|
||||
""""""
|
||||
print("\n\033[1m> Finished chain.\033[0m")
|
||||
|
||||
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
|
||||
""""""
|
||||
if title is not None:
|
||||
print(title)
|
||||
print_text(text, **kwargs)
|
||||
|
||||
import json
|
||||
class JSONLogger(Logger):
|
||||
|
||||
def __init__(self, log_dir):
|
||||
self.log_dir = Path(log_dir)
|
||||
self.log_dir.mkdir(exist_ok=True)
|
||||
|
||||
def log_start_of_chain(self, inputs):
|
||||
""""""
|
||||
fname = self.log_dir / f"{inputs[CONTEXT_KEY]['id']}.json"
|
||||
if not fname.exists():
|
||||
with open(fname, 'w') as f:
|
||||
json.dump([], f)
|
||||
|
||||
def log_end_of_chain(self, outputs):
|
||||
""""""
|
||||
fname = self.log_dir / f"{outputs[CONTEXT_KEY]['id']}.json"
|
||||
with open(fname) as f:
|
||||
logs = json.load(f)
|
||||
logs.append(outputs)
|
||||
with open(fname, 'w') as f:
|
||||
json.dump(logs, f)
|
||||
|
||||
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
|
||||
""""""
|
||||
fname = self.log_dir / f"{context['id']}.json"
|
||||
with open(fname) as f:
|
||||
logs = json.load(f)
|
||||
logs.append({"text": text, "title": title})
|
||||
with open(fname, 'w') as f:
|
||||
json.dump(logs, f)
|
||||
|
@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Union
|
||||
from langchain.agents.agent import Agent
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.input import get_color_mapping, print_text
|
||||
from langchain.printing import print_text, get_color_mapping
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
29
langchain/printing.py
Normal file
29
langchain/printing.py
Normal file
@ -0,0 +1,29 @@
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
|
||||
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
|
||||
"""Print text with highlighting and no end characters."""
|
||||
if color is None:
|
||||
print(text, end=end)
|
||||
else:
|
||||
color_str = _TEXT_COLOR_MAPPING[color]
|
||||
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
|
||||
|
||||
|
||||
def get_color_mapping(
|
||||
items: List[str], excluded_colors: Optional[List] = None
|
||||
) -> Dict[str, str]:
|
||||
"""Get mapping for items to a support color."""
|
||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||
if excluded_colors is not None:
|
||||
colors = [c for c in colors if c not in excluded_colors]
|
||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||
return color_mapping
|
||||
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
"blue": "36;1",
|
||||
"yellow": "33;1",
|
||||
"pink": "38;5;200",
|
||||
"green": "32;1",
|
||||
}
|
@ -3,7 +3,8 @@
|
||||
import sys
|
||||
from io import StringIO
|
||||
|
||||
from langchain.input import ChainedInput, get_color_mapping
|
||||
from langchain.input import ChainedInput
|
||||
from langchain.printing import get_color_mapping
|
||||
|
||||
|
||||
def test_chained_input_not_verbose() -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user