mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Add new run types for Runnables (#8488)
- allow overriding run_type in on_chain_start <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! Please make sure you're PR is passing linting and testing before submitting. Run `make format`, `make lint` and `make test` to check this locally. If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md -->
This commit is contained in:
parent
bd2e298468
commit
0ec020698f
@ -227,6 +227,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
tags: Optional[List[str]] = None,
|
||||
parent_run_id: Optional[UUID] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
run_type: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Start a trace for a chain run."""
|
||||
@ -246,7 +247,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type="chain",
|
||||
run_type=run_type or "chain",
|
||||
tags=tags or [],
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
@ -259,7 +260,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None or chain_run.run_type != "chain":
|
||||
if chain_run is None:
|
||||
raise TracerException("No chain Run found to be traced")
|
||||
|
||||
chain_run.outputs = outputs
|
||||
@ -279,7 +280,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None or chain_run.run_type != "chain":
|
||||
if chain_run is None:
|
||||
raise TracerException("No chain Run found to be traced")
|
||||
|
||||
chain_run.error = repr(error)
|
||||
|
@ -41,12 +41,14 @@ class BaseGenerationOutputParser(
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
|
||||
@ -87,12 +89,14 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
|
||||
),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
else:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.parse_result([Generation(text=inner_input)]),
|
||||
input,
|
||||
config,
|
||||
run_type="parser",
|
||||
)
|
||||
|
||||
def parse_result(self, result: List[Generation]) -> T:
|
||||
|
@ -37,7 +37,10 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
|
||||
|
||||
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
|
||||
return self._call_with_config(
|
||||
lambda inner_input: self.format_prompt(**inner_input), input, config
|
||||
lambda inner_input: self.format_prompt(**inner_input),
|
||||
input,
|
||||
config,
|
||||
run_type="prompt",
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
|
@ -163,6 +163,7 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
func: Callable[[Input], Output],
|
||||
input: Input,
|
||||
config: Optional[RunnableConfig],
|
||||
run_type: Optional[str] = None,
|
||||
) -> Output:
|
||||
from langchain.callbacks.manager import CallbackManager
|
||||
|
||||
@ -173,7 +174,9 @@ class Runnable(Generic[Input, Output], ABC):
|
||||
inheritable_metadata=config.get("metadata"),
|
||||
)
|
||||
run_manager = callback_manager.on_chain_start(
|
||||
dumpd(self), input if isinstance(input, dict) else {"input": input}
|
||||
dumpd(self),
|
||||
input if isinstance(input, dict) else {"input": input},
|
||||
run_type=run_type,
|
||||
)
|
||||
try:
|
||||
output = func(input)
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user