support returning run info for llms, chat models and chains (#5666)

returning the run id is important for accessing the run later on
This commit is contained in:
Ankush Gola 2023-06-06 10:07:46 -07:00 committed by GitHub
parent 65111eb2b3
commit b177a29d3f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 74 additions and 8 deletions

View File

@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
Callbacks, Callbacks,
) )
from langchain.schema import BaseMemory from langchain.schema import RUN_KEY, BaseMemory, RunInfo
def _get_verbosity() -> bool: def _get_verbosity() -> bool:
@ -108,6 +108,8 @@ class Chain(BaseModel, ABC):
inputs: Union[Dict[str, Any], Any], inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False, return_only_outputs: bool = False,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*,
include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired. """Run the logic of this chain and add to output if desired.
@ -118,7 +120,10 @@ class Chain(BaseModel, ABC):
response. If True, only new keys generated by this chain will be response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False. chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
""" """
inputs = self.prep_inputs(inputs) inputs = self.prep_inputs(inputs)
callback_manager = CallbackManager.configure( callback_manager = CallbackManager.configure(
@ -139,13 +144,20 @@ class Chain(BaseModel, ABC):
run_manager.on_chain_error(e) run_manager.on_chain_error(e)
raise e raise e
run_manager.on_chain_end(outputs) run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs) final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
async def acall( async def acall(
self, self,
inputs: Union[Dict[str, Any], Any], inputs: Union[Dict[str, Any], Any],
return_only_outputs: bool = False, return_only_outputs: bool = False,
callbacks: Callbacks = None, callbacks: Callbacks = None,
*,
include_run_info: bool = False,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Run the logic of this chain and add to output if desired. """Run the logic of this chain and add to output if desired.
@ -156,7 +168,10 @@ class Chain(BaseModel, ABC):
response. If True, only new keys generated by this chain will be response. If True, only new keys generated by this chain will be
returned. If False, both input keys and new keys generated by this returned. If False, both input keys and new keys generated by this
chain will be returned. Defaults to False. chain will be returned. Defaults to False.
callbacks: Callbacks to use for this chain run. If not provided, will
use the callbacks provided to the chain.
include_run_info: Whether to include run info in the response. Defaults
to False.
""" """
inputs = self.prep_inputs(inputs) inputs = self.prep_inputs(inputs)
callback_manager = AsyncCallbackManager.configure( callback_manager = AsyncCallbackManager.configure(
@ -177,7 +192,12 @@ class Chain(BaseModel, ABC):
await run_manager.on_chain_error(e) await run_manager.on_chain_error(e)
raise e raise e
await run_manager.on_chain_end(outputs) await run_manager.on_chain_end(outputs)
return self.prep_outputs(inputs, outputs, return_only_outputs) final_outputs: Dict[str, Any] = self.prep_outputs(
inputs, outputs, return_only_outputs
)
if include_run_info:
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
return final_outputs
def prep_outputs( def prep_outputs(
self, self,

View File

@ -25,6 +25,7 @@ from langchain.schema import (
HumanMessage, HumanMessage,
LLMResult, LLMResult,
PromptValue, PromptValue,
RunInfo,
) )
@ -93,6 +94,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
generations = [res.generations for res in results] generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output) output = LLMResult(generations=generations, llm_output=llm_output)
run_manager.on_llm_end(output) run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output return output
async def agenerate( async def agenerate(
@ -131,6 +134,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
generations = [res.generations for res in results] generations = [res.generations for res in results]
output = LLMResult(generations=generations, llm_output=llm_output) output = LLMResult(generations=generations, llm_output=llm_output)
await run_manager.on_llm_end(output) await run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output return output
def generate_prompt( def generate_prompt(

View File

@ -25,6 +25,7 @@ from langchain.schema import (
Generation, Generation,
LLMResult, LLMResult,
PromptValue, PromptValue,
RunInfo,
get_buffer_string, get_buffer_string,
) )
@ -190,6 +191,8 @@ class BaseLLM(BaseLanguageModel, ABC):
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
raise e raise e
run_manager.on_llm_end(output) run_manager.on_llm_end(output)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output return output
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_manager = callback_manager.on_llm_start( run_manager = callback_manager.on_llm_start(
@ -210,10 +213,14 @@ class BaseLLM(BaseLanguageModel, ABC):
llm_output = update_cache( llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
) )
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
else: else:
llm_output = {} llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output) return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
async def agenerate( async def agenerate(
self, self,
@ -256,6 +263,8 @@ class BaseLLM(BaseLanguageModel, ABC):
await run_manager.on_llm_error(e, verbose=self.verbose) await run_manager.on_llm_error(e, verbose=self.verbose)
raise e raise e
await run_manager.on_llm_end(output, verbose=self.verbose) await run_manager.on_llm_end(output, verbose=self.verbose)
if run_manager:
output.run = RunInfo(run_id=run_manager.run_id)
return output return output
if len(missing_prompts) > 0: if len(missing_prompts) > 0:
run_manager = await callback_manager.on_llm_start( run_manager = await callback_manager.on_llm_start(
@ -278,10 +287,14 @@ class BaseLLM(BaseLanguageModel, ABC):
llm_output = update_cache( llm_output = update_cache(
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
) )
run_info = None
if run_manager:
run_info = RunInfo(run_id=run_manager.run_id)
else: else:
llm_output = {} llm_output = {}
run_info = None
generations = [existing_prompts[i] for i in range(len(prompts))] generations = [existing_prompts[i] for i in range(len(prompts))]
return LLMResult(generations=generations, llm_output=llm_output) return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
def __call__( def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None

View File

@ -13,9 +13,12 @@ from typing import (
TypeVar, TypeVar,
Union, Union,
) )
from uuid import UUID
from pydantic import BaseModel, Extra, Field, root_validator from pydantic import BaseModel, Extra, Field, root_validator
RUN_KEY = "__run"
def get_buffer_string( def get_buffer_string(
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
@ -156,6 +159,12 @@ class ChatGeneration(Generation):
return values return values
class RunInfo(BaseModel):
"""Class that contains all relevant metadata for a Run."""
run_id: UUID
class ChatResult(BaseModel): class ChatResult(BaseModel):
"""Class that contains all relevant information for a Chat Result.""" """Class that contains all relevant information for a Chat Result."""
@ -173,6 +182,16 @@ class LLMResult(BaseModel):
each input could have multiple generations.""" each input could have multiple generations."""
llm_output: Optional[dict] = None llm_output: Optional[dict] = None
"""For arbitrary LLM provider specific output.""" """For arbitrary LLM provider specific output."""
run: Optional[RunInfo] = None
"""Run metadata."""
def __eq__(self, other: object) -> bool:
if not isinstance(other, LLMResult):
return NotImplemented
return (
self.generations == other.generations
and self.llm_output == other.llm_output
)
class PromptValue(BaseModel, ABC): class PromptValue(BaseModel, ABC):

View File

@ -5,7 +5,7 @@ import pytest
from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.schema import BaseMemory from langchain.schema import RUN_KEY, BaseMemory
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
@ -72,6 +72,15 @@ def test_bad_outputs() -> None:
chain({"foo": "baz"}) chain({"foo": "baz"})
def test_run_info() -> None:
"""Test that run_info is returned properly when specified"""
chain = FakeChain()
output = chain({"foo": "bar"}, include_run_info=True)
assert "foo" in output
assert "bar" in output
assert RUN_KEY in output
def test_correct_call() -> None: def test_correct_call() -> None:
"""Test correct call of fake chain.""" """Test correct call of fake chain."""
chain = FakeChain() chain = FakeChain()