diff --git a/langchain/chains/base.py b/langchain/chains/base.py index b10a87dc..2db63a8f 100644 --- a/langchain/chains/base.py +++ b/langchain/chains/base.py @@ -18,7 +18,7 @@ from langchain.callbacks.manager import ( CallbackManagerForChainRun, Callbacks, ) -from langchain.schema import BaseMemory +from langchain.schema import RUN_KEY, BaseMemory, RunInfo def _get_verbosity() -> bool: @@ -108,6 +108,8 @@ class Chain(BaseModel, ABC): inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, + *, + include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -118,7 +120,10 @@ class Chain(BaseModel, ABC): response. If True, only new keys generated by this chain will be returned. If False, both input keys and new keys generated by this chain will be returned. Defaults to False. - + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. """ inputs = self.prep_inputs(inputs) callback_manager = CallbackManager.configure( @@ -139,13 +144,20 @@ class Chain(BaseModel, ABC): run_manager.on_chain_error(e) raise e run_manager.on_chain_end(outputs) - return self.prep_outputs(inputs, outputs, return_only_outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs async def acall( self, inputs: Union[Dict[str, Any], Any], return_only_outputs: bool = False, callbacks: Callbacks = None, + *, + include_run_info: bool = False, ) -> Dict[str, Any]: """Run the logic of this chain and add to output if desired. @@ -156,7 +168,10 @@ class Chain(BaseModel, ABC): response. If True, only new keys generated by this chain will be returned. If False, both input keys and new keys generated by this chain will be returned. Defaults to False. - + callbacks: Callbacks to use for this chain run. If not provided, will + use the callbacks provided to the chain. + include_run_info: Whether to include run info in the response. Defaults + to False. """ inputs = self.prep_inputs(inputs) callback_manager = AsyncCallbackManager.configure( @@ -177,7 +192,12 @@ class Chain(BaseModel, ABC): await run_manager.on_chain_error(e) raise e await run_manager.on_chain_end(outputs) - return self.prep_outputs(inputs, outputs, return_only_outputs) + final_outputs: Dict[str, Any] = self.prep_outputs( + inputs, outputs, return_only_outputs + ) + if include_run_info: + final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id) + return final_outputs def prep_outputs( self, diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index de2cdd06..dcb4ebeb 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -25,6 +25,7 @@ from langchain.schema import ( HumanMessage, LLMResult, PromptValue, + RunInfo, ) @@ -93,6 +94,8 @@ class BaseChatModel(BaseLanguageModel, ABC): generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) run_manager.on_llm_end(output) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output async def agenerate( @@ -131,6 +134,8 @@ class BaseChatModel(BaseLanguageModel, ABC): generations = [res.generations for res in results] output = LLMResult(generations=generations, llm_output=llm_output) await run_manager.on_llm_end(output) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output def generate_prompt( diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 21267a96..84ba2c5c 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -25,6 +25,7 @@ from langchain.schema import ( Generation, LLMResult, PromptValue, + RunInfo, get_buffer_string, ) @@ -190,6 +191,8 @@ class BaseLLM(BaseLanguageModel, ABC): run_manager.on_llm_error(e) raise e run_manager.on_llm_end(output) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output if len(missing_prompts) > 0: run_manager = callback_manager.on_llm_start( @@ -210,10 +213,14 @@ class BaseLLM(BaseLanguageModel, ABC): llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) + run_info = None + if run_manager: + run_info = RunInfo(run_id=run_manager.run_id) else: llm_output = {} + run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] - return LLMResult(generations=generations, llm_output=llm_output) + return LLMResult(generations=generations, llm_output=llm_output, run=run_info) async def agenerate( self, @@ -256,6 +263,8 @@ class BaseLLM(BaseLanguageModel, ABC): await run_manager.on_llm_error(e, verbose=self.verbose) raise e await run_manager.on_llm_end(output, verbose=self.verbose) + if run_manager: + output.run = RunInfo(run_id=run_manager.run_id) return output if len(missing_prompts) > 0: run_manager = await callback_manager.on_llm_start( @@ -278,10 +287,14 @@ class BaseLLM(BaseLanguageModel, ABC): llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) + run_info = None + if run_manager: + run_info = RunInfo(run_id=run_manager.run_id) else: llm_output = {} + run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] - return LLMResult(generations=generations, llm_output=llm_output) + return LLMResult(generations=generations, llm_output=llm_output, run=run_info) def __call__( self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None diff --git a/langchain/schema.py b/langchain/schema.py index 4a04bd04..b74b40a7 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -13,9 +13,12 @@ from typing import ( TypeVar, Union, ) +from uuid import UUID from pydantic import BaseModel, Extra, Field, root_validator +RUN_KEY = "__run" + def get_buffer_string( messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI" @@ -156,6 +159,12 @@ class ChatGeneration(Generation): return values +class RunInfo(BaseModel): + """Class that contains all relevant metadata for a Run.""" + + run_id: UUID + + class ChatResult(BaseModel): """Class that contains all relevant information for a Chat Result.""" @@ -173,6 +182,16 @@ class LLMResult(BaseModel): each input could have multiple generations.""" llm_output: Optional[dict] = None """For arbitrary LLM provider specific output.""" + run: Optional[RunInfo] = None + """Run metadata.""" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LLMResult): + return NotImplemented + return ( + self.generations == other.generations + and self.llm_output == other.llm_output + ) class PromptValue(BaseModel, ABC): diff --git a/tests/unit_tests/chains/test_base.py b/tests/unit_tests/chains/test_base.py index 1e5022b8..d60e06a8 100644 --- a/tests/unit_tests/chains/test_base.py +++ b/tests/unit_tests/chains/test_base.py @@ -5,7 +5,7 @@ import pytest from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains.base import Chain -from langchain.schema import BaseMemory +from langchain.schema import RUN_KEY, BaseMemory from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler @@ -72,6 +72,15 @@ def test_bad_outputs() -> None: chain({"foo": "baz"}) +def test_run_info() -> None: + """Test that run_info is returned properly when specified""" + chain = FakeChain() + output = chain({"foo": "bar"}, include_run_info=True) + assert "foo" in output + assert "bar" in output + assert RUN_KEY in output + + def test_correct_call() -> None: """Test correct call of fake chain.""" chain = FakeChain()