pull/9694/head
William Fu-Hinthorn 1 year ago committed by Nuno Campos
parent f9a845b382
commit 4d7cd6db5f

@ -8,7 +8,6 @@ from abc import ABC, abstractmethod
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from uuid import UUID
import yaml
@ -69,7 +68,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_id=config.get("run_id"),
run_name=config.get("run_name"),
**kwargs,
)
@ -92,7 +90,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
callbacks=config.get("callbacks"),
tags=config.get("tags"),
metadata=config.get("metadata"),
run_id=config.get("run_id"),
run_name=config.get("run_name"),
**kwargs,
)
@ -240,7 +237,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
@ -283,7 +279,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id=run_id,
name=run_name,
)
try:
@ -311,7 +306,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
run_name: Optional[str] = None,
include_run_info: bool = False,
) -> Dict[str, Any]:
@ -354,7 +348,6 @@ class Chain(Serializable, Runnable[Dict[str, Any], Dict[str, Any]], ABC):
run_manager = await callback_manager.on_chain_start(
dumpd(self),
inputs,
run_id=run_id,
name=run_name,
)
try:

@ -4,7 +4,6 @@ import warnings
from abc import ABC, abstractmethod
from inspect import signature
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from uuid import UUID
from langchain.load.dump import dumpd
from langchain.load.serializable import Serializable
@ -165,7 +164,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[Document]:
"""Retrieve documents relevant to a query.
@ -195,7 +193,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = callback_manager.on_retriever_start(
dumpd(self),
query,
run_id=run_id,
**kwargs,
)
try:
@ -223,7 +220,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
callbacks: Callbacks = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> List[Document]:
"""Asynchronously get documents relevant to a query.
@ -253,7 +249,6 @@ class BaseRetriever(Serializable, Runnable[str, List[Document]], ABC):
run_manager = await callback_manager.on_retriever_start(
dumpd(self),
query,
run_id=run_id,
**kwargs,
)
try:

@ -266,7 +266,6 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
@ -309,7 +308,6 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
input,
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
@ -368,7 +366,6 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
{"input": ""},
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
@ -450,7 +447,6 @@ class Runnable(Generic[Input, Output], ABC):
dumpd(self),
{"input": ""},
run_type=run_type,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
try:
@ -528,7 +524,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
first_error = None
for runnable in self.runnables:
@ -562,7 +558,7 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
first_error = None
@ -613,7 +609,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
cm.on_chain_start(
dumpd(self),
input if isinstance(input, dict) else {"input": input},
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
@ -675,7 +670,6 @@ class RunnableWithFallbacks(Serializable, Runnable[Input, Output]):
cm.on_chain_start(
dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
@ -784,7 +778,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
# invoke all steps in sequence
@ -814,7 +808,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
# invoke all steps in sequence
@ -860,7 +854,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
cm.on_chain_start(
dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
@ -917,7 +910,6 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
cm.on_chain_start(
dumpd(self),
input,
run_id=config.get("run_id"),
name=config.get("run_name"),
)
for cm, input, config in zip(callback_managers, inputs, configs)
@ -957,7 +949,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last]
@ -1026,7 +1018,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
steps = [self.first] + self.middle + [self.last]
@ -1159,7 +1151,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
)
# start the root run
run_manager = callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
# gather results from all steps
@ -1200,7 +1192,7 @@ class RunnableMap(Serializable, Runnable[Input, Dict[str, Any]]):
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
dumpd(self), input, run_id=config.get("run_id"), name=config.get("run_name")
dumpd(self), input, name=config.get("run_name")
)
# gather results from all steps

@ -4,7 +4,6 @@ from concurrent.futures import Executor, ThreadPoolExecutor
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Union
from uuid import UUID
from typing_extensions import TypedDict
@ -39,11 +38,6 @@ class RunnableConfig(TypedDict, total=False):
Name for the tracer run for this call. Defaults to the name of the class.
"""
run_id: UUID
"""
Unique ID for the tracer run for this call. Defaults to uuid4().
"""
_locals: Dict[str, Any]
"""
Local variables

@ -8,7 +8,6 @@ from abc import abstractmethod
from functools import partial
from inspect import signature
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
from uuid import UUID
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import (
@ -298,7 +297,6 @@ class ChildTool(BaseTool):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool."""
@ -322,7 +320,6 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
run_id=run_id,
**kwargs,
)
try:
@ -373,7 +370,6 @@ class ChildTool(BaseTool):
*,
tags: Optional[List[str]] = None,
metadata: Optional[Dict[str, Any]] = None,
run_id: Optional[UUID] = None,
**kwargs: Any,
) -> Any:
"""Run the tool asynchronously."""
@ -396,7 +392,6 @@ class ChildTool(BaseTool):
{"name": self.name, "description": self.description},
tool_input if isinstance(tool_input, str) else str(tool_input),
color=start_color,
run_id=run_id,
**kwargs,
)
try:

Loading…
Cancel
Save