diff --git a/libs/langchain/langchain/chains/sql_database/query.py b/libs/langchain/langchain/chains/sql_database/query.py index e868c2242a..7325325904 100644 --- a/libs/langchain/langchain/chains/sql_database/query.py +++ b/libs/langchain/langchain/chains/sql_database/query.py @@ -4,7 +4,7 @@ from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS from langchain.schema.language_model import BaseLanguageModel from langchain.schema.output_parser import NoOpOutputParser from langchain.schema.prompt_template import BasePromptTemplate -from langchain.schema.runnable import RunnableParallel, RunnableSequence +from langchain.schema.runnable import Runnable, RunnableParallel from langchain.utilities.sql_database import SQLDatabase @@ -30,7 +30,7 @@ def create_sql_query_chain( db: SQLDatabase, prompt: Optional[BasePromptTemplate] = None, k: int = 5, -) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]: +) -> Runnable[Union[SQLInput, SQLInputWithTables], str]: """Create a chain that generates SQL queries. Args: diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 51365bd0e1..7e1f51554a 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -242,7 +242,7 @@ class Runnable(Generic[Input, Output], ABC): Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], - ) -> RunnableSequence[Input, Other]: + ) -> Runnable[Input, Other]: """Compose this runnable with another object to create a RunnableSequence.""" return RunnableSequence(first=self, last=coerce_to_runnable(other)) @@ -254,7 +254,7 @@ class Runnable(Generic[Input, Output], ABC): Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], - ) -> RunnableSequence[Other, Output]: + ) -> Runnable[Other, Output]: """Compose this runnable with another object to create a RunnableSequence.""" return RunnableSequence(first=coerce_to_runnable(other), last=self) @@ -1064,7 +1064,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): Callable[[Iterator[Any]], Iterator[Other]], Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], ], - ) -> RunnableSequence[Input, Other]: + ) -> Runnable[Input, Other]: if isinstance(other, RunnableSequence): return RunnableSequence( first=self.first, @@ -1086,7 +1086,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]): Callable[[Iterator[Other]], Iterator[Any]], Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]], ], - ) -> RunnableSequence[Other, Output]: + ) -> Runnable[Other, Output]: if isinstance(other, RunnableSequence): return RunnableSequence( first=other.first, diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py index 4b0bbb0136..82c055c069 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_locals.py @@ -7,6 +7,7 @@ from langchain.prompts import PromptTemplate from langchain.schema.runnable import ( GetLocalVar, PutLocalVar, + Runnable, RunnablePassthrough, RunnableSequence, ) @@ -52,12 +53,12 @@ def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) -> def test_get_in_map() -> None: - runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")} + runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")} assert runnable.invoke("foo") == {"bar": "foo"} def test_put_in_map() -> None: - runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input") + runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input") with pytest.raises(KeyError): runnable.invoke("foo") diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index 944a0eb525..2ec0ed3719 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -1978,7 +1978,7 @@ def test_combining_sequences( lambda x: {"question": x[0] + x[1]} ) - chain2 = input_formatter | prompt2 | chat2 | parser2 + chain2 = cast(RunnableSequence, input_formatter | prompt2 | chat2 | parser2) assert isinstance(chain, RunnableSequence) assert chain2.first == input_formatter @@ -1987,7 +1987,7 @@ def test_combining_sequences( if sys.version_info >= (3, 9): assert dumps(chain2, pretty=True) == snapshot - combined_chain = chain | chain2 + combined_chain = cast(RunnableSequence, chain | chain2) assert combined_chain.first == prompt assert combined_chain.middle == [ @@ -2972,7 +2972,7 @@ def llm_with_multi_fallbacks() -> RunnableWithFallbacks: @pytest.fixture() -def llm_chain_with_fallbacks() -> RunnableSequence: +def llm_chain_with_fallbacks() -> Runnable: error_llm = FakeListLLM(responses=["foo"], i=1) pass_llm = FakeListLLM(responses=["bar"])