Compare commits

...

1 Commits

Author SHA1 Message Date
Harrison Chase 844151605c WIP logging to disk 1 year ago

@ -7,9 +7,11 @@ from pydantic import BaseModel
from langchain.agents.tools import Tool
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import ChainedInput, get_color_mapping
from langchain.input import ChainedInput
from langchain.printing import get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.logger import CONTEXT_KEY
class Action(NamedTuple):
@ -116,7 +118,7 @@ class Agent(Chain, BaseModel, ABC):
tool, tool_input = parsed_output
return Action(tool, tool_input, full_output)
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
"""Run text through and get agent response."""
text = inputs[self.input_key]
# Construct a mapping of tool name to tool for easy lookup
@ -128,7 +130,7 @@ class Agent(Chain, BaseModel, ABC):
# prompts the LLM to take an action.
starter_string = text + self.starter_string + self.llm_prefix
# We use the ChainedInput class to iteratively add to the input over time.
chained_input = ChainedInput(starter_string, verbose=self.verbose)
chained_input = ChainedInput(starter_string, inputs[CONTEXT_KEY], logger=self.logger)
# We construct a mapping from each tool to a color, used for logging.
color_mapping = get_color_mapping(
[tool.name for tool in self.tools], excluded_colors=["green"]

@ -1,9 +1,12 @@
"""Base interface that all chains should implement."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from langchain.logger import PrintLogger
import uuid
from pydantic import BaseModel, Extra
from pydantic import BaseModel, Extra, root_validator
from langchain.logger import Logger, CONTEXT_KEY
class Memory(BaseModel, ABC):
"""Base interface for memory in chains."""
@ -35,6 +38,7 @@ class Chain(BaseModel, ABC):
verbose: bool = False
"""Whether to print out response text."""
logger: Optional[Logger] = None
@property
@abstractmethod
@ -46,6 +50,19 @@ class Chain(BaseModel, ABC):
def output_keys(self) -> List[str]:
"""Output keys this chain expects."""
@root_validator()
def add_logger(cls, values: Dict) -> Dict:
"""Add a printing logger if verbose=True and none provided."""
if values["verbose"] and values["logger"] is None:
values["logger"] = PrintLogger()
return values
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
arbitrary_types_allowed = True
def _validate_inputs(self, inputs: Dict[str, str]) -> None:
"""Check that all inputs are present."""
missing_keys = set(self.input_keys).difference(inputs)
@ -76,16 +93,22 @@ class Chain(BaseModel, ABC):
chain will be returned. Defaults to False.
"""
if CONTEXT_KEY not in inputs:
inputs[CONTEXT_KEY] = {}
if "id" not in inputs[CONTEXT_KEY]:
inputs[CONTEXT_KEY]["id"] = str(uuid.uuid4())
if self.memory is not None:
external_context = self.memory.load_memory_variables(inputs)
inputs = dict(inputs, **external_context)
self._validate_inputs(inputs)
if self.verbose:
print("\n\n\033[1m> Entering new chain...\033[0m")
if self.logger:
self.logger.log_start_of_chain(inputs)
outputs = self._call(inputs)
if self.verbose:
print("\n\033[1m> Finished chain.\033[0m")
self._validate_outputs(outputs)
outputs[CONTEXT_KEY] = inputs[CONTEXT_KEY]
if self.logger:
self.logger.log_end_of_chain(outputs)
if self.memory is not None:
self.memory.save_context(inputs, outputs)
if return_only_outputs:

@ -4,9 +4,10 @@ from typing import Any, Dict, List
from pydantic import BaseModel, Extra
from langchain.chains.base import Chain
from langchain.input import print_text
from langchain.printing import print_text
from langchain.llms.base import LLM
from langchain.prompts.base import BasePromptTemplate
from langchain.logger import CONTEXT_KEY
class LLMChain(Chain, BaseModel):
@ -54,9 +55,9 @@ class LLMChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
selected_inputs = {k: inputs[k] for k in self.prompt.input_variables}
prompt = self.prompt.format(**selected_inputs)
if self.verbose:
print("Prompt after formatting:")
print_text(prompt, color="green", end="\n")
if self.logger:
title="Prompt after formatting:"
self.logger.log(prompt, inputs[CONTEXT_KEY],title=title, color="green", end="\n")
kwargs = {}
if "stop" in inputs:
kwargs["stop"] = inputs["stop"]

@ -1,5 +1,5 @@
"""Chain that interprets a prompt and executes python code to do math."""
from typing import Dict, List
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
@ -9,6 +9,7 @@ from langchain.chains.llm_math.prompt import PROMPT
from langchain.chains.python import PythonChain
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.logger import CONTEXT_KEY
class LLMMathChain(Chain, BaseModel):
@ -48,10 +49,10 @@ class LLMMathChain(Chain, BaseModel):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_executor = LLMChain(prompt=PROMPT, llm=self.llm)
python_executor = PythonChain()
chained_input = ChainedInput(inputs[self.input_key], verbose=self.verbose)
chained_input = ChainedInput(inputs[self.input_key], inputs[CONTEXT_KEY], logger=self.logger)
t = llm_executor.predict(question=chained_input.input, stop=["```output"])
chained_input.add(t, color="green")
t = t.strip()

@ -5,7 +5,6 @@ from typing import Dict, List
from pydantic import BaseModel, Extra, root_validator
from langchain.chains.base import Chain
from langchain.input import get_color_mapping, print_text
class SequentialChain(Chain, BaseModel):
@ -127,11 +126,8 @@ class SimpleSequentialChain(Chain, BaseModel):
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
_input = inputs[self.input_key]
color_mapping = get_color_mapping([str(i) for i in range(len(self.chains))])
for i, chain in enumerate(self.chains):
_input = chain.run(_input)
if self.strip_outputs:
_input = _input.strip()
if self.verbose:
print_text(_input, color=color_mapping[str(i)], end="\n")
return {self.output_key: _input}

@ -1,5 +1,5 @@
"""Chain for interacting with SQL Database."""
from typing import Dict, List
from typing import Dict, List, Any
from pydantic import BaseModel, Extra
@ -9,6 +9,7 @@ from langchain.chains.sql_database.prompt import PROMPT
from langchain.input import ChainedInput
from langchain.llms.base import LLM
from langchain.sql_database import SQLDatabase
from langchain.logger import CONTEXT_KEY
class SQLDatabaseChain(Chain, BaseModel):
@ -51,10 +52,10 @@ class SQLDatabaseChain(Chain, BaseModel):
"""
return [self.output_key]
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
llm_chain = LLMChain(llm=self.llm, prompt=PROMPT)
chained_input = ChainedInput(
inputs[self.input_key] + "\nSQLQuery:", verbose=self.verbose
inputs[self.input_key] + "\nSQLQuery:", inputs[CONTEXT_KEY], logger=self.logger
)
llm_inputs = {
"input": chained_input.input,

@ -1,48 +1,24 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional
from typing import Optional
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""Print text with highlighting and no end characters."""
if color is None:
print(text, end=end)
else:
color_str = _TEXT_COLOR_MAPPING[color]
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
from langchain.logger import Logger
class ChainedInput:
"""Class for working with input that is the result of chains."""
def __init__(self, text: str, verbose: bool = False):
def __init__(self, text: str, context: dict, logger: Optional[Logger] = None):
"""Initialize with verbose flag and initial text."""
self._verbose = verbose
if self._verbose:
print_text(text, None)
self._logger = logger
if self._logger:
self._logger.log(text, context)
self._input = text
self._context = context
def add(self, text: str, color: Optional[str] = None) -> None:
"""Add text to input, print if in verbose mode."""
if self._verbose:
print_text(text, color)
if self._logger:
self._logger.log(text, self._context, color=color)
self._input += text
@property

@ -0,0 +1,70 @@
from abc import ABC, abstractmethod
from typing import Optional, Any
from langchain.printing import print_text
from pathlib import Path
CONTEXT_KEY = "__context__"
class Logger(ABC):
@abstractmethod
def log_start_of_chain(self, inputs):
""""""
@abstractmethod
def log_end_of_chain(self, outputs):
""""""
@abstractmethod
def log(self, text: str, context: dict, **kwargs):
""""""
class PrintLogger(Logger):
def log_start_of_chain(self, inputs):
""""""
print("\n\n\033[1m> Entering new chain...\033[0m")
def log_end_of_chain(self, outputs):
""""""
print("\n\033[1m> Finished chain.\033[0m")
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
""""""
if title is not None:
print(title)
print_text(text, **kwargs)
import json
class JSONLogger(Logger):
def __init__(self, log_dir):
self.log_dir = Path(log_dir)
self.log_dir.mkdir(exist_ok=True)
def log_start_of_chain(self, inputs):
""""""
fname = self.log_dir / f"{inputs[CONTEXT_KEY]['id']}.json"
if not fname.exists():
with open(fname, 'w') as f:
json.dump([], f)
def log_end_of_chain(self, outputs):
""""""
fname = self.log_dir / f"{outputs[CONTEXT_KEY]['id']}.json"
with open(fname) as f:
logs = json.load(f)
logs.append(outputs)
with open(fname, 'w') as f:
json.dump(logs, f)
def log(self, text: str, context: dict, title: Optional[str ] =None ,**kwargs:Any):
""""""
fname = self.log_dir / f"{context['id']}.json"
with open(fname) as f:
logs = json.load(f)
logs.append({"text": text, "title": title})
with open(fname, 'w') as f:
json.dump(logs, f)

@ -4,7 +4,7 @@ from typing import List, Optional, Sequence, Union
from langchain.agents.agent import Agent
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.printing import print_text, get_color_mapping
from langchain.llms.base import LLM
from langchain.prompts.prompt import PromptTemplate

@ -0,0 +1,29 @@
from typing import Optional, List, Dict
def print_text(text: str, color: Optional[str] = None, end: str = "") -> None:
"""Print text with highlighting and no end characters."""
if color is None:
print(text, end=end)
else:
color_str = _TEXT_COLOR_MAPPING[color]
print(f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m", end=end)
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
}

@ -3,7 +3,8 @@
import sys
from io import StringIO
from langchain.input import ChainedInput, get_color_mapping
from langchain.input import ChainedInput
from langchain.printing import get_color_mapping
def test_chained_input_not_verbose() -> None:

Loading…
Cancel
Save