diff --git a/libs/langchain/langchain/schema/runnable/branch.py b/libs/langchain/langchain/schema/runnable/branch.py index d609fedeff..105582f0c2 100644 --- a/libs/langchain/langchain/schema/runnable/branch.py +++ b/libs/langchain/langchain/schema/runnable/branch.py @@ -7,6 +7,7 @@ from typing import ( Optional, Sequence, Tuple, + Type, Union, cast, ) @@ -125,7 +126,7 @@ class RunnableBranch(RunnableSerializable[Input, Output]): return cls.__module__.split(".")[:-1] @property - def input_schema(self) -> type[BaseModel]: + def input_schema(self) -> Type[BaseModel]: runnables = ( [self.default] + [r for _, r in self.branches]