From 5a71b8160946556d6fd3b8d5a412b82008e9a4ff Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 25 Oct 2023 19:00:22 -0400 Subject: [PATCH] Relax type annotation for custom input/output types (#12300) This is needed to be able to do stuff like: ```python runnable.with_types(input_type=List[str]) ``` --- libs/langchain/langchain/schema/runnable/base.py | 7 ++++--- .../unit_tests/schema/runnable/test_runnable.py | 16 ++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 58c06068d9..c95d5254d9 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2367,9 +2367,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]): config: RunnableConfig = Field(default_factory=dict) - custom_input_type: Optional[Union[Type[Input], BaseModel]] = None - - custom_output_type: Optional[Union[Type[Output], BaseModel]] = None + # Union[Type[Input], BaseModel] + things like List[str] + custom_input_type: Optional[Any] = None + # Union[Type[Output], BaseModel] + things like List[str] + custom_output_type: Optional[Any] = None class Config: arbitrary_types_allowed = True diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index af04de990b..4b272bf2f3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -557,6 +557,22 @@ def test_lambda_schemas() -> None: } +def test_with_types_with_type_generics() -> None: + """Verify that with_types works if we use things like List[int]""" + + def foo(x: int) -> None: + """Add one to the input.""" + raise NotImplementedError() + + # Try specifying some + RunnableLambda(foo).with_types( + output_type=List[int], input_type=List[int] # type: ignore + ) + RunnableLambda(foo).with_types( + output_type=Sequence[int], input_type=Sequence[int] # type: ignore[arg-type] + ) + + def test_schema_complex_seq() -> None: prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?") prompt2 = ChatPromptTemplate.from_template(