forked from Archives/langchain
support returning run info for llms, chat models and chains (#5666)
returning the run id is important for accessing the run later on
This commit is contained in:
parent
65111eb2b3
commit
b177a29d3f
@ -18,7 +18,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.schema import BaseMemory
|
from langchain.schema import RUN_KEY, BaseMemory, RunInfo
|
||||||
|
|
||||||
|
|
||||||
def _get_verbosity() -> bool:
|
def _get_verbosity() -> bool:
|
||||||
@ -108,6 +108,8 @@ class Chain(BaseModel, ABC):
|
|||||||
inputs: Union[Dict[str, Any], Any],
|
inputs: Union[Dict[str, Any], Any],
|
||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
*,
|
||||||
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the logic of this chain and add to output if desired.
|
"""Run the logic of this chain and add to output if desired.
|
||||||
|
|
||||||
@ -118,7 +120,10 @@ class Chain(BaseModel, ABC):
|
|||||||
response. If True, only new keys generated by this chain will be
|
response. If True, only new keys generated by this chain will be
|
||||||
returned. If False, both input keys and new keys generated by this
|
returned. If False, both input keys and new keys generated by this
|
||||||
chain will be returned. Defaults to False.
|
chain will be returned. Defaults to False.
|
||||||
|
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||||
|
use the callbacks provided to the chain.
|
||||||
|
include_run_info: Whether to include run info in the response. Defaults
|
||||||
|
to False.
|
||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
inputs = self.prep_inputs(inputs)
|
||||||
callback_manager = CallbackManager.configure(
|
callback_manager = CallbackManager.configure(
|
||||||
@ -139,13 +144,20 @@ class Chain(BaseModel, ABC):
|
|||||||
run_manager.on_chain_error(e)
|
run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
run_manager.on_chain_end(outputs)
|
run_manager.on_chain_end(outputs)
|
||||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||||
|
inputs, outputs, return_only_outputs
|
||||||
|
)
|
||||||
|
if include_run_info:
|
||||||
|
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||||
|
return final_outputs
|
||||||
|
|
||||||
async def acall(
|
async def acall(
|
||||||
self,
|
self,
|
||||||
inputs: Union[Dict[str, Any], Any],
|
inputs: Union[Dict[str, Any], Any],
|
||||||
return_only_outputs: bool = False,
|
return_only_outputs: bool = False,
|
||||||
callbacks: Callbacks = None,
|
callbacks: Callbacks = None,
|
||||||
|
*,
|
||||||
|
include_run_info: bool = False,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Run the logic of this chain and add to output if desired.
|
"""Run the logic of this chain and add to output if desired.
|
||||||
|
|
||||||
@ -156,7 +168,10 @@ class Chain(BaseModel, ABC):
|
|||||||
response. If True, only new keys generated by this chain will be
|
response. If True, only new keys generated by this chain will be
|
||||||
returned. If False, both input keys and new keys generated by this
|
returned. If False, both input keys and new keys generated by this
|
||||||
chain will be returned. Defaults to False.
|
chain will be returned. Defaults to False.
|
||||||
|
callbacks: Callbacks to use for this chain run. If not provided, will
|
||||||
|
use the callbacks provided to the chain.
|
||||||
|
include_run_info: Whether to include run info in the response. Defaults
|
||||||
|
to False.
|
||||||
"""
|
"""
|
||||||
inputs = self.prep_inputs(inputs)
|
inputs = self.prep_inputs(inputs)
|
||||||
callback_manager = AsyncCallbackManager.configure(
|
callback_manager = AsyncCallbackManager.configure(
|
||||||
@ -177,7 +192,12 @@ class Chain(BaseModel, ABC):
|
|||||||
await run_manager.on_chain_error(e)
|
await run_manager.on_chain_error(e)
|
||||||
raise e
|
raise e
|
||||||
await run_manager.on_chain_end(outputs)
|
await run_manager.on_chain_end(outputs)
|
||||||
return self.prep_outputs(inputs, outputs, return_only_outputs)
|
final_outputs: Dict[str, Any] = self.prep_outputs(
|
||||||
|
inputs, outputs, return_only_outputs
|
||||||
|
)
|
||||||
|
if include_run_info:
|
||||||
|
final_outputs[RUN_KEY] = RunInfo(run_id=run_manager.run_id)
|
||||||
|
return final_outputs
|
||||||
|
|
||||||
def prep_outputs(
|
def prep_outputs(
|
||||||
self,
|
self,
|
||||||
|
@ -25,6 +25,7 @@ from langchain.schema import (
|
|||||||
HumanMessage,
|
HumanMessage,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
|
RunInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -93,6 +94,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
generations = [res.generations for res in results]
|
generations = [res.generations for res in results]
|
||||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||||
run_manager.on_llm_end(output)
|
run_manager.on_llm_end(output)
|
||||||
|
if run_manager:
|
||||||
|
output.run = RunInfo(run_id=run_manager.run_id)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
@ -131,6 +134,8 @@ class BaseChatModel(BaseLanguageModel, ABC):
|
|||||||
generations = [res.generations for res in results]
|
generations = [res.generations for res in results]
|
||||||
output = LLMResult(generations=generations, llm_output=llm_output)
|
output = LLMResult(generations=generations, llm_output=llm_output)
|
||||||
await run_manager.on_llm_end(output)
|
await run_manager.on_llm_end(output)
|
||||||
|
if run_manager:
|
||||||
|
output.run = RunInfo(run_id=run_manager.run_id)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
|
@ -25,6 +25,7 @@ from langchain.schema import (
|
|||||||
Generation,
|
Generation,
|
||||||
LLMResult,
|
LLMResult,
|
||||||
PromptValue,
|
PromptValue,
|
||||||
|
RunInfo,
|
||||||
get_buffer_string,
|
get_buffer_string,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -190,6 +191,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
run_manager.on_llm_error(e)
|
run_manager.on_llm_error(e)
|
||||||
raise e
|
raise e
|
||||||
run_manager.on_llm_end(output)
|
run_manager.on_llm_end(output)
|
||||||
|
if run_manager:
|
||||||
|
output.run = RunInfo(run_id=run_manager.run_id)
|
||||||
return output
|
return output
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_manager = callback_manager.on_llm_start(
|
run_manager = callback_manager.on_llm_start(
|
||||||
@ -210,10 +213,14 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
llm_output = update_cache(
|
llm_output = update_cache(
|
||||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||||
)
|
)
|
||||||
|
run_info = None
|
||||||
|
if run_manager:
|
||||||
|
run_info = RunInfo(run_id=run_manager.run_id)
|
||||||
else:
|
else:
|
||||||
llm_output = {}
|
llm_output = {}
|
||||||
|
run_info = None
|
||||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||||
return LLMResult(generations=generations, llm_output=llm_output)
|
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self,
|
self,
|
||||||
@ -256,6 +263,8 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
await run_manager.on_llm_error(e, verbose=self.verbose)
|
await run_manager.on_llm_error(e, verbose=self.verbose)
|
||||||
raise e
|
raise e
|
||||||
await run_manager.on_llm_end(output, verbose=self.verbose)
|
await run_manager.on_llm_end(output, verbose=self.verbose)
|
||||||
|
if run_manager:
|
||||||
|
output.run = RunInfo(run_id=run_manager.run_id)
|
||||||
return output
|
return output
|
||||||
if len(missing_prompts) > 0:
|
if len(missing_prompts) > 0:
|
||||||
run_manager = await callback_manager.on_llm_start(
|
run_manager = await callback_manager.on_llm_start(
|
||||||
@ -278,10 +287,14 @@ class BaseLLM(BaseLanguageModel, ABC):
|
|||||||
llm_output = update_cache(
|
llm_output = update_cache(
|
||||||
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts
|
||||||
)
|
)
|
||||||
|
run_info = None
|
||||||
|
if run_manager:
|
||||||
|
run_info = RunInfo(run_id=run_manager.run_id)
|
||||||
else:
|
else:
|
||||||
llm_output = {}
|
llm_output = {}
|
||||||
|
run_info = None
|
||||||
generations = [existing_prompts[i] for i in range(len(prompts))]
|
generations = [existing_prompts[i] for i in range(len(prompts))]
|
||||||
return LLMResult(generations=generations, llm_output=llm_output)
|
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
|
||||||
|
@ -13,9 +13,12 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
|
|
||||||
|
RUN_KEY = "__run"
|
||||||
|
|
||||||
|
|
||||||
def get_buffer_string(
|
def get_buffer_string(
|
||||||
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
messages: List[BaseMessage], human_prefix: str = "Human", ai_prefix: str = "AI"
|
||||||
@ -156,6 +159,12 @@ class ChatGeneration(Generation):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class RunInfo(BaseModel):
|
||||||
|
"""Class that contains all relevant metadata for a Run."""
|
||||||
|
|
||||||
|
run_id: UUID
|
||||||
|
|
||||||
|
|
||||||
class ChatResult(BaseModel):
|
class ChatResult(BaseModel):
|
||||||
"""Class that contains all relevant information for a Chat Result."""
|
"""Class that contains all relevant information for a Chat Result."""
|
||||||
|
|
||||||
@ -173,6 +182,16 @@ class LLMResult(BaseModel):
|
|||||||
each input could have multiple generations."""
|
each input could have multiple generations."""
|
||||||
llm_output: Optional[dict] = None
|
llm_output: Optional[dict] = None
|
||||||
"""For arbitrary LLM provider specific output."""
|
"""For arbitrary LLM provider specific output."""
|
||||||
|
run: Optional[RunInfo] = None
|
||||||
|
"""Run metadata."""
|
||||||
|
|
||||||
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, LLMResult):
|
||||||
|
return NotImplemented
|
||||||
|
return (
|
||||||
|
self.generations == other.generations
|
||||||
|
and self.llm_output == other.llm_output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PromptValue(BaseModel, ABC):
|
class PromptValue(BaseModel, ABC):
|
||||||
|
@ -5,7 +5,7 @@ import pytest
|
|||||||
|
|
||||||
from langchain.callbacks.manager import CallbackManagerForChainRun
|
from langchain.callbacks.manager import CallbackManagerForChainRun
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.schema import BaseMemory
|
from langchain.schema import RUN_KEY, BaseMemory
|
||||||
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
@ -72,6 +72,15 @@ def test_bad_outputs() -> None:
|
|||||||
chain({"foo": "baz"})
|
chain({"foo": "baz"})
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_info() -> None:
|
||||||
|
"""Test that run_info is returned properly when specified"""
|
||||||
|
chain = FakeChain()
|
||||||
|
output = chain({"foo": "bar"}, include_run_info=True)
|
||||||
|
assert "foo" in output
|
||||||
|
assert "bar" in output
|
||||||
|
assert RUN_KEY in output
|
||||||
|
|
||||||
|
|
||||||
def test_correct_call() -> None:
|
def test_correct_call() -> None:
|
||||||
"""Test correct call of fake chain."""
|
"""Test correct call of fake chain."""
|
||||||
chain = FakeChain()
|
chain = FakeChain()
|
||||||
|
Loading…
Reference in New Issue
Block a user