diff --git a/libs/langchain/langchain/schema/runnable/base.py b/libs/langchain/langchain/schema/runnable/base.py index 58c06068d9..c95d5254d9 100644 --- a/libs/langchain/langchain/schema/runnable/base.py +++ b/libs/langchain/langchain/schema/runnable/base.py @@ -2367,9 +2367,10 @@ class RunnableBinding(RunnableSerializable[Input, Output]): config: RunnableConfig = Field(default_factory=dict) - custom_input_type: Optional[Union[Type[Input], BaseModel]] = None - - custom_output_type: Optional[Union[Type[Output], BaseModel]] = None + # Union[Type[Input], BaseModel] + things like List[str] + custom_input_type: Optional[Any] = None + # Union[Type[Output], BaseModel] + things like List[str] + custom_output_type: Optional[Any] = None class Config: arbitrary_types_allowed = True diff --git a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py index af04de990b..4b272bf2f3 100644 --- a/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py +++ b/libs/langchain/tests/unit_tests/schema/runnable/test_runnable.py @@ -557,6 +557,22 @@ def test_lambda_schemas() -> None: } +def test_with_types_with_type_generics() -> None: + """Verify that with_types works if we use things like List[int]""" + + def foo(x: int) -> None: + """Add one to the input.""" + raise NotImplementedError() + + # Try specifying some + RunnableLambda(foo).with_types( + output_type=List[int], input_type=List[int] # type: ignore + ) + RunnableLambda(foo).with_types( + output_type=Sequence[int], input_type=Sequence[int] # type: ignore[arg-type] + ) + + def test_schema_complex_seq() -> None: prompt1 = ChatPromptTemplate.from_template("what is the city {person} is from?") prompt2 = ChatPromptTemplate.from_template(