forked from Archives/langchain
support returning run info for llms, chat models and chains (#5666)
returning the run id is important for accessing the run later on
This commit is contained in:
parent
65111eb2b3
commit
b177a29d3f
@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
|
||||
CallbackManagerForChainRun,
|
||||
Callbacks,
|
||||
)
|
||||
from langchain.schema import BaseMemory
|
||||
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||
|
||||
|
||||
def _get_verbosity() -> bool:
|
||||
@ -108,6 +108,8 @@ class Chain(BaseModel, ABC):
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@ -118,7 +120,10 @@ class Chain(BaseModel, ABC):
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = CallbackManager.configure(
|
||||
@ -139,13 +144,20 @@ class Chain(BaseModel, ABC):
|
||||
run_manager.on_chain_error(e)
|
||||
raise e
|
||||
run_manager.on_chain_end(outputs)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
async def acall(
|
||||
self,
|
||||
inputs: Union[Dict[str, Any], Any],
|
||||
return_only_outputs: bool = False,
|
||||
callbacks: Callbacks = None,
|
||||
*,
|
||||
include_run_info: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
"""Run the logic of this chain and add to output if desired.
|
||||
|
||||
@ -156,7 +168,10 @@ class Chain(BaseModel, ABC):
|
||||
response. If True, only new keys generated by this chain will be
|
||||
returned. If False, both input keys and new keys generated by this
|
||||
chain will be returned. Defaults to False.
|
||||
|
||||
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||
use the callbacks provided to the chain.
|
||||
include_run_info: Whether to include run info in the response. Defaults
|
||||
to False.
|
||||
"""
|
||||
inputs = self.prep_inputs(inputs)
|
||||
callback_manager = AsyncCallbackManager.configure(
|
||||
@ -177,7 +192,12 @@ class Chain(BaseModel, ABC):
|
||||
await run_manager.on_chain_error(e)
|
||||
raise e
|
||||
await run_manager.on_chain_end(outputs)
|
||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
||||
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||
inputs, outputs, return_only_outputs
|
||||
)
|
||||
if include_run_info:
|
||||
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||
return final_outputs
|
||||
|
||||
def prep_outputs(
|
||||
self,
|
||||
|
@ -25,6 +25,7 @@ from langchain.schema import (
|
||||
HumanMessage,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
)
|
||||
|
||||
|
||||
@ -93,6 +94,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
return output
|
||||
|
||||
async def agenerate(
|
||||
@ -131,6 +134,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
||||
generations = [res.generations for res in results]
|
||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||
await run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
return output
|
||||
|
||||
def generate_prompt(
|
||||
|
@ -25,6 +25,7 @@ from langchain.schema import (
|
||||
Generation,
|
||||
LLMResult,
|
||||
PromptValue,
|
||||
RunInfo,
|
||||
get_buffer_string,
|
||||
)
|
||||
|
||||
@ -190,6 +191,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
run_manager.on_llm_error(e)
|
||||
raise e
|
||||
run_manager.on_llm_end(output)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = callback_manager.on_llm_start(
|
||||
@ -210,10 +213,14 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
async def agenerate(
|
||||
self,
|
||||
@ -256,6 +263,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||
raise e
|
||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||
if run_manager:
|
||||
output.run = RunInfo(run_id=run_manager.run_id)
|
||||
return output
|
||||
if len(missing_prompts) > 0:
|
||||
run_manager = await callback_manager.on_llm_start(
|
||||
@ -278,10 +287,14 @@ class BaseLLM(BaseLanguageModel, ABC):
|
||||
llm_output = update_cache(
|
||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||
)
|
||||
run_info = None
|
||||
if run_manager:
|
||||
run_info = RunInfo(run_id=run_manager.run_id)
|
||||
else:
|
||||
llm_output = {}
|
||||
run_info = None
|
||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||
return LLMResult(generations=generations, llm_output=llm_output)
|
||||
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||
|
||||
def __call__(
|
||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||
|
@ -13,9 +13,12 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Extra, Field, root_validator
|
||||
|
||||
RUN_KEY = "__run"
|
||||
|
||||
|
||||
def get_buffer_string(
|
||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||
@ -156,6 +159,12 @@ class ChatGeneration(Generation):
|
||||
return values
|
||||
|
||||
|
||||
class RunInfo(BaseModel):
|
||||
"""Class that contains all relevant metadata for a Run."""
|
||||
|
||||
run_id: UUID
|
||||
|
||||
|
||||
class ChatResult(BaseModel):
|
||||
"""Class that contains all relevant information for a Chat Result."""
|
||||
|
||||
@ -173,6 +182,16 @@ class LLMResult(BaseModel):
|
||||
each input could have multiple generations."""
|
||||
llm_output: Optional[dict] = None
|
||||
"""For arbitrary LLM provider specific output."""
|
||||
run: Optional[RunInfo] = None
|
||||
"""Run metadata."""
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
if not isinstance(other, LLMResult):
|
||||
return NotImplemented
|
||||
return (
|
||||
self.generations == other.generations
|
||||
and self.llm_output == other.llm_output
|
||||
)
|
||||
|
||||
|
||||
class PromptValue(BaseModel, ABC):
|
||||
|
@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.schema import BaseMemory
|
||||
from langchain.schema import RUN_KEY, BaseMemory
|
||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||
|
||||
|
||||
@ -72,6 +72,15 @@ def test_bad_outputs() -> None:
|
||||
chain({"foo": "baz"})
|
||||
|
||||
|
||||
def test_run_info() -> None:
|
||||
"""Test that run_info is returned properly when specified"""
|
||||
chain = FakeChain()
|
||||
output = chain({"foo": "bar"}, include_run_info=True)
|
||||
assert "foo" in output
|
||||
assert "bar" in output
|
||||
assert RUN_KEY in output
|
||||
|
||||
|
||||
def test_correct_call() -> None:
|
||||
"""Test correct call of fake chain."""
|
||||
chain = FakeChain()
|
||||
|
Loading…
Reference in New Issue
Block a user