Add new run types for Runnables (#8488)

- allow overriding run_type in on_chain_start

<!-- Thank you for contributing to LangChain!

Replace this comment with:
  - Description: a description of the change, 
  - Issue: the issue # it fixes (if applicable),
  - Dependencies: any dependencies required for this change,
- Tag maintainer: for a quicker response, tag the relevant maintainer
(see below),
- Twitter handle: we announce bigger features on Twitter. If your PR
gets announced and you'd like a mention, we'll gladly shout you out!

Please make sure you're PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
  2. an example notebook showing its use.

Maintainer responsibilities:
  - General / Misc / if you don't know who to tag: @baskaryan
  - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev
  - Models / Prompts: @hwchase17, @baskaryan
  - Memory: @hwchase17
  - Agents / Tools / Toolkits: @hinthornw
  - Tracing / Callbacks: @agola11
  - Async: @agola11

If no one reviews your PR within a few days, feel free to @-mention the
same people again.

See contribution guidelines for more information on how to write/run
tests, lint, etc:
https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md
 -->
This commit is contained in:
Nuno Campos 2023-08-01 12:56:40 +01:00 committed by GitHub
parent bd2e298468
commit 0ec020698f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 23 additions and 12 deletions

View File

@ -227,6 +227,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
parent_run_id: Optional[UUID] = None, parent_run_id: Optional[UUID] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
run_type: Optional[str] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Start a trace for a chain run.""" """Start a trace for a chain run."""
@ -246,7 +247,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order, execution_order=execution_order,
child_execution_order=execution_order, child_execution_order=execution_order,
child_runs=[], child_runs=[],
run_type="chain", run_type=run_type or "chain",
tags=tags or [], tags=tags or [],
) )
self._start_trace(chain_run) self._start_trace(chain_run)
@ -259,7 +260,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.") raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id)) chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != "chain": if chain_run is None:
raise TracerException("No chain Run found to be traced") raise TracerException("No chain Run found to be traced")
chain_run.outputs = outputs chain_run.outputs = outputs
@ -279,7 +280,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id: if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.") raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id)) chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != "chain": if chain_run is None:
raise TracerException("No chain Run found to be traced") raise TracerException("No chain Run found to be traced")
chain_run.error = repr(error) chain_run.error = repr(error)

View File

@ -41,12 +41,14 @@ class BaseGenerationOutputParser(
), ),
input, input,
config, config,
run_type="parser",
) )
else: else:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input, input,
config, config,
run_type="parser",
) )
@ -87,12 +89,14 @@ class BaseOutputParser(BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
), ),
input, input,
config, config,
run_type="parser",
) )
else: else:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.parse_result([Generation(text=inner_input)]), lambda inner_input: self.parse_result([Generation(text=inner_input)]),
input, input,
config, config,
run_type="parser",
) )
def parse_result(self, result: List[Generation]) -> T: def parse_result(self, result: List[Generation]) -> T:

View File

@ -37,7 +37,10 @@ class BasePromptTemplate(Serializable, Runnable[Dict, PromptValue], ABC):
def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue: def invoke(self, input: Dict, config: RunnableConfig | None = None) -> PromptValue:
return self._call_with_config( return self._call_with_config(
lambda inner_input: self.format_prompt(**inner_input), input, config lambda inner_input: self.format_prompt(**inner_input),
input,
config,
run_type="prompt",
) )
@abstractmethod @abstractmethod

View File

@ -163,6 +163,7 @@ class Runnable(Generic[Input, Output], ABC):
func: Callable[[Input], Output], func: Callable[[Input], Output],
input: Input, input: Input,
config: Optional[RunnableConfig], config: Optional[RunnableConfig],
run_type: Optional[str] = None,
) -> Output: ) -> Output:
from langchain.callbacks.manager import CallbackManager from langchain.callbacks.manager import CallbackManager
@ -173,7 +174,9 @@ class Runnable(Generic[Input, Output], ABC):
inheritable_metadata=config.get("metadata"), inheritable_metadata=config.get("metadata"),
) )
run_manager = callback_manager.on_chain_start( run_manager = callback_manager.on_chain_start(
dumpd(self), input if isinstance(input, dict) else {"input": input} dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_type=run_type,
) )
try: try:
output = func(input) output = func(input)

File diff suppressed because one or more lines are too long