mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
Add new run types for Runnables (#8488)
- allow overriding run_type in on_chain_start <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
bd2e298468
commit
0ec020698f
@ -227,6 +227,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
tags: Optional[List[str]] = None,
|
tags: Optional[List[str]] = None,
|
||||||
parent_run_id: Optional[UUID] = None,
|
parent_run_id: Optional[UUID] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
run_type: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Start a trace for a chain run."""
|
"""Start a trace for a chain run."""
|
||||||
@ -246,7 +247,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
execution_order=execution_order,
|
execution_order=execution_order,
|
||||||
child_execution_order=execution_order,
|
child_execution_order=execution_order,
|
||||||
child_runs=[],
|
child_runs=[],
|
||||||
run_type="chain",
|
run_type=run_type or "chain",
|
||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
)
|
)
|
||||||
self._start_trace(chain_run)
|
self._start_trace(chain_run)
|
||||||
@ -259,7 +260,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||||
chain_run = self.run_map.get(str(run_id))
|
chain_run = self.run_map.get(str(run_id))
|
||||||
if chain_run is None or chain_run.run_type != "chain":
|
if chain_run is None:
|
||||||
raise TracerException("No chain Run found to be traced")
|
raise TracerException("No chain Run found to be traced")
|
||||||
|
|
||||||
chain_run.outputs = outputs
|
chain_run.outputs = outputs
|
||||||
@ -279,7 +280,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
if not run_id:
|
if not run_id:
|
||||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||||
chain_run = self.run_map.get(str(run_id))
|
chain_run = self.run_map.get(str(run_id))
|
||||||
if chain_run is None or chain_run.run_type != "chain":
|
if chain_run is None:
|
||||||
raise TracerException("No chain Run found to be traced")
|
raise TracerException("No chain Run found to be traced")
|
||||||
|
|
||||||
chain_run.error = repr(error)
|
chain_run.error = repr(error)
|
||||||
|
@ -41,12 +41,14 @@ class BaseGenerationOutputParser(
|
|||||||
),
|
),
|
||||||
input,
|
input,
|
||||||
config,
|
config,
|
||||||
|
run_type="parser",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
config,
|
config,
|
||||||
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -87,12 +89,14 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
|||||||
),
|
),
|
||||||
input,
|
input,
|
||||||
config,
|
config,
|
||||||
|
run_type="parser",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||||
input,
|
input,
|
||||||
config,
|
config,
|
||||||
|
run_type="parser",
|
||||||
)
|
)
|
||||||
|
|
||||||
def parse_result(self, result: List[Generation]) -> T:
|
def parse_result(self, result: List[Generation]) -> T:
|
||||||
|
@ -37,7 +37,10 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
|||||||
|
|
||||||
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
||||||
return self._call_with_config(
|
return self._call_with_config(
|
||||||
lambda inner_input: self.format_prompt(**inner_input), input, config
|
lambda inner_input: self.format_prompt(**inner_input),
|
||||||
|
input,
|
||||||
|
config,
|
||||||
|
run_type="prompt",
|
||||||
)
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
|
@ -163,6 +163,7 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
func: Callable[[Input], Output],
|
func: Callable[[Input], Output],
|
||||||
input: Input,
|
input: Input,
|
||||||
config: Optional[RunnableConfig],
|
config: Optional[RunnableConfig],
|
||||||
|
run_type: Optional[str] = None,
|
||||||
) -> Output:
|
) -> Output:
|
||||||
from langchain.callbacks.manager import CallbackManager
|
from langchain.callbacks.manager import CallbackManager
|
||||||
|
|
||||||
@ -173,7 +174,9 @@ class Runnable(Generic[Input, Output], ABC):
|
|||||||
inheritable_metadata=config.get("metadata"),
|
inheritable_metadata=config.get("metadata"),
|
||||||
)
|
)
|
||||||
run_manager = callback_manager.on_chain_start(
|
run_manager = callback_manager.on_chain_start(
|
||||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
dumpd(self),
|
||||||
|
input if isinstance(input, dict) else {"input": input},
|
||||||
|
run_type=run_type,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
output = func(input)
|
output = func(input)
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user