Use a less specific return type for | on Runnables (#11762)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/11835/head
Nuno Campos 9 months ago committed by GitHub
parent 6c5bb1b2e1
commit 4321d192ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,7 +4,7 @@ from langchain.chains.sql_database.prompt import PROMPT, SQL_PROMPTS
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output_parser import NoOpOutputParser
from langchain.schema.prompt_template import BasePromptTemplate
from langchain.schema.runnable import RunnableParallel, RunnableSequence
from langchain.schema.runnable import Runnable, RunnableParallel
from langchain.utilities.sql_database import SQLDatabase
@ -30,7 +30,7 @@ def create_sql_query_chain(
db: SQLDatabase,
prompt: Optional[BasePromptTemplate] = None,
k: int = 5,
) -> RunnableSequence[Union[SQLInput, SQLInputWithTables], str]:
) -> Runnable[Union[SQLInput, SQLInputWithTables], str]:
"""Create a chain that generates SQL queries.
Args:

@ -242,7 +242,7 @@ class Runnable(Generic[Input, Output], ABC):
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
) -> Runnable[Input, Other]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(first=self, last=coerce_to_runnable(other))
@ -254,7 +254,7 @@ class Runnable(Generic[Input, Output], ABC):
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
) -> Runnable[Other, Output]:
"""Compose this runnable with another object to create a RunnableSequence."""
return RunnableSequence(first=coerce_to_runnable(other), last=self)
@ -1064,7 +1064,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSequence[Input, Other]:
) -> Runnable[Input, Other]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
@ -1086,7 +1086,7 @@ class RunnableSequence(RunnableSerializable[Input, Output]):
Callable[[Iterator[Other]], Iterator[Any]],
Mapping[str, Union[Runnable[Other, Any], Callable[[Other], Any], Any]],
],
) -> RunnableSequence[Other, Output]:
) -> Runnable[Other, Output]:
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,

@ -7,6 +7,7 @@ from langchain.prompts import PromptTemplate
from langchain.schema.runnable import (
GetLocalVar,
PutLocalVar,
Runnable,
RunnablePassthrough,
RunnableSequence,
)
@ -52,12 +53,12 @@ def test_incorrect_usage(runnable: RunnableSequence, error: Type[Exception]) ->
def test_get_in_map() -> None:
runnable: RunnableSequence = PutLocalVar("input") | {"bar": GetLocalVar("input")}
runnable: Runnable = PutLocalVar("input") | {"bar": GetLocalVar("input")}
assert runnable.invoke("foo") == {"bar": "foo"}
def test_put_in_map() -> None:
runnable: RunnableSequence = {"bar": PutLocalVar("input")} | GetLocalVar("input")
runnable: Runnable = {"bar": PutLocalVar("input")} | GetLocalVar("input")
with pytest.raises(KeyError):
runnable.invoke("foo")

@ -1978,7 +1978,7 @@ def test_combining_sequences(
lambda x: {"question": x[0] + x[1]}
)
chain2 = input_formatter | prompt2 | chat2 | parser2
chain2 = cast(RunnableSequence, input_formatter | prompt2 | chat2 | parser2)
assert isinstance(chain, RunnableSequence)
assert chain2.first == input_formatter
@ -1987,7 +1987,7 @@ def test_combining_sequences(
if sys.version_info >= (3, 9):
assert dumps(chain2, pretty=True) == snapshot
combined_chain = chain | chain2
combined_chain = cast(RunnableSequence, chain | chain2)
assert combined_chain.first == prompt
assert combined_chain.middle == [
@ -2972,7 +2972,7 @@ def llm_with_multi_fallbacks() -> RunnableWithFallbacks:
@pytest.fixture()
def llm_chain_with_fallbacks() -> RunnableSequence:
def llm_chain_with_fallbacks() -> Runnable:
error_llm = FakeListLLM(responses=["foo"], i=1)
pass_llm = FakeListLLM(responses=["bar"])

Loading…
Cancel
Save