support returning run info for llms, chat models and chains (#5666)

returning the run id is important for accessing the run later on
This commit is contained in:
Ankush Gola 2023-06-06 10:07:46 -07:00 committed by GitHub
parent 65111eb2b3
commit b177a29d3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 8 deletions

View File

@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.schema import BaseMemory
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
def _get_verbosity() -> bool:
@ -108,6 +108,8 @@ class Chain(BaseModel, ABC):
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -118,7 +120,10 @@ class Chain(BaseModel, ABC):
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure(
@ -139,13 +144,20 @@ class Chain(BaseModel, ABC):
run_manager.on_chain_error(e)
raise e
run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def acall(
self,
inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False,
callbacks: Callbacks = None,
*,
include_run_info: bool = False,
) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired.
@ -156,7 +168,10 @@ class Chain(BaseModel, ABC):
response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
"""
inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure(
@ -177,7 +192,12 @@ class Chain(BaseModel, ABC):
await run_manager.on_chain_error(e)
raise e
await run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs)
final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs(
self,

View File

@ -25,6 +25,7 @@ from langchain.schema import (
HumanMessage,
LLMResult,
PromptValue,
RunInfo,
)
@ -93,6 +94,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output
async def agenerate(
@ -131,6 +134,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output
def generate_prompt(

View File

@ -25,6 +25,7 @@ from langchain.schema import (
Generation,
LLMResult,
PromptValue,
RunInfo,
get_buffer_string,
)
@ -190,6 +191,8 @@ class BaseLLM(BaseLanguageModel, ABC):
run_manager.on_llm_error(e)
raise e
run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output
if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start(
@ -210,10 +213,14 @@ class BaseLLM(BaseLanguageModel, ABC):
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
async def agenerate(
self,
@ -256,6 +263,8 @@ class BaseLLM(BaseLanguageModel, ABC):
await run_manager.on_llm_error(e, verbose=self.verbose)
raise e
await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output
if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start(
@ -278,10 +287,14 @@ class BaseLLM(BaseLanguageModel, ABC):
llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
)
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
else:
llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output)
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None

View File

@ -13,9 +13,12 @@ from typing import (
TypeVar,
Union,
)
from uuid import UUID
from pydantic import BaseModel, Extra, Field, root_validator
RUN_KEY = "__run"
def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
@ -156,6 +159,12 @@ class ChatGeneration(Generation):
return values
class RunInfo(BaseModel):
"""Class that contains all relevant metadata for a Run."""
run_id: UUID
class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result."""
@ -173,6 +182,16 @@ class LLMResult(BaseModel):
each input could have multiple generations."""
llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output."""
run: Optional[RunInfo] = None
"""Run metadata."""
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)
class PromptValue(BaseModel, ABC):

View File

@ -5,7 +5,7 @@ import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.schema import BaseMemory
from langchain.schema import RUN_KEY, BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -72,6 +72,15 @@ def test_bad_outputs() -> None:
chain({"foo": "baz"})
def test_run_info() -> None:
"""Test that run_info is returned properly when specified"""
chain = FakeChain()
output = chain({"foo": "bar"}, include_run_info=True)
assert "foo" in output
assert "bar" in output
assert RUN_KEY in output
def test_correct_call() -> None:
"""Test correct call of fake chain."""
chain = FakeChain()